-
-
Save shunting314/cf5a06a3b92c9b629ce885b491571e7a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| --- /tmp/inductor.py 2026-04-30 10:22:24.247513427 -0700 | |
| +++ /tmp/helion.py 2026-04-30 10:22:40.475653971 -0700 | |
| @@ -1,8 +1,8 @@ | |
| class GraphModule(torch.nn.Module): | |
| - def forward(self, L_x_ads: "f32[4096, 523, 128][66944, 128, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_x_user: "f32[4096, 759, 128][97152, 128, 1]cuda:0", L_x_ads_to_user_map: "i64[4096][1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_: "f32[32, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_: "f32[16, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[512, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[256, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[512, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_: "f32[41024, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_: "f32[20512, 256][256, 1]cuda:0", L_dense_proj_user: "f32[4096, 4096][4096, 1]cuda:0", L_dense_proj_ads: "f32[4096, 4096][4096, 1]cuda:0", L_self_modules_ln1_modules_ln_user_parameters_weight_: "f32[16240][1]cuda:0", L_self_modules_ln1_modules_ln_ads_parameters_weight_: "f32[24608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[9216, 16240][16240, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[4608, 24608][24608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[4608, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[2304, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_: "f32[9216, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_: "f32[4608, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_p | |
| + def forward(self, L_x_ads: "f32[4096, 523, 128][66944, 128, 1]cuda:0", L_x_user: "f32[4096, 759, 128][97152, 128, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_x_ads_to_user_map: "i64[4096][1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_: "f32[32, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_: "f32[16, 523][523, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_: "f32[192, 759][759, 1]cuda:0", L_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_: "f32[96, 523][523, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[512, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[256, 2048][2048, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[512, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[256][1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_: "f32[41024, 256][256, 1]cuda:0", L_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_: "f32[20512, 256][256, 1]cuda:0", L_dense_proj_user: "f32[4096, 4096][4096, 1]cuda:0", L_dense_proj_ads: "f32[4096, 4096][4096, 1]cuda:0", L_self_modules_ln1_modules_ln_user_parameters_weight_: "f32[16240][1]cuda:0", L_self_modules_ln1_modules_ln_ads_parameters_weight_: "f32[24608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_: "f32[9216, 16240][16240, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_: "f32[4608, 24608][24608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_: "f32[4608, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_: "f32[2304, 4608][4608, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_: "f32[2304][1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_: "f32[9216, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_: "f32[4608, 2304][2304, 1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_: "f32[4608][1]cuda:0", L_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_p | |
| l_x_ads = L_x_ads | |
| - l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ | |
| l_x_user = L_x_user | |
| + l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ | |
| l_x_ads_to_user_map = L_x_ads_to_user_map | |
| l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ | |
| l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ = L_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ | |
| @@ -50,430 +50,430 @@ | |
| l_self_modules_ln2_modules_ln_user_parameters_weight_ = L_self_modules_ln2_modules_ln_user_parameters_weight_ | |
| l_self_modules_ln2_modules_ln_ads_parameters_weight_ = L_self_modules_ln2_modules_ln_ads_parameters_weight_ | |
| - matmul: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ @ l_x_user; l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| + helion_transposed_matmul_default: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(l_x_user, l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| - res_user: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul[(slice(None, None, None), slice(None, 96, None))] | |
| + res_user: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default[(slice(None, None, None), slice(None, 96, None))] | |
| - matmul_1: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ = None | |
| + matmul: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_0_modules_fc_modules_linear_ads_parameters_weight_ = None | |
| - getitem_1: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul[(slice(None, None, None), slice(96, None, None))]; matmul = None | |
| + getitem_1: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default[(slice(None, None, None), slice(96, None, None))]; helion_transposed_matmul_default = None | |
| - index_select: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_1, dim = 0, index = l_x_ads_to_user_map); getitem_1 = None | |
| + index_select: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_1, dim = 0, index = l_x_ads_to_user_map); getitem_1 = None | |
| - res_ads: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_1 + index_select; matmul_1 = index_select = None | |
| + res_ads: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = matmul + index_select; matmul = index_select = None | |
| - matmul_2: "bf16[4096, 32, 128][4096, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ @ l_x_user; l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| + helion_transposed_matmul_default_1: "f32[4096, 32, 128][4096, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(l_x_user, l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| - res_user_1: "bf16[4096, 16, 128][4096, 128, 1]cuda:0" = matmul_2[(slice(None, None, None), slice(None, 16, None))] | |
| + res_user_1: "f32[4096, 16, 128][4096, 128, 1]cuda:0" = helion_transposed_matmul_default_1[(slice(None, None, None), slice(None, 16, None))] | |
| - matmul_3: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ = None | |
| + matmul_1: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_1_modules_fc_modules_linear_ads_parameters_weight_ = None | |
| - getitem_3: "bf16[4096, 16, 128][4096, 128, 1]cuda:0" = matmul_2[(slice(None, None, None), slice(16, None, None))]; matmul_2 = None | |
| + getitem_3: "f32[4096, 16, 128][4096, 128, 1]cuda:0" = helion_transposed_matmul_default_1[(slice(None, None, None), slice(16, None, None))]; helion_transposed_matmul_default_1 = None | |
| - index_select_1: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = torch.index_select(getitem_3, dim = 0, index = l_x_ads_to_user_map); getitem_3 = None | |
| + index_select_1: "f32[4096, 16, 128][2048, 128, 1]cuda:0" = torch.index_select(getitem_3, dim = 0, index = l_x_ads_to_user_map); getitem_3 = None | |
| - res_ads_1: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = matmul_3 + index_select_1; matmul_3 = index_select_1 = None | |
| + res_ads_1: "f32[4096, 16, 128][2048, 128, 1]cuda:0" = matmul_1 + index_select_1; matmul_1 = index_select_1 = None | |
| - matmul_4: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_ @ l_x_user; l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| + helion_transposed_matmul_default_2: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(l_x_user, l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| - res_user_2: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_4[(slice(None, None, None), slice(None, 96, None))] | |
| + res_user_2: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_2[(slice(None, None, None), slice(None, 96, None))] | |
| - matmul_5: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ = None | |
| + matmul_2: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ @ l_x_ads; l_self_modules_maybe_fused_lce_modules_lces_modules_2_modules_fc_modules_linear_ads_parameters_weight_ = None | |
| - getitem_5: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_4[(slice(None, None, None), slice(96, None, None))]; matmul_4 = None | |
| + getitem_5: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_2[(slice(None, None, None), slice(96, None, None))]; helion_transposed_matmul_default_2 = None | |
| - index_select_2: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_5, dim = 0, index = l_x_ads_to_user_map); getitem_5 = None | |
| + index_select_2: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_5, dim = 0, index = l_x_ads_to_user_map); getitem_5 = None | |
| - res_ads_2: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_5 + index_select_2; matmul_5 = index_select_2 = None | |
| + res_ads_2: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_2 + index_select_2; matmul_2 = index_select_2 = None | |
| - view: "bf16[4096, 2048][4096, 1]cuda:0" = res_user_1.view(4096, -1); res_user_1 = None | |
| + view: "f32[4096, 2048][4096, 1]cuda:0" = res_user_1.view(4096, -1); res_user_1 = None | |
| - view_1: "bf16[4096, 2048][2048, 1]cuda:0" = res_ads_1.view(4096, -1); res_ads_1 = None | |
| + view_1: "f32[4096, 2048][2048, 1]cuda:0" = res_ads_1.view(4096, -1); res_ads_1 = None | |
| linear: "bf16[4096, 512][512, 1]cuda:0" = torch._C._nn.linear(view, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_, None); view = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_shared_parameters_weight_ = None | |
| res_user_3: "bf16[4096, 256][512, 1]cuda:0" = linear[(slice(None, None, None), slice(None, 256, None))] | |
| linear_1: "bf16[4096, 256][256, 1]cuda:0" = torch._C._nn.linear(view_1, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_, None); view_1 = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_0_modules_linear_ads_parameters_weight_ = None | |
| getitem_7: "bf16[4096, 256][512, 1]cuda:0" = linear[(slice(None, None, None), slice(256, None, None))]; linear = None | |
| index_select_3: "bf16[4096, 256][256, 1]cuda:0" = torch.index_select(getitem_7, dim = 0, index = l_x_ads_to_user_map); getitem_7 = None | |
| res_ads_3: "bf16[4096, 256][256, 1]cuda:0" = linear_1 + index_select_3; linear_1 = index_select_3 = None | |
| output: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_user_3, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_3 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_1: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_ads_3, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_3 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output) | |
| mul: "f32[4096, 256][256, 1]cuda:0" = output * sigmoid; output = sigmoid = None | |
| sigmoid_1: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output_1) | |
| mul_1: "f32[4096, 256][256, 1]cuda:0" = output_1 * sigmoid_1; output_1 = sigmoid_1 = None | |
| linear_2: "bf16[4096, 512][512, 1]cuda:0" = torch._C._nn.linear(mul, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_, None); mul = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_shared_parameters_weight_ = None | |
| res_user_4: "bf16[4096, 256][512, 1]cuda:0" = linear_2[(slice(None, None, None), slice(None, 256, None))] | |
| linear_3: "bf16[4096, 256][256, 1]cuda:0" = torch._C._nn.linear(mul_1, l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_, None); mul_1 = l_self_modules_bitattn_modules_v_proj_modules_fcs_modules_1_modules_linear_ads_parameters_weight_ = None | |
| getitem_9: "bf16[4096, 256][512, 1]cuda:0" = linear_2[(slice(None, None, None), slice(256, None, None))]; linear_2 = None | |
| index_select_4: "bf16[4096, 256][256, 1]cuda:0" = torch.index_select(getitem_9, dim = 0, index = l_x_ads_to_user_map); getitem_9 = None | |
| res_ads_4: "bf16[4096, 256][256, 1]cuda:0" = linear_3 + index_select_4; linear_3 = index_select_4 = None | |
| output_2: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_user_4, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_4 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_3: "f32[4096, 256][256, 1]cuda:0" = torch.rms_norm(res_ads_4, (256,), l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_4 = l_self_modules_bitattn_modules_v_proj_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid_2: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output_2) | |
| mul_2: "f32[4096, 256][256, 1]cuda:0" = output_2 * sigmoid_2; output_2 = sigmoid_2 = None | |
| sigmoid_3: "f32[4096, 256][256, 1]cuda:0" = torch.sigmoid(output_3) | |
| mul_3: "f32[4096, 256][256, 1]cuda:0" = output_3 * sigmoid_3; output_3 = sigmoid_3 = None | |
| linear_4: "bf16[4096, 41024][41024, 1]cuda:0" = torch._C._nn.linear(mul_2, l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_, None); mul_2 = l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_shared_parameters_weight_ = None | |
| res_user_5: "bf16[4096, 20512][41024, 1]cuda:0" = linear_4[(slice(None, None, None), slice(None, 20512, None))] | |
| linear_5: "bf16[4096, 20512][20512, 1]cuda:0" = torch._C._nn.linear(mul_3, l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_, None); mul_3 = l_self_modules_bitattn_modules_v_proj_modules_output_fc_modules_linear_ads_parameters_weight_ = None | |
| getitem_11: "bf16[4096, 20512][41024, 1]cuda:0" = linear_4[(slice(None, None, None), slice(20512, None, None))]; linear_4 = None | |
| index_select_5: "bf16[4096, 20512][20512, 1]cuda:0" = torch.index_select(getitem_11, dim = 0, index = l_x_ads_to_user_map); getitem_11 = None | |
| res_ads_5: "bf16[4096, 20512][20512, 1]cuda:0" = linear_5 + index_select_5; linear_5 = index_select_5 = None | |
| view_2: "bf16[4096, 16, 1282][41024, 1282, 1]cuda:0" = res_user_5.view(4096, -1, 1282); res_user_5 = None | |
| view_3: "bf16[4096, 16, 1282][20512, 1282, 1]cuda:0" = res_ads_5.view(4096, -1, 1282); res_ads_5 = None | |
| index_select_6: "f32[4096, 759, 128][97152, 128, 1]cuda:0" = torch.index_select(l_x_user, dim = 0, index = l_x_ads_to_user_map) | |
| cat: "f32[4096, 1282, 128][164096, 128, 1]cuda:0" = torch.cat([index_select_6, l_x_ads], dim = 1); index_select_6 = l_x_ads = None | |
| - matmul_6: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = view_3 @ cat; view_3 = None | |
| + matmul_3: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = view_3 @ cat; view_3 = None | |
| getitem_12: "bf16[4096, 16, 759][41024, 1282, 1]cuda:0" = view_2[(slice(None, None, None), slice(None, None, None), slice(None, 759, None))]; view_2 = None | |
| - matmul_7: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = getitem_12 @ l_x_user; getitem_12 = None | |
| + matmul_4: "bf16[4096, 16, 128][2048, 128, 1]cuda:0" = getitem_12 @ l_x_user; getitem_12 = None | |
| permute: "f32[4096, 128, 1282][164096, 1, 128]cuda:0" = cat.permute(0, 2, 1); cat = None | |
| - res_ads_6: "bf16[4096, 16, 1282][20512, 1282, 1]cuda:0" = matmul_6 @ permute; matmul_6 = permute = None | |
| + res_ads_6: "bf16[4096, 16, 1282][20512, 1282, 1]cuda:0" = matmul_3 @ permute; matmul_3 = permute = None | |
| permute_1: "f32[4096, 128, 759][97152, 1, 128]cuda:0" = l_x_user.permute(0, 2, 1); l_x_user = None | |
| - res_user_6: "bf16[4096, 16, 759][12144, 759, 1]cuda:0" = matmul_7 @ permute_1; matmul_7 = permute_1 = None | |
| + res_user_6: "bf16[4096, 16, 759][12144, 759, 1]cuda:0" = matmul_4 @ permute_1; matmul_4 = permute_1 = None | |
| view_4: "bf16[4096, 12144][12144, 1]cuda:0" = res_user_6.view(4096, -1) | |
| view_5: "bf16[4096, 20512][20512, 1]cuda:0" = res_ads_6.view(4096, -1) | |
| cat_1: "f32[4096, 16240][16240, 1]cuda:0" = torch.cat([view_4, l_dense_proj_user], dim = 1); view_4 = l_dense_proj_user = None | |
| cat_2: "f32[4096, 24608][24608, 1]cuda:0" = torch.cat([view_5, l_dense_proj_ads], dim = 1); view_5 = l_dense_proj_ads = None | |
| res_user_7: "f32[4096, 16240][16240, 1]cuda:0" = torch.rms_norm(cat_1, (16240,), l_self_modules_ln1_modules_ln_user_parameters_weight_, None); cat_1 = l_self_modules_ln1_modules_ln_user_parameters_weight_ = None | |
| res_ads_7: "f32[4096, 24608][24608, 1]cuda:0" = torch.rms_norm(cat_2, (24608,), l_self_modules_ln1_modules_ln_ads_parameters_weight_, None); cat_2 = l_self_modules_ln1_modules_ln_ads_parameters_weight_ = None | |
| linear_6: "bf16[4096, 9216][9216, 1]cuda:0" = torch._C._nn.linear(res_user_7, l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_, None); res_user_7 = l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_shared_parameters_weight_ = None | |
| res_user_8: "bf16[4096, 4608][9216, 1]cuda:0" = linear_6[(slice(None, None, None), slice(None, 4608, None))] | |
| linear_7: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(res_ads_7, l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_, None); res_ads_7 = l_self_modules_bitmlp_modules_fcs_modules_0_modules_linear_ads_parameters_weight_ = None | |
| getitem_14: "bf16[4096, 4608][9216, 1]cuda:0" = linear_6[(slice(None, None, None), slice(4608, None, None))]; linear_6 = None | |
| index_select_7: "bf16[4096, 4608][4608, 1]cuda:0" = torch.index_select(getitem_14, dim = 0, index = l_x_ads_to_user_map); getitem_14 = None | |
| res_ads_8: "bf16[4096, 4608][4608, 1]cuda:0" = linear_7 + index_select_7; linear_7 = index_select_7 = None | |
| output_4: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_user_8, (4608,), l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_8 = l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_5: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_ads_8, (4608,), l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_8 = l_self_modules_bitmlp_modules_acts_modules_0_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid_4: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_4) | |
| mul_4: "f32[4096, 4608][4608, 1]cuda:0" = output_4 * sigmoid_4; output_4 = sigmoid_4 = None | |
| sigmoid_5: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_5) | |
| mul_5: "f32[4096, 4608][4608, 1]cuda:0" = output_5 * sigmoid_5; output_5 = sigmoid_5 = None | |
| linear_8: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(mul_4, l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_shared_parameters_weight_ = None | |
| res_user_9: "bf16[4096, 2304][4608, 1]cuda:0" = linear_8[(slice(None, None, None), slice(None, 2304, None))] | |
| linear_9: "bf16[4096, 2304][2304, 1]cuda:0" = torch._C._nn.linear(mul_5, l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_1_modules_linear_ads_parameters_weight_ = None | |
| getitem_16: "bf16[4096, 2304][4608, 1]cuda:0" = linear_8[(slice(None, None, None), slice(2304, None, None))]; linear_8 = None | |
| index_select_8: "bf16[4096, 2304][2304, 1]cuda:0" = torch.index_select(getitem_16, dim = 0, index = l_x_ads_to_user_map); getitem_16 = None | |
| res_ads_9: "bf16[4096, 2304][2304, 1]cuda:0" = linear_9 + index_select_8; linear_9 = index_select_8 = None | |
| output_6: "f32[4096, 2304][2304, 1]cuda:0" = torch.rms_norm(res_user_9, (2304,), l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_9 = l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_7: "f32[4096, 2304][2304, 1]cuda:0" = torch.rms_norm(res_ads_9, (2304,), l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_9 = l_self_modules_bitmlp_modules_acts_modules_1_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid_6: "f32[4096, 2304][2304, 1]cuda:0" = torch.sigmoid(output_6) | |
| mul_6: "f32[4096, 2304][2304, 1]cuda:0" = output_6 * sigmoid_6; output_6 = sigmoid_6 = None | |
| sigmoid_7: "f32[4096, 2304][2304, 1]cuda:0" = torch.sigmoid(output_7) | |
| mul_7: "f32[4096, 2304][2304, 1]cuda:0" = output_7 * sigmoid_7; output_7 = sigmoid_7 = None | |
| linear_10: "bf16[4096, 9216][9216, 1]cuda:0" = torch._C._nn.linear(mul_6, l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_, None); mul_6 = l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_shared_parameters_weight_ = None | |
| res_user_10: "bf16[4096, 4608][9216, 1]cuda:0" = linear_10[(slice(None, None, None), slice(None, 4608, None))] | |
| linear_11: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(mul_7, l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_, None); mul_7 = l_self_modules_bitmlp_modules_fcs_modules_2_modules_linear_ads_parameters_weight_ = None | |
| getitem_18: "bf16[4096, 4608][9216, 1]cuda:0" = linear_10[(slice(None, None, None), slice(4608, None, None))]; linear_10 = None | |
| index_select_9: "bf16[4096, 4608][4608, 1]cuda:0" = torch.index_select(getitem_18, dim = 0, index = l_x_ads_to_user_map); getitem_18 = None | |
| res_ads_10: "bf16[4096, 4608][4608, 1]cuda:0" = linear_11 + index_select_9; linear_11 = index_select_9 = None | |
| res_user_11: "f32[4096, 4608][4608, 1]cuda:0" = mul_4 + res_user_10; mul_4 = res_user_10 = None | |
| res_ads_11: "f32[4096, 4608][4608, 1]cuda:0" = mul_5 + res_ads_10; mul_5 = res_ads_10 = None | |
| output_8: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_user_11, (4608,), l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_11 = l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_9: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_ads_11, (4608,), l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_11 = l_self_modules_bitmlp_modules_acts_modules_2_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid_8: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_8) | |
| mul_8: "f32[4096, 4608][4608, 1]cuda:0" = output_8 * sigmoid_8; output_8 = sigmoid_8 = None | |
| sigmoid_9: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_9) | |
| mul_9: "f32[4096, 4608][4608, 1]cuda:0" = output_9 * sigmoid_9; output_9 = sigmoid_9 = None | |
| linear_12: "bf16[4096, 3072][3072, 1]cuda:0" = torch._C._nn.linear(mul_8, l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_shared_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_shared_parameters_weight_ = None | |
| res_user_12: "bf16[4096, 1536][3072, 1]cuda:0" = linear_12[(slice(None, None, None), slice(None, 1536, None))] | |
| linear_13: "bf16[4096, 1536][1536, 1]cuda:0" = torch._C._nn.linear(mul_9, l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_ads_parameters_weight_, None); l_self_modules_bitmlp_modules_fcs_modules_3_modules_linear_ads_parameters_weight_ = None | |
| getitem_20: "bf16[4096, 1536][3072, 1]cuda:0" = linear_12[(slice(None, None, None), slice(1536, None, None))]; linear_12 = None | |
| index_select_10: "bf16[4096, 1536][1536, 1]cuda:0" = torch.index_select(getitem_20, dim = 0, index = l_x_ads_to_user_map); getitem_20 = None | |
| res_ads_12: "bf16[4096, 1536][1536, 1]cuda:0" = linear_13 + index_select_10; linear_13 = index_select_10 = None | |
| output_10: "f32[4096, 1536][1536, 1]cuda:0" = torch.rms_norm(res_user_12, (1536,), l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_12 = l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_11: "f32[4096, 1536][1536, 1]cuda:0" = torch.rms_norm(res_ads_12, (1536,), l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_12 = l_self_modules_bitmlp_modules_acts_modules_3_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid_10: "f32[4096, 1536][1536, 1]cuda:0" = torch.sigmoid(output_10) | |
| mul_10: "f32[4096, 1536][1536, 1]cuda:0" = output_10 * sigmoid_10; output_10 = sigmoid_10 = None | |
| sigmoid_11: "f32[4096, 1536][1536, 1]cuda:0" = torch.sigmoid(output_11) | |
| mul_11: "f32[4096, 1536][1536, 1]cuda:0" = output_11 * sigmoid_11; output_11 = sigmoid_11 = None | |
| linear_14: "bf16[4096, 9216][9216, 1]cuda:0" = torch._C._nn.linear(mul_10, l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_shared_parameters_weight_, None); mul_10 = l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_shared_parameters_weight_ = None | |
| res_user_13: "bf16[4096, 4608][9216, 1]cuda:0" = linear_14[(slice(None, None, None), slice(None, 4608, None))] | |
| linear_15: "bf16[4096, 4608][4608, 1]cuda:0" = torch._C._nn.linear(mul_11, l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_ads_parameters_weight_, None); mul_11 = l_self_modules_bitmlp_modules_fcs_modules_4_modules_linear_ads_parameters_weight_ = None | |
| getitem_22: "bf16[4096, 4608][9216, 1]cuda:0" = linear_14[(slice(None, None, None), slice(4608, None, None))]; linear_14 = None | |
| index_select_11: "bf16[4096, 4608][4608, 1]cuda:0" = torch.index_select(getitem_22, dim = 0, index = l_x_ads_to_user_map); getitem_22 = None | |
| res_ads_13: "bf16[4096, 4608][4608, 1]cuda:0" = linear_15 + index_select_11; linear_15 = index_select_11 = None | |
| res_user_14: "f32[4096, 4608][4608, 1]cuda:0" = mul_8 + res_user_13; mul_8 = res_user_13 = None | |
| res_ads_14: "f32[4096, 4608][4608, 1]cuda:0" = mul_9 + res_ads_13; mul_9 = res_ads_13 = None | |
| output_12: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_user_14, (4608,), l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_user_modules_ln_parameters_weight_, None); res_user_14 = l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_user_modules_ln_parameters_weight_ = None | |
| output_13: "f32[4096, 4608][4608, 1]cuda:0" = torch.rms_norm(res_ads_14, (4608,), l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_ads_modules_ln_parameters_weight_, None); res_ads_14 = l_self_modules_bitmlp_modules_acts_modules_4_modules_activation_modules_ln_ads_modules_ln_parameters_weight_ = None | |
| sigmoid_12: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_12) | |
| mul_12: "f32[4096, 4608][4608, 1]cuda:0" = output_12 * sigmoid_12; output_12 = sigmoid_12 = None | |
| sigmoid_13: "f32[4096, 4608][4608, 1]cuda:0" = torch.sigmoid(output_13) | |
| mul_13: "f32[4096, 4608][4608, 1]cuda:0" = output_13 * sigmoid_13; output_13 = sigmoid_13 = None | |
| linear_16: "bf16[4096, 12288][12288, 1]cuda:0" = torch._C._nn.linear(mul_12, l_self_modules_bitmlp_modules_output_fc_modules_linear_shared_parameters_weight_, None); l_self_modules_bitmlp_modules_output_fc_modules_linear_shared_parameters_weight_ = None | |
| res_user_15: "bf16[4096, 6144][12288, 1]cuda:0" = linear_16[(slice(None, None, None), slice(None, 6144, None))] | |
| linear_17: "bf16[4096, 6144][6144, 1]cuda:0" = torch._C._nn.linear(mul_13, l_self_modules_bitmlp_modules_output_fc_modules_linear_ads_parameters_weight_, None); l_self_modules_bitmlp_modules_output_fc_modules_linear_ads_parameters_weight_ = None | |
| getitem_24: "bf16[4096, 6144][12288, 1]cuda:0" = linear_16[(slice(None, None, None), slice(6144, None, None))]; linear_16 = None | |
| index_select_12: "bf16[4096, 6144][6144, 1]cuda:0" = torch.index_select(getitem_24, dim = 0, index = l_x_ads_to_user_map); getitem_24 = None | |
| res_ads_15: "bf16[4096, 6144][6144, 1]cuda:0" = linear_17 + index_select_12; linear_17 = index_select_12 = None | |
| view_6: "bf16[4096, 48, 128][12288, 128, 1]cuda:0" = res_user_15.view(4096, -1, 128); res_user_15 = None | |
| view_7: "bf16[4096, 48, 128][6144, 128, 1]cuda:0" = res_ads_15.view(4096, -1, 128); res_ads_15 = None | |
| - matmul_10: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_ @ view_6; l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_ = view_6 = None | |
| + helion_transposed_matmul_default_3: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.ops.ads_mkl.helion_transposed_matmul.default(view_6, l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_, None, True, True); view_6 = l_self_modules_post_snn_lce_modules_fc_modules_linear_shared_parameters_weight_ = None | |
| - res_user_16: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_10[(slice(None, None, None), slice(None, 96, None))] | |
| + res_user_16: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_3[(slice(None, None, None), slice(None, 96, None))] | |
| - matmul_11: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ @ view_7; l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ = view_7 = None | |
| + matmul_7: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ @ view_7; l_self_modules_post_snn_lce_modules_fc_modules_linear_ads_parameters_weight_ = view_7 = None | |
| - getitem_26: "bf16[4096, 96, 128][24576, 128, 1]cuda:0" = matmul_10[(slice(None, None, None), slice(96, None, None))]; matmul_10 = None | |
| + getitem_26: "f32[4096, 96, 128][24576, 128, 1]cuda:0" = helion_transposed_matmul_default_3[(slice(None, None, None), slice(96, None, None))]; helion_transposed_matmul_default_3 = None | |
| - index_select_13: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_26, dim = 0, index = l_x_ads_to_user_map); getitem_26 = l_x_ads_to_user_map = None | |
| + index_select_13: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = torch.index_select(getitem_26, dim = 0, index = l_x_ads_to_user_map); getitem_26 = l_x_ads_to_user_map = None | |
| - res_ads_16: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_11 + index_select_13; matmul_11 = index_select_13 = None | |
| + res_ads_16: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = matmul_7 + index_select_13; matmul_7 = index_select_13 = None | |
| - res_user_17: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = res_user_16 + res_user; res_user_16 = res_user = None | |
| + res_user_17: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = res_user_16 + res_user; res_user_16 = res_user = None | |
| - res_ads_17: "bf16[4096, 96, 128][12288, 128, 1]cuda:0" = res_ads_16 + res_ads; res_ads_16 = res_ads = None | |
| + res_ads_17: "f32[4096, 96, 128][12288, 128, 1]cuda:0" = res_ads_16 + res_ads; res_ads_16 = res_ads = None | |
| - cat_3: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_user_17, res_user_2], dim = 1); res_user_17 = res_user_2 = None | |
| + cat_3: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_user_17, res_user_2], dim = 1); res_user_17 = res_user_2 = None | |
| - cat_4: "bf16[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_ads_17, res_ads_2], dim = 1); res_ads_17 = res_ads_2 = None | |
| + cat_4: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.cat([res_ads_17, res_ads_2], dim = 1); res_ads_17 = res_ads_2 = None | |
| res_user_18: "f32[4096, 192, 128][24576, 128, 1]cuda:0" = torch.rms_norm(cat_3, (128,), l_self_modules_ln2_modules_ln_user_parameters_weight_, None); cat_3 = l_self_modules_ln2_modules_ln_user_parameters_weight_ = None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment