diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index b84b345f827..232d2ccecd6 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -299,48 +299,6 @@ def get_bindings_model_config(self, num_heads = self.pretrained_config.num_attention_heads // ( self.mapping.tp_size * self.mapping.cp_size) - # Handle both uniform and per-layer KV heads - num_kv_heads_per_layer = getattr(self.pretrained_config, - 'num_kv_heads_per_layer', None) - if num_kv_heads_per_layer is not None: - # For models with per-layer KV heads, like nemotron-nas - kv_heads_per_layer_raw = num_kv_heads_per_layer - use_per_layer_kv_heads = True - else: - # Check if num_key_value_heads is a list (per-layer) or scalar (uniform) - num_kv_heads_raw = getattr(self.pretrained_config, - 'num_key_value_heads', None) - - if num_kv_heads_raw is not None and isinstance( - num_kv_heads_raw, list): - # num_key_value_heads is a list - treat as per-layer KV heads - kv_heads_per_layer_raw = num_kv_heads_raw - use_per_layer_kv_heads = True - else: - # num_key_value_heads is scalar or None - treat as uniform KV heads - if num_kv_heads_raw is None: - # For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads - num_kv_heads_raw = getattr( - self.pretrained_config, 'num_query_groups', - self.pretrained_config.num_attention_heads) - - num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size * - self.mapping.cp_size) - use_per_layer_kv_heads = False - - if use_per_layer_kv_heads: - # TRT-LLM LoRA requires uniform KV heads across layers - if self.lora_config is not None and len( - set(kv_heads_per_layer_raw)) > 1: - raise ValueError( - f"TRT-LLM LoRA requires uniform KV heads across layers, " - f"got: {kv_heads_per_layer_raw}") - # Apply TP/CP scaling to each layer - num_kv_heads_per_layer = [ - kv_heads // (self.mapping.tp_size * self.mapping.cp_size) - for kv_heads in kv_heads_per_layer_raw - ] - hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size model_config_cpp = ModelConfigCpp( @@ -361,9 +319,18 @@ def get_bindings_model_config(self, else: model_config_cpp.tokens_per_block = tokens_per_block - if use_per_layer_kv_heads: + num_key_value_heads = getattr(self.pretrained_config, + "num_key_value_heads", num_heads) + if isinstance(num_key_value_heads, (list, tuple)): + # Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models) + num_kv_heads_per_layer = [ + kv_heads // (self.mapping.tp_size * self.mapping.cp_size) + for kv_heads in num_key_value_heads + ] model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer else: + num_kv_heads = num_key_value_heads // (self.mapping.tp_size * + self.mapping.cp_size) model_config_cpp.set_num_kv_heads(num_kv_heads) mlp_hidden_size = None diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index f6eb8de4abc..88b86b03f54 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -451,18 +451,16 @@ def create_py_executor_instance( num_experts = _try_infer_num_experts(model_engine.model.model_config) - num_attn_layers = model_binding_config.num_attention_layers() - per_layer_kv_heads = [ - model_binding_config.num_kv_heads(i) for i in range(num_attn_layers) - ] - num_kv_attention_heads = max(per_layer_kv_heads) - if len(set(per_layer_kv_heads)) > 1: - # NOTE: This code-path is currently untested and not validated. Can fail! - # This support is tracked in TRTLLM-6561 + num_kv_attention_heads_per_layer = model_binding_config.num_kv_heads_per_layer + if max(num_kv_attention_heads_per_layer) != min( + num_kv_attention_heads_per_layer): logger.warning( - f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. " - "This code-path is currently untested and not validated. May fail!" + "Defining LORA with per-layer KV heads is not supported for LORA, using the max number of KV heads per layer" ) + num_kv_attention_heads = max(num_kv_attention_heads_per_layer) + else: + # all layers have the same number of KV heads + num_kv_attention_heads = num_kv_attention_heads_per_layer[0] lora_modules = LoraModule.create_lora_modules( lora_module_names=lora_config.lora_target_modules, diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index f9e636ec678..c9e53286908 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -350,7 +350,6 @@ def test_llama_7b_lora_config_overrides_peft_cache_config(): # TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high # https://jirasw.nvidia.com/browse/TRTLLM-5045 -@pytest.mark.skip(reason="https://nvbugs/5401210") @skip_gpu_memory_less_than_138gb def test_nemotron_nas_lora() -> None: lora_config = LoraConfig(lora_dir=[