Last active
May 12, 2026 15:08
-
-
Save knyazer/73ebb649e9b917ab4eb355ef86f6f0f3 to your computer and use it in GitHub Desktop.
interface for distillation dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class PromptStore: | |
| path: Path | |
| name: str # dataset name | |
| hash: str # block-wise hash of the mixture or smth like that | |
| @classmethod | |
| def init(cls, path: Path): | |
| # load the dataset, dump into the path with some mixing | |
| # or not load it, but just reuse from the path + some heuristic way to check that metadata matches remote | |
| # e.g. making sure that the dataset will be restorable exactly if we redownload it | |
| # different dataset mixes inherit from the abstract class, implement their own .init method, choose their own name | |
| enum Device: | |
| CPU = 'cpu' | |
| GPU = 'gpu' | |
| DISK = 'disk' | |
| class RaggedStore: | |
| path: Path | |
| lengths: Int[Array, "num"] | |
| offsets: Int[Array, "num_plus_one"] | |
| def to(self, device: Device): | |
| ... # move the data from numpy on cpu to device and vice versa (generally you would keep the dataset on cpu) | |
| # idk how to handle 'disk' yet but probably we can with some block-wise storage format | |
| # so we basically start with 'disk' format where nothing is stored in any RAM, then when we wish to load it | |
| # we call .to('cpu') | |
| # then we do .batched() on the top level and .to('gpu') | |
| def shuffle(self, *, indices: Int[Array, ""], seed: int) -> Self: | |
| ... # shuffle makes some operations disallowed, e.g. .shuffle() -> .to('disk') should fail (for now) | |
| def select(self, index: list[int]): | |
| ... | |
| class ResponseFeatures(ABC): # handle this via our fancy loading, such that we can trivially support different feature formats | |
| layerwise_activations: Float[Array, "... num_layers activation"] | None | |
| top_k_logits: Float[Array, "... top_k"] | |
| top_k_token_ids: Int[Array, "... top_k"] | |
| logsumexp: Float[Array, "..."] | |
| class ResponseFeatureStore(RaggedStore): | |
| # Rows correspond to token_store.tokens[token_store.response_mask] in flattened order. | |
| features: ResponseFeatures | |
| class TokenStore(RaggedStore): | |
| tokens: Int[Array, "total_tokens"] | |
| response_mask: Bool[Array, "total_tokens"] | |
| class UnrollStore: | |
| token_store: TokenStore | |
| response_features: ResponseFeatureStore | None | |
| def __post_init__(self): | |
| ... # validate that len(response_features) == token_store.response_mask.sum() | |
| def select(self, index: list[int]) -> Self: | |
| ... # slice tokens by token offsets and features by response offsets derived from response_mask | |
| def batched(self, *, is_finite: bool = False) -> Iterator[UnrollBatch]: | |
| ... | |
| class UnrollBatch: | |
| tokens: Int[Array, "batch seq"] | |
| response_mask: Bool[Array, "batch seq"] | |
| lengths: Int[Array, "batch"] | |
| features: ResponseFeatures | None | |
| class DatasetMetadata: | |
| model_name: str | |
| class DistillationDataset: | |
| metadata: DatasetMetadata | |
| prompt_store: PromptStore # transparent wrapper that just points to the data | |
| unroll_store: UnrollStore | |
| tokenizer: Tokenizer # a tokenizer (MessageProcessor) stored for completeness | |
| def batched(self, *, is_finite: bool = False) -> Iterator[UnrollBatch]: | |
| ... # generates samples until exhausted or forever, iid | |
| def shuffle(): | |
| ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment