Skip to content

Instantly share code, notes, and snippets.

@scottjmaddox
Last active May 17, 2026 04:41
Show Gist options
  • Select an option

  • Save scottjmaddox/a0254ed46dff9e24adf10711b8dae611 to your computer and use it in GitHub Desktop.

Select an option

Save scottjmaddox/a0254ed46dff9e24adf10711b8dae611 to your computer and use it in GitHub Desktop.
Train a multiscreen model on the fineweb dataset
"""
Train a Multiscreen model [1] using the Aurora optimizer [2] on the FineWeb dataset [3]
tokenized using the GPT-2 tokenizer by Keller Jordan for the modded-nanogpt speedrun [4].
Use the `data/cached_fineweb10B.py` script from [4] to download the pre-tokenized dataset
to `datasets/fineweb10B/` before running.
References:
1. Screening is Enough. https://arxiv.org/abs/2604.01178
2. Aurora: A Leverage-Aware Optimizer for Rectangular Matrices. https://blog.tilderesearch.com/blog/aurora
3. The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale. https://arxiv.org/abs/2406.17557
4. Modded-NanoGPT. https://github.com/kellerjordan/modded-nanogpt
"""
import contextlib
import copy
import gc
import glob
import json
import math
import os
import shutil
import sys
import threading
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Callable, Optional
hparams = {}
def env(name: str, default, dtype: type | None = None, is_hparam: bool = True):
val = os.environ.get(name)
if val is None:
val = default
else:
if dtype is None:
dtype = type(default)
if dtype is bool:
val = val.lower() in ("1", "true", "yes")
elif dtype is int:
val = int(eval(val))
elif dtype is float:
val = float(eval(val))
else:
val = dtype(val)
if is_hparam:
hparams[name] = val
return val
WANDB_WATCH = env("WANDB_WATCH", False, bool)
WANDB_WATCH_LOG = env("WANDB_WATCH_LOG", "all", str)
WANDB_WATCH_LOG_FREQ = env("WANDB_WATCH_LOG_FREQ", 1000, int)
# Data
DATA_PATH = env("DATA_PATH", ".", str, is_hparam=False)
TRAIN_FILES = os.path.join(DATA_PATH, env("TRAIN_FILES", "datasets/fineweb10B/fineweb_train_*.bin", str, is_hparam=False))
VAL_FILES = os.path.join(DATA_PATH, env("VAL_FILES", "datasets/fineweb10B/fineweb_val_*.bin", str, is_hparam=False))
VAL_TOKENS = env("VAL_TOKENS", 2**24, int)
# Model architecture
DEPTH = env("DEPTH", 16, int)
MODEL_DIM = env("MODEL_DIM", DEPTH**2, int)
NUM_HEADS = env("NUM_HEADS", DEPTH, int)
KEY_DIM = env("KEY_DIM", 16, int)
VALUE_DIM = env("VALUE_DIM", 64, int)
WINDOW_THRESHOLD = env("WINDOW_THRESHOLD", 256.0, float)
VOCAB_SIZE = env("VOCAB_SIZE", 50_304, int)
# Training
TRAIN_STEPS = env("TRAIN_STEPS", 2**11, int)
TRAIN_MAX_SEQ_LEN = env("TRAIN_MAX_SEQ_LEN", 2**12, int)
TRAIN_BATCH_SIZE = env("TRAIN_BATCH_SIZE", 2**22, int)
GRAD_ACCUM_STEPS = env("GRAD_ACCUM_STEPS", 2 ** (5 + DEPTH // 8), int)
assert TRAIN_BATCH_SIZE % GRAD_ACCUM_STEPS == 0
assert (TRAIN_BATCH_SIZE // GRAD_ACCUM_STEPS) >= TRAIN_MAX_SEQ_LEN
# Validation
VAL_BATCH_SIZE = env("VAL_BATCH_SIZE", 2**22, int)
VAL_PERIOD = env("VAL_PERIOD", 20, int)
# Checkpointing
SAVE_CHECKPOINT = env("SAVE_CHECKPOINT", False, bool)
CHECKPOINT_PERIOD = env("CHECKPOINT_PERIOD", 40, int)
# Compilation and kernel configuration
COMPILE_MODEL = env("COMPILE_MODEL", True, bool)
KERNEL_WARMUP_TRAIN_STEPS = env("KERNEL_WARMUP_TRAIN_STEPS", 10, int)
KERNEL_WARMUP_VAL_STEPS = env("KERNEL_WARMUP_VAL_STEPS", 10, int)
MULTISCREEN_BLOCK_M = env("MULTISCREEN_BLOCK_M", 32, int)
MULTISCREEN_BLOCK_N = env("MULTISCREEN_BLOCK_N", 32, int)
MULTISCREEN_NUM_WARPS = env("MULTISCREEN_NUM_WARPS", 4, int)
MULTISCREEN_NUM_STAGES = env("MULTISCREEN_NUM_STAGES", 2, int)
# Optimizers
WARMUP_STEPS = env("WARMUP_STEPS", 2**10, int)
COOLDOWN_FRAC = env("COOLDOWN_FRAC", 0.0, float)
MATRIX_OPTIMIZER = env("MATRIX_OPTIMIZER", "Aurora", str)
MATRIX_LOG2_LR = env("MATRIX_LOG2_LR", -3, float)
MATRIX_LR = 2**MATRIX_LOG2_LR
if MATRIX_OPTIMIZER == "Adam":
MATRIX_BETA1 = env("MATRIX_BETA1", 0.9, float)
MATRIX_BETA2 = env("MATRIX_BETA2", 0.95, float)
elif MATRIX_OPTIMIZER == "Aurora":
MATRIX_MOMENTUM = env("MATRIX_MOMENTUM", 0.95, float)
else:
raise ValueError(f"invalid MATRIX_OPTIMIZER: {MATRIX_OPTIMIZER}")
MATRIX_EPS = env("MATRIX_EPS", 1e-8, float)
OTHER_LOG2_LR = env("OTHER_LOG2_LR", -4, float)
OTHER_LR = 2**OTHER_LOG2_LR
OTHER_BETA1 = env("OTHER_BETA1", 0.9, float)
OTHER_BETA2 = env("OTHER_BETA2", 0.95, float)
OTHER_EPS = env("OTHER_EPS", 1e-8, float)
# Profiling
PROFILE_ENABLED = env("PROFILE_ENABLED", False, bool)
PROFILE_CPU = env("PROFILE_CPU", True, bool)
PROFILE_CUDA = env("PROFILE_CUDA", True, bool)
PROFILE_RECORD_SHAPES = env("PROFILE_RECORD_SHAPES", True, bool)
PROFILE_PROFILE_MEMORY = env("PROFILE_PROFILE_MEMORY", True, bool)
PROFILE_WITH_STACK = env("PROFILE_WITH_STACK", False, bool)
PROFILE_WITH_FLOPS = env("PROFILE_WITH_FLOPS", True, bool)
PROFILE_WAIT_STEPS = env("PROFILE_WAIT_STEPS", 0, int)
PROFILE_WARMUP_STEPS = env("PROFILE_WARMUP_STEPS", 10, int)
PROFILE_ACTIVE_STEPS = env("PROFILE_ACTIVE_STEPS", 10, int)
PROFILE_REPEAT = env("PROFILE_REPEAT", 1, int)
# Example profiling usage:
#
# VAL_PERIOD=0 \
# SAVE_CHECKPOINT=0 \
# TRAIN_STEPS=111 \
# PROFILE_ENABLED=1 \
# PROFILE_WARMUP_STEPS=100 \
# PROFILE_ACTIVE_STEPS=10 \
# uv run train_multiscreen_fineweb.py
# Weights & Biases
# WANDB_ERROR_REPORTING: Set this to false to prevent wandb from logging fatal errors to its error tracking system.
# WANDB_DISABLE_GIT: Set to true to prevent wandb from probing for a git repository and capturing the latest commit / diff.
# WANDB_SAVE_CODE: Set to true to ensure the main script is saved.
# WANDB_PROJECT: Sets the default project name for your runs.
# WANDB_RUN_GROUP: Specify the experiment name to automatically group runs together.
# WANDB_NAME: Sets a custom display name for the run.
# WANDB_NOTES: Adds a detailed description or "commit message" for the run.
# os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_ERROR_REPORTING"] = "false"
os.environ["WANDB_DISABLE_GIT"] = "true"
os.environ["WANDB_SAVE_CODE"] = "true"
os.environ["WANDB_PROJECT"] = "multiscreen-fineweb"
WANDB_RUN_GROUP = "MS v4"
os.environ["WANDB_RUN_GROUP"] = WANDB_RUN_GROUP
os.environ["WANDB_NAME"] = f"{WANDB_RUN_GROUP} d{DEPTH} {MATRIX_OPTIMIZER}+embed LOG2_LRs: ({MATRIX_LOG2_LR}, {OTHER_LOG2_LR})"
#######################################################################################################################
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.distributed as dist
import torch.nn.functional as F
import triton
import triton.language as tl
from torch import Tensor, nn
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
import wandb
#######################################################################################################################
# varlen_multiscreen triton kernel by @scottjmaddox and GPT-5.4 / 5.5
def panic(msg):
raise Exception(msg)
@dataclass(frozen=True)
class VarlenMultiscreenTilingConfig:
block_m: int = MULTISCREEN_BLOCK_M
block_n: int = MULTISCREEN_BLOCK_N
@dataclass(frozen=True)
class VarlenMultiscreenLaunchConfig:
block_dk: int = KEY_DIM if KEY_DIM % 16 == 0 and KEY_DIM <= 128 else panic(f"bad KEY_DIM: {KEY_DIM}")
block_dv: int = VALUE_DIM if VALUE_DIM % 16 == 0 and VALUE_DIM <= 128 else panic(f"bad VALUE_DIM: {VALUE_DIM}")
num_warps: int = MULTISCREEN_NUM_WARPS
num_stages: int = MULTISCREEN_NUM_STAGES
@dataclass(frozen=True)
class VarlenMultiscreenKernelConfig:
tiling: VarlenMultiscreenTilingConfig = field(default_factory=VarlenMultiscreenTilingConfig)
launch: VarlenMultiscreenLaunchConfig = field(default_factory=VarlenMultiscreenLaunchConfig)
DEFAULT_VARLEN_MULTISCREEN_KERNEL_CONFIG = VarlenMultiscreenKernelConfig()
@dataclass(frozen=True)
class VarlenMultiscreenMeta:
tiling_cfg: VarlenMultiscreenTilingConfig
w: Tensor
r: Tensor
q_seg_starts: Tensor
q_seg_ends: Tensor
kv_seg_starts: Tensor
kv_seg_ends: Tensor
tile_segment_ids: Tensor
tile_q_starts: Tensor
tile_kv_start_idxs: Tensor
tile_kv_num_blocks: Tensor
num_segments: int
num_q_tiles: int
def _check_varlen_multiscreen_inputs(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
) -> None:
assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
assert q.shape[0] == k.shape[0] == v.shape[0]
assert q.shape[2] == k.shape[2]
assert k.shape[1] == v.shape[1]
assert q.device == k.device == v.device == q_pos.device == kv_pos.device == cu_seqlens_q.device == cu_seqlens_kv.device
assert q_pos.ndim == kv_pos.ndim == 1
assert cu_seqlens_q.ndim == cu_seqlens_kv.ndim == 1
assert q_pos.shape[0] == q.shape[1]
assert kv_pos.shape[0] == k.shape[1]
assert cu_seqlens_q.shape == cu_seqlens_kv.shape
assert w.ndim == 1 and r.ndim == 1
assert w.shape[0] == q.shape[0]
assert r.shape[0] == q.shape[0]
def _check_varlen_multiscreen_meta(
meta: VarlenMultiscreenMeta,
*,
device: torch.device,
num_heads: int,
) -> None:
assert meta.q_seg_starts.device == meta.q_seg_ends.device == meta.kv_seg_starts.device == meta.kv_seg_ends.device == device
assert meta.tile_segment_ids.device == meta.tile_q_starts.device == meta.tile_kv_start_idxs.device == meta.tile_kv_num_blocks.device == device
assert meta.q_seg_starts.ndim == meta.q_seg_ends.ndim == meta.kv_seg_starts.ndim == meta.kv_seg_ends.ndim == 1
assert meta.tile_segment_ids.ndim == meta.tile_q_starts.ndim == 1
assert meta.tile_kv_start_idxs.ndim == meta.tile_kv_num_blocks.ndim == 2
assert meta.q_seg_starts.shape == meta.q_seg_ends.shape == meta.kv_seg_starts.shape == meta.kv_seg_ends.shape
assert meta.tile_segment_ids.shape == meta.tile_q_starts.shape
assert meta.tile_kv_start_idxs.shape == meta.tile_kv_num_blocks.shape
assert meta.tile_kv_start_idxs.shape[0] == num_heads
assert meta.tile_kv_start_idxs.shape[1] == meta.tile_segment_ids.numel()
@torch.library.custom_op("multiscreen::varlen_multiscreen", mutates_args=())
def triton_varlen_multiscreen(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
) -> Tensor:
_check_varlen_multiscreen_inputs(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r)
cfg = DEFAULT_VARLEN_MULTISCREEN_KERNEL_CONFIG
meta = build_varlen_multiscreen_meta(cu_seqlens_q, cu_seqlens_kv, w, r, tiling_cfg=cfg.tiling)
_check_varlen_multiscreen_meta(meta, device=q.device, num_heads=q.shape[0])
out = torch.empty((q.shape[0], q.shape[1], v.shape[2]), device=v.device, dtype=v.dtype)
_launch_triton_varlen_multiscreen(q, k, v, q_pos, kv_pos, out, meta, launch_cfg=cfg.launch)
return out
@torch.library.custom_op("multiscreen::varlen_multiscreen_backward", mutates_args=())
def triton_varlen_multiscreen_backward(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
dout: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
_check_varlen_multiscreen_inputs(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r)
cfg = DEFAULT_VARLEN_MULTISCREEN_KERNEL_CONFIG
meta = build_varlen_multiscreen_meta(cu_seqlens_q, cu_seqlens_kv, w, r, tiling_cfg=cfg.tiling)
_check_varlen_multiscreen_meta(meta, device=q.device, num_heads=q.shape[0])
assert dout.shape == (q.shape[0], q.shape[1], v.shape[2])
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
dw = torch.zeros_like(w, dtype=torch.float32)
dr = torch.zeros_like(r, dtype=torch.float32)
_launch_triton_varlen_multiscreen_backward(
q, k, v, q_pos, kv_pos, dout, dq, dk, dv, dw, dr, meta, launch_cfg=cfg.launch
)
return dq, dk, dv, dw.to(w.dtype), dr.to(r.dtype)
def build_varlen_multiscreen_meta(
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
*,
tiling_cfg: VarlenMultiscreenTilingConfig,
) -> VarlenMultiscreenMeta:
q_seg_starts = cu_seqlens_q[:-1].contiguous()
q_seg_ends = cu_seqlens_q[1:].contiguous()
kv_seg_starts = cu_seqlens_kv[:-1].contiguous()
kv_seg_ends = cu_seqlens_kv[1:].contiguous()
num_segments = q_seg_starts.numel()
assert num_segments == kv_seg_starts.numel()
w_runtime = w.detach().float()
full_prefix_heads = torch.isinf(w_runtime)
finite_w = torch.where(full_prefix_heads, torch.zeros_like(w_runtime), w_runtime)
head_window_tile_counts = torch.ceil(finite_w / tiling_cfg.block_n).to(torch.int32)
kv_seg_lens = kv_seg_ends - kv_seg_starts
max_kv_tile_count = (
torch.div(kv_seg_lens.max() + (tiling_cfg.block_n - 1), tiling_cfg.block_n, rounding_mode="floor")
if kv_seg_lens.numel() > 0
else head_window_tile_counts.new_zeros(())
)
head_window_tile_counts = torch.where(
full_prefix_heads,
max_kv_tile_count.to(head_window_tile_counts.dtype),
head_window_tile_counts,
)
q_seg_lens = q_seg_ends - q_seg_starts
num_q_tiles_per_seg = torch.div(q_seg_lens + (tiling_cfg.block_m - 1), tiling_cfg.block_m, rounding_mode="floor")
tile_segment_ids = torch.repeat_interleave(
torch.arange(num_segments, device=cu_seqlens_q.device, dtype=torch.int32),
num_q_tiles_per_seg.to(torch.int64),
)
seg_tile_prefix = torch.cumsum(num_q_tiles_per_seg, dim=0) - num_q_tiles_per_seg
tile_prefix = seg_tile_prefix[tile_segment_ids.to(torch.int64)]
local_tile_idx = torch.arange(tile_segment_ids.shape[0], device=cu_seqlens_q.device, dtype=torch.int32) - tile_prefix
tile_q_starts = (q_seg_starts[tile_segment_ids.to(torch.int64)] + local_tile_idx * tiling_cfg.block_m).to(torch.int32)
idx = tile_segment_ids.to(torch.int64)
q_tile_ends = torch.minimum(tile_q_starts + tiling_cfg.block_m, q_seg_ends[idx])
q_min = tile_q_starts - q_seg_starts[idx]
q_max = q_tile_ends - q_seg_starts[idx] - 1
kv_seg_lens = kv_seg_ends[idx] - kv_seg_starts[idx]
head_window_tokens = head_window_tile_counts.unsqueeze(1) * tiling_cfg.block_n
kv_min_pos = torch.clamp(q_min.unsqueeze(0) - head_window_tokens + 1, min=0)
kv_max_pos = torch.minimum(kv_seg_lens.unsqueeze(0) - 1, q_max.unsqueeze(0))
valid = kv_max_pos >= kv_min_pos
tile_kv_start_idxs = torch.where(
valid,
torch.div(kv_min_pos, tiling_cfg.block_n, rounding_mode="floor"),
torch.zeros_like(kv_min_pos),
).to(torch.int32).contiguous()
tile_kv_end_idxs = torch.where(
valid,
torch.div(kv_max_pos, tiling_cfg.block_n, rounding_mode="floor") + 1,
torch.zeros_like(kv_max_pos),
)
tile_kv_num_blocks = (tile_kv_end_idxs - tile_kv_start_idxs).to(torch.int32).contiguous()
return VarlenMultiscreenMeta(
tiling_cfg=tiling_cfg,
w=w,
r=r,
q_seg_starts=q_seg_starts,
q_seg_ends=q_seg_ends,
kv_seg_starts=kv_seg_starts,
kv_seg_ends=kv_seg_ends,
tile_segment_ids=tile_segment_ids,
tile_q_starts=tile_q_starts,
tile_kv_start_idxs=tile_kv_start_idxs,
tile_kv_num_blocks=tile_kv_num_blocks,
num_segments=num_segments,
num_q_tiles=tile_segment_ids.numel(),
)
@triton.jit
def _varlen_multiscreen_fwd_kernel(
q_ptr,
k_ptr,
v_ptr,
q_pos_ptr,
kv_pos_ptr,
q_seg_starts_ptr,
q_seg_ends_ptr,
kv_seg_starts_ptr,
kv_seg_ends_ptr,
tile_segment_ids_ptr,
tile_q_starts_ptr,
tile_kv_start_idxs_ptr,
tile_kv_num_blocks_ptr,
w_ptr,
r_ptr,
out_ptr,
num_q_tiles,
tile_kv_start_idxs_stride_h,
tile_kv_start_idxs_stride_t,
tile_kv_num_blocks_stride_h,
tile_kv_num_blocks_stride_t,
q_stride_h,
q_stride_t,
q_stride_d,
k_stride_h,
k_stride_t,
k_stride_d,
v_stride_h,
v_stride_t,
v_stride_d,
out_stride_h,
out_stride_t,
out_stride_d,
num_segments,
num_heads,
d_k,
d_v,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DK: tl.constexpr,
BLOCK_DV: tl.constexpr,
):
PI = 3.141592653589793
pid = tl.program_id(axis=0)
hid = tl.program_id(axis=1)
if hid >= num_heads:
return
w = tl.load(w_ptr + hid).to(tl.float32)
r = tl.load(r_ptr + hid).to(tl.float32)
segment_id = tl.load(tile_segment_ids_ptr + pid)
q_start = tl.load(tile_q_starts_ptr + pid)
q_end = tl.load(q_seg_ends_ptr + segment_id)
kv_start_base = tl.load(kv_seg_starts_ptr + segment_id)
kv_end = tl.load(kv_seg_ends_ptr + segment_id)
kv_tile_start_idx = tl.load(tile_kv_start_idxs_ptr + hid * tile_kv_start_idxs_stride_h + pid * tile_kv_start_idxs_stride_t)
active_kv_tile_count = tl.load(tile_kv_num_blocks_ptr + hid * tile_kv_num_blocks_stride_h + pid * tile_kv_num_blocks_stride_t)
offs_m = q_start + tl.arange(0, BLOCK_M)
offs_dk = tl.arange(0, BLOCK_DK)
offs_dv = tl.arange(0, BLOCK_DV)
q_mask = offs_m < q_end
dk_mask = offs_dk < d_k
dv_mask = offs_dv < d_v
q_ptrs = q_ptr + hid * q_stride_h + offs_m[:, None] * q_stride_t + offs_dk[None, :] * q_stride_d
q_block = tl.load(q_ptrs, mask=q_mask[:, None] & dk_mask[None, :], other=0.0).to(tl.float32)
q_pos = tl.load(q_pos_ptr + offs_m, mask=q_mask, other=0).to(tl.float32)
acc = tl.zeros((BLOCK_M, BLOCK_DV), dtype=tl.float32)
kv_tile_offset = 0
while kv_tile_offset < active_kv_tile_count:
kv_tile_idx = kv_tile_start_idx + kv_tile_offset
kv_start = kv_start_base + kv_tile_idx * BLOCK_N
offs_n = kv_start + tl.arange(0, BLOCK_N)
kv_mask = offs_n < kv_end
k_ptrs = k_ptr + hid * k_stride_h + offs_n[None, :] * k_stride_t + offs_dk[:, None] * k_stride_d
k_block = tl.load(k_ptrs, mask=dk_mask[:, None] & kv_mask[None, :], other=0.0).to(tl.float32)
sim = tl.dot(q_block, k_block)
kv_pos = tl.load(kv_pos_ptr + offs_n, mask=kv_mask, other=0).to(tl.float32)
rel = kv_pos[None, :] - q_pos[:, None]
causal_mask = rel <= 0.0
window_mask = causal_mask & (rel > -w)
alpha = tl.maximum(1.0 - r * (1.0 - sim), 0.0)
alpha = alpha * alpha
cosine_mask = 0.5 * (tl.cos(PI * rel / w) + 1.0)
alpha = tl.where(window_mask & q_mask[:, None] & kv_mask[None, :], alpha * cosine_mask, 0.0)
v_ptrs = v_ptr + hid * v_stride_h + offs_n[:, None] * v_stride_t + offs_dv[None, :] * v_stride_d
v_block = tl.load(v_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0.0).to(tl.float32)
acc += tl.dot(alpha, v_block)
kv_tile_offset += 1
out_ptrs = out_ptr + hid * out_stride_h + offs_m[:, None] * out_stride_t + offs_dv[None, :] * out_stride_d
tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=q_mask[:, None] & dv_mask[None, :])
@triton.jit
def _varlen_multiscreen_bwd_kernel(
q_ptr,
k_ptr,
v_ptr,
q_pos_ptr,
kv_pos_ptr,
q_seg_starts_ptr,
q_seg_ends_ptr,
kv_seg_starts_ptr,
kv_seg_ends_ptr,
tile_segment_ids_ptr,
tile_q_starts_ptr,
tile_kv_start_idxs_ptr,
tile_kv_num_blocks_ptr,
w_ptr,
r_ptr,
dout_ptr,
dq_ptr,
dk_ptr,
dv_ptr,
dw_ptr,
dr_ptr,
num_q_tiles,
tile_kv_start_idxs_stride_h,
tile_kv_start_idxs_stride_t,
tile_kv_num_blocks_stride_h,
tile_kv_num_blocks_stride_t,
q_stride_h,
q_stride_t,
q_stride_d,
k_stride_h,
k_stride_t,
k_stride_d,
v_stride_h,
v_stride_t,
v_stride_d,
dout_stride_h,
dout_stride_t,
dout_stride_d,
dq_stride_h,
dq_stride_t,
dq_stride_d,
dk_stride_h,
dk_stride_t,
dk_stride_d,
dv_stride_h,
dv_stride_t,
dv_stride_d,
num_segments,
num_heads,
d_k,
d_v,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DK: tl.constexpr,
BLOCK_DV: tl.constexpr,
):
PI = 3.141592653589793
pid = tl.program_id(axis=0)
hid = tl.program_id(axis=1)
if hid >= num_heads:
return
w = tl.load(w_ptr + hid).to(tl.float32)
r = tl.load(r_ptr + hid).to(tl.float32)
segment_id = tl.load(tile_segment_ids_ptr + pid)
q_start = tl.load(tile_q_starts_ptr + pid)
q_end = tl.load(q_seg_ends_ptr + segment_id)
kv_start_base = tl.load(kv_seg_starts_ptr + segment_id)
kv_end = tl.load(kv_seg_ends_ptr + segment_id)
kv_tile_start_idx = tl.load(tile_kv_start_idxs_ptr + hid * tile_kv_start_idxs_stride_h + pid * tile_kv_start_idxs_stride_t)
active_kv_tile_count = tl.load(tile_kv_num_blocks_ptr + hid * tile_kv_num_blocks_stride_h + pid * tile_kv_num_blocks_stride_t)
offs_m = q_start + tl.arange(0, BLOCK_M)
offs_n_base = tl.arange(0, BLOCK_N)
offs_dk = tl.arange(0, BLOCK_DK)
offs_dv = tl.arange(0, BLOCK_DV)
q_mask = offs_m < q_end
dk_mask = offs_dk < d_k
dv_mask = offs_dv < d_v
q_ptrs = q_ptr + hid * q_stride_h + offs_m[:, None] * q_stride_t + offs_dk[None, :] * q_stride_d
q_block = tl.load(q_ptrs, mask=q_mask[:, None] & dk_mask[None, :], other=0.0).to(tl.float32)
q_pos = tl.load(q_pos_ptr + offs_m, mask=q_mask, other=0).to(tl.float32)
dout_ptrs = dout_ptr + hid * dout_stride_h + offs_m[:, None] * dout_stride_t + offs_dv[None, :] * dout_stride_d
dout_block = tl.load(dout_ptrs, mask=q_mask[:, None] & dv_mask[None, :], other=0.0).to(tl.float32)
dq_acc = tl.zeros((BLOCK_M, BLOCK_DK), dtype=tl.float32)
dw_acc = 0.0
dr_acc = 0.0
kv_tile_offset = 0
while kv_tile_offset < active_kv_tile_count:
kv_tile_idx = kv_tile_start_idx + kv_tile_offset
kv_start = kv_start_base + kv_tile_idx * BLOCK_N
offs_n = kv_start + offs_n_base
kv_mask = offs_n < kv_end
k_ptrs = k_ptr + hid * k_stride_h + offs_n[None, :] * k_stride_t + offs_dk[:, None] * k_stride_d
k_block = tl.load(k_ptrs, mask=dk_mask[:, None] & kv_mask[None, :], other=0.0).to(tl.float32)
v_ptrs = v_ptr + hid * v_stride_h + offs_n[:, None] * v_stride_t + offs_dv[None, :] * v_stride_d
v_block = tl.load(v_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0.0).to(tl.float32)
kv_pos = tl.load(kv_pos_ptr + offs_n, mask=kv_mask, other=0).to(tl.float32)
sim = tl.dot(q_block, k_block)
rel = kv_pos[None, :] - q_pos[:, None]
causal_mask = rel <= 0.0
window_mask = causal_mask & (rel > -w)
base = 1.0 - r * (1.0 - sim)
relu_base = tl.maximum(base, 0.0)
content = relu_base * relu_base
cosine_phase = PI * rel / w
cosine_mask = 0.5 * (tl.cos(cosine_phase) + 1.0)
active_mask = window_mask & q_mask[:, None] & kv_mask[None, :]
alpha = tl.where(active_mask, content * cosine_mask, 0.0)
grad_alpha = tl.dot(dout_block, tl.trans(v_block))
grad_content = tl.where(active_mask, grad_alpha * cosine_mask, 0.0)
positive_mask = active_mask & (base > 0.0)
dsim = tl.where(positive_mask, grad_content * (2.0 * relu_base * r), 0.0)
dq_acc += tl.dot(dsim, tl.trans(k_block))
dk_part = tl.dot(tl.trans(dsim), q_block)
dk_ptrs = dk_ptr + hid * dk_stride_h + offs_n[:, None] * dk_stride_t + offs_dk[None, :] * dk_stride_d
tl.atomic_add(dk_ptrs, dk_part, mask=kv_mask[:, None] & dk_mask[None, :])
dv_part = tl.dot(tl.trans(alpha), dout_block)
dv_ptrs = dv_ptr + hid * dv_stride_h + offs_n[:, None] * dv_stride_t + offs_dv[None, :] * dv_stride_d
tl.atomic_add(dv_ptrs, dv_part, mask=kv_mask[:, None] & dv_mask[None, :])
dcos_dw = 0.5 * tl.sin(cosine_phase) * (PI * rel / (w * w))
dw_mat = tl.where(active_mask, grad_alpha * content * dcos_dw, 0.0)
dr_mat = tl.where(positive_mask, grad_content * (2.0 * relu_base * (sim - 1.0)), 0.0)
dw_acc += tl.sum(tl.sum(dw_mat, axis=1), axis=0)
dr_acc += tl.sum(tl.sum(dr_mat, axis=1), axis=0)
kv_tile_offset += 1
dq_ptrs = dq_ptr + hid * dq_stride_h + offs_m[:, None] * dq_stride_t + offs_dk[None, :] * dq_stride_d
tl.store(dq_ptrs, dq_acc.to(dq_ptr.dtype.element_ty), mask=q_mask[:, None] & dk_mask[None, :])
tl.atomic_add(dw_ptr + hid, dw_acc)
tl.atomic_add(dr_ptr + hid, dr_acc)
def _launch_triton_varlen_multiscreen(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
out: Tensor,
meta: VarlenMultiscreenMeta,
launch_cfg: VarlenMultiscreenLaunchConfig,
) -> None:
grid = (meta.tile_segment_ids.numel(), q.shape[0])
if q.shape[2] > launch_cfg.block_dk:
raise NotImplementedError(f"d_k={q.shape[2]} exceeds block_dk={launch_cfg.block_dk}")
if v.shape[2] > launch_cfg.block_dv:
raise NotImplementedError(f"d_v={v.shape[2]} exceeds block_dv={launch_cfg.block_dv}")
_varlen_multiscreen_fwd_kernel[grid](
q, k, v, q_pos, kv_pos, meta.q_seg_starts, meta.q_seg_ends, meta.kv_seg_starts, meta.kv_seg_ends,
meta.tile_segment_ids, meta.tile_q_starts, meta.tile_kv_start_idxs, meta.tile_kv_num_blocks, meta.w, meta.r, out,
meta.num_q_tiles,
meta.tile_kv_start_idxs.stride(0), meta.tile_kv_start_idxs.stride(1),
meta.tile_kv_num_blocks.stride(0), meta.tile_kv_num_blocks.stride(1),
q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2), out.stride(0), out.stride(1), out.stride(2),
meta.num_segments, q.shape[0], q.shape[2], v.shape[2],
BLOCK_M=meta.tiling_cfg.block_m, BLOCK_N=meta.tiling_cfg.block_n,
BLOCK_DK=launch_cfg.block_dk, BLOCK_DV=launch_cfg.block_dv,
num_warps=launch_cfg.num_warps, num_stages=launch_cfg.num_stages,
)
def _launch_triton_varlen_multiscreen_backward(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
dout: Tensor,
dq: Tensor,
dk: Tensor,
dv: Tensor,
dw: Tensor,
dr: Tensor,
meta: VarlenMultiscreenMeta,
launch_cfg: VarlenMultiscreenLaunchConfig,
) -> None:
grid = (meta.tile_segment_ids.numel(), q.shape[0])
if q.shape[2] > launch_cfg.block_dk:
raise NotImplementedError(f"d_k={q.shape[2]} exceeds block_dk={launch_cfg.block_dk}")
if v.shape[2] > launch_cfg.block_dv:
raise NotImplementedError(f"d_v={v.shape[2]} exceeds block_dv={launch_cfg.block_dv}")
_varlen_multiscreen_bwd_kernel[grid](
q, k, v, q_pos, kv_pos, meta.q_seg_starts, meta.q_seg_ends, meta.kv_seg_starts, meta.kv_seg_ends,
meta.tile_segment_ids, meta.tile_q_starts, meta.tile_kv_start_idxs, meta.tile_kv_num_blocks, meta.w, meta.r,
dout, dq, dk, dv, dw, dr,
meta.num_q_tiles,
meta.tile_kv_start_idxs.stride(0), meta.tile_kv_start_idxs.stride(1),
meta.tile_kv_num_blocks.stride(0), meta.tile_kv_num_blocks.stride(1),
q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2), dout.stride(0), dout.stride(1), dout.stride(2),
dq.stride(0), dq.stride(1), dq.stride(2), dk.stride(0), dk.stride(1), dk.stride(2),
dv.stride(0), dv.stride(1), dv.stride(2), meta.num_segments, q.shape[0], q.shape[2], v.shape[2],
BLOCK_M=meta.tiling_cfg.block_m, BLOCK_N=meta.tiling_cfg.block_n,
BLOCK_DK=launch_cfg.block_dk, BLOCK_DV=launch_cfg.block_dv,
num_warps=launch_cfg.num_warps, num_stages=launch_cfg.num_stages,
)
@triton_varlen_multiscreen.register_fake
def _fake_triton_varlen_multiscreen(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
) -> Tensor:
return q.new_empty((q.shape[0], q.shape[1], v.shape[2]))
@triton_varlen_multiscreen_backward.register_fake
def _fake_triton_varlen_multiscreen_backward(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
dout: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
return (
q.new_empty(q.shape),
k.new_empty(k.shape),
v.new_empty(v.shape),
w.new_empty(w.shape),
r.new_empty(r.shape),
)
def _varlen_multiscreen_setup_context(ctx, inputs, output):
q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r = inputs
ctx.save_for_backward(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r)
def _varlen_multiscreen_backward(ctx, grad_out: Tensor):
q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r = ctx.saved_tensors
dq, dk, dv, dw, dr = triton_varlen_multiscreen_backward(
q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r, grad_out,
)
return dq, dk, dv, None, None, None, None, dw, dr
triton_varlen_multiscreen.register_autograd(
_varlen_multiscreen_backward,
setup_context=_varlen_multiscreen_setup_context,
)
def varlen_multiscreen(
q: Tensor,
k: Tensor,
v: Tensor,
q_pos: Tensor,
kv_pos: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
w: Tensor,
r: Tensor,
) -> Tensor:
return triton_varlen_multiscreen(q, k, v, q_pos, kv_pos, cu_seqlens_q, cu_seqlens_kv, w, r)
def _load_data_shard(file: Path):
header = torch.from_file(str(file), False, 256, dtype=torch.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2])
with file.open("rb", buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy())
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
return tokens
#######################################################################################################################
# multiscreen model implementation by @scottjmaddox and GPT-5.4 / 5.5
def maybe_torch_compile(**compile_kwargs):
def decorator(fn):
if COMPILE_MODEL:
return torch.compile(fn, **compile_kwargs)
return fn
return decorator
def row_normalize(x: Tensor, eps: float = 1e-8):
return x / x.norm(dim=-1, keepdim=True).clamp_min(eps)
def tanh_norm(x: Tensor, eps: float = 1e-8):
norm = x.norm(dim=-1, keepdim=True).clamp_min(eps)
return x * (torch.tanh(norm) / norm)
def apply_mipe(z: Tensor, w: Tensor, pos_ids: Tensor, window_threshold: float) -> Tensor:
pos_ids = pos_ids.to(device=z.device, dtype=z.dtype).view(1, z.shape[1], 1)
w = w.to(device=z.device, dtype=z.dtype).view(z.shape[0], 1, 1)
gamma = torch.where(
w < window_threshold,
0.5 * (torch.cos(w * (math.pi / window_threshold)) + 1.0),
torch.zeros_like(w),
)
inv_w = torch.where(torch.isinf(w), torch.zeros_like(w), torch.reciprocal(w))
phi = math.pi * pos_ids * gamma * inv_w
cos_phi = torch.cos(phi)
sin_phi = torch.sin(phi)
z0 = z[..., 0:1]
z1 = z[..., 1:2]
rotated0 = z0 * cos_phi - z1 * sin_phi
rotated1 = z0 * sin_phi + z1 * cos_phi
return torch.cat((rotated0, rotated1, z[..., 2:]), dim=-1)
def project_heads(x: Tensor, w: Tensor) -> Tensor:
out = F.linear(x, w.reshape(-1, w.shape[-1]))
return out.view(x.shape[0], w.shape[0], w.shape[1]).permute(1, 0, 2).contiguous()
class MultiscreenLayer(nn.Module):
def __init__(
self,
model_dim: int,
num_heads: int,
key_dim: int,
value_dim: int,
depth: int,
window_threshold: float,
train_max_seq_len: int,
):
super().__init__()
self.model_dim = model_dim
self.num_heads = num_heads
self.key_dim = key_dim
self.value_dim = value_dim
self.window_threshold = window_threshold
self.train_max_seq_len = train_max_seq_len
self.w_q = nn.Parameter((torch.randn(num_heads, key_dim, model_dim) * (0.1 / math.sqrt(key_dim))).bfloat16())
self.w_k = nn.Parameter((torch.randn(num_heads, key_dim, model_dim) * (0.1 / math.sqrt(key_dim))).bfloat16())
self.w_v = nn.Parameter((torch.randn(num_heads, value_dim, model_dim) * (0.1 / math.sqrt(value_dim))).bfloat16())
self.w_g = nn.Parameter((torch.randn(num_heads, value_dim, model_dim) * 0.1).bfloat16())
self.w_o = nn.Parameter((torch.randn(num_heads, model_dim, value_dim) * (0.1 / math.sqrt(model_dim))).bfloat16())
self.s_w = nn.Parameter(torch.linspace(0, math.log(window_threshold), steps=num_heads))
self.s_r = nn.Parameter(torch.zeros(num_heads))
self.s_o = nn.Parameter(torch.full((num_heads,), math.log(1.0 / math.sqrt(num_heads * depth))))
def w(self) -> Tensor:
w = torch.exp(self.s_w) + 1.0
if self.training:
return w
use_full_prefix = w.detach().float() > self.train_max_seq_len
return torch.where(use_full_prefix, torch.full_like(w, float("inf")), w)
def r(self) -> Tensor:
return torch.sigmoid(self.s_r)
def forward(
self,
x: Tensor,
pos_ids: Tensor,
cu_seqlens: Tensor,
) -> Tensor:
w = self.w()
r = self.r()
q = project_heads(x, self.w_q)
k = project_heads(x, self.w_k)
v = project_heads(x, self.w_v)
g = project_heads(x, self.w_g)
q = row_normalize(q)
k = row_normalize(k)
v = row_normalize(v)
q = apply_mipe(q, w, pos_ids, self.window_threshold)
k = apply_mipe(k, w, pos_ids, self.window_threshold)
h_screen = varlen_multiscreen(q, k, v, pos_ids, pos_ids, cu_seqlens, cu_seqlens, w, r)
h_out = tanh_norm(h_screen) * torch.tanh(F.silu(g))
outputs = torch.bmm(h_out, self.w_o.mT)
outputs = outputs * torch.exp(self.s_o).to(outputs.dtype).view(self.s_o.shape[0], 1, 1)
return outputs.sum(dim=0)
@dataclass
class ModelConfig:
vocab_size: int
model_dim: int
depth: int
num_heads: int
key_dim: int
value_dim: int
window_threshold: float
train_max_seq_len: int
class MultiscreenLM(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Parameter((torch.randn(cfg.vocab_size, cfg.model_dim) * (0.1 / math.sqrt(self.cfg.model_dim))).bfloat16())
self.layers = nn.ModuleList([
MultiscreenLayer(
model_dim=cfg.model_dim,
num_heads=cfg.num_heads,
key_dim=cfg.key_dim,
value_dim=cfg.value_dim,
depth=cfg.depth,
window_threshold=cfg.window_threshold,
train_max_seq_len=cfg.train_max_seq_len,
)
for _ in range(cfg.depth)
])
self.s_e = nn.Parameter(torch.zeros(()))
self.s_f = nn.Parameter(torch.tensor(math.log(math.sqrt(cfg.model_dim))))
def get_extra_state(self):
return {"model_config": asdict(self.cfg)}
def set_extra_state(self, state):
if not state:
return
model_config = state.get("model_config")
if model_config is not None:
self.cfg = ModelConfig(**model_config)
def window_stats(self) -> dict[str, float]:
with torch.no_grad():
windows = [layer.w().detach().float() for layer in self.layers]
all_windows = torch.cat(windows)
return {
"window_min": float(all_windows.min().item()),
"window_median": float(all_windows.median().item()),
"window_mean": float(all_windows.mean().item()),
"window_max": float(all_windows.max().item()),
}
@maybe_torch_compile(dynamic=True, fullgraph=True)
def forward(
self,
input_ids: Tensor,
target_ids: Tensor,
pos_ids: Tensor,
cu_seqlens: Tensor,
) -> tuple[Tensor, Tensor]:
embed = row_normalize(self.embed)
x = F.embedding(input_ids, embed)
x = x * torch.exp(self.s_e).to(x.dtype)
for layer in self.layers:
x = x + layer(x, pos_ids, cu_seqlens)
logits = x @ embed.mT * torch.exp(self.s_f).to(x.dtype)
loss_sum = F.cross_entropy(logits.float(), target_ids.view(-1), reduction="sum")
token_count = target_ids.numel()
return loss_sum, token_count
#######################################################################################################################
# data loader adapted from https://github.com/kellerjordan/modded-nanogpt
BOS_ID = 50256
class BOSFinder:
def __init__(self, tokens: Tensor, world_size: int = 1, quickload: bool = False):
self.tokens = tokens
self.size = tokens.numel()
self.quickload = quickload
if quickload:
self.bos_idx = (tokens[:4_000_000] == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy()
self.thread = None
self.ready = threading.Event()
self.start()
else:
self.bos_idx = (tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy()
self.i = 0
self.world_size = world_size
self.batch_iter = 0
def _load(self):
self.bos_idx_async = (self.tokens == BOS_ID).nonzero(as_tuple=True)[0].to(torch.int64).cpu().numpy()
self.ready.set()
def start(self):
self.ready.clear()
self.thread = threading.Thread(target=self._load)
self.thread.start()
def get(self):
if self.thread:
self.ready.wait()
self.thread.join()
self.bos_idx = self.bos_idx_async
def next_batch(self, num_tokens_local: int, max_seq_len: int):
if self.quickload and self.batch_iter == 5:
self.get()
n = len(self.bos_idx)
starts = [[] for _ in range(self.world_size)]
ends = [[] for _ in range(self.world_size)]
idx = self.i
for r in range(self.world_size):
cur_len = 0
while cur_len <= num_tokens_local:
if idx >= n:
raise StopIteration("Insufficient BOS ahead; hit tail of shard.")
cur = self.bos_idx[idx]
end = min(
self.bos_idx[idx + 1] if idx + 1 < n else self.size,
cur + max_seq_len,
cur + num_tokens_local - cur_len + 1,
)
starts[r].append(cur)
ends[r].append(end)
cur_len += end - cur
idx += 1
assert cur_len == num_tokens_local + 1
self.i = idx
self.batch_iter += 1
return starts, ends
class DataPreloader:
def __init__(self, file_iter, world_size: int = 1):
self.file_iter = file_iter
self.world_size = world_size
self.thread = None
self.data = None
self.ready = threading.Event()
def _load(self):
tokens = _load_data_shard(next(self.file_iter))
self.data = (tokens, BOSFinder(tokens, self.world_size))
self.ready.set()
def start(self):
self.ready.clear()
self.thread = threading.Thread(target=self._load)
self.thread.start()
def get(self):
if self.thread:
self.ready.wait()
self.thread.join()
return self.data
def distributed_data_generator(filename_pattern: str, num_tokens: int, max_seq_len: int, grad_accum_steps: int = 1, align_to_bos: bool = True):
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
assert num_tokens % (world_size * grad_accum_steps) == 0, "Batch size must be divisible by world size"
num_tokens = num_tokens // grad_accum_steps
files = [Path(file) for file in sorted(glob.glob(filename_pattern))]
if not files:
raise FileNotFoundError(f"No files found for pattern: {filename_pattern}")
file_iter = iter(files)
tokens = _load_data_shard(next(file_iter))
if align_to_bos:
finder = BOSFinder(tokens, world_size=world_size, quickload=True)
preloader = DataPreloader(file_iter, world_size)
preloader.start()
else:
pos = 0
while True:
num_tokens_local = num_tokens // world_size
if align_to_bos:
try:
seq_starts, seq_ends = finder.next_batch(num_tokens_local, max_seq_len)
start_idxs = torch.tensor(seq_starts[rank])
end_idxs = torch.tensor(seq_ends[rank])
except StopIteration:
tokens, finder = preloader.get()
preloader.start()
continue
buf = torch.cat([tokens[i:j] for i, j in zip(start_idxs, end_idxs)])
_input_ids = buf[:-1]
_target_ids = buf[1:]
end_idxs[-1] -= 1
seqends = (end_idxs - start_idxs).cumsum(0)
else:
if pos + num_tokens + 1 >= len(tokens):
tokens, pos = _load_data_shard(next(file_iter)), 0
pos_local = pos + rank * num_tokens_local
buf = tokens[pos_local: pos_local + num_tokens_local + 1]
_input_ids = buf[:-1].view(num_tokens_local)
_target_ids = buf[1:].view(num_tokens_local)
bos_positions = torch.nonzero(_input_ids == BOS_ID)[:, 0]
seqends = bos_positions[bos_positions > 0]
if seqends.numel() == 0 or seqends[-1] != _input_ids.numel():
seqends = torch.cat([seqends, seqends.new_full((1,), _input_ids.numel())])
pos += num_tokens
assert seqends.numel() >= 1
assert (seqends > 0).all()
assert seqends[-1] == _input_ids.numel()
_cu_seqlens = torch.cat([seqends.new_zeros(1), seqends])
_input_ids = _input_ids.to(dtype=torch.int32)
_target_ids = _target_ids.to(dtype=torch.int64)
_cu_seqlens = _cu_seqlens.to(torch.int32)
_pos_ids = build_pos_ids_from_cu_seqlens(_cu_seqlens)
_input_ids = _input_ids.to(device="cuda", non_blocking=True)
_target_ids = _target_ids.to(device="cuda", non_blocking=True)
_pos_ids = _pos_ids.to(device="cuda", non_blocking=True)
_cu_seqlens = _cu_seqlens.to(device="cuda", non_blocking=True)
torch._dynamo.maybe_mark_dynamic(_input_ids, 0)
torch._dynamo.maybe_mark_dynamic(_target_ids, 0)
torch._dynamo.maybe_mark_dynamic(_pos_ids, 0)
torch._dynamo.maybe_mark_dynamic(_cu_seqlens, 0)
new_params = yield (_input_ids, _target_ids, _pos_ids, _cu_seqlens)
if new_params is not None:
new_num_tokens, new_max_seq_len, new_grad_accum_steps = new_params
assert new_num_tokens % (world_size * new_grad_accum_steps) == 0
num_tokens = new_num_tokens // new_grad_accum_steps
max_seq_len = new_max_seq_len
def build_pos_ids_from_cu_seqlens(cu_seqlens: Tensor) -> Tensor:
seqstarts = cu_seqlens[:-1]
seqends = cu_seqlens[1:]
seqlengths = seqends - seqstarts
doc_starts = torch.repeat_interleave(seqstarts, seqlengths)
return torch.arange(cu_seqlens[-1], device=cu_seqlens.device, dtype=torch.int32) - doc_starts
#######################################################################################################################
# Aurora optimizer, adapted from https://github.com/tilde-research/aurora-release
@torch.no_grad()
def _polar(G: Tensor) -> Tensor:
"""Polar factor via 12-step simple-quintic Newton-Schulz.
Args:
G: input matrix of shape [..., m, n].
Returns:
polar(G) of the same shape, in bfloat16. All non-zero singular values
of G are mapped to 1.
"""
if G.ndim < 2:
raise ValueError(f"polar expects tensors with ndim >= 2, got ndim={G.ndim}")
X = G.bfloat16()
transposed = G.size(-2) > G.size(-1)
if transposed:
X = X.mT
# Ensure spectral norm <= 1 so the iteration converges to polar.
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Simple-quintic coefficients:
# p(σ) = aσ + bσ³ + cσ⁵ with σ=1 super-attracting.
a, b, c = 2.0, -1.5, 0.5
for _ in range(12):
A = X @ X.mT
B = b * A + c * (A @ A)
X = a * X + B @ X
if transposed:
X = X.mT
return X
_polar_compiled = torch.compile(_polar, dynamic=True, fullgraph=True)
def polar(G: Tensor, *, compile: bool = True) -> Tensor:
return _polar_compiled(G) if compile else _polar(G)
class Aurora(torch.optim.Optimizer):
def __init__(
self,
params,
lr: float = 0.05,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
preconditioner_steps: int = 2,
preconditioner_beta: float = 0.5,
eps: float = 1e-7,
compile: bool = True,
):
self._validate_hparams(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
preconditioner_steps=preconditioner_steps,
preconditioner_beta=preconditioner_beta,
eps=eps,
compile=compile,
)
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
preconditioner_steps=preconditioner_steps,
preconditioner_beta=preconditioner_beta,
eps=eps,
compile=compile,
)
super().__init__(params, defaults)
@staticmethod
def _validate_hparams(
*,
lr: float,
momentum: float,
weight_decay: float,
nesterov: bool,
preconditioner_steps: int,
preconditioner_beta: float,
eps: float,
compile: bool,
) -> None:
if lr <= 0.0:
raise ValueError(f"lr must be positive, got {lr}")
if not (0.0 < momentum < 1.0):
raise ValueError(f"momentum must be in (0, 1), got {momentum}")
if weight_decay < 0.0:
raise ValueError(f"weight_decay must be non-negative, got {weight_decay}")
if not isinstance(nesterov, bool):
raise ValueError(f"nesterov must be bool, got {type(nesterov).__name__}")
if not isinstance(preconditioner_steps, int) or preconditioner_steps < 1:
raise ValueError(
"preconditioner_steps must be an int >= 1, "
f"got {preconditioner_steps}"
)
if preconditioner_beta <= 0.0:
raise ValueError(
f"preconditioner_beta must be positive, got {preconditioner_beta}"
)
if eps <= 0.0:
raise ValueError(f"eps must be positive, got {eps}")
if not isinstance(compile, bool):
raise ValueError(f"compile must be bool, got {type(compile).__name__}")
@staticmethod
@torch.no_grad()
def _aurora_direction(
update: Tensor,
*,
preconditioner_steps: int,
preconditioner_beta: float,
eps: float,
compile: bool,
) -> Tensor:
m, n = update.size(-2), update.size(-1)
if m == n:
return polar(update, compile=compile)
# For wide matrices, transpose to tall, apply, transpose back.
# polar(G * D) = polar(D * G^T)^T
transposed = m < n
if transposed:
update = update.mT
m, n = n, m
G32 = update.to(torch.float32)
target_row_sq = n / m
row_norm = G32.norm(dim=-1, keepdim=True).clamp(min=eps)
D = 1.0 / row_norm
for step_idx in range(preconditioner_steps):
U = polar(D * G32, compile=compile)
if step_idx < preconditioner_steps - 1:
row_sq = (
U.to(torch.float32)
.pow(2)
.sum(dim=-1, keepdim=True)
.clamp(min=eps * eps)
)
D = D * (target_row_sq / row_sq).pow(preconditioner_beta)
return U.mT if transposed else U
@torch.no_grad()
def step(self, closure: Optional[Callable[[], Tensor]] = None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
weight_decay = group["weight_decay"]
nesterov = group["nesterov"]
preconditioner_steps = group["preconditioner_steps"]
preconditioner_beta = group["preconditioner_beta"]
eps = group["eps"]
compile = group["compile"]
self._validate_hparams(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
preconditioner_steps=preconditioner_steps,
preconditioner_beta=preconditioner_beta,
eps=eps,
compile=compile,
)
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.detach()
if grad.is_sparse:
raise RuntimeError("Aurora does not support sparse gradients.")
if grad.shape != p.shape:
raise ValueError(
f"grad shape {tuple(grad.shape)} must match parameter shape "
f"{tuple(p.shape)}"
)
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["momentum_buffer"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
)
state["step"] += 1
buf = state["momentum_buffer"]
if buf.shape != p.shape:
raise RuntimeError(
f"momentum_buffer shape {tuple(buf.shape)} must match "
f"parameter shape {tuple(p.shape)}"
)
if buf.device != p.device or buf.dtype != p.dtype:
buf = state["momentum_buffer"] = buf.to(
device=p.device,
dtype=p.dtype,
)
buf.lerp_(grad, 1.0 - momentum)
if nesterov:
update = torch.lerp(grad, buf, momentum)
else:
update = buf.clone()
update = self._aurora_direction(
update,
preconditioner_steps=preconditioner_steps,
preconditioner_beta=preconditioner_beta,
eps=eps,
compile=compile,
)
update.mul_(max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
if weight_decay != 0.0:
p.mul_(1.0 - lr * weight_decay)
p.add_(update, alpha=-lr)
return loss
#######################################################################################################################
# main training loop
IS_DIST = "RANK" in os.environ
RANK = int(os.environ.get("RANK", 0))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
IS_MASTER_PROCESS = RANK == 0
RUN_DIR = None
LOG_FP = None
def log(s: str):
if IS_MASTER_PROCESS:
print(s)
if LOG_FP is not None:
print(s, file=LOG_FP, flush=True)
def maybe_create_profiler():
if not PROFILE_ENABLED:
return contextlib.nullcontext()
if RUN_DIR is None:
raise RuntimeError("Profiler requires RUN_DIR to be initialized")
log(
"torch.profiler enabled "
f"(wait={PROFILE_WAIT_STEPS}, warmup={PROFILE_WARMUP_STEPS}, active={PROFILE_ACTIVE_STEPS}, repeat={PROFILE_REPEAT}, "
f"cpu={PROFILE_CPU}, cuda={PROFILE_CUDA})"
)
activities = []
if PROFILE_CPU:
activities.append(torch.profiler.ProfilerActivity.CPU)
if PROFILE_CUDA:
activities.append(torch.profiler.ProfilerActivity.CUDA)
if not activities:
raise ValueError("At least one profiler activity must be enabled")
profile_dir = RUN_DIR / "profiles"
profile_dir.mkdir(parents=True, exist_ok=True)
log(f"{profile_dir=}")
trace_handler = torch.profiler.tensorboard_trace_handler(str(profile_dir))
return torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=PROFILE_WAIT_STEPS,
warmup=PROFILE_WARMUP_STEPS,
active=PROFILE_ACTIVE_STEPS,
repeat=PROFILE_REPEAT,
),
on_trace_ready=trace_handler,
record_shapes=PROFILE_RECORD_SHAPES,
profile_memory=PROFILE_PROFILE_MEMORY,
with_stack=PROFILE_WITH_STACK,
with_flops=PROFILE_WITH_FLOPS,
)
def main():
assert torch.cuda.is_available()
torch.cuda.set_device(LOCAL_RANK)
if IS_DIST:
dist.init_process_group(backend="nccl", device_id=LOCAL_RANK)
dist.barrier()
if IS_MASTER_PROCESS:
wandb_run = wandb.init(config=hparams)
wandb.define_metric("step")
wandb.define_metric("*", step_metric="step")
global RUN_DIR
RUN_DIR = Path(wandb_run.dir)
script_path = Path(__file__)
shutil.copy(script_path, RUN_DIR / script_path.name)
global LOG_FP; LOG_FP = (RUN_DIR / "log.txt").open("w")
log(f"{RUN_DIR=}")
log(f"Running Python {sys.version}")
log(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
log("=" * 100)
with maybe_create_profiler() as profiler:
cfg = ModelConfig(
vocab_size=VOCAB_SIZE,
model_dim=MODEL_DIM,
depth=DEPTH,
num_heads=NUM_HEADS,
key_dim=KEY_DIM,
value_dim=VALUE_DIM,
window_threshold=WINDOW_THRESHOLD,
train_max_seq_len=TRAIN_MAX_SEQ_LEN,
)
if IS_MASTER_PROCESS:
(RUN_DIR / "model_config.json").write_text(json.dumps(asdict(cfg), indent=2))
model = MultiscreenLM(cfg).cuda()
if dist.is_initialized():
for param in model.parameters():
dist.broadcast(param.detach(), 0)
total_param_count = sum(p.numel() for p in model.parameters())
embed_param_count = model.embed.numel()
non_embed_param_count = total_param_count - embed_param_count
log(f"Model has {total_param_count:,} parameters ({total_param_count / 1e6:.2f}M).")
if IS_MASTER_PROCESS:
extra_hparams = {
"total_params": total_param_count,
"embed_params": embed_param_count,
"non_embed_params": non_embed_param_count,
}
wandb.config.update(extra_hparams)
matrix_param_name_suffixes = set()
matrix_params = []
other_param_name_suffixes = set()
other_params = []
for name, param in model.named_parameters():
name_suffix = name.split('.')[-1]
if name_suffix.startswith("w_") or name_suffix == "embed":
matrix_param_name_suffixes.add(name_suffix)
matrix_params.append(param)
else:
other_param_name_suffixes.add(name_suffix)
other_params.append(param)
log(f"{matrix_param_name_suffixes=}")
log(f"{other_param_name_suffixes=}")
if MATRIX_OPTIMIZER == "Adam":
matrix_optimizer = torch.optim.Adam(
matrix_params,
lr=MATRIX_LR,
betas=(MATRIX_BETA1, MATRIX_BETA2),
eps=MATRIX_EPS,
weight_decay=0,
)
elif MATRIX_OPTIMIZER == "Aurora":
matrix_optimizer = Aurora(
matrix_params,
lr=MATRIX_LR,
momentum=MATRIX_MOMENTUM,
eps=MATRIX_EPS,
)
else:
raise ValueError(f"invalid MATRIX_OPTIMIZER: {MATRIX_OPTIMIZER}")
other_optimizer = torch.optim.Adam(
other_params,
lr=OTHER_LR,
betas=(OTHER_BETA1, OTHER_BETA2),
eps=OTHER_EPS,
)
def get_lr_scale(step: int):
if step < WARMUP_STEPS:
return (step + 1) / WARMUP_STEPS
x = step / max(TRAIN_STEPS, 1)
if x >= 1.0 - COOLDOWN_FRAC:
w = (1.0 - x) / COOLDOWN_FRAC
return w + (1.0 - w) * 0.1
return 1.0
initial_state = dict(
model=copy.deepcopy(model.state_dict()),
matrix_optimizer=copy.deepcopy(matrix_optimizer.state_dict()),
other_optimizer=copy.deepcopy(other_optimizer.state_dict()),
)
if KERNEL_WARMUP_TRAIN_STEPS > 0 or KERNEL_WARMUP_VAL_STEPS > 0:
log(
"starting kernel warmup "
f"(train_steps={KERNEL_WARMUP_TRAIN_STEPS}, val_steps={KERNEL_WARMUP_VAL_STEPS})"
)
train_loader = distributed_data_generator(TRAIN_FILES, TRAIN_BATCH_SIZE, TRAIN_MAX_SEQ_LEN, grad_accum_steps=GRAD_ACCUM_STEPS)
model.train()
if IS_MASTER_PROCESS and WANDB_WATCH:
wandb.watch(model, log=WANDB_WATCH_LOG, log_freq=WANDB_WATCH_LOG_FREQ)
for step in range(KERNEL_WARMUP_TRAIN_STEPS):
print(f"Kernel Warmup Train {step + 1}/{KERNEL_WARMUP_TRAIN_STEPS}", end="\r\x1b[2K")
loss_sum, token_count = model(*next(train_loader))
assert token_count > 0
(loss_sum / token_count).backward()
matrix_optimizer.step()
other_optimizer.step()
matrix_optimizer.zero_grad(set_to_none=True)
other_optimizer.zero_grad(set_to_none=True)
del train_loader
model.zero_grad(set_to_none=True)
model.eval()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
val_loader = distributed_data_generator(VAL_FILES, VAL_BATCH_SIZE, -1, grad_accum_steps=GRAD_ACCUM_STEPS, align_to_bos=False)
with torch.no_grad():
for val_step in range(KERNEL_WARMUP_VAL_STEPS):
print(f"Kernel Warmup Val {val_step + 1}/{KERNEL_WARMUP_VAL_STEPS}", end="\r\x1b[2K")
model(*next(val_loader))
del val_loader
model.train()
model.load_state_dict(initial_state["model"])
matrix_optimizer.load_state_dict(initial_state["matrix_optimizer"])
other_optimizer.load_state_dict(initial_state["other_optimizer"])
del initial_state
gc.collect()
torch.cuda.synchronize()
log("kernel warmup complete; timed training starts now")
train_loader = distributed_data_generator(TRAIN_FILES, TRAIN_BATCH_SIZE, TRAIN_MAX_SEQ_LEN, grad_accum_steps=GRAD_ACCUM_STEPS)
training_time = 0.0
total_tokens_seen = 0
torch.cuda.synchronize()
t0 = time.perf_counter()
train_batch = next(train_loader)
for step in range(TRAIN_STEPS + 1):
last_step = step == TRAIN_STEPS
if last_step or (VAL_PERIOD > 0 and step % VAL_PERIOD == 0):
torch.cuda.synchronize()
training_time += time.perf_counter() - t0
eval_t0 = time.perf_counter()
model.eval()
assert VAL_TOKENS % VAL_BATCH_SIZE == 0
val_steps = max(1, GRAD_ACCUM_STEPS * VAL_TOKENS // VAL_BATCH_SIZE)
val_loader = distributed_data_generator(VAL_FILES, VAL_BATCH_SIZE, -1, grad_accum_steps=GRAD_ACCUM_STEPS, align_to_bos=False)
val_loss_sum = torch.zeros((), device="cuda")
val_token_count = torch.zeros((), device="cuda")
with torch.no_grad():
for val_step in range(val_steps):
print(f"Validation {val_step + 1}/{val_steps}", end="\r\x1b[2K")
loss_sum, token_count = model(*next(val_loader))
val_loss_sum += loss_sum
val_token_count += token_count
assert val_token_count > 0
val_loss = val_loss_sum / val_token_count
if dist.is_initialized():
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
step_avg = training_time / max(step, 1)
eval_time = time.perf_counter() - eval_t0
model.train()
wandb.log({
"step": step,
"train_time": training_time,
"val_loss": float(val_loss.item()),
"eval_time": eval_time,
"total_tokens": int(total_tokens_seen),
**model.window_stats(),
})
log(f"step:{step}/{TRAIN_STEPS} val_loss:{val_loss.item():.4f} tks:{total_tokens_seen:.3e} train_time:{training_time:.0f}s step_avg:{step_avg:.2f}s eval_time:{eval_time:.2f}s")
del val_loader, val_loss_sum, val_token_count
torch.cuda.synchronize()
t0 = time.perf_counter()
is_checkpoint_step = CHECKPOINT_PERIOD > 0 and step % CHECKPOINT_PERIOD == 0
if (last_step or is_checkpoint_step) and SAVE_CHECKPOINT:
torch.cuda.synchronize()
training_time += time.perf_counter() - t0
model.eval()
if IS_MASTER_PROCESS:
torch.save(model.state_dict(), RUN_DIR / f"model-{step:06d}.pt")
torch.save(matrix_optimizer.state_dict(), RUN_DIR / f"matrix_optimizer-{step:06d}.pt")
torch.save(other_optimizer.state_dict(), RUN_DIR / f"other_optimizer-{step:06d}.pt")
model.train()
torch.cuda.synchronize()
t0 = time.perf_counter()
if last_step:
break
step_t0 = time.perf_counter()
lr_scale = get_lr_scale(step)
for param_group in matrix_optimizer.param_groups:
param_group["lr"] = MATRIX_LR * lr_scale
for param_group in other_optimizer.param_groups:
param_group["lr"] = OTHER_LR * lr_scale
train_loss_sum = torch.zeros((), device="cuda")
train_token_count = 0
for grad_idx in range(GRAD_ACCUM_STEPS):
# with torch.profiler.record_function("model.forward"):
loss_sum, token_count = model(*train_batch)
# torch.cuda.synchronize()
# with torch.profiler.record_function("train_batch"):
train_batch = next(train_loader)
# torch.cuda.synchronize()
assert token_count > 0
loss = loss_sum / token_count / GRAD_ACCUM_STEPS
assert torch.isfinite(loss), "loss is NaN or Inf"
# with torch.profiler.record_function("loss.backward"):
loss.backward()
# torch.cuda.synchronize()
train_loss_sum += loss_sum.detach()
train_token_count += token_count
total_tokens_seen += token_count
# with torch.profiler.record_function("optimizer.step"):
matrix_optimizer.step()
other_optimizer.step()
matrix_optimizer.zero_grad(set_to_none=True)
other_optimizer.zero_grad(set_to_none=True)
# torch.cuda.synchronize()
train_loss = float((train_loss_sum / max(train_token_count, 1)).item())
approx_training_time = training_time + time.perf_counter() - t0
step_avg = approx_training_time / (step + 1)
step_time = time.perf_counter() - step_t0
wandb.log({
"step": step + 1,
"train_time": approx_training_time,
"lr_scale": lr_scale,
"train_loss": train_loss,
"step_time": step_time,
"avg_step_time": step_avg,
"total_tokens": int(total_tokens_seen),
**model.window_stats(),
})
log(f"step:{step + 1}/{TRAIN_STEPS} train_loss:{train_loss:.4f} tks:{total_tokens_seen:.3e} train_time:{approx_training_time:.0f}s step_avg:{step_avg:.2f}s step_time:{step_time:.2f}s")
if profiler is not None:
profiler.step()
log(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB")
wandb.finish()
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
@scottjmaddox
Copy link
Copy Markdown
Author

See the initial Aurora results on 16-layer, 28M parameter Multiscreen model here: https://x.com/scottjmaddox/status/2053839590922899652?s=20

@scottjmaddox
Copy link
Copy Markdown
Author

Revision 2 fixes torch.exp(self.s_r) + 1 -> torch.sigmoid(self.s_r).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment