Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save shunting314/ba7a75d526fddece2ed7471ae1b941ed to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/ba7a75d526fddece2ed7471ae1b941ed to your computer and use it in GitHub Desktop.
diff --git a/fbcode/helion/helion/__init__.py b/fbcode/helion/helion/__init__.py
--- a/fbcode/helion/helion/__init__.py
+++ b/fbcode/helion/helion/__init__.py
@@ -34,3 +34,6 @@
from ._compiler._dynamo.variables import register_dynamo_variable # noqa: E402
register_dynamo_variable()
+
+import torch
+torch.cuda.memory._record_memory_history(max_entries=100000)
diff --git a/fbcode/helion/helion/autotuner/benchmark_provider.py b/fbcode/helion/helion/autotuner/benchmark_provider.py
--- a/fbcode/helion/helion/autotuner/benchmark_provider.py
+++ b/fbcode/helion/helion/autotuner/benchmark_provider.py
@@ -893,6 +893,8 @@
prefix=f"Generated Triton code for {decorator}:",
)
self.kernel.maybe_log_repro(self.log.error, self.args, config)
+ if "OutOfMemoryError" in type(e).__qualname__:
+ breakpoint()
raise exc.TritonError(
error=f"{type(e).__qualname__}: {e}",
decorator=decorator,
diff --git a/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py b/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py
--- a/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py
+++ b/fbcode/pytorch/tritonbench/tritonbench/components/do_bench/run.py
@@ -730,6 +730,7 @@
)
)
else:
+ latency_measure_mode = "inductor_benchmarker"
bench_fn = (
partial(_do_bench_profiler, skip_cache_clearing=skip_cache_clearing)
if latency_measure_mode == "profiler"
diff --git a/fbcode/repro_bwd_memleak.py b/fbcode/repro_bwd_memleak.py
new file mode 100644
--- /dev/null
+++ b/fbcode/repro_bwd_memleak.py
@@ -0,0 +1,120 @@
+"""
+Repro for tritonbench backward benchmark memory leak.
+
+Root cause: get_bwd_fn creates closures with retain_graph=True. The retained
+autograd graph creates a C++<->Python reference cycle that Python's refcount
+can't break. Without gc.collect(), these closures (and their large tensors)
+accumulate across input iterations.
+
+Usage:
+ python repro_bwd_memleak.py # shows the leak
+ python repro_bwd_memleak.py --fix # shows gc.collect() fixes it
+"""
+
+import argparse
+import gc
+
+import torch
+
+
+def make_bwd_fn(A, B):
+ """Mirrors tritonbench's get_bwd_fn pattern."""
+ grad_tensors = [t for t in [A, B] if t.requires_grad]
+ state = {"y": None, "dy": None}
+
+ def fwd_fn():
+ return A @ B
+
+ def bwd_fn():
+ for t in grad_tensors:
+ if t.grad is not None:
+ t.grad = None
+ if state["y"] is None:
+ state["y"] = fwd_fn()
+ state["dy"] = 0.1 * torch.randn_like(state["y"])
+ # retain_graph=True is needed to call backward multiple times,
+ # but it keeps the autograd graph alive, creating a C++<->Python
+ # reference cycle that gc.collect() is needed to break.
+ state["y"].backward(state["dy"], retain_graph=True)
+ return grad_tensors
+
+ return bwd_fn
+
+
+def simulate_benchmark_run(n_input_ids: int, use_fix: bool):
+ # Shape mirrors a typical LLM GEMM: (8192, 4096) @ (4096, 6144) bf16
+ M, K, N = 8192, 4096, 6144
+
+ baseline_fn = None
+ mem_snapshots = []
+
+ for input_id in range(n_input_ids):
+ # Each input_id gets fresh tensors (like tritonbench's get_example_inputs)
+ A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
+ B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True)
+
+ # Reset baseline (mirrors tritonbench line 1116)
+ baseline_fn = None
+
+ # Warmup / benchmark baseline
+ baseline_fn = make_bwd_fn(A, B)
+ for _ in range(3):
+ baseline_fn()
+
+ # Benchmark a non-baseline variant (same op, different impl in real code)
+ variant_fn = make_bwd_fn(A, B)
+ for _ in range(3):
+ variant_fn()
+
+ # Mirrors tritonbench line 1244
+ del A, B
+
+ if use_fix:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ allocated_gb = torch.cuda.memory_allocated() / 1e9
+ mem_snapshots.append(allocated_gb)
+ if input_id % 10 == 0 or input_id == n_input_ids - 1:
+ print(f" input_id={input_id:3d} allocated={allocated_gb:.2f} GB")
+
+ return mem_snapshots
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--fix", action="store_true", help="apply gc.collect() fix")
+ parser.add_argument("--n", type=int, default=50, help="number of input IDs")
+ args = parser.parse_args()
+
+ if not torch.cuda.is_available():
+ print("CUDA not available")
+ return
+
+ torch.cuda.reset_peak_memory_stats()
+ label = "WITH fix (gc.collect)" if args.fix else "WITHOUT fix (leaking)"
+ print(f"\n=== {label} ===")
+ print(f"Running {args.n} input IDs...\n")
+
+ snapshots = simulate_benchmark_run(args.n, use_fix=args.fix)
+
+ peak_gb = torch.cuda.max_memory_allocated() / 1e9
+ final_gb = torch.cuda.memory_allocated() / 1e9
+ growth_gb = snapshots[-1] - snapshots[0]
+
+ print(f"\nResults:")
+ print(f" Initial allocated : {snapshots[0]:.2f} GB")
+ print(f" Final allocated : {snapshots[-1]:.2f} GB")
+ print(f" Growth : {growth_gb:.2f} GB")
+ print(f" Peak : {peak_gb:.2f} GB")
+
+ if not args.fix and growth_gb > 1.0:
+ print(f"\nLeak confirmed: {growth_gb:.1f} GB accumulated across {args.n} iterations.")
+ print("Each iteration leaves behind ~2 bwd closures with retained autograd graphs.")
+ print("Fix: add gc.collect() after each _do_bench call, or run with --fix.")
+ elif args.fix and growth_gb < 0.5:
+ print(f"\nFix works: memory stays flat (growth={growth_gb:.2f} GB).")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fbcode/snapshot-with-stack.pkl b/fbcode/snapshot-with-stack.pkl
new file mode 100644
Binary file fbcode/snapshot-with-stack.pkl has changed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment