Last active
May 17, 2026 04:41
-
-
Save scottjmaddox/a0254ed46dff9e24adf10711b8dae611 to your computer and use it in GitHub Desktop.
Train a multiscreen model on the fineweb dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Train a Multiscreen model [1] using the Aurora optimizer [2] on the FineWeb dataset [3] | |
| tokenized using the GPT-2 tokenizer by Keller Jordan for the modded-nanogpt speedrun [4]. | |
| Use the `data/cached_fineweb10B.py` script from [4] to download the pre-tokenized dataset | |
| to `datasets/fineweb10B/` before running. | |
| References: | |
| 1. Screening is Enough. https://arxiv.org/abs/2604.01178 | |
| 2. Aurora: A Leverage-Aware Optimizer for Rectangular Matrices. https://blog.tilderesearch.com/blog/aurora | |
| 3. The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale. https://arxiv.org/abs/2406.17557 | |
| 4. Modded-NanoGPT. https://github.com/kellerjordan/modded-nanogpt | |
| """ | |
| import contextlib | |
| import copy | |
| import gc | |
| import glob | |
| import json | |
| import math | |
| import os | |
| import shutil | |
| import sys | |
| import threading | |
| import time | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Callable, Optional | |
| hparams = {} | |
| def env(name: str, default, dtype: type | None = None, is_hparam: bool = True): | |
| val = os.environ.get(name) | |
| if val is None: | |
| val = default | |
| else: | |
| if dtype is None: | |
| dtype = type(default) | |
| if dtype is bool: | |
| val = val.lower() in ("1", "true", "yes") | |
| elif dtype is int: | |
| val = int(eval(val)) | |
| elif dtype is float: | |
| val = float(eval(val)) | |
| else: | |
| val = dtype(val) | |
| if is_hparam: | |
| hparams[name] = val | |
| return val | |
| WANDB_WATCH = env("WANDB_WATCH", False, bool) | |
| WANDB_WATCH_LOG = env("WANDB_WATCH_LOG", "all", str) | |
| WANDB_WATCH_LOG_FREQ = env("WANDB_WATCH_LOG_FREQ", 1000, int) | |
| # Data | |
| DATA_PATH = env("DATA_PATH", ".", str, is_hparam=False) | |
| TRAIN_FILES = os.path.join(DATA_PATH, env("TRAIN_FILES", "datasets/fineweb10B/fineweb_train_*.bin", str, is_hparam=False)) | |
| VAL_FILES = os.path.join(DATA_PATH, env("VAL_FILES", "datasets/fineweb10B/fineweb_val_*.bin", str, is_hparam=False)) | |
| VAL_TOKENS = env("VAL_TOKENS", 2**24, int) | |
| # Model architecture | |
| DEPTH = env("DEPTH", 16, int) | |
| MODEL_DIM = env("MODEL_DIM", DEPTH**2, int) | |
| NUM_HEADS = env("NUM_HEADS", DEPTH, int) | |
| KEY_DIM = env("KEY_DIM", 16, int) | |
| VALUE_DIM = env("VALUE_DIM", 64, int) | |
| WINDOW_THRESHOLD = env("WINDOW_THRESHOLD", 256.0, float) | |
| VOCAB_SIZE = env("VOCAB_SIZE", 50_304, int) | |
| # Training | |
| TRAIN_STEPS = env("TRAIN_STEPS", 2**11, int) | |
| TRAIN_MAX_SEQ_LEN = env("TRAIN_MAX_SEQ_LEN", 2**12, int) | |
| TRAIN_BATCH_SIZE = env("TRAIN_BATCH_SIZE", 2**22, int) | |
| GRAD_ACCUM_STEPS = env("GRAD_ACCUM_STEPS", 2 ** (5 + DEPTH // 8), int) | |
| assert TRAIN_BATCH_SIZE % GRAD_ACCUM_STEPS == 0 | |
| assert (TRAIN_BATCH_SIZE // GRAD_ACCUM_STEPS) >= TRAIN_MAX_SEQ_LEN | |
| # Validation | |
| VAL_BATCH_SIZE = env("VAL_BATCH_SIZE", 2**22, int) | |
| VAL_PERIOD = env("VAL_PERIOD", 20, int) | |
| # Checkpointing | |
| SAVE_CHECKPOINT = env("SAVE_CHECKPOINT", False, bool) | |
| CHECKPOINT_PERIOD = env("CHECKPOINT_PERIOD", 40, int) | |
| # Compilation and kernel configuration | |
| COMPILE_MODEL = env("COMPILE_MODEL", True, bool) | |
| KERNEL_WARMUP_TRAIN_STEPS = env("KERNEL_WARMUP_TRAIN_STEPS", 10, int) | |
| KERNEL_WARMUP_VAL_STEPS = env("KERNEL_WARMUP_VAL_STEPS", 10, int) | |
| MULTISCREEN_BLOCK_M = env("MULTISCREEN_BLOCK_M", 32, int) | |
| MULTISCREEN_BLOCK_N = env("MULTISCREEN_BLOCK_N", 32, int) | |
| MULTISCREEN_NUM_WARPS = env("MULTISCREEN_NUM_WARPS", 4, int) | |
| MULTISCREEN_NUM_STAGES = env("MULTISCREEN_NUM_STAGES", 2, int) | |
| # Optimizers | |
| WARMUP_STEPS = env("WARMUP_STEPS", 2**10, int) | |
| COOLDOWN_FRAC = env("COOLDOWN_FRAC", 0.0, float) | |
| MATRIX_OPTIMIZER = env("MATRIX_OPTIMIZER", "Aurora", str) | |
| MATRIX_LOG2_LR = env("MATRIX_LOG2_LR", -3, float) | |
| MATRIX_LR = 2**MATRIX_LOG2_LR | |
| if MATRIX_OPTIMIZER == "Adam": | |
| MATRIX_BETA1 = env("MATRIX_BETA1", 0.9, float) | |
| MATRIX_BETA2 = env("MATRIX_BETA2", 0.95, float) | |
| elif MATRIX_OPTIMIZER == "Aurora": | |
| MATRIX_MOMENTUM = env("MATRIX_MOMENTUM", 0.95, float) | |
| else: | |
| raise ValueError(f"invalid MATRIX_OPTIMIZER: {MATRIX_OPTIMIZER}") | |
| MATRIX_EPS = env("MATRIX_EPS", 1e-8, float) | |
| OTHER_LOG2_LR = env("OTHER_LOG2_LR", -4, float) | |
| OTHER_LR = 2**OTHER_LOG2_LR | |
| OTHER_BETA1 = env("OTHER_BETA1", 0.9, float) | |
| OTHER_BETA2 = env("OTHER_BETA2", 0.95, float) | |
| OTHER_EPS = env("OTHER_EPS", 1e-8, float) | |
| # Profiling | |
| PROFILE_ENABLED = env("PROFILE_ENABLED", False, bool) | |
| PROFILE_CPU = env("PROFILE_CPU", True, bool) | |
| PROFILE_CUDA = env("PROFILE_CUDA", True, bool) | |
| PROFILE_RECORD_SHAPES = env("PROFILE_RECORD_SHAPES", True, bool) | |
| PROFILE_PROFILE_MEMORY = env("PROFILE_PROFILE_MEMORY", True, bool) | |
| PROFILE_WITH_STACK = env("PROFILE_WITH_STACK", False, bool) | |
| PROFILE_WITH_FLOPS = env("PROFILE_WITH_FLOPS", True, bool) | |
| PROFILE_WAIT_STEPS = env("PROFILE_WAIT_STEPS", 0, int) | |
| PROFILE_WARMUP_STEPS = env("PROFILE_WARMUP_STEPS", 10, int) | |
| PROFILE_ACTIVE_STEPS = env("PROFILE_ACTIVE_STEPS", 10, int) | |
| PROFILE_REPEAT = env("PROFILE_REPEAT", 1, int) | |
| # Example profiling usage: | |
| # | |
| # VAL_PERIOD=0 \ | |
| # SAVE_CHECKPOINT=0 \ | |
| # TRAIN_STEPS=111 \ | |
| # PROFILE_ENABLED=1 \ | |
| # PROFILE_WARMUP_STEPS=100 \ | |
| # PROFILE_ACTIVE_STEPS=10 \ | |
| # uv run train_multiscreen_fineweb.py | |
| # Weights & Biases | |
| # WANDB_ERROR_REPORTING: Set this to false to prevent wandb from logging fatal errors to its error tracking system. | |
| # WANDB_DISABLE_GIT: Set to true to prevent wandb from probing for a git repository and capturing the latest commit / diff. | |
| # WANDB_SAVE_CODE: Set to true to ensure the main script is saved. | |
| # WANDB_PROJECT: Sets the default project name for your runs. | |
| # WANDB_RUN_GROUP: Specify the experiment name to automatically group runs together. | |
| # WANDB_NAME: Sets a custom display name for the run. | |
| # WANDB_NOTES: Adds a detailed description or "commit message" for the run. | |
| # os.environ["WANDB_MODE"] = "offline" | |
| os.environ["WANDB_ERROR_REPORTING"] = "false" | |
| os.environ["WANDB_DISABLE_GIT"] = "true" | |
| os.environ["WANDB_SAVE_CODE"] = "true" | |
| os.environ["WANDB_PROJECT"] = "multiscreen-fineweb" | |
| WANDB_RUN_GROUP = "MS v4" | |
| os.environ["WANDB_RUN_GROUP"] = WANDB_RUN_GROUP | |
| os.environ["WANDB_NAME"] = f"{WANDB_RUN_GROUP} d{DEPTH} {MATRIX_OPTIMIZER}+embed LOG2_LRs: ({MATRIX_LOG2_LR}, {OTHER_LOG2_LR})" | |
| ####################################################################################################################### | |
| os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| import triton | |
| import triton.language as tl | |
| from torch import Tensor, nn | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| import wandb | |
| ####################################################################################################################### | |
| # varlen_multiscreen triton kernel by @scottjmaddox and GPT-5.4 / 5.5 | |
| def panic(msg): | |
| raise Exception(msg) | |
| @dataclass(frozen=True) | |
| class VarlenMultiscreenTilingConfig: | |
| block_m: int = MULTISCREEN_BLOCK_M | |
| block_n: int = MULTISCREEN_BLOCK_N | |
| @dataclass(frozen=True) | |
| class VarlenMultiscreenLaunchConfig: | |
| block_dk: int = KEY_DIM if KEY_DIM % 16 == 0 and KEY_DIM <= 128 else panic(f"bad KEY_DIM: {KEY_DIM}") | |
| block_dv: int = VALUE_DIM if VALUE_DIM % 16 == 0 and VALUE_DIM <= 128 else panic(f"bad VALUE_DIM: {VALUE_DIM}") | |
| num_warps: int = MULTISCREEN_NUM_WARPS | |
| num_stages: int = MULTISCREEN_NUM_STAGES | |
| @dataclass(frozen=True) | |
| class VarlenMultiscreenKernelConfig: | |
| tiling: VarlenMultiscreenTilingConfig = field(default_factory=VarlenMultiscreenTilingConfig) | |
| launch: VarlenMultiscreenLaunchConfig = field(default_factory=VarlenMultiscreenLaunchConfig) | |
| DEFAULT_VARLEN_MULTISCREEN_KERNEL_CONFIG = VarlenMultiscreenKernelConfig() | |
| @dataclass(frozen=True) | |
| class VarlenMultiscreenMeta: | |
| tiling_cfg: VarlenMultiscreenTilingConfig | |
| w: Tensor | |
| r: Tensor | |
| q_seg_starts: Tensor | |
| q_seg_ends: Tensor | |
| kv_seg_starts: Tensor | |
| kv_seg_ends: Tensor | |
| tile_segment_ids: Tensor | |
| tile_q_starts: Tensor | |
| tile_kv_start_idxs: Tensor | |
| tile_kv_num_blocks: Tensor | |
| num_segments: int | |
| num_q_tiles: int | |
| def _check_varlen_multiscreen_inputs( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| ) -> None: | |
| assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 | |
| assert q.shape[0] == k.shape[0] == v.shape[0] | |
| assert q.shape[2] == k.shape[2] | |
| assert k.shape[1] == v.shape[1] | |
| assert q.device == k.device == v.device == q_pos.device == kv_pos.device == cu_seqlens_q.device == cu_seqlens_kv.device | |
| assert q_pos.ndim == kv_pos.ndim == 1 | |
| assert cu_seqlens_q.ndim == cu_seqlens_kv.ndim == 1 | |
| assert q_pos.shape[0] == q.shape[1] | |
| assert kv_pos.shape[0] == k.shape[1] | |
| assert cu_seqlens_q.shape == cu_seqlens_kv.shape | |
| assert w.ndim == 1 and r.ndim == 1 | |
| assert w.shape[0] == q.shape[0] | |
| assert r.shape[0] == q.shape[0] | |
| def _check_varlen_multiscreen_meta( | |
| meta: VarlenMultiscreenMeta, | |
| *, | |
| device: torch.device, | |
| num_heads: int, | |
| ) -> None: | |
| assert meta.q_seg_starts.device == meta.q_seg_ends.device == meta.kv_seg_starts.device == meta.kv_seg_ends.device == device | |
| assert meta.tile_segment_ids.device == meta.tile_q_starts.device == meta.tile_kv_start_idxs.device == meta.tile_kv_num_blocks.device == device | |
| assert meta.q_seg_starts.ndim == meta.q_seg_ends.ndim == meta.kv_seg_starts.ndim == meta.kv_seg_ends.ndim == 1 | |
| assert meta.tile_segment_ids.ndim == meta.tile_q_starts.ndim == 1 | |
| assert meta.tile_kv_start_idxs.ndim == meta.tile_kv_num_blocks.ndim == 2 | |
| assert meta.q_seg_starts.shape == meta.q_seg_ends.shape == meta.kv_seg_starts.shape == meta.kv_seg_ends.shape | |
| assert meta.tile_segment_ids.shape == meta.tile_q_starts.shape | |
| assert meta.tile_kv_start_idxs.shape == meta.tile_kv_num_blocks.shape | |
| assert meta.tile_kv_start_idxs.shape[0] == num_heads | |
| assert meta.tile_kv_start_idxs.shape[1] == meta.tile_segment_ids.numel() | |
| @torch.library.custom_op("multiscreen::varlen_multiscreen", mutates_args=()) | |
| def triton_varlen_multiscreen( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| ) -> Tensor: | |
| _check_varlen_multiscreen_inputs(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r) | |
| cfg = DEFAULT_VARLEN_MULTISCREEN_KERNEL_CONFIG | |
| meta = build_varlen_multiscreen_meta(cu_seqlens_q, cu_seqlens_kv, w, r, tiling_cfg=cfg.tiling) | |
| _check_varlen_multiscreen_meta(meta, device=q.device, num_heads=q.shape[0]) | |
| out = torch.empty((q.shape[0], q.shape[1], v.shape[2]), device=v.device, dtype=v.dtype) | |
| _launch_triton_varlen_multiscreen(q, k, v, q_pos, kv_pos, out, meta, launch_cfg=cfg.launch) | |
| return out | |
| @torch.library.custom_op("multiscreen::varlen_multiscreen_backward", mutates_args=()) | |
| def triton_varlen_multiscreen_backward( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| dout: Tensor, | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: | |
| _check_varlen_multiscreen_inputs(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r) | |
| cfg = DEFAULT_VARLEN_MULTISCREEN_KERNEL_CONFIG | |
| meta = build_varlen_multiscreen_meta(cu_seqlens_q, cu_seqlens_kv, w, r, tiling_cfg=cfg.tiling) | |
| _check_varlen_multiscreen_meta(meta, device=q.device, num_heads=q.shape[0]) | |
| assert dout.shape == (q.shape[0], q.shape[1], v.shape[2]) | |
| dq = torch.zeros_like(q) | |
| dk = torch.zeros_like(k) | |
| dv = torch.zeros_like(v) | |
| dw = torch.zeros_like(w, dtype=torch.float32) | |
| dr = torch.zeros_like(r, dtype=torch.float32) | |
| _launch_triton_varlen_multiscreen_backward( | |
| q, k, v, q_pos, kv_pos, dout, dq, dk, dv, dw, dr, meta, launch_cfg=cfg.launch | |
| ) | |
| return dq, dk, dv, dw.to(w.dtype), dr.to(r.dtype) | |
| def build_varlen_multiscreen_meta( | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| *, | |
| tiling_cfg: VarlenMultiscreenTilingConfig, | |
| ) -> VarlenMultiscreenMeta: | |
| q_seg_starts = cu_seqlens_q[:-1].contiguous() | |
| q_seg_ends = cu_seqlens_q[1:].contiguous() | |
| kv_seg_starts = cu_seqlens_kv[:-1].contiguous() | |
| kv_seg_ends = cu_seqlens_kv[1:].contiguous() | |
| num_segments = q_seg_starts.numel() | |
| assert num_segments == kv_seg_starts.numel() | |
| w_runtime = w.detach().float() | |
| full_prefix_heads = torch.isinf(w_runtime) | |
| finite_w = torch.where(full_prefix_heads, torch.zeros_like(w_runtime), w_runtime) | |
| head_window_tile_counts = torch.ceil(finite_w / tiling_cfg.block_n).to(torch.int32) | |
| kv_seg_lens = kv_seg_ends - kv_seg_starts | |
| max_kv_tile_count = ( | |
| torch.div(kv_seg_lens.max() + (tiling_cfg.block_n - 1), tiling_cfg.block_n, rounding_mode="floor") | |
| if kv_seg_lens.numel() > 0 | |
| else head_window_tile_counts.new_zeros(()) | |
| ) | |
| head_window_tile_counts = torch.where( | |
| full_prefix_heads, | |
| max_kv_tile_count.to(head_window_tile_counts.dtype), | |
| head_window_tile_counts, | |
| ) | |
| q_seg_lens = q_seg_ends - q_seg_starts | |
| num_q_tiles_per_seg = torch.div(q_seg_lens + (tiling_cfg.block_m - 1), tiling_cfg.block_m, rounding_mode="floor") | |
| tile_segment_ids = torch.repeat_interleave( | |
| torch.arange(num_segments, device=cu_seqlens_q.device, dtype=torch.int32), | |
| num_q_tiles_per_seg.to(torch.int64), | |
| ) | |
| seg_tile_prefix = torch.cumsum(num_q_tiles_per_seg, dim=0) - num_q_tiles_per_seg | |
| tile_prefix = seg_tile_prefix[tile_segment_ids.to(torch.int64)] | |
| local_tile_idx = torch.arange(tile_segment_ids.shape[0], device=cu_seqlens_q.device, dtype=torch.int32) - tile_prefix | |
| tile_q_starts = (q_seg_starts[tile_segment_ids.to(torch.int64)] + local_tile_idx * tiling_cfg.block_m).to(torch.int32) | |
| idx = tile_segment_ids.to(torch.int64) | |
| q_tile_ends = torch.minimum(tile_q_starts + tiling_cfg.block_m, q_seg_ends[idx]) | |
| q_min = tile_q_starts - q_seg_starts[idx] | |
| q_max = q_tile_ends - q_seg_starts[idx] - 1 | |
| kv_seg_lens = kv_seg_ends[idx] - kv_seg_starts[idx] | |
| head_window_tokens = head_window_tile_counts.unsqueeze(1) * tiling_cfg.block_n | |
| kv_min_pos = torch.clamp(q_min.unsqueeze(0) - head_window_tokens + 1, min=0) | |
| kv_max_pos = torch.minimum(kv_seg_lens.unsqueeze(0) - 1, q_max.unsqueeze(0)) | |
| valid = kv_max_pos >= kv_min_pos | |
| tile_kv_start_idxs = torch.where( | |
| valid, | |
| torch.div(kv_min_pos, tiling_cfg.block_n, rounding_mode="floor"), | |
| torch.zeros_like(kv_min_pos), | |
| ).to(torch.int32).contiguous() | |
| tile_kv_end_idxs = torch.where( | |
| valid, | |
| torch.div(kv_max_pos, tiling_cfg.block_n, rounding_mode="floor") + 1, | |
| torch.zeros_like(kv_max_pos), | |
| ) | |
| tile_kv_num_blocks = (tile_kv_end_idxs - tile_kv_start_idxs).to(torch.int32).contiguous() | |
| return VarlenMultiscreenMeta( | |
| tiling_cfg=tiling_cfg, | |
| w=w, | |
| r=r, | |
| q_seg_starts=q_seg_starts, | |
| q_seg_ends=q_seg_ends, | |
| kv_seg_starts=kv_seg_starts, | |
| kv_seg_ends=kv_seg_ends, | |
| tile_segment_ids=tile_segment_ids, | |
| tile_q_starts=tile_q_starts, | |
| tile_kv_start_idxs=tile_kv_start_idxs, | |
| tile_kv_num_blocks=tile_kv_num_blocks, | |
| num_segments=num_segments, | |
| num_q_tiles=tile_segment_ids.numel(), | |
| ) | |
| @triton.jit | |
| def _varlen_multiscreen_fwd_kernel( | |
| q_ptr, | |
| k_ptr, | |
| v_ptr, | |
| q_pos_ptr, | |
| kv_pos_ptr, | |
| q_seg_starts_ptr, | |
| q_seg_ends_ptr, | |
| kv_seg_starts_ptr, | |
| kv_seg_ends_ptr, | |
| tile_segment_ids_ptr, | |
| tile_q_starts_ptr, | |
| tile_kv_start_idxs_ptr, | |
| tile_kv_num_blocks_ptr, | |
| w_ptr, | |
| r_ptr, | |
| out_ptr, | |
| num_q_tiles, | |
| tile_kv_start_idxs_stride_h, | |
| tile_kv_start_idxs_stride_t, | |
| tile_kv_num_blocks_stride_h, | |
| tile_kv_num_blocks_stride_t, | |
| q_stride_h, | |
| q_stride_t, | |
| q_stride_d, | |
| k_stride_h, | |
| k_stride_t, | |
| k_stride_d, | |
| v_stride_h, | |
| v_stride_t, | |
| v_stride_d, | |
| out_stride_h, | |
| out_stride_t, | |
| out_stride_d, | |
| num_segments, | |
| num_heads, | |
| d_k, | |
| d_v, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| BLOCK_DK: tl.constexpr, | |
| BLOCK_DV: tl.constexpr, | |
| ): | |
| PI = 3.141592653589793 | |
| pid = tl.program_id(axis=0) | |
| hid = tl.program_id(axis=1) | |
| if hid >= num_heads: | |
| return | |
| w = tl.load(w_ptr + hid).to(tl.float32) | |
| r = tl.load(r_ptr + hid).to(tl.float32) | |
| segment_id = tl.load(tile_segment_ids_ptr + pid) | |
| q_start = tl.load(tile_q_starts_ptr + pid) | |
| q_end = tl.load(q_seg_ends_ptr + segment_id) | |
| kv_start_base = tl.load(kv_seg_starts_ptr + segment_id) | |
| kv_end = tl.load(kv_seg_ends_ptr + segment_id) | |
| kv_tile_start_idx = tl.load(tile_kv_start_idxs_ptr + hid * tile_kv_start_idxs_stride_h + pid * tile_kv_start_idxs_stride_t) | |
| active_kv_tile_count = tl.load(tile_kv_num_blocks_ptr + hid * tile_kv_num_blocks_stride_h + pid * tile_kv_num_blocks_stride_t) | |
| offs_m = q_start + tl.arange(0, BLOCK_M) | |
| offs_dk = tl.arange(0, BLOCK_DK) | |
| offs_dv = tl.arange(0, BLOCK_DV) | |
| q_mask = offs_m < q_end | |
| dk_mask = offs_dk < d_k | |
| dv_mask = offs_dv < d_v | |
| q_ptrs = q_ptr + hid * q_stride_h + offs_m[:, None] * q_stride_t + offs_dk[None, :] * q_stride_d | |
| q_block = tl.load(q_ptrs, mask=q_mask[:, None] & dk_mask[None, :], other=0.0).to(tl.float32) | |
| q_pos = tl.load(q_pos_ptr + offs_m, mask=q_mask, other=0).to(tl.float32) | |
| acc = tl.zeros((BLOCK_M, BLOCK_DV), dtype=tl.float32) | |
| kv_tile_offset = 0 | |
| while kv_tile_offset < active_kv_tile_count: | |
| kv_tile_idx = kv_tile_start_idx + kv_tile_offset | |
| kv_start = kv_start_base + kv_tile_idx * BLOCK_N | |
| offs_n = kv_start + tl.arange(0, BLOCK_N) | |
| kv_mask = offs_n < kv_end | |
| k_ptrs = k_ptr + hid * k_stride_h + offs_n[None, :] * k_stride_t + offs_dk[:, None] * k_stride_d | |
| k_block = tl.load(k_ptrs, mask=dk_mask[:, None] & kv_mask[None, :], other=0.0).to(tl.float32) | |
| sim = tl.dot(q_block, k_block) | |
| kv_pos = tl.load(kv_pos_ptr + offs_n, mask=kv_mask, other=0).to(tl.float32) | |
| rel = kv_pos[None, :] - q_pos[:, None] | |
| causal_mask = rel <= 0.0 | |
| window_mask = causal_mask & (rel > -w) | |
| alpha = tl.maximum(1.0 - r * (1.0 - sim), 0.0) | |
| alpha = alpha * alpha | |
| cosine_mask = 0.5 * (tl.cos(PI * rel / w) + 1.0) | |
| alpha = tl.where(window_mask & q_mask[:, None] & kv_mask[None, :], alpha * cosine_mask, 0.0) | |
| v_ptrs = v_ptr + hid * v_stride_h + offs_n[:, None] * v_stride_t + offs_dv[None, :] * v_stride_d | |
| v_block = tl.load(v_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0.0).to(tl.float32) | |
| acc += tl.dot(alpha, v_block) | |
| kv_tile_offset += 1 | |
| out_ptrs = out_ptr + hid * out_stride_h + offs_m[:, None] * out_stride_t + offs_dv[None, :] * out_stride_d | |
| tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=q_mask[:, None] & dv_mask[None, :]) | |
| @triton.jit | |
| def _varlen_multiscreen_bwd_kernel( | |
| q_ptr, | |
| k_ptr, | |
| v_ptr, | |
| q_pos_ptr, | |
| kv_pos_ptr, | |
| q_seg_starts_ptr, | |
| q_seg_ends_ptr, | |
| kv_seg_starts_ptr, | |
| kv_seg_ends_ptr, | |
| tile_segment_ids_ptr, | |
| tile_q_starts_ptr, | |
| tile_kv_start_idxs_ptr, | |
| tile_kv_num_blocks_ptr, | |
| w_ptr, | |
| r_ptr, | |
| dout_ptr, | |
| dq_ptr, | |
| dk_ptr, | |
| dv_ptr, | |
| dw_ptr, | |
| dr_ptr, | |
| num_q_tiles, | |
| tile_kv_start_idxs_stride_h, | |
| tile_kv_start_idxs_stride_t, | |
| tile_kv_num_blocks_stride_h, | |
| tile_kv_num_blocks_stride_t, | |
| q_stride_h, | |
| q_stride_t, | |
| q_stride_d, | |
| k_stride_h, | |
| k_stride_t, | |
| k_stride_d, | |
| v_stride_h, | |
| v_stride_t, | |
| v_stride_d, | |
| dout_stride_h, | |
| dout_stride_t, | |
| dout_stride_d, | |
| dq_stride_h, | |
| dq_stride_t, | |
| dq_stride_d, | |
| dk_stride_h, | |
| dk_stride_t, | |
| dk_stride_d, | |
| dv_stride_h, | |
| dv_stride_t, | |
| dv_stride_d, | |
| num_segments, | |
| num_heads, | |
| d_k, | |
| d_v, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| BLOCK_DK: tl.constexpr, | |
| BLOCK_DV: tl.constexpr, | |
| ): | |
| PI = 3.141592653589793 | |
| pid = tl.program_id(axis=0) | |
| hid = tl.program_id(axis=1) | |
| if hid >= num_heads: | |
| return | |
| w = tl.load(w_ptr + hid).to(tl.float32) | |
| r = tl.load(r_ptr + hid).to(tl.float32) | |
| segment_id = tl.load(tile_segment_ids_ptr + pid) | |
| q_start = tl.load(tile_q_starts_ptr + pid) | |
| q_end = tl.load(q_seg_ends_ptr + segment_id) | |
| kv_start_base = tl.load(kv_seg_starts_ptr + segment_id) | |
| kv_end = tl.load(kv_seg_ends_ptr + segment_id) | |
| kv_tile_start_idx = tl.load(tile_kv_start_idxs_ptr + hid * tile_kv_start_idxs_stride_h + pid * tile_kv_start_idxs_stride_t) | |
| active_kv_tile_count = tl.load(tile_kv_num_blocks_ptr + hid * tile_kv_num_blocks_stride_h + pid * tile_kv_num_blocks_stride_t) | |
| offs_m = q_start + tl.arange(0, BLOCK_M) | |
| offs_n_base = tl.arange(0, BLOCK_N) | |
| offs_dk = tl.arange(0, BLOCK_DK) | |
| offs_dv = tl.arange(0, BLOCK_DV) | |
| q_mask = offs_m < q_end | |
| dk_mask = offs_dk < d_k | |
| dv_mask = offs_dv < d_v | |
| q_ptrs = q_ptr + hid * q_stride_h + offs_m[:, None] * q_stride_t + offs_dk[None, :] * q_stride_d | |
| q_block = tl.load(q_ptrs, mask=q_mask[:, None] & dk_mask[None, :], other=0.0).to(tl.float32) | |
| q_pos = tl.load(q_pos_ptr + offs_m, mask=q_mask, other=0).to(tl.float32) | |
| dout_ptrs = dout_ptr + hid * dout_stride_h + offs_m[:, None] * dout_stride_t + offs_dv[None, :] * dout_stride_d | |
| dout_block = tl.load(dout_ptrs, mask=q_mask[:, None] & dv_mask[None, :], other=0.0).to(tl.float32) | |
| dq_acc = tl.zeros((BLOCK_M, BLOCK_DK), dtype=tl.float32) | |
| dw_acc = 0.0 | |
| dr_acc = 0.0 | |
| kv_tile_offset = 0 | |
| while kv_tile_offset < active_kv_tile_count: | |
| kv_tile_idx = kv_tile_start_idx + kv_tile_offset | |
| kv_start = kv_start_base + kv_tile_idx * BLOCK_N | |
| offs_n = kv_start + offs_n_base | |
| kv_mask = offs_n < kv_end | |
| k_ptrs = k_ptr + hid * k_stride_h + offs_n[None, :] * k_stride_t + offs_dk[:, None] * k_stride_d | |
| k_block = tl.load(k_ptrs, mask=dk_mask[:, None] & kv_mask[None, :], other=0.0).to(tl.float32) | |
| v_ptrs = v_ptr + hid * v_stride_h + offs_n[:, None] * v_stride_t + offs_dv[None, :] * v_stride_d | |
| v_block = tl.load(v_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0.0).to(tl.float32) | |
| kv_pos = tl.load(kv_pos_ptr + offs_n, mask=kv_mask, other=0).to(tl.float32) | |
| sim = tl.dot(q_block, k_block) | |
| rel = kv_pos[None, :] - q_pos[:, None] | |
| causal_mask = rel <= 0.0 | |
| window_mask = causal_mask & (rel > -w) | |
| base = 1.0 - r * (1.0 - sim) | |
| relu_base = tl.maximum(base, 0.0) | |
| content = relu_base * relu_base | |
| cosine_phase = PI * rel / w | |
| cosine_mask = 0.5 * (tl.cos(cosine_phase) + 1.0) | |
| active_mask = window_mask & q_mask[:, None] & kv_mask[None, :] | |
| alpha = tl.where(active_mask, content * cosine_mask, 0.0) | |
| grad_alpha = tl.dot(dout_block, tl.trans(v_block)) | |
| grad_content = tl.where(active_mask, grad_alpha * cosine_mask, 0.0) | |
| positive_mask = active_mask & (base > 0.0) | |
| dsim = tl.where(positive_mask, grad_content * (2.0 * relu_base * r), 0.0) | |
| dq_acc += tl.dot(dsim, tl.trans(k_block)) | |
| dk_part = tl.dot(tl.trans(dsim), q_block) | |
| dk_ptrs = dk_ptr + hid * dk_stride_h + offs_n[:, None] * dk_stride_t + offs_dk[None, :] * dk_stride_d | |
| tl.atomic_add(dk_ptrs, dk_part, mask=kv_mask[:, None] & dk_mask[None, :]) | |
| dv_part = tl.dot(tl.trans(alpha), dout_block) | |
| dv_ptrs = dv_ptr + hid * dv_stride_h + offs_n[:, None] * dv_stride_t + offs_dv[None, :] * dv_stride_d | |
| tl.atomic_add(dv_ptrs, dv_part, mask=kv_mask[:, None] & dv_mask[None, :]) | |
| dcos_dw = 0.5 * tl.sin(cosine_phase) * (PI * rel / (w * w)) | |
| dw_mat = tl.where(active_mask, grad_alpha * content * dcos_dw, 0.0) | |
| dr_mat = tl.where(positive_mask, grad_content * (2.0 * relu_base * (sim - 1.0)), 0.0) | |
| dw_acc += tl.sum(tl.sum(dw_mat, axis=1), axis=0) | |
| dr_acc += tl.sum(tl.sum(dr_mat, axis=1), axis=0) | |
| kv_tile_offset += 1 | |
| dq_ptrs = dq_ptr + hid * dq_stride_h + offs_m[:, None] * dq_stride_t + offs_dk[None, :] * dq_stride_d | |
| tl.store(dq_ptrs, dq_acc.to(dq_ptr.dtype.element_ty), mask=q_mask[:, None] & dk_mask[None, :]) | |
| tl.atomic_add(dw_ptr + hid, dw_acc) | |
| tl.atomic_add(dr_ptr + hid, dr_acc) | |
| def _launch_triton_varlen_multiscreen( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| out: Tensor, | |
| meta: VarlenMultiscreenMeta, | |
| launch_cfg: VarlenMultiscreenLaunchConfig, | |
| ) -> None: | |
| grid = (meta.tile_segment_ids.numel(), q.shape[0]) | |
| if q.shape[2] > launch_cfg.block_dk: | |
| raise NotImplementedError(f"d_k={q.shape[2]} exceeds block_dk={launch_cfg.block_dk}") | |
| if v.shape[2] > launch_cfg.block_dv: | |
| raise NotImplementedError(f"d_v={v.shape[2]} exceeds block_dv={launch_cfg.block_dv}") | |
| _varlen_multiscreen_fwd_kernel[grid]( | |
| q, k, v, q_pos, kv_pos, meta.q_seg_starts, meta.q_seg_ends, meta.kv_seg_starts, meta.kv_seg_ends, | |
| meta.tile_segment_ids, meta.tile_q_starts, meta.tile_kv_start_idxs, meta.tile_kv_num_blocks, meta.w, meta.r, out, | |
| meta.num_q_tiles, | |
| meta.tile_kv_start_idxs.stride(0), meta.tile_kv_start_idxs.stride(1), | |
| meta.tile_kv_num_blocks.stride(0), meta.tile_kv_num_blocks.stride(1), | |
| q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), | |
| v.stride(0), v.stride(1), v.stride(2), out.stride(0), out.stride(1), out.stride(2), | |
| meta.num_segments, q.shape[0], q.shape[2], v.shape[2], | |
| BLOCK_M=meta.tiling_cfg.block_m, BLOCK_N=meta.tiling_cfg.block_n, | |
| BLOCK_DK=launch_cfg.block_dk, BLOCK_DV=launch_cfg.block_dv, | |
| num_warps=launch_cfg.num_warps, num_stages=launch_cfg.num_stages, | |
| ) | |
| def _launch_triton_varlen_multiscreen_backward( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| dout: Tensor, | |
| dq: Tensor, | |
| dk: Tensor, | |
| dv: Tensor, | |
| dw: Tensor, | |
| dr: Tensor, | |
| meta: VarlenMultiscreenMeta, | |
| launch_cfg: VarlenMultiscreenLaunchConfig, | |
| ) -> None: | |
| grid = (meta.tile_segment_ids.numel(), q.shape[0]) | |
| if q.shape[2] > launch_cfg.block_dk: | |
| raise NotImplementedError(f"d_k={q.shape[2]} exceeds block_dk={launch_cfg.block_dk}") | |
| if v.shape[2] > launch_cfg.block_dv: | |
| raise NotImplementedError(f"d_v={v.shape[2]} exceeds block_dv={launch_cfg.block_dv}") | |
| _varlen_multiscreen_bwd_kernel[grid]( | |
| q, k, v, q_pos, kv_pos, meta.q_seg_starts, meta.q_seg_ends, meta.kv_seg_starts, meta.kv_seg_ends, | |
| meta.tile_segment_ids, meta.tile_q_starts, meta.tile_kv_start_idxs, meta.tile_kv_num_blocks, meta.w, meta.r, | |
| dout, dq, dk, dv, dw, dr, | |
| meta.num_q_tiles, | |
| meta.tile_kv_start_idxs.stride(0), meta.tile_kv_start_idxs.stride(1), | |
| meta.tile_kv_num_blocks.stride(0), meta.tile_kv_num_blocks.stride(1), | |
| q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), | |
| v.stride(0), v.stride(1), v.stride(2), dout.stride(0), dout.stride(1), dout.stride(2), | |
| dq.stride(0), dq.stride(1), dq.stride(2), dk.stride(0), dk.stride(1), dk.stride(2), | |
| dv.stride(0), dv.stride(1), dv.stride(2), meta.num_segments, q.shape[0], q.shape[2], v.shape[2], | |
| BLOCK_M=meta.tiling_cfg.block_m, BLOCK_N=meta.tiling_cfg.block_n, | |
| BLOCK_DK=launch_cfg.block_dk, BLOCK_DV=launch_cfg.block_dv, | |
| num_warps=launch_cfg.num_warps, num_stages=launch_cfg.num_stages, | |
| ) | |
| @triton_varlen_multiscreen.register_fake | |
| def _fake_triton_varlen_multiscreen( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| ) -> Tensor: | |
| return q.new_empty((q.shape[0], q.shape[1], v.shape[2])) | |
| @triton_varlen_multiscreen_backward.register_fake | |
| def _fake_triton_varlen_multiscreen_backward( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| dout: Tensor, | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: | |
| return ( | |
| q.new_empty(q.shape), | |
| k.new_empty(k.shape), | |
| v.new_empty(v.shape), | |
| w.new_empty(w.shape), | |
| r.new_empty(r.shape), | |
| ) | |
| def _varlen_multiscreen_setup_context(ctx, inputs, output): | |
| q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r = inputs | |
| ctx.save_for_backward(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r) | |
| def _varlen_multiscreen_backward(ctx, grad_out: Tensor): | |
| q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r = ctx.saved_tensors | |
| dq, dk, dv, dw, dr = triton_varlen_multiscreen_backward( | |
| q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r, grad_out, | |
| ) | |
| return dq, dk, dv, None, None, None, None, dw, dr | |
| triton_varlen_multiscreen.register_autograd( | |
| _varlen_multiscreen_backward, | |
| setup_context=_varlen_multiscreen_setup_context, | |
| ) | |
| def varlen_multiscreen( | |
| q: Tensor, | |
| k: Tensor, | |
| v: Tensor, | |
| q_pos: Tensor, | |
| kv_pos: Tensor, | |
| cu_seqlens_q: Tensor, | |
| cu_seqlens_kv: Tensor, | |
| w: Tensor, | |
| r: Tensor, | |
| ) -> Tensor: | |
| return triton_varlen_multiscreen(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r) | |
| def _load_data_shard(file: Path): | |
| header = torch.from_file(str(file), False, 256, dtype=torch.int32) | |
| assert header[0] == 20240520, "magic number mismatch in the data .bin file" | |
| assert header[1] == 1, "unsupported version" | |
| num_tokens = int(header[2]) | |
| with file.open("rb", buffering=0) as f: | |
| tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) | |
| f.seek(256 * 4) | |
| nbytes = f.readinto(tokens.numpy()) | |
| assert nbytes == 2 * num_tokens, "number of tokens read does not match header" | |
| return tokens | |
| ####################################################################################################################### | |
| # multiscreen model implementation by @scottjmaddox and GPT-5.4 / 5.5 | |
| def maybe_torch_compile(**compile_kwargs): | |
| def decorator(fn): | |
| if COMPILE_MODEL: | |
| return torch.compile(fn, **compile_kwargs) | |
| return fn | |
| return decorator | |
| def row_normalize(x: Tensor, eps: float = 1e-8): | |
| return x / x.norm(dim=-1, keepdim=True).clamp_min(eps) | |
| def tanh_norm(x: Tensor, eps: float = 1e-8): | |
| norm = x.norm(dim=-1, keepdim=True).clamp_min(eps) | |
| return x * (torch.tanh(norm) / norm) | |
| def apply_mipe(z: Tensor, w: Tensor, pos_ids: Tensor, window_threshold: float) -> Tensor: | |
| pos_ids = pos_ids.to(device=z.device, dtype=z.dtype).view(1, z.shape[1], 1) | |
| w = w.to(device=z.device, dtype=z.dtype).view(z.shape[0], 1, 1) | |
| gamma = torch.where( | |
| w < window_threshold, | |
| 0.5 * (torch.cos(w * (math.pi / window_threshold)) + 1.0), | |
| torch.zeros_like(w), | |
| ) | |
| inv_w = torch.where(torch.isinf(w), torch.zeros_like(w), torch.reciprocal(w)) | |
| phi = math.pi * pos_ids * gamma * inv_w | |
| cos_phi = torch.cos(phi) | |
| sin_phi = torch.sin(phi) | |
| z0 = z[..., 0:1] | |
| z1 = z[..., 1:2] | |
| rotated0 = z0 * cos_phi - z1 * sin_phi | |
| rotated1 = z0 * sin_phi + z1 * cos_phi | |
| return torch.cat((rotated0, rotated1, z[..., 2:]), dim=-1) | |
| def project_heads(x: Tensor, w: Tensor) -> Tensor: | |
| out = F.linear(x, w.reshape(-1, w.shape[-1])) | |
| return out.view(x.shape[0], w.shape[0], w.shape[1]).permute(1, 0, 2).contiguous() | |
| class MultiscreenLayer(nn.Module): | |
| def __init__( | |
| self, | |
| model_dim: int, | |
| num_heads: int, | |
| key_dim: int, | |
| value_dim: int, | |
| depth: int, | |
| window_threshold: float, | |
| train_max_seq_len: int, | |
| ): | |
| super().__init__() | |
| self.model_dim = model_dim | |
| self.num_heads = num_heads | |
| self.key_dim = key_dim | |
| self.value_dim = value_dim | |
| self.window_threshold = window_threshold | |
| self.train_max_seq_len = train_max_seq_len | |
| self.w_q = nn.Parameter((torch.randn(num_heads, key_dim, model_dim) * (0.1 / math.sqrt(key_dim))).bfloat16()) | |
| self.w_k = nn.Parameter((torch.randn(num_heads, key_dim, model_dim) * (0.1 / math.sqrt(key_dim))).bfloat16()) | |
| self.w_v = nn.Parameter((torch.randn(num_heads, value_dim, model_dim) * (0.1 / math.sqrt(value_dim))).bfloat16()) | |
| self.w_g = nn.Parameter((torch.randn(num_heads, value_dim, model_dim) * 0.1).bfloat16()) | |
| self.w_o = nn.Parameter((torch.randn(num_heads, model_dim, value_dim) * (0.1 / math.sqrt(model_dim))).bfloat16()) | |
| self.s_w = nn.Parameter(torch.linspace(0, math.log(window_threshold), steps=num_heads)) | |
| self.s_r = nn.Parameter(torch.zeros(num_heads)) | |
| self.s_o = nn.Parameter(torch.full((num_heads,), math.log(1.0 / math.sqrt(num_heads * depth)))) | |
| def w(self) -> Tensor: | |
| w = torch.exp(self.s_w) + 1.0 | |
| if self.training: | |
| return w | |
| use_full_prefix = w.detach().float() > self.train_max_seq_len | |
| return torch.where(use_full_prefix, torch.full_like(w, float("inf")), w) | |
| def r(self) -> Tensor: | |
| return torch.sigmoid(self.s_r) | |
| def forward( | |
| self, | |
| x: Tensor, | |
| pos_ids: Tensor, | |
| cu_seqlens: Tensor, | |
| ) -> Tensor: | |
| w = self.w() | |
| r = self.r() | |
| q = project_heads(x, self.w_q) | |
| k = project_heads(x, self.w_k) | |
| v = project_heads(x, self.w_v) | |
| g = project_heads(x, self.w_g) | |
| q = row_normalize(q) | |
| k = row_normalize(k) | |
| v = row_normalize(v) | |
| q = apply_mipe(q, w, pos_ids, self.window_threshold) | |
| k = apply_mipe(k, w, pos_ids, self.window_threshold) | |
| h_screen = varlen_multiscreen(q, k, v, pos_ids, pos_ids, cu_seqlens, cu_seqlens, w, r) | |
| h_out = tanh_norm(h_screen) * torch.tanh(F.silu(g)) | |
| outputs = torch.bmm(h_out, self.w_o.mT) | |
| outputs = outputs * torch.exp(self.s_o).to(outputs.dtype).view(self.s_o.shape[0], 1, 1) | |
| return outputs.sum(dim=0) | |
| @dataclass | |
| class ModelConfig: | |
| vocab_size: int | |
| model_dim: int | |
| depth: int | |
| num_heads: int | |
| key_dim: int | |
| value_dim: int | |
| window_threshold: float | |
| train_max_seq_len: int | |
| class MultiscreenLM(nn.Module): | |
| def __init__(self, cfg: ModelConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed = nn.Parameter((torch.randn(cfg.vocab_size, cfg.model_dim) * (0.1 / math.sqrt(self.cfg.model_dim))).bfloat16()) | |
| self.layers = nn.ModuleList([ | |
| MultiscreenLayer( | |
| model_dim=cfg.model_dim, | |
| num_heads=cfg.num_heads, | |
| key_dim=cfg.key_dim, | |
| value_dim=cfg.value_dim, | |
| depth=cfg.depth, | |
| window_threshold=cfg.window_threshold, | |
| train_max_seq_len=cfg.train_max_seq_len, | |
| ) | |
| for _ in range(cfg.depth) | |
| ]) | |
| self.s_e = nn.Parameter(torch.zeros(())) | |
| self.s_f = nn.Parameter(torch.tensor(math.log(math.sqrt(cfg.model_dim)))) | |
| def get_extra_state(self): | |
| return {"model_config": asdict(self.cfg)} | |
| def set_extra_state(self, state): | |
| if not state: | |
| return | |
| model_config = state.get("model_config") | |
| if model_config is not None: | |
| self.cfg = ModelConfig(**model_config) | |
| def window_stats(self) -> dict[str, float]: | |
| with torch.no_grad(): | |
| windows = [layer.w().detach().float() for layer in self.layers] | |
| all_windows = torch.cat(windows) | |
| return { | |
| "window_min": float(all_windows.min().item()), | |
| "window_median": float(all_windows.median().item()), | |
| "window_mean": float(all_windows.mean().item()), | |
| "window_max": float(all_windows.max().item()), | |
| } | |
| @maybe_torch_compile(dynamic=True, fullgraph=True) | |
| def forward( | |
| self, | |
| input_ids: Tensor, | |
| target_ids: Tensor, | |
| pos_ids: Tensor, | |
| cu_seqlens: Tensor, | |
| ) -> tuple[Tensor, Tensor]: | |
| embed = row_normalize(self.embed) | |
| x = F.embedding(input_ids, embed) | |
| x = x * torch.exp(self.s_e).to(x.dtype) | |
| for layer in self.layers: | |
| x = x + layer(x, pos_ids, cu_seqlens) | |
| logits = x @ embed.mT * torch.exp(self.s_f).to(x.dtype) | |
| loss_sum = F.cross_entropy(logits.float(), target_ids.view(-1), reduction="sum") | |
| token_count = target_ids.numel() | |
| return loss_sum, token_count | |
| ####################################################################################################################### | |
| # data loader adapted from https://github.com/kellerjordan/modded-nanogpt | |
| BOS_ID = 50256 | |
| class BOSFinder: | |
| def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False): | |
| self.tokens = tokens | |
| self.size = tokens.numel() | |
| self.quickload = quickload | |
| if quickload: | |
| self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() | |
| self.thread = None | |
| self.ready = threading.Event() | |
| self.start() | |
| else: | |
| self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() | |
| self.i = 0 | |
| self.world_size = world_size | |
| self.batch_iter = 0 | |
| def _load(self): | |
| self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy() | |
| self.ready.set() | |
| def start(self): | |
| self.ready.clear() | |
| self.thread = threading.Thread(target=self._load) | |
| self.thread.start() | |
| def get(self): | |
| if self.thread: | |
| self.ready.wait() | |
| self.thread.join() | |
| self.bos_idx = self.bos_idx_async | |
| def next_batch(self, num_tokens_local: int, max_seq_len: int): | |
| if self.quickload and self.batch_iter == 5: | |
| self.get() | |
| n = len(self.bos_idx) | |
| starts = [[] for _ in range(self.world_size)] | |
| ends = [[] for _ in range(self.world_size)] | |
| idx = self.i | |
| for r in range(self.world_size): | |
| cur_len = 0 | |
| while cur_len <= num_tokens_local: | |
| if idx >= n: | |
| raise StopIteration("Insufficient BOS ahead; hit tail of shard.") | |
| cur = self.bos_idx[idx] | |
| end = min( | |
| self.bos_idx[idx + 1] if idx + 1 < n else self.size, | |
| cur + max_seq_len, | |
| cur + num_tokens_local - cur_len + 1, | |
| ) | |
| starts[r].append(cur) | |
| ends[r].append(end) | |
| cur_len += end - cur | |
| idx += 1 | |
| assert cur_len == num_tokens_local + 1 | |
| self.i = idx | |
| self.batch_iter += 1 | |
| return starts, ends | |
| class DataPreloader: | |
| def __init__(self, file_iter, world_size: int = 1): | |
| self.file_iter = file_iter | |
| self.world_size = world_size | |
| self.thread = None | |
| self.data = None | |
| self.ready = threading.Event() | |
| def _load(self): | |
| tokens = _load_data_shard(next(self.file_iter)) | |
| self.data = (tokens, BOSFinder(tokens, self.world_size)) | |
| self.ready.set() | |
| def start(self): | |
| self.ready.clear() | |
| self.thread = threading.Thread(target=self._load) | |
| self.thread.start() | |
| def get(self): | |
| if self.thread: | |
| self.ready.wait() | |
| self.thread.join() | |
| return self.data | |
| def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True): | |
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size" | |
| num_tokens = num_tokens // grad_accum_steps | |
| files = [Path(file) for file in sorted(glob.glob(filename_pattern))] | |
| if not files: | |
| raise FileNotFoundError(f"No files found for pattern: {filename_pattern}") | |
| file_iter = iter(files) | |
| tokens = _load_data_shard(next(file_iter)) | |
| if align_to_bos: | |
| finder = BOSFinder(tokens, world_size=world_size, quickload=True) | |
| preloader = DataPreloader(file_iter, world_size) | |
| preloader.start() | |
| else: | |
| pos = 0 | |
| while True: | |
| num_tokens_local = num_tokens // world_size | |
| if align_to_bos: | |
| try: | |
| seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len) | |
| start_idxs = torch.tensor(seq_starts[rank]) | |
| end_idxs = torch.tensor(seq_ends[rank]) | |
| except StopIteration: | |
| tokens, finder = preloader.get() | |
| preloader.start() | |
| continue | |
| buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)]) | |
| _input_ids = buf[:-1] | |
| _target_ids = buf[1:] | |
| end_idxs[-1] -= 1 | |
| seqends = (end_idxs - start_idxs).cumsum(0) | |
| else: | |
| if pos + num_tokens + 1 >= len(tokens): | |
| tokens, pos = _load_data_shard(next(file_iter)), 0 | |
| pos_local = pos + rank * num_tokens_local | |
| buf = tokens[pos_local: pos_local + num_tokens_local + 1] | |
| _input_ids = buf[:-1].view(num_tokens_local) | |
| _target_ids = buf[1:].view(num_tokens_local) | |
| bos_positions = torch.nonzero(_input_ids == BOS_ID)[:, 0] | |
| seqends = bos_positions[bos_positions > 0] | |
| if seqends.numel() == 0 or seqends[-1] != _input_ids.numel(): | |
| seqends = torch.cat([seqends, seqends.new_full((1,), _input_ids.numel())]) | |
| pos += num_tokens | |
| assert seqends.numel() >= 1 | |
| assert (seqends > 0).all() | |
| assert seqends[-1] == _input_ids.numel() | |
| _cu_seqlens = torch.cat([seqends.new_zeros(1), seqends]) | |
| _input_ids = _input_ids.to(dtype=torch.int32) | |
| _target_ids = _target_ids.to(dtype=torch.int64) | |
| _cu_seqlens = _cu_seqlens.to(torch.int32) | |
| _pos_ids = build_pos_ids_from_cu_seqlens(_cu_seqlens) | |
| _input_ids = _input_ids.to(device="cuda", non_blocking=True) | |
| _target_ids = _target_ids.to(device="cuda", non_blocking=True) | |
| _pos_ids = _pos_ids.to(device="cuda", non_blocking=True) | |
| _cu_seqlens = _cu_seqlens.to(device="cuda", non_blocking=True) | |
| torch._dynamo.maybe_mark_dynamic(_input_ids, 0) | |
| torch._dynamo.maybe_mark_dynamic(_target_ids, 0) | |
| torch._dynamo.maybe_mark_dynamic(_pos_ids, 0) | |
| torch._dynamo.maybe_mark_dynamic(_cu_seqlens, 0) | |
| new_params = yield (_input_ids, _target_ids, _pos_ids, _cu_seqlens) | |
| if new_params is not None: | |
| new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params | |
| assert new_num_tokens % (world_size * new_grad_accum_steps) == 0 | |
| num_tokens = new_num_tokens // new_grad_accum_steps | |
| max_seq_len = new_max_seq_len | |
| def build_pos_ids_from_cu_seqlens(cu_seqlens: Tensor) -> Tensor: | |
| seqstarts = cu_seqlens[:-1] | |
| seqends = cu_seqlens[1:] | |
| seqlengths = seqends - seqstarts | |
| doc_starts = torch.repeat_interleave(seqstarts, seqlengths) | |
| return torch.arange(cu_seqlens[-1], device=cu_seqlens.device, dtype=torch.int32) - doc_starts | |
| ####################################################################################################################### | |
| # Aurora optimizer, adapted from https://github.com/tilde-research/aurora-release | |
| @torch.no_grad() | |
| def _polar(G: Tensor) -> Tensor: | |
| """Polar factor via 12-step simple-quintic Newton-Schulz. | |
| Args: | |
| G: input matrix of shape [..., m, n]. | |
| Returns: | |
| polar(G) of the same shape, in bfloat16. All non-zero singular values | |
| of G are mapped to 1. | |
| """ | |
| if G.ndim < 2: | |
| raise ValueError(f"polar expects tensors with ndim >= 2, got ndim={G.ndim}") | |
| X = G.bfloat16() | |
| transposed = G.size(-2) > G.size(-1) | |
| if transposed: | |
| X = X.mT | |
| # Ensure spectral norm <= 1 so the iteration converges to polar. | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
| # Simple-quintic coefficients: | |
| # p(σ) = aσ + bσ³ + cσ⁵ with σ=1 super-attracting. | |
| a, b, c = 2.0, -1.5, 0.5 | |
| for _ in range(12): | |
| A = X @ X.mT | |
| B = b * A + c * (A @ A) | |
| X = a * X + B @ X | |
| if transposed: | |
| X = X.mT | |
| return X | |
| _polar_compiled = torch.compile(_polar, dynamic=True, fullgraph=True) | |
| def polar(G: Tensor, *, compile: bool = True) -> Tensor: | |
| return _polar_compiled(G) if compile else _polar(G) | |
| class Aurora(torch.optim.Optimizer): | |
| def __init__( | |
| self, | |
| params, | |
| lr: float = 0.05, | |
| momentum: float = 0.95, | |
| weight_decay: float = 0.0, | |
| nesterov: bool = True, | |
| preconditioner_steps: int = 2, | |
| preconditioner_beta: float = 0.5, | |
| eps: float = 1e-7, | |
| compile: bool = True, | |
| ): | |
| self._validate_hparams( | |
| lr=lr, | |
| momentum=momentum, | |
| weight_decay=weight_decay, | |
| nesterov=nesterov, | |
| preconditioner_steps=preconditioner_steps, | |
| preconditioner_beta=preconditioner_beta, | |
| eps=eps, | |
| compile=compile, | |
| ) | |
| defaults = dict( | |
| lr=lr, | |
| momentum=momentum, | |
| weight_decay=weight_decay, | |
| nesterov=nesterov, | |
| preconditioner_steps=preconditioner_steps, | |
| preconditioner_beta=preconditioner_beta, | |
| eps=eps, | |
| compile=compile, | |
| ) | |
| super().__init__(params, defaults) | |
| @staticmethod | |
| def _validate_hparams( | |
| *, | |
| lr: float, | |
| momentum: float, | |
| weight_decay: float, | |
| nesterov: bool, | |
| preconditioner_steps: int, | |
| preconditioner_beta: float, | |
| eps: float, | |
| compile: bool, | |
| ) -> None: | |
| if lr <= 0.0: | |
| raise ValueError(f"lr must be positive, got {lr}") | |
| if not (0.0 < momentum < 1.0): | |
| raise ValueError(f"momentum must be in (0, 1), got {momentum}") | |
| if weight_decay < 0.0: | |
| raise ValueError(f"weight_decay must be non-negative, got {weight_decay}") | |
| if not isinstance(nesterov, bool): | |
| raise ValueError(f"nesterov must be bool, got {type(nesterov).__name__}") | |
| if not isinstance(preconditioner_steps, int) or preconditioner_steps < 1: | |
| raise ValueError( | |
| "preconditioner_steps must be an int >= 1, " | |
| f"got {preconditioner_steps}" | |
| ) | |
| if preconditioner_beta <= 0.0: | |
| raise ValueError( | |
| f"preconditioner_beta must be positive, got {preconditioner_beta}" | |
| ) | |
| if eps <= 0.0: | |
| raise ValueError(f"eps must be positive, got {eps}") | |
| if not isinstance(compile, bool): | |
| raise ValueError(f"compile must be bool, got {type(compile).__name__}") | |
| @staticmethod | |
| @torch.no_grad() | |
| def _aurora_direction( | |
| update: Tensor, | |
| *, | |
| preconditioner_steps: int, | |
| preconditioner_beta: float, | |
| eps: float, | |
| compile: bool, | |
| ) -> Tensor: | |
| m, n = update.size(-2), update.size(-1) | |
| if m == n: | |
| return polar(update, compile=compile) | |
| # For wide matrices, transpose to tall, apply, transpose back. | |
| # polar(G * D) = polar(D * G^T)^T | |
| transposed = m < n | |
| if transposed: | |
| update = update.mT | |
| m, n = n, m | |
| G32 = update.to(torch.float32) | |
| target_row_sq = n / m | |
| row_norm = G32.norm(dim=-1, keepdim=True).clamp(min=eps) | |
| D = 1.0 / row_norm | |
| for step_idx in range(preconditioner_steps): | |
| U = polar(D * G32, compile=compile) | |
| if step_idx < preconditioner_steps - 1: | |
| row_sq = ( | |
| U.to(torch.float32) | |
| .pow(2) | |
| .sum(dim=-1, keepdim=True) | |
| .clamp(min=eps * eps) | |
| ) | |
| D = D * (target_row_sq / row_sq).pow(preconditioner_beta) | |
| return U.mT if transposed else U | |
| @torch.no_grad() | |
| def step(self, closure: Optional[Callable[[], Tensor]] = None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| lr = group["lr"] | |
| momentum = group["momentum"] | |
| weight_decay = group["weight_decay"] | |
| nesterov = group["nesterov"] | |
| preconditioner_steps = group["preconditioner_steps"] | |
| preconditioner_beta = group["preconditioner_beta"] | |
| eps = group["eps"] | |
| compile = group["compile"] | |
| self._validate_hparams( | |
| lr=lr, | |
| momentum=momentum, | |
| weight_decay=weight_decay, | |
| nesterov=nesterov, | |
| preconditioner_steps=preconditioner_steps, | |
| preconditioner_beta=preconditioner_beta, | |
| eps=eps, | |
| compile=compile, | |
| ) | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| grad = p.grad.detach() | |
| if grad.is_sparse: | |
| raise RuntimeError("Aurora does not support sparse gradients.") | |
| if grad.shape != p.shape: | |
| raise ValueError( | |
| f"grad shape {tuple(grad.shape)} must match parameter shape " | |
| f"{tuple(p.shape)}" | |
| ) | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["step"] = 0 | |
| state["momentum_buffer"] = torch.zeros_like( | |
| p, | |
| memory_format=torch.preserve_format, | |
| ) | |
| state["step"] += 1 | |
| buf = state["momentum_buffer"] | |
| if buf.shape != p.shape: | |
| raise RuntimeError( | |
| f"momentum_buffer shape {tuple(buf.shape)} must match " | |
| f"parameter shape {tuple(p.shape)}" | |
| ) | |
| if buf.device != p.device or buf.dtype != p.dtype: | |
| buf = state["momentum_buffer"] = buf.to( | |
| device=p.device, | |
| dtype=p.dtype, | |
| ) | |
| buf.lerp_(grad, 1.0 - momentum) | |
| if nesterov: | |
| update = torch.lerp(grad, buf, momentum) | |
| else: | |
| update = buf.clone() | |
| update = self._aurora_direction( | |
| update, | |
| preconditioner_steps=preconditioner_steps, | |
| preconditioner_beta=preconditioner_beta, | |
| eps=eps, | |
| compile=compile, | |
| ) | |
| update.mul_(max(1.0, p.size(-2) / p.size(-1)) ** 0.5) | |
| if weight_decay != 0.0: | |
| p.mul_(1.0 - lr * weight_decay) | |
| p.add_(update, alpha=-lr) | |
| return loss | |
| ####################################################################################################################### | |
| # main training loop | |
| IS_DIST = "RANK" in os.environ | |
| RANK = int(os.environ.get("RANK", 0)) | |
| LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) | |
| IS_MASTER_PROCESS = RANK == 0 | |
| RUN_DIR = None | |
| LOG_FP = None | |
| def log(s: str): | |
| if IS_MASTER_PROCESS: | |
| print(s) | |
| if LOG_FP is not None: | |
| print(s, file=LOG_FP, flush=True) | |
| def maybe_create_profiler(): | |
| if not PROFILE_ENABLED: | |
| return contextlib.nullcontext() | |
| if RUN_DIR is None: | |
| raise RuntimeError("Profiler requires RUN_DIR to be initialized") | |
| log( | |
| "torch.profiler enabled " | |
| f"(wait={PROFILE_WAIT_STEPS}, warmup={PROFILE_WARMUP_STEPS}, active={PROFILE_ACTIVE_STEPS}, repeat={PROFILE_REPEAT}, " | |
| f"cpu={PROFILE_CPU}, cuda={PROFILE_CUDA})" | |
| ) | |
| activities = [] | |
| if PROFILE_CPU: | |
| activities.append(torch.profiler.ProfilerActivity.CPU) | |
| if PROFILE_CUDA: | |
| activities.append(torch.profiler.ProfilerActivity.CUDA) | |
| if not activities: | |
| raise ValueError("At least one profiler activity must be enabled") | |
| profile_dir = RUN_DIR / "profiles" | |
| profile_dir.mkdir(parents=True, exist_ok=True) | |
| log(f"{profile_dir=}") | |
| trace_handler = torch.profiler.tensorboard_trace_handler(str(profile_dir)) | |
| return torch.profiler.profile( | |
| activities=activities, | |
| schedule=torch.profiler.schedule( | |
| wait=PROFILE_WAIT_STEPS, | |
| warmup=PROFILE_WARMUP_STEPS, | |
| active=PROFILE_ACTIVE_STEPS, | |
| repeat=PROFILE_REPEAT, | |
| ), | |
| on_trace_ready=trace_handler, | |
| record_shapes=PROFILE_RECORD_SHAPES, | |
| profile_memory=PROFILE_PROFILE_MEMORY, | |
| with_stack=PROFILE_WITH_STACK, | |
| with_flops=PROFILE_WITH_FLOPS, | |
| ) | |
| def main(): | |
| assert torch.cuda.is_available() | |
| torch.cuda.set_device(LOCAL_RANK) | |
| if IS_DIST: | |
| dist.init_process_group(backend="nccl", device_id=LOCAL_RANK) | |
| dist.barrier() | |
| if IS_MASTER_PROCESS: | |
| wandb_run = wandb.init(config=hparams) | |
| wandb.define_metric("step") | |
| wandb.define_metric("*", step_metric="step") | |
| global RUN_DIR | |
| RUN_DIR = Path(wandb_run.dir) | |
| script_path = Path(__file__) | |
| shutil.copy(script_path, RUN_DIR / script_path.name) | |
| global LOG_FP; LOG_FP = (RUN_DIR / "log.txt").open("w") | |
| log(f"{RUN_DIR=}") | |
| log(f"Running Python {sys.version}") | |
| log(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") | |
| log("=" * 100) | |
| with maybe_create_profiler() as profiler: | |
| cfg = ModelConfig( | |
| vocab_size=VOCAB_SIZE, | |
| model_dim=MODEL_DIM, | |
| depth=DEPTH, | |
| num_heads=NUM_HEADS, | |
| key_dim=KEY_DIM, | |
| value_dim=VALUE_DIM, | |
| window_threshold=WINDOW_THRESHOLD, | |
| train_max_seq_len=TRAIN_MAX_SEQ_LEN, | |
| ) | |
| if IS_MASTER_PROCESS: | |
| (RUN_DIR / "model_config.json").write_text(json.dumps(asdict(cfg), indent=2)) | |
| model = MultiscreenLM(cfg).cuda() | |
| if dist.is_initialized(): | |
| for param in model.parameters(): | |
| dist.broadcast(param.detach(), 0) | |
| total_param_count = sum(p.numel() for p in model.parameters()) | |
| embed_param_count = model.embed.numel() | |
| non_embed_param_count = total_param_count - embed_param_count | |
| log(f"Model has {total_param_count:,} parameters ({total_param_count / 1e6:.2f}M).") | |
| if IS_MASTER_PROCESS: | |
| extra_hparams = { | |
| "total_params": total_param_count, | |
| "embed_params": embed_param_count, | |
| "non_embed_params": non_embed_param_count, | |
| } | |
| wandb.config.update(extra_hparams) | |
| matrix_param_name_suffixes = set() | |
| matrix_params = [] | |
| other_param_name_suffixes = set() | |
| other_params = [] | |
| for name, param in model.named_parameters(): | |
| name_suffix = name.split('.')[-1] | |
| if name_suffix.startswith("w_") or name_suffix == "embed": | |
| matrix_param_name_suffixes.add(name_suffix) | |
| matrix_params.append(param) | |
| else: | |
| other_param_name_suffixes.add(name_suffix) | |
| other_params.append(param) | |
| log(f"{matrix_param_name_suffixes=}") | |
| log(f"{other_param_name_suffixes=}") | |
| if MATRIX_OPTIMIZER == "Adam": | |
| matrix_optimizer = torch.optim.Adam( | |
| matrix_params, | |
| lr=MATRIX_LR, | |
| betas=(MATRIX_BETA1, MATRIX_BETA2), | |
| eps=MATRIX_EPS, | |
| weight_decay=0, | |
| ) | |
| elif MATRIX_OPTIMIZER == "Aurora": | |
| matrix_optimizer = Aurora( | |
| matrix_params, | |
| lr=MATRIX_LR, | |
| momentum=MATRIX_MOMENTUM, | |
| eps=MATRIX_EPS, | |
| ) | |
| else: | |
| raise ValueError(f"invalid MATRIX_OPTIMIZER: {MATRIX_OPTIMIZER}") | |
| other_optimizer = torch.optim.Adam( | |
| other_params, | |
| lr=OTHER_LR, | |
| betas=(OTHER_BETA1, OTHER_BETA2), | |
| eps=OTHER_EPS, | |
| ) | |
| def get_lr_scale(step: int): | |
| if step < WARMUP_STEPS: | |
| return (step + 1) / WARMUP_STEPS | |
| x = step / max(TRAIN_STEPS, 1) | |
| if x >= 1.0 - COOLDOWN_FRAC: | |
| w = (1.0 - x) / COOLDOWN_FRAC | |
| return w + (1.0 - w) * 0.1 | |
| return 1.0 | |
| initial_state = dict( | |
| model=copy.deepcopy(model.state_dict()), | |
| matrix_optimizer=copy.deepcopy(matrix_optimizer.state_dict()), | |
| other_optimizer=copy.deepcopy(other_optimizer.state_dict()), | |
| ) | |
| if KERNEL_WARMUP_TRAIN_STEPS > 0 or KERNEL_WARMUP_VAL_STEPS > 0: | |
| log( | |
| "starting kernel warmup " | |
| f"(train_steps={KERNEL_WARMUP_TRAIN_STEPS}, val_steps={KERNEL_WARMUP_VAL_STEPS})" | |
| ) | |
| train_loader = distributed_data_generator(TRAIN_FILES, TRAIN_BATCH_SIZE, TRAIN_MAX_SEQ_LEN, grad_accum_steps=GRAD_ACCUM_STEPS) | |
| model.train() | |
| if IS_MASTER_PROCESS and WANDB_WATCH: | |
| wandb.watch(model, log=WANDB_WATCH_LOG, log_freq=WANDB_WATCH_LOG_FREQ) | |
| for step in range(KERNEL_WARMUP_TRAIN_STEPS): | |
| print(f"Kernel Warmup Train {step + 1}/{KERNEL_WARMUP_TRAIN_STEPS}", end="\r\x1b[2K") | |
| loss_sum, token_count = model(*next(train_loader)) | |
| assert token_count > 0 | |
| (loss_sum / token_count).backward() | |
| matrix_optimizer.step() | |
| other_optimizer.step() | |
| matrix_optimizer.zero_grad(set_to_none=True) | |
| other_optimizer.zero_grad(set_to_none=True) | |
| del train_loader | |
| model.zero_grad(set_to_none=True) | |
| model.eval() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| val_loader = distributed_data_generator(VAL_FILES, VAL_BATCH_SIZE, -1, grad_accum_steps=GRAD_ACCUM_STEPS, align_to_bos=False) | |
| with torch.no_grad(): | |
| for val_step in range(KERNEL_WARMUP_VAL_STEPS): | |
| print(f"Kernel Warmup Val {val_step + 1}/{KERNEL_WARMUP_VAL_STEPS}", end="\r\x1b[2K") | |
| model(*next(val_loader)) | |
| del val_loader | |
| model.train() | |
| model.load_state_dict(initial_state["model"]) | |
| matrix_optimizer.load_state_dict(initial_state["matrix_optimizer"]) | |
| other_optimizer.load_state_dict(initial_state["other_optimizer"]) | |
| del initial_state | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| log("kernel warmup complete; timed training starts now") | |
| train_loader = distributed_data_generator(TRAIN_FILES, TRAIN_BATCH_SIZE, TRAIN_MAX_SEQ_LEN, grad_accum_steps=GRAD_ACCUM_STEPS) | |
| training_time = 0.0 | |
| total_tokens_seen = 0 | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| train_batch = next(train_loader) | |
| for step in range(TRAIN_STEPS + 1): | |
| last_step = step == TRAIN_STEPS | |
| if last_step or (VAL_PERIOD > 0 and step % VAL_PERIOD == 0): | |
| torch.cuda.synchronize() | |
| training_time += time.perf_counter() - t0 | |
| eval_t0 = time.perf_counter() | |
| model.eval() | |
| assert VAL_TOKENS % VAL_BATCH_SIZE == 0 | |
| val_steps = max(1, GRAD_ACCUM_STEPS * VAL_TOKENS // VAL_BATCH_SIZE) | |
| val_loader = distributed_data_generator(VAL_FILES, VAL_BATCH_SIZE, -1, grad_accum_steps=GRAD_ACCUM_STEPS, align_to_bos=False) | |
| val_loss_sum = torch.zeros((), device="cuda") | |
| val_token_count = torch.zeros((), device="cuda") | |
| with torch.no_grad(): | |
| for val_step in range(val_steps): | |
| print(f"Validation {val_step + 1}/{val_steps}", end="\r\x1b[2K") | |
| loss_sum, token_count = model(*next(val_loader)) | |
| val_loss_sum += loss_sum | |
| val_token_count += token_count | |
| assert val_token_count > 0 | |
| val_loss = val_loss_sum / val_token_count | |
| if dist.is_initialized(): | |
| dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) | |
| step_avg = training_time / max(step, 1) | |
| eval_time = time.perf_counter() - eval_t0 | |
| model.train() | |
| wandb.log({ | |
| "step": step, | |
| "train_time": training_time, | |
| "val_loss": float(val_loss.item()), | |
| "eval_time": eval_time, | |
| "total_tokens": int(total_tokens_seen), | |
| **model.window_stats(), | |
| }) | |
| log(f"step:{step}/{TRAIN_STEPS} val_loss:{val_loss.item():.4f} tks:{total_tokens_seen:.3e} train_time:{training_time:.0f}s step_avg:{step_avg:.2f}s eval_time:{eval_time:.2f}s") | |
| del val_loader, val_loss_sum, val_token_count | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| is_checkpoint_step = CHECKPOINT_PERIOD > 0 and step % CHECKPOINT_PERIOD == 0 | |
| if (last_step or is_checkpoint_step) and SAVE_CHECKPOINT: | |
| torch.cuda.synchronize() | |
| training_time += time.perf_counter() - t0 | |
| model.eval() | |
| if IS_MASTER_PROCESS: | |
| torch.save(model.state_dict(), RUN_DIR / f"model-{step:06d}.pt") | |
| torch.save(matrix_optimizer.state_dict(), RUN_DIR / f"matrix_optimizer-{step:06d}.pt") | |
| torch.save(other_optimizer.state_dict(), RUN_DIR / f"other_optimizer-{step:06d}.pt") | |
| model.train() | |
| torch.cuda.synchronize() | |
| t0 = time.perf_counter() | |
| if last_step: | |
| break | |
| step_t0 = time.perf_counter() | |
| lr_scale = get_lr_scale(step) | |
| for param_group in matrix_optimizer.param_groups: | |
| param_group["lr"] = MATRIX_LR * lr_scale | |
| for param_group in other_optimizer.param_groups: | |
| param_group["lr"] = OTHER_LR * lr_scale | |
| train_loss_sum = torch.zeros((), device="cuda") | |
| train_token_count = 0 | |
| for grad_idx in range(GRAD_ACCUM_STEPS): | |
| # with torch.profiler.record_function("model.forward"): | |
| loss_sum, token_count = model(*train_batch) | |
| # torch.cuda.synchronize() | |
| # with torch.profiler.record_function("train_batch"): | |
| train_batch = next(train_loader) | |
| # torch.cuda.synchronize() | |
| assert token_count > 0 | |
| loss = loss_sum / token_count / GRAD_ACCUM_STEPS | |
| assert torch.isfinite(loss), "loss is NaN or Inf" | |
| # with torch.profiler.record_function("loss.backward"): | |
| loss.backward() | |
| # torch.cuda.synchronize() | |
| train_loss_sum += loss_sum.detach() | |
| train_token_count += token_count | |
| total_tokens_seen += token_count | |
| # with torch.profiler.record_function("optimizer.step"): | |
| matrix_optimizer.step() | |
| other_optimizer.step() | |
| matrix_optimizer.zero_grad(set_to_none=True) | |
| other_optimizer.zero_grad(set_to_none=True) | |
| # torch.cuda.synchronize() | |
| train_loss = float((train_loss_sum / max(train_token_count, 1)).item()) | |
| approx_training_time = training_time + time.perf_counter() - t0 | |
| step_avg = approx_training_time / (step + 1) | |
| step_time = time.perf_counter() - step_t0 | |
| wandb.log({ | |
| "step": step + 1, | |
| "train_time": approx_training_time, | |
| "lr_scale": lr_scale, | |
| "train_loss": train_loss, | |
| "step_time": step_time, | |
| "avg_step_time": step_avg, | |
| "total_tokens": int(total_tokens_seen), | |
| **model.window_stats(), | |
| }) | |
| log(f"step:{step + 1}/{TRAIN_STEPS} train_loss:{train_loss:.4f} tks:{total_tokens_seen:.3e} train_time:{approx_training_time:.0f}s step_avg:{step_avg:.2f}s step_time:{step_time:.2f}s") | |
| if profiler is not None: | |
| profiler.step() | |
| log(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") | |
| wandb.finish() | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() |
Author
Author
Revision 2 fixes torch.exp(self.s_r) + 1 -> torch.sigmoid(self.s_r).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
See the initial Aurora results on 16-layer, 28M parameter Multiscreen model here: https://x.com/scottjmaddox/status/2053839590922899652?s=20