Skip to content

Instantly share code, notes, and snippets.

@knyazer
Last active May 12, 2026 15:08
Show Gist options
  • Select an option

  • Save knyazer/73ebb649e9b917ab4eb355ef86f6f0f3 to your computer and use it in GitHub Desktop.

Select an option

Save knyazer/73ebb649e9b917ab4eb355ef86f6f0f3 to your computer and use it in GitHub Desktop.
interface for distillation dataset
class PromptStore:
path: Path
name: str # dataset name
hash: str # block-wise hash of the mixture or smth like that
@classmethod
def init(cls, path: Path):
# load the dataset, dump into the path with some mixing
# or not load it, but just reuse from the path + some heuristic way to check that metadata matches remote
# e.g. making sure that the dataset will be restorable exactly if we redownload it
# different dataset mixes inherit from the abstract class, implement their own .init method, choose their own name
enum Device:
CPU = 'cpu'
GPU = 'gpu'
DISK = 'disk'
class RaggedStore:
path: Path
lengths: Int[Array, "num"]
offsets: Int[Array, "num_plus_one"]
def to(self, device: Device):
... # move the data from numpy on cpu to device and vice versa (generally you would keep the dataset on cpu)
# idk how to handle 'disk' yet but probably we can with some block-wise storage format
# so we basically start with 'disk' format where nothing is stored in any RAM, then when we wish to load it
# we call .to('cpu')
# then we do .batched() on the top level and .to('gpu')
def shuffle(self, *, indices: Int[Array, ""], seed: int) -> Self:
... # shuffle makes some operations disallowed, e.g. .shuffle() -> .to('disk') should fail (for now)
def select(self, index: list[int]):
...
class ResponseFeatures(ABC): # handle this via our fancy loading, such that we can trivially support different feature formats
layerwise_activations: Float[Array, "... num_layers activation"] | None
top_k_logits: Float[Array, "... top_k"]
top_k_token_ids: Int[Array, "... top_k"]
logsumexp: Float[Array, "..."]
class ResponseFeatureStore(RaggedStore):
# Rows correspond to token_store.tokens[token_store.response_mask] in flattened order.
features: ResponseFeatures
class TokenStore(RaggedStore):
tokens: Int[Array, "total_tokens"]
response_mask: Bool[Array, "total_tokens"]
class UnrollStore:
token_store: TokenStore
response_features: ResponseFeatureStore | None
def __post_init__(self):
... # validate that len(response_features) == token_store.response_mask.sum()
def select(self, index: list[int]) -> Self:
... # slice tokens by token offsets and features by response offsets derived from response_mask
def batched(self, *, is_finite: bool = False) -> Iterator[UnrollBatch]:
...
class UnrollBatch:
tokens: Int[Array, "batch seq"]
response_mask: Bool[Array, "batch seq"]
lengths: Int[Array, "batch"]
features: ResponseFeatures | None
class DatasetMetadata:
model_name: str
class DistillationDataset:
metadata: DatasetMetadata
prompt_store: PromptStore # transparent wrapper that just points to the data
unroll_store: UnrollStore
tokenizer: Tokenizer # a tokenizer (MessageProcessor) stored for completeness
def batched(self, *, is_finite: bool = False) -> Iterator[UnrollBatch]:
... # generates samples until exhausted or forever, iid
def shuffle():
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment