Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created April 30, 2026 17:29
Show Gist options
  • Select an option

  • Save shunting314/cf5a06a3b92c9b629ce885b491571e7a to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/cf5a06a3b92c9b629ce885b491571e7a to your computer and use it in GitHub Desktop.
--- /tmp/inductor.py 2026-04-30 10:22:24.247513427 -0700
+++ /tmp/helion.py 2026-04-30 10:22:40.475653971 -0700
@@ -1,8 +1,8 @@
class GraphModule(torch.nn.Module):
- def forward(self, L_x_ads: "f32[4096, 523, 128][66944, 128, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_x_user: "f32[4096, 759, 128][97152, 128, 1]cuda:0", L_x_ads_to_user_map: "i64[4096][1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_: "f32[32, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_: "f32[16, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[512, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[256, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[512, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_: "f32[41024, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_: "f32[20512, 256][256, 1]cuda:0", L_dense_proj_user: "f32[4096, 4096][4096, 1]cuda:0", L_dense_proj_ads: "f32[4096, 4096][4096, 1]cuda:0", L_self_modules_ln1_modules_ln_user_parameters_weight_: "f32[16240][1]cuda:0", L_self_modules_ln1_modules_ln_ads_parameters_weight_: "f32[24608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[9216, 16240][16240, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[4608, 24608][24608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[4608, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[2304, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_: "f32[9216, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_: "f32[4608, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_p
+ def forward(self, L_x_ads: "f32[4096, 523, 128][66944, 128, 1]cuda:0", L_x_user: "f32[4096, 759, 128][97152, 128, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_x_ads_to_user_map: "i64[4096][1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_: "f32[32, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_: "f32[16, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[512, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[256, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[512, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_: "f32[41024, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_: "f32[20512, 256][256, 1]cuda:0", L_dense_proj_user: "f32[4096, 4096][4096, 1]cuda:0", L_dense_proj_ads: "f32[4096, 4096][4096, 1]cuda:0", L_self_modules_ln1_modules_ln_user_parameters_weight_: "f32[16240][1]cuda:0", L_self_modules_ln1_modules_ln_ads_parameters_weight_: "f32[24608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[9216, 16240][16240, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[4608, 24608][24608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[4608, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[2304, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_: "f32[9216, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_: "f32[4608, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_p
l_x_ads = L_x_ads
- l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_
l_x_user = L_x_user
+ l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_
l_x_ads_to_user_map = L_x_ads_to_user_map
l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_
l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_
@@ -50,430 +50,430 @@
l_self_modules_ln2_modules_ln_user_parameters_weight_ = L_self_modules_ln2_modules_ln_user_parameters_weight_
l_self_modules_ln2_modules_ln_ads_parameters_weight_ = L_self_modules_ln2_modules_ln_ads_parameters_weight_
- matmul: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ @ l_x_user; l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = None
+ helion_transposed_matmul_default: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(l_x_user, l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = None
- res_user: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul[(slice(None, None, None), slice(None, 96, None))]
+ res_user: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default[(slice(None, None, None), slice(None, 96, None))]
- matmul_1: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ = None
+ matmul: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ = None
- getitem_1: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul[(slice(None, None, None), slice(96, None, None))]; matmul = None
+ getitem_1: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default[(slice(None, None, None), slice(96, None, None))]; helion_transposed_matmul_default = None
- index_select: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_1, dim = 0, index = l_x_ads_to_user_map); getitem_1 = None
+ index_select: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_1, dim = 0, index = l_x_ads_to_user_map); getitem_1 = None
- res_ads: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_1 + index_select; matmul_1 = index_select = None
+ res_ads: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = matmul + index_select; matmul = index_select = None
- matmul_2: "bf16[4096, 32, 128][4096, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ @ l_x_user; l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ = None
+ helion_transposed_matmul_default_1: "f32[4096, 32, 128][4096, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(l_x_user, l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ = None
- res_user_1: "bf16[4096, 16, 128][4096, 128, 1]cuda:0" = matmul_2[(slice(None, None, None), slice(None, 16, None))]
+ res_user_1: "f32[4096, 16, 128][4096, 128, 1]cuda:0" = helion_transposed_matmul_default_1[(slice(None, None, None), slice(None, 16, None))]
- matmul_3: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ = None
+ matmul_1: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ = None
- getitem_3: "bf16[4096, 16, 128][4096, 128, 1]cuda:0" = matmul_2[(slice(None, None, None), slice(16, None, None))]; matmul_2 = None
+ getitem_3: "f32[4096, 16, 128][4096, 128, 1]cuda:0" = helion_transposed_matmul_default_1[(slice(None, None, None), slice(16, None, None))]; helion_transposed_matmul_default_1 = None
- index_select_1: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = torch.index_select(getitem_3, dim = 0, index = l_x_ads_to_user_map); getitem_3 = None
+ index_select_1: "f32[4096, 16, 128][2048, 128, 1]cuda:0" = torch.index_select(getitem_3, dim = 0, index = l_x_ads_to_user_map); getitem_3 = None
- res_ads_1: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = matmul_3 + index_select_1; matmul_3 = index_select_1 = None
+ res_ads_1: "f32[4096, 16, 128][2048, 128, 1]cuda:0" = matmul_1 + index_select_1; matmul_1 = index_select_1 = None
- matmul_4: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_ @ l_x_user; l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_ = None
+ helion_transposed_matmul_default_2: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(l_x_user, l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_ = None
- res_user_2: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_4[(slice(None, None, None), slice(None, 96, None))]
+ res_user_2: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_2[(slice(None, None, None), slice(None, 96, None))]
- matmul_5: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ = None
+ matmul_2: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ = None
- getitem_5: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_4[(slice(None, None, None), slice(96, None, None))]; matmul_4 = None
+ getitem_5: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_2[(slice(None, None, None), slice(96, None, None))]; helion_transposed_matmul_default_2 = None
- index_select_2: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_5, dim = 0, index = l_x_ads_to_user_map); getitem_5 = None
+ index_select_2: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_5, dim = 0, index = l_x_ads_to_user_map); getitem_5 = None
- res_ads_2: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_5 + index_select_2; matmul_5 = index_select_2 = None
+ res_ads_2: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_2 + index_select_2; matmul_2 = index_select_2 = None
- view: "bf16[4096, 2048][4096, 1]cuda:0" = res_user_1.view(4096, -1); res_user_1 = None
+ view: "f32[4096, 2048][4096, 1]cuda:0" = res_user_1.view(4096, -1); res_user_1 = None
- view_1: "bf16[4096, 2048][2048, 1]cuda:0" = res_ads_1.view(4096, -1); res_ads_1 = None
+ view_1: "f32[4096, 2048][2048, 1]cuda:0" = res_ads_1.view(4096, -1); res_ads_1 = None
linear: "bf16[4096, 512][512, 1]cuda:0" = torch._C._nn.linear(view, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_, None); view = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_ = None
res_user_3: "bf16[4096, 256][512, 1]cuda:0" = linear[(slice(None, None, None), slice(None, 256, None))]
linear_1: "bf16[4096, 256][256, 1]cuda:0" = torch._C._nn.linear(view_1, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_, None); view_1 = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_ = None
getitem_7: "bf16[4096, 256][512, 1]cuda:0" = linear[(slice(None, None, None), slice(256, None, None))]; linear = None
index_select_3: "bf16[4096, 256][256, 1]cuda:0" = torch.index_select(getitem_7, dim = 0, index = l_x_ads_to_user_map); getitem_7 = None
res_ads_3: "bf16[4096, 256][256, 1]cuda:0" = linear_1 + index_select_3; linear_1 = index_select_3 = None
output: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_user_3, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_3 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_1: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_ads_3, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_3 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output)
mul: "f32[4096, 256][256, 1]cuda:0" = output * sigmoid; output = sigmoid = None
sigmoid_1: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output_1)
mul_1: "f32[4096, 256][256, 1]cuda:0" = output_1 * sigmoid_1; output_1 = sigmoid_1 = None
linear_2: "bf16[4096, 512][512, 1]cuda:0" = torch._C._nn.linear(mul, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_, None); mul = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_ = None
res_user_4: "bf16[4096, 256][512, 1]cuda:0" = linear_2[(slice(None, None, None), slice(None, 256, None))]
linear_3: "bf16[4096, 256][256, 1]cuda:0" = torch._C._nn.linear(mul_1, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_, None); mul_1 = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_ = None
getitem_9: "bf16[4096, 256][512, 1]cuda:0" = linear_2[(slice(None, None, None), slice(256, None, None))]; linear_2 = None
index_select_4: "bf16[4096, 256][256, 1]cuda:0" = torch.index_select(getitem_9, dim = 0, index = l_x_ads_to_user_map); getitem_9 = None
res_ads_4: "bf16[4096, 256][256, 1]cuda:0" = linear_3 + index_select_4; linear_3 = index_select_4 = None
output_2: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_user_4, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_4 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_3: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_ads_4, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_4 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid_2: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output_2)
mul_2: "f32[4096, 256][256, 1]cuda:0" = output_2 * sigmoid_2; output_2 = sigmoid_2 = None
sigmoid_3: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output_3)
mul_3: "f32[4096, 256][256, 1]cuda:0" = output_3 * sigmoid_3; output_3 = sigmoid_3 = None
linear_4: "bf16[4096, 41024][41024, 1]cuda:0" = torch._C._nn.linear(mul_2, l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_, None); mul_2 = l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_ = None
res_user_5: "bf16[4096, 20512][41024, 1]cuda:0" = linear_4[(slice(None, None, None), slice(None, 20512, None))]
linear_5: "bf16[4096, 20512][20512, 1]cuda:0" = torch._C._nn.linear(mul_3, l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_, None); mul_3 = l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_ = None
getitem_11: "bf16[4096, 20512][41024, 1]cuda:0" = linear_4[(slice(None, None, None), slice(20512, None, None))]; linear_4 = None
index_select_5: "bf16[4096, 20512][20512, 1]cuda:0" = torch.index_select(getitem_11, dim = 0, index = l_x_ads_to_user_map); getitem_11 = None
res_ads_5: "bf16[4096, 20512][20512, 1]cuda:0" = linear_5 + index_select_5; linear_5 = index_select_5 = None
view_2: "bf16[4096, 16, 1282][41024, 1282, 1]cuda:0" = res_user_5.view(4096, -1, 1282); res_user_5 = None
view_3: "bf16[4096, 16, 1282][20512, 1282, 1]cuda:0" = res_ads_5.view(4096, -1, 1282); res_ads_5 = None
index_select_6: "f32[4096, 759, 128][97152, 128, 1]cuda:0" = torch.index_select(l_x_user, dim = 0, index = l_x_ads_to_user_map)
cat: "f32[4096, 1282, 128][164096, 128, 1]cuda:0" = torch.cat([index_select_6, l_x_ads], dim = 1); index_select_6 = l_x_ads = None
- matmul_6: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = view_3 @ cat; view_3 = None
+ matmul_3: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = view_3 @ cat; view_3 = None
getitem_12: "bf16[4096, 16, 759][41024, 1282, 1]cuda:0" = view_2[(slice(None, None, None), slice(None, None, None), slice(None, 759, None))]; view_2 = None
- matmul_7: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = getitem_12 @ l_x_user; getitem_12 = None
+ matmul_4: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = getitem_12 @ l_x_user; getitem_12 = None
permute: "f32[4096, 128, 1282][164096, 1, 128]cuda:0" = cat.permute(0, 2, 1); cat = None
- res_ads_6: "bf16[4096, 16, 1282][20512, 1282, 1]cuda:0" = matmul_6 @ permute; matmul_6 = permute = None
+ res_ads_6: "bf16[4096, 16, 1282][20512, 1282, 1]cuda:0" = matmul_3 @ permute; matmul_3 = permute = None
permute_1: "f32[4096, 128, 759][97152, 1, 128]cuda:0" = l_x_user.permute(0, 2, 1); l_x_user = None
- res_user_6: "bf16[4096, 16, 759][12144, 759, 1]cuda:0" = matmul_7 @ permute_1; matmul_7 = permute_1 = None
+ res_user_6: "bf16[4096, 16, 759][12144, 759, 1]cuda:0" = matmul_4 @ permute_1; matmul_4 = permute_1 = None
view_4: "bf16[4096, 12144][12144, 1]cuda:0" = res_user_6.view(4096, -1)
view_5: "bf16[4096, 20512][20512, 1]cuda:0" = res_ads_6.view(4096, -1)
cat_1: "f32[4096, 16240][16240, 1]cuda:0" = torch.cat([view_4, l_dense_proj_user], dim = 1); view_4 = l_dense_proj_user = None
cat_2: "f32[4096, 24608][24608, 1]cuda:0" = torch.cat([view_5, l_dense_proj_ads], dim = 1); view_5 = l_dense_proj_ads = None
res_user_7: "f32[4096, 16240][16240, 1]cuda:0" = torch.rms_norm(cat_1, (16240,), l_self_modules_ln1_modules_ln_user_parameters_weight_, None); cat_1 = l_self_modules_ln1_modules_ln_user_parameters_weight_ = None
res_ads_7: "f32[4096, 24608][24608, 1]cuda:0" = torch.rms_norm(cat_2, (24608,), l_self_modules_ln1_modules_ln_ads_parameters_weight_, None); cat_2 = l_self_modules_ln1_modules_ln_ads_parameters_weight_ = None
linear_6: "bf16[4096, 9216][9216, 1]cuda:0" = torch._C._nn.linear(res_user_7, l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_, None); res_user_7 = l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_ = None
res_user_8: "bf16[4096, 4608][9216, 1]cuda:0" = linear_6[(slice(None, None, None), slice(None, 4608, None))]
linear_7: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(res_ads_7, l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_, None); res_ads_7 = l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_ = None
getitem_14: "bf16[4096, 4608][9216, 1]cuda:0" = linear_6[(slice(None, None, None), slice(4608, None, None))]; linear_6 = None
index_select_7: "bf16[4096, 4608][4608, 1]cuda:0" = torch.index_select(getitem_14, dim = 0, index = l_x_ads_to_user_map); getitem_14 = None
res_ads_8: "bf16[4096, 4608][4608, 1]cuda:0" = linear_7 + index_select_7; linear_7 = index_select_7 = None
output_4: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_user_8, (4608,), l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_8 = l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_5: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_ads_8, (4608,), l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_8 = l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid_4: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_4)
mul_4: "f32[4096, 4608][4608, 1]cuda:0" = output_4 * sigmoid_4; output_4 = sigmoid_4 = None
sigmoid_5: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_5)
mul_5: "f32[4096, 4608][4608, 1]cuda:0" = output_5 * sigmoid_5; output_5 = sigmoid_5 = None
linear_8: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(mul_4, l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_ = None
res_user_9: "bf16[4096, 2304][4608, 1]cuda:0" = linear_8[(slice(None, None, None), slice(None, 2304, None))]
linear_9: "bf16[4096, 2304][2304, 1]cuda:0" = torch._C._nn.linear(mul_5, l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_ = None
getitem_16: "bf16[4096, 2304][4608, 1]cuda:0" = linear_8[(slice(None, None, None), slice(2304, None, None))]; linear_8 = None
index_select_8: "bf16[4096, 2304][2304, 1]cuda:0" = torch.index_select(getitem_16, dim = 0, index = l_x_ads_to_user_map); getitem_16 = None
res_ads_9: "bf16[4096, 2304][2304, 1]cuda:0" = linear_9 + index_select_8; linear_9 = index_select_8 = None
output_6: "f32[4096, 2304][2304, 1]cuda:0" = torch.rms_norm(res_user_9, (2304,), l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_9 = l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_7: "f32[4096, 2304][2304, 1]cuda:0" = torch.rms_norm(res_ads_9, (2304,), l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_9 = l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid_6: "f32[4096, 2304][2304, 1]cuda:0" = torch.sigmoid(output_6)
mul_6: "f32[4096, 2304][2304, 1]cuda:0" = output_6 * sigmoid_6; output_6 = sigmoid_6 = None
sigmoid_7: "f32[4096, 2304][2304, 1]cuda:0" = torch.sigmoid(output_7)
mul_7: "f32[4096, 2304][2304, 1]cuda:0" = output_7 * sigmoid_7; output_7 = sigmoid_7 = None
linear_10: "bf16[4096, 9216][9216, 1]cuda:0" = torch._C._nn.linear(mul_6, l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_, None); mul_6 = l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_ = None
res_user_10: "bf16[4096, 4608][9216, 1]cuda:0" = linear_10[(slice(None, None, None), slice(None, 4608, None))]
linear_11: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(mul_7, l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_, None); mul_7 = l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_ = None
getitem_18: "bf16[4096, 4608][9216, 1]cuda:0" = linear_10[(slice(None, None, None), slice(4608, None, None))]; linear_10 = None
index_select_9: "bf16[4096, 4608][4608, 1]cuda:0" = torch.index_select(getitem_18, dim = 0, index = l_x_ads_to_user_map); getitem_18 = None
res_ads_10: "bf16[4096, 4608][4608, 1]cuda:0" = linear_11 + index_select_9; linear_11 = index_select_9 = None
res_user_11: "f32[4096, 4608][4608, 1]cuda:0" = mul_4 + res_user_10; mul_4 = res_user_10 = None
res_ads_11: "f32[4096, 4608][4608, 1]cuda:0" = mul_5 + res_ads_10; mul_5 = res_ads_10 = None
output_8: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_user_11, (4608,), l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_11 = l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_9: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_ads_11, (4608,), l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_11 = l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid_8: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_8)
mul_8: "f32[4096, 4608][4608, 1]cuda:0" = output_8 * sigmoid_8; output_8 = sigmoid_8 = None
sigmoid_9: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_9)
mul_9: "f32[4096, 4608][4608, 1]cuda:0" = output_9 * sigmoid_9; output_9 = sigmoid_9 = None
linear_12: "bf16[4096, 3072][3072, 1]cuda:0" = torch._C._nn.linear(mul_8, l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_shared_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_shared_parameters_weight_ = None
res_user_12: "bf16[4096, 1536][3072, 1]cuda:0" = linear_12[(slice(None, None, None), slice(None, 1536, None))]
linear_13: "bf16[4096, 1536][1536, 1]cuda:0" = torch._C._nn.linear(mul_9, l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_ads_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_ads_parameters_weight_ = None
getitem_20: "bf16[4096, 1536][3072, 1]cuda:0" = linear_12[(slice(None, None, None), slice(1536, None, None))]; linear_12 = None
index_select_10: "bf16[4096, 1536][1536, 1]cuda:0" = torch.index_select(getitem_20, dim = 0, index = l_x_ads_to_user_map); getitem_20 = None
res_ads_12: "bf16[4096, 1536][1536, 1]cuda:0" = linear_13 + index_select_10; linear_13 = index_select_10 = None
output_10: "f32[4096, 1536][1536, 1]cuda:0" = torch.rms_norm(res_user_12, (1536,), l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_12 = l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_11: "f32[4096, 1536][1536, 1]cuda:0" = torch.rms_norm(res_ads_12, (1536,), l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_12 = l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid_10: "f32[4096, 1536][1536, 1]cuda:0" = torch.sigmoid(output_10)
mul_10: "f32[4096, 1536][1536, 1]cuda:0" = output_10 * sigmoid_10; output_10 = sigmoid_10 = None
sigmoid_11: "f32[4096, 1536][1536, 1]cuda:0" = torch.sigmoid(output_11)
mul_11: "f32[4096, 1536][1536, 1]cuda:0" = output_11 * sigmoid_11; output_11 = sigmoid_11 = None
linear_14: "bf16[4096, 9216][9216, 1]cuda:0" = torch._C._nn.linear(mul_10, l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_shared_parameters_weight_, None); mul_10 = l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_shared_parameters_weight_ = None
res_user_13: "bf16[4096, 4608][9216, 1]cuda:0" = linear_14[(slice(None, None, None), slice(None, 4608, None))]
linear_15: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(mul_11, l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_ads_parameters_weight_, None); mul_11 = l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_ads_parameters_weight_ = None
getitem_22: "bf16[4096, 4608][9216, 1]cuda:0" = linear_14[(slice(None, None, None), slice(4608, None, None))]; linear_14 = None
index_select_11: "bf16[4096, 4608][4608, 1]cuda:0" = torch.index_select(getitem_22, dim = 0, index = l_x_ads_to_user_map); getitem_22 = None
res_ads_13: "bf16[4096, 4608][4608, 1]cuda:0" = linear_15 + index_select_11; linear_15 = index_select_11 = None
res_user_14: "f32[4096, 4608][4608, 1]cuda:0" = mul_8 + res_user_13; mul_8 = res_user_13 = None
res_ads_14: "f32[4096, 4608][4608, 1]cuda:0" = mul_9 + res_ads_13; mul_9 = res_ads_13 = None
output_12: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_user_14, (4608,), l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_14 = l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None
output_13: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_ads_14, (4608,), l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_14 = l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None
sigmoid_12: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_12)
mul_12: "f32[4096, 4608][4608, 1]cuda:0" = output_12 * sigmoid_12; output_12 = sigmoid_12 = None
sigmoid_13: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_13)
mul_13: "f32[4096, 4608][4608, 1]cuda:0" = output_13 * sigmoid_13; output_13 = sigmoid_13 = None
linear_16: "bf16[4096, 12288][12288, 1]cuda:0" = torch._C._nn.linear(mul_12, l_self_modules_bitmlp_modules_output_fc_modules_linear_shared_parameters_weight_, None); l_self_modules_bitmlp_modules_output_fc_modules_linear_shared_parameters_weight_ = None
res_user_15: "bf16[4096, 6144][12288, 1]cuda:0" = linear_16[(slice(None, None, None), slice(None, 6144, None))]
linear_17: "bf16[4096, 6144][6144, 1]cuda:0" = torch._C._nn.linear(mul_13, l_self_modules_bitmlp_modules_output_fc_modules_linear_ads_parameters_weight_, None); l_self_modules_bitmlp_modules_output_fc_modules_linear_ads_parameters_weight_ = None
getitem_24: "bf16[4096, 6144][12288, 1]cuda:0" = linear_16[(slice(None, None, None), slice(6144, None, None))]; linear_16 = None
index_select_12: "bf16[4096, 6144][6144, 1]cuda:0" = torch.index_select(getitem_24, dim = 0, index = l_x_ads_to_user_map); getitem_24 = None
res_ads_15: "bf16[4096, 6144][6144, 1]cuda:0" = linear_17 + index_select_12; linear_17 = index_select_12 = None
view_6: "bf16[4096, 48, 128][12288, 128, 1]cuda:0" = res_user_15.view(4096, -1, 128); res_user_15 = None
view_7: "bf16[4096, 48, 128][6144, 128, 1]cuda:0" = res_ads_15.view(4096, -1, 128); res_ads_15 = None
- matmul_10: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_ @ view_6; l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_ = view_6 = None
+ helion_transposed_matmul_default_3: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(view_6, l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); view_6 = l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_ = None
- res_user_16: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_10[(slice(None, None, None), slice(None, 96, None))]
+ res_user_16: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_3[(slice(None, None, None), slice(None, 96, None))]
- matmul_11: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ @ view_7; l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ = view_7 = None
+ matmul_7: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ @ view_7; l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ = view_7 = None
- getitem_26: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_10[(slice(None, None, None), slice(96, None, None))]; matmul_10 = None
+ getitem_26: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_3[(slice(None, None, None), slice(96, None, None))]; helion_transposed_matmul_default_3 = None
- index_select_13: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_26, dim = 0, index = l_x_ads_to_user_map); getitem_26 = l_x_ads_to_user_map = None
+ index_select_13: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_26, dim = 0, index = l_x_ads_to_user_map); getitem_26 = l_x_ads_to_user_map = None
- res_ads_16: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_11 + index_select_13; matmul_11 = index_select_13 = None
+ res_ads_16: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_7 + index_select_13; matmul_7 = index_select_13 = None
- res_user_17: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = res_user_16 + res_user; res_user_16 = res_user = None
+ res_user_17: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = res_user_16 + res_user; res_user_16 = res_user = None
- res_ads_17: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = res_ads_16 + res_ads; res_ads_16 = res_ads = None
+ res_ads_17: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = res_ads_16 + res_ads; res_ads_16 = res_ads = None
- cat_3: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_user_17, res_user_2], dim = 1); res_user_17 = res_user_2 = None
+ cat_3: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_user_17, res_user_2], dim = 1); res_user_17 = res_user_2 = None
- cat_4: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_ads_17, res_ads_2], dim = 1); res_ads_17 = res_ads_2 = None
+ cat_4: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_ads_17, res_ads_2], dim = 1); res_ads_17 = res_ads_2 = None
res_user_18: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.rms_norm(cat_3, (128,), l_self_modules_ln2_modules_ln_user_parameters_weight_, None); cat_3 = l_self_modules_ln2_modules_ln_user_parameters_weight_ = None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment