diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 44d69076fc0..0128cc8d6bb 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -604,9 +604,12 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: load_weights_vanilla_helper(module, weights) scale_name = self._get_scale_name(weights) - weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size, - module.tp_rank, - module.tp_mode).squeeze() + full_weight_scale = weights[0][scale_name] + # modelopt fp8_pb_wo can have 2 extra singleton dimensions + if full_weight_scale.dim() == 4: + full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1) + weight_scale = load_weight_shard(full_weight_scale, module.tp_size, + module.tp_rank, module.tp_mode) copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -619,13 +622,23 @@ def load_weights_fused_qkv_linear(self, module: Linear, fused_weight = torch.cat((q_weight, k_weight, v_weight)) scale_name = self._get_scale_name(weights) - q_scale = load_weight_shard(weights[0][scale_name], module.tp_size, + full_q_scale = weights[0][scale_name] + full_k_scale = weights[1][scale_name] + full_v_scale = weights[2][scale_name] + # modelopt fp8_pb_wo can have 2 extra singleton dimensions + if full_q_scale.dim() == 4: + full_q_scale = full_q_scale.squeeze(1).squeeze(-1) + if full_k_scale.dim() == 4: + full_k_scale = full_k_scale.squeeze(1).squeeze(-1) + if full_v_scale.dim() == 4: + full_v_scale = full_v_scale.squeeze(1).squeeze(-1) + q_scale = load_weight_shard(full_q_scale, module.tp_size, module.tp_rank, module.tp_mode) - k_scale = load_weight_shard(weights[1][scale_name], module.tp_size, + k_scale = load_weight_shard(full_k_scale, module.tp_size, module.tp_rank, module.tp_mode) - v_scale = load_weight_shard(weights[2][scale_name], module.tp_size, + v_scale = load_weight_shard(full_v_scale, module.tp_size, module.tp_rank, module.tp_mode) - fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze() + fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)) copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_fp8_block_scale) @@ -637,11 +650,18 @@ def load_weights_fused_gate_up_linear(self, module: Linear, fused_weight = torch.cat((gate_weight, up_weight)) scale_name = self._get_scale_name(weights) - left_scale = load_weight_shard(weights[0][scale_name], module.tp_size, + full_left_scale = weights[0][scale_name] + full_right_scale = weights[1][scale_name] + # modelopt fp8_pb_wo can have 2 extra singleton dimensions + if full_left_scale.dim() == 4: + full_left_scale = full_left_scale.squeeze(1).squeeze(-1) + if full_right_scale.dim() == 4: + full_right_scale = full_right_scale.squeeze(1).squeeze(-1) + left_scale = load_weight_shard(full_left_scale, module.tp_size, module.tp_rank, module.tp_mode) - right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, + right_scale = load_weight_shard(full_right_scale, module.tp_size, module.tp_rank, module.tp_mode) - fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze() + fused_scale = torch.cat([left_scale, right_scale], dim=0) copy_weight(module.weight, fused_weight) copy_weight(module.weight_scale, fused_scale)