Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions tensorrt_llm/_torch/models/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from ..distributed import AllReduceParams
from ..model_config import ModelConfig
from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
Expand Down Expand Up @@ -105,9 +104,6 @@ def forward(
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
mrope_config: Optional[dict] = None,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
attention_mask_data: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
Expand All @@ -121,9 +117,6 @@ def forward(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=attention_mask,
mrope_config=mrope_config,
all_reduce_params=all_reduce_params,
lora_params=lora_params,
attention_window_size=self.attention_window_size,
attention_mask_data=attention_mask_data,
**kwargs)
Expand Down Expand Up @@ -209,7 +202,6 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
attention_mask_data: Optional[torch.Tensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:

Expand All @@ -222,14 +214,14 @@ def forward(
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
is not None else PredefinedAttentionMask.CAUSAL,
attention_mask_data=attention_mask_data,
lora_params=lora_params,
**kwargs,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, lora_params=lora_params)
hidden_states = self.mlp(hidden_states,
lora_params=kwargs.get("lora_params", None))
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -270,7 +262,6 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
local_attention_mask_data: Optional[torch.Tensor] = None,
global_attention_mask_data: Optional[torch.Tensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
Expand All @@ -291,7 +282,7 @@ def forward(
attention_mask_data=local_attention_mask_data
if decoder_layer.self_attn.is_sliding else
global_attention_mask_data,
lora_params=lora_params,
**kwargs,
)

hidden_states = self.norm(hidden_states)
Expand Down Expand Up @@ -465,6 +456,7 @@ def forward(
inputs_embeds=inputs_embeds,
local_attention_mask_data=local_attention_mask_data,
global_attention_mask_data=global_attention_mask_data,
**kwargs,
)

return self.logits_processor.forward(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/qa/examples_test_list.txt
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora
examples/test_medusa.py::test_codellama_medusa_1gpu[CodeLlama-7b-Instruct]
examples/test_medusa.py::test_mistral_medusa_1gpu[mistral-7b-v0.1]
examples/test_medusa.py::test_qwen_medusa_1gpu[qwen_7b_chat]
Expand Down
57 changes: 57 additions & 0 deletions tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.sampling_params import SamplingParams
Expand Down Expand Up @@ -492,6 +493,62 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
assert len(outputs) == 2


def test_gemma3_1b_instruct_multi_lora() -> None:
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"

target_modules = ['attn_q', 'attn_k', 'attn_v']

# Set up temporary directory for LoRA adapters
with tempfile.TemporaryDirectory() as lora_dir:
print("Creating dummy LoRAs...")

model = AutoModelForCausalLM.from_pretrained(model_dir,
torch_dtype=torch.bfloat16,
device_map="auto")
hf_modules = ["q_proj", "k_proj", "v_proj"]
peft_lora_config = PeftLoraConfig(r=8,
target_modules=hf_modules,
bias="none",
task_type="CAUSAL_LM")
lora_paths = []
for i in range(2):
lora_model = get_peft_model(model, peft_lora_config)
for param in lora_model.parameters():
param.data.zero_()
lora_path = f"{lora_dir}/lora_{i}"
lora_model.save_pretrained(lora_path)
lora_paths.append(lora_path)

trtllm_lora_config = LoraConfig(lora_dir=lora_paths,
lora_target_modules=target_modules,
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)
# Disabling kv cache reuse as a WAR to deal with gaps in kernel support for Gemma3's non-inclusive sliding window size.
kv_cache_config = KvCacheConfig(
enable_block_reuse=False,
enable_partial_reuse=False,
)
llm = LLM(model_dir,
lora_config=trtllm_lora_config,
kv_cache_config=kv_cache_config)

prompts = [
"Is it ok to fill diesel in a petrol car?",
"What is the capital of France?",
]
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
lora_requests = [lora_req1, lora_req2]
sampling_params = SamplingParams(max_tokens=200)

outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_requests)

assert len(outputs) == 2


@pytest.mark.parametrize(
"lora_rank,max_lora_rank,description",
[
Expand Down