Created
May 1, 2026 17:19
-
-
Save shunting314/ba7a75d526fddece2ed7471ae1b941ed to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/fbcode/helion/helion/__init__.py b/fbcode/helion/helion/__init__.py | |
| --- a/fbcode/helion/helion/__init__.py | |
| +++ b/fbcode/helion/helion/__init__.py | |
| @@ -34,3 +34,6 @@ | |
| from ._compiler._dynamo.variables import register_dynamo_variable # noqa: E402 | |
| register_dynamo_variable() | |
| + | |
| +import torch | |
| +torch.cuda.memory._record_memory_history(max_entries=100000) | |
| diff --git a/fbcode/helion/helion/autotuner/benchmark_provider.py b/fbcode/helion/helion/autotuner/benchmark_provider.py | |
| --- a/fbcode/helion/helion/autotuner/benchmark_provider.py | |
| +++ b/fbcode/helion/helion/autotuner/benchmark_provider.py | |
| @@ -893,6 +893,8 @@ | |
| prefix=f"Generated Triton code for {decorator}:", | |
| ) | |
| self.kernel.maybe_log_repro(self.log.error, self.args, config) | |
| + if "OutOfMemoryError" in type(e).__qualname__: | |
| + breakpoint() | |
| raise exc.TritonError( | |
| error=f"{type(e).__qualname__}: {e}", | |
| decorator=decorator, | |
| diff --git a/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py b/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py | |
| --- a/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py | |
| +++ b/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py | |
| @@ -730,6 +730,7 @@ | |
| ) | |
| ) | |
| else: | |
| + latency_measure_mode = "inductor_benchmarker" | |
| bench_fn = ( | |
| partial(_do_bench_profiler, skip_cache_clearing=skip_cache_clearing) | |
| if latency_measure_mode == "profiler" | |
| diff --git a/fbcode/repro_bwd_memleak.py b/fbcode/repro_bwd_memleak.py | |
| new file mode 100644 | |
| --- /dev/null | |
| +++ b/fbcode/repro_bwd_memleak.py | |
| @@ -0,0 +1,120 @@ | |
| +""" | |
| +Repro for tritonbench backward benchmark memory leak. | |
| + | |
| +Root cause: get_bwd_fn creates closures with retain_graph=True. The retained | |
| +autograd graph creates a C++<->Python reference cycle that Python's refcount | |
| +can't break. Without gc.collect(), these closures (and their large tensors) | |
| +accumulate across input iterations. | |
| + | |
| +Usage: | |
| + python repro_bwd_memleak.py # shows the leak | |
| + python repro_bwd_memleak.py --fix # shows gc.collect() fixes it | |
| +""" | |
| + | |
| +import argparse | |
| +import gc | |
| + | |
| +import torch | |
| + | |
| + | |
| +def make_bwd_fn(A, B): | |
| + """Mirrors tritonbench's get_bwd_fn pattern.""" | |
| + grad_tensors = [t for t in [A, B] if t.requires_grad] | |
| + state = {"y": None, "dy": None} | |
| + | |
| + def fwd_fn(): | |
| + return A @ B | |
| + | |
| + def bwd_fn(): | |
| + for t in grad_tensors: | |
| + if t.grad is not None: | |
| + t.grad = None | |
| + if state["y"] is None: | |
| + state["y"] = fwd_fn() | |
| + state["dy"] = 0.1 * torch.randn_like(state["y"]) | |
| + # retain_graph=True is needed to call backward multiple times, | |
| + # but it keeps the autograd graph alive, creating a C++<->Python | |
| + # reference cycle that gc.collect() is needed to break. | |
| + state["y"].backward(state["dy"], retain_graph=True) | |
| + return grad_tensors | |
| + | |
| + return bwd_fn | |
| + | |
| + | |
| +def simulate_benchmark_run(n_input_ids: int, use_fix: bool): | |
| + # Shape mirrors a typical LLM GEMM: (8192, 4096) @ (4096, 6144) bf16 | |
| + M, K, N = 8192, 4096, 6144 | |
| + | |
| + baseline_fn = None | |
| + mem_snapshots = [] | |
| + | |
| + for input_id in range(n_input_ids): | |
| + # Each input_id gets fresh tensors (like tritonbench's get_example_inputs) | |
| + A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) | |
| + B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True) | |
| + | |
| + # Reset baseline (mirrors tritonbench line 1116) | |
| + baseline_fn = None | |
| + | |
| + # Warmup / benchmark baseline | |
| + baseline_fn = make_bwd_fn(A, B) | |
| + for _ in range(3): | |
| + baseline_fn() | |
| + | |
| + # Benchmark a non-baseline variant (same op, different impl in real code) | |
| + variant_fn = make_bwd_fn(A, B) | |
| + for _ in range(3): | |
| + variant_fn() | |
| + | |
| + # Mirrors tritonbench line 1244 | |
| + del A, B | |
| + | |
| + if use_fix: | |
| + gc.collect() | |
| + torch.cuda.empty_cache() | |
| + | |
| + allocated_gb = torch.cuda.memory_allocated() / 1e9 | |
| + mem_snapshots.append(allocated_gb) | |
| + if input_id % 10 == 0 or input_id == n_input_ids - 1: | |
| + print(f" input_id={input_id:3d} allocated={allocated_gb:.2f} GB") | |
| + | |
| + return mem_snapshots | |
| + | |
| + | |
| +def main(): | |
| + parser = argparse.ArgumentParser() | |
| + parser.add_argument("--fix", action="store_true", help="apply gc.collect() fix") | |
| + parser.add_argument("--n", type=int, default=50, help="number of input IDs") | |
| + args = parser.parse_args() | |
| + | |
| + if not torch.cuda.is_available(): | |
| + print("CUDA not available") | |
| + return | |
| + | |
| + torch.cuda.reset_peak_memory_stats() | |
| + label = "WITH fix (gc.collect)" if args.fix else "WITHOUT fix (leaking)" | |
| + print(f"\n=== {label} ===") | |
| + print(f"Running {args.n} input IDs...\n") | |
| + | |
| + snapshots = simulate_benchmark_run(args.n, use_fix=args.fix) | |
| + | |
| + peak_gb = torch.cuda.max_memory_allocated() / 1e9 | |
| + final_gb = torch.cuda.memory_allocated() / 1e9 | |
| + growth_gb = snapshots[-1] - snapshots[0] | |
| + | |
| + print(f"\nResults:") | |
| + print(f" Initial allocated : {snapshots[0]:.2f} GB") | |
| + print(f" Final allocated : {snapshots[-1]:.2f} GB") | |
| + print(f" Growth : {growth_gb:.2f} GB") | |
| + print(f" Peak : {peak_gb:.2f} GB") | |
| + | |
| + if not args.fix and growth_gb > 1.0: | |
| + print(f"\nLeak confirmed: {growth_gb:.1f} GB accumulated across {args.n} iterations.") | |
| + print("Each iteration leaves behind ~2 bwd closures with retained autograd graphs.") | |
| + print("Fix: add gc.collect() after each _do_bench call, or run with --fix.") | |
| + elif args.fix and growth_gb < 0.5: | |
| + print(f"\nFix works: memory stays flat (growth={growth_gb:.2f} GB).") | |
| + | |
| + | |
| +if __name__ == "__main__": | |
| + main() | |
| diff --git a/fbcode/snapshot-with-stack.pkl b/fbcode/snapshot-with-stack.pkl | |
| new file mode 100644 | |
| Binary file fbcode/snapshot-with-stack.pkl has changed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment