diff --git a/tensorrt_llm/_torch/modules/embedding.py b/tensorrt_llm/_torch/modules/embedding.py index 94a088c5d2c..6a217ed39c0 100644 --- a/tensorrt_llm/_torch/modules/embedding.py +++ b/tensorrt_llm/_torch/modules/embedding.py @@ -123,7 +123,7 @@ def get_masked_input_and_mask( # We use torch.compile() to fuse the tiny pointwise ops before all_reduce/all_gather for Embedding module. -@torch.compile(mode="max-autotune-no-cudagraphs") +@torch.compile(options={"max-autotune": True}) def pre_comm_embedding_ops( input_: torch.Tensor, weight: torch.Tensor, diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index a76a6792f82..7b5fd7dd441 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -335,7 +335,7 @@ def __init__(self, spec_config: Eagle3Config, mapping: Mapping): self.max_draft_tokens = self.spec_config.max_draft_tokens self.mapping = mapping - @torch.compile(mode="max-autotune-no-cudagraphs") + @torch.compile(options={"max-autotune": True}) def forward(self, input_ids, position_ids, hidden_states, logits, attn_metadata, spec_metadata, draft_model): batch_size = attn_metadata.num_seqs