diff --git a/examples/llm-api/quickstart_advanced.py b/examples/llm-api/quickstart_advanced.py index 65c2f5a644c..6fc6fc28f4e 100644 --- a/examples/llm-api/quickstart_advanced.py +++ b/examples/llm-api/quickstart_advanced.py @@ -108,6 +108,9 @@ def add_llm_args(parser): default=False, action='store_true', help='Use piecewise CUDA graph to optimize the model') + parser.add_argument('--apply_chat_template', + default=False, + action='store_true') # Sampling parser.add_argument("--max_tokens", type=int, default=64) @@ -273,6 +276,15 @@ def main(): prompts = args.prompt if args.prompt else example_prompts llm, sampling_params = setup_llm(args) + new_prompts = [] + if args.apply_chat_template: + for prompt in prompts: + messages = [{"role": "user", "content": f"{prompt}"}] + new_prompts.append( + llm.tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True)) + prompts = new_prompts outputs = llm.generate(prompts, sampling_params) for i, output in enumerate(outputs): diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 4173b338c22..8bd4c49bbc9 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -342,6 +342,7 @@ def __call__(self, position_ids: torch.Tensor, q: torch.Tensor, class RopeParams: dim: int = 0 theta: float = 10000.0 + alpha: float = 1.0 scale_type: RotaryScalingType = RotaryScalingType.none scale: float = 1.0 low_freq_factor: float = 1.0 @@ -384,6 +385,7 @@ def from_config(config) -> "RopeParams": rope_params.scale_type = RotaryScalingType.none rope_params.scale = 1.0 if rope_scaling is not None: + rope_params.alpha = rope_scaling.get("alpha", 1.0) rotary_scaling_type = rope_scaling.get( "type", None) or rope_scaling.get("rope_type") rope_params.scale_type = RotaryScalingType.from_string( @@ -462,6 +464,7 @@ def create_rope_const_params(self, interleave: bool = True): self.scale_type, rope_scaling_config={ "factor": self.scale, + "alpha": self.alpha, "low_freq_factor": self.low_freq_factor, "high_freq_factor": self.high_freq_factor, "original_max_position_embeddings": diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 125a637a493..0158e23858f 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -151,8 +151,11 @@ def fuse_pos_embd(self): @property def enable_flash_mla(self): if self.attn_backend == 'TRTLLM': - if hasattr(self.pretrained_config, "kv_lora_rank") and hasattr( - self.pretrained_config, "qk_rope_head_dim"): + if hasattr( + self.pretrained_config, "kv_lora_rank" + ) and self.pretrained_config.kv_lora_rank is not None and hasattr( + self.pretrained_config, "qk_rope_head_dim" + ) and self.pretrained_config.qk_rope_head_dim is not None: head_dim = self.pretrained_config.kv_lora_rank + self.pretrained_config.qk_rope_head_dim if head_dim == 576 and torch.cuda.get_device_capability() == ( 9, 0): diff --git a/tensorrt_llm/_torch/models/__init__.py b/tensorrt_llm/_torch/models/__init__.py index 6d6b12d06ab..668a6730b0a 100644 --- a/tensorrt_llm/_torch/models/__init__.py +++ b/tensorrt_llm/_torch/models/__init__.py @@ -8,6 +8,7 @@ from .modeling_gemma3 import Gemma3ForCausalLM from .modeling_gemma3vl import Gemma3VLM from .modeling_gpt_oss import GptOssForCausalLM +from .modeling_hunyuan_moe import HunYuanMoEV1ForCausalLM from .modeling_hyperclovax import HCXVisionForCausalLM from .modeling_llama import LlamaForCausalLM from .modeling_llava_next import LlavaNextModel @@ -38,6 +39,8 @@ "Gemma3ForCausalLM", "Gemma3VLM", "HCXVisionForCausalLM", + "HunYuanMoEV1ForCausalLM", + "Gemma3Model", "LlamaForCausalLM", "LlavaNextModel", "Mistral3VLM", diff --git a/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py new file mode 100644 index 00000000000..f8072e71fc6 --- /dev/null +++ b/tensorrt_llm/_torch/models/modeling_hunyuan_moe.py @@ -0,0 +1,434 @@ +from typing import Dict, Optional, Union + +import torch +from torch import nn +from tqdm import tqdm +from transformers import PretrainedConfig + +from tensorrt_llm._torch.distributed import AllReduceParams +from tensorrt_llm.functional import PositionEmbeddingType + +from ..attention_backend import AttentionMetadata +from ..attention_backend.interface import (PositionalEmbeddingParams, + PredefinedAttentionMask, RopeParams) +from ..model_config import ModelConfig +from ..modules.attention import Attention, QkNormType +from ..modules.decoder_layer import DecoderLayer +from ..modules.embedding import Embedding +from ..modules.fused_moe import (CutlassFusedMoE, RenormalizeMoeRoutingMethod, + VanillaMoE, create_moe) +from ..modules.gated_mlp import GatedMLP +from ..modules.linear import Linear, TensorParallelMode +from ..modules.multi_stream_utils import maybe_execute_in_parallel +from ..modules.rms_norm import RMSNorm +from ..utils import AuxStreamType, Fp4QuantizedTensor +from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, + duplicate_kv_weight, register_auto_model) + + +class HunyuanMoE(nn.Module): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + aux_stream: torch.cuda.Stream, + ): + super().__init__() + config = model_config.pretrained_config + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.moe_intermediate_size = config.moe_intermediate_size[0] \ + if isinstance(config.moe_intermediate_size, list) else config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.moe_topk[0] \ + if isinstance(config.moe_topk, list) else config.moe_topk + self.enable_attention_dp = model_config.mapping.enable_attention_dp + + # moe gate (linear layer) only runs in half/full precision for now + self.gate = Linear(self.hidden_dim, + self.num_experts, + bias=False, + dtype=config.torch_dtype) + + reduce_results = True + + self.experts = create_moe( + num_experts=self.num_experts, + routing_method=RenormalizeMoeRoutingMethod(top_k=self.top_k), + hidden_size=self.hidden_dim, + intermediate_size=self.moe_intermediate_size, + aux_stream=aux_stream, + dtype=config.torch_dtype, + reduce_results=reduce_results, + model_config=model_config) + + self.shared_mlp = GatedMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=config.mlp_bias if hasattr(config, 'mlp_bias') else False, + dtype=config.torch_dtype, + config=model_config, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + assert hidden_states.shape[-1] == self.hidden_dim + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_dim) + + shared_expert_output = self.shared_mlp(hidden_states) + all_rank_num_tokens = attn_metadata.all_rank_num_tokens + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=False) + + final_hidden_states = shared_expert_output + final_hidden_states + + return final_hidden_states.view(orig_shape) + + +class HunYuanAttention(Attention): + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + use_qk_norm: bool = True, + nope_layer: bool = False, + aux_stream: Optional[torch.cuda.Stream] = None, + ): + config = model_config.pretrained_config + + self.use_rope = not nope_layer + pos_embd_params = PositionalEmbeddingParams( + type=PositionEmbeddingType.rope_gpt_neox, + rope=RopeParams.from_config(config), + is_neox=True, + ) if self.use_rope else None + self.use_qk_norm = use_qk_norm + + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + bias=config.attention_bias, + pos_embd_params=pos_embd_params, + qk_norm_type=QkNormType.post_rope + if use_qk_norm else QkNormType.none, + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + ) + + self.head_dim = config.hidden_size // config.num_attention_heads + self.query_layernorm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.key_layernorm = RMSNorm(hidden_size=self.head_dim, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.aux_stream = aux_stream + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + + def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], + v: Optional[torch.Tensor], position_ids: torch.Tensor): + q, k, v = self.split_qkv(q, k, v) + if position_ids is not None: + q, k, v = super().apply_rope(q, k, v, position_ids) + # Llama4 applies QK norm after RoPE. + if self.use_qk_norm: + q, k = self.apply_qk_norm(q, k) + + return q, k, v + + def apply_qk_norm(self, q, k): + + def q_l2norm(): + return self.query_layernorm(q.reshape(-1, self.head_dim)).reshape( + -1, self.q_size) + + def k_l2norm(): + return self.key_layernorm(k.reshape(-1, self.head_dim)).reshape( + -1, self.kv_size) + + q, k = maybe_execute_in_parallel( + q_l2norm, + k_l2norm, + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + + return q, k + + def forward( + self, + position_ids: Optional[torch.IntTensor], + hidden_states: Union[torch.Tensor, Fp4QuantizedTensor], + attn_metadata: AttentionMetadata, + attention_mask: PredefinedAttentionMask = PredefinedAttentionMask. + CAUSAL, + mrope_config: Optional[dict] = None, + all_reduce_params: Optional[AllReduceParams] = None, + lora_params: Optional[dict] = None, + **kwargs, + ) -> torch.Tensor: + assert lora_params is None, "LORA is not supported for HunYuanAttention" + return super().forward( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=attention_mask, + mrope_config=mrope_config, + all_reduce_params=all_reduce_params, + lora_params=lora_params, + **kwargs, + ) + + +class HunYuanDecoderLayer(DecoderLayer): + + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): + super().__init__() + config = model_config.pretrained_config + self.layer_idx = layer_idx + + # attention + self.self_attn = HunYuanAttention( + model_config, + layer_idx=layer_idx, + ) + + is_experts_valid = ((isinstance(config.num_experts, int) + and config.num_experts > 1) + or (isinstance(config.num_experts, list) + and max(config.num_experts) > 1)) + is_moe_single_node = is_experts_valid and layer_idx >= config.moe_layer_num_skipped # only support one node yet + + if is_moe_single_node: + self.mlp = HunyuanMoE( + model_config, aux_stream_dict[AuxStreamType.MoeChunkingOverlap]) + else: + self.mlp = GatedMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=config.mlp_bias, + dtype=config.torch_dtype, + config=model_config) + + norm_type = getattr(config, 'norm_type', 'rms') + if norm_type == 'hf_rms' or norm_type == 'rms': + self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + elif norm_type == 'fused' or norm_type == 'torch_nn': + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps) + else: + assert False, "other norm_type are not supported" + + def forward( + self, + position_ids: torch.LongTensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + **kwargs, + ) + # Fully Connected + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, attn_metadata) + hidden_states = residual + hidden_states + return hidden_states + + +class HunYuanModel(DecoderModel): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(model_config) + config = model_config.pretrained_config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.aux_stream_dict = { + key: torch.cuda.Stream() + for key in [ + AuxStreamType.Attention, AuxStreamType.MoeShared, + AuxStreamType.MoeChunkingOverlap + ] + } + + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + mapping=model_config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + ) + + self.layers = nn.ModuleList([ + HunYuanDecoderLayer(model_config, layer_idx, self.aux_stream_dict) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> torch.Tensor: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + for layer_idx, decoder_layer in enumerate(self.layers): + kwargs['layer_idx'] = layer_idx + hidden_states = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +@register_auto_model("HunYuanMoEV1ForCausalLM") +class HunYuanMoEV1ForCausalLM(DecoderModelForCausalLM[HunYuanModel, + PretrainedConfig]): + + def __init__(self, model_config: ModelConfig[PretrainedConfig]): + super().__init__(HunYuanModel(model_config), + config=model_config, + hidden_size=model_config.pretrained_config.hidden_size, + vocab_size=model_config.pretrained_config.vocab_size) + self._execution_stats = None + print("---debug model_config: ", model_config) + + def load_weights(self, weights: Dict): + tp_size = self.model_config.mapping.tp_size + head_dim = self.config.hidden_size // self.config.num_attention_heads + + def filter_weights(prefix, weights: Dict): + result = {} + for k, v in weights.items(): + if k.startswith(prefix): + new_k = k[len(prefix) + 1:] + result[new_k] = v + return result + + params_map = { + 'qkv_proj': ['q_proj', 'k_proj', 'v_proj'], + 'gate_up_proj': ['gate_proj', 'up_proj'] + } + for name, module in tqdm(list(self.named_modules()), + desc="Loading weights"): + if len(module._parameters) > 0: + # skip load weights if tie word embeddings is enabled and layer is lm_head + if self.config.tie_word_embeddings and name.startswith( + "lm_head"): + continue + names = name.split('.') + if names[-1] in params_map: + # model.layers.{idx}.mlp.shared_mlp.gate_up_proj or model.layers.{idx}.self_attn.qkv_proj + module_weights = [] + for new_name in params_map[names[-1]]: + fw = filter_weights('.'.join(names[:-1] + [new_name]), + weights) + if new_name in ['k_proj', 'v_proj']: + fw = { + k: + duplicate_kv_weight( + weight=v[:], + num_kv_heads=v[:].shape[0] // head_dim, + tensor_parallel_size=tp_size) + if k in ["weight", "bias"] else v + for k, v in fw.items() + } + module_weights.append(fw) + module.load_weights(weights=module_weights) + else: + name = name.replace('gate', 'gate.wg') + module_weights = filter_weights(name, weights) + if isinstance(module, CutlassFusedMoE) or isinstance( + module, VanillaMoE): + # model.layers.{idx}.mlp.experts + updated_module_weights = {} + for weight_name, weight_value in module_weights.items(): + new_weight_name = weight_name.replace( + "gate_proj", + "w1").replace("up_proj", + "w3").replace("down_proj", "w2") + updated_module_weights[ + new_weight_name] = weight_value + del module_weights + module.load_weights(weights=[updated_module_weights]) + elif hasattr(module, 'load_weights'): + # model.layers.{idx}.self_attn.o_proj or model.layers.{idx}.mlp.shared_mlp.down_proj + # or model.layers.{idx}.mlp.experts.gate + module.load_weights(weights=[module_weights]) + else: + for n, p in module._parameters.items(): + if p is not None: + p.data.copy_(module_weights[n][:]) + + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + return_context_logits: bool = False, + **kwargs, + ) -> torch.Tensor: + output = self.model( + input_ids=input_ids, + attn_metadata=attn_metadata, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + return self.logits_processor.forward( + output, + self.lm_head, + attn_metadata, + return_context_logits, + ) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index bfa20eff407..8bbf814e2a6 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1,5 +1,6 @@ import math import weakref +from enum import IntEnum from typing import Optional, Union, cast import torch @@ -27,6 +28,15 @@ from .rotary_embedding import RotaryEmbedding +class QkNormType(IntEnum): + """ + The type of QK normalization. + """ + none = 0 # No normalization applied to Q and K + pre_rope = 1 # Apply normalization before Rope + post_rope = 2 # Apply normalization after Rope + + def extract_extra_attrs(layer_idx: str, attn_type: str): assert attn_type in ["mla", "attn"], "Invalid attention type" extra_attrs = get_model_extra_attrs() @@ -113,6 +123,7 @@ def __init__( dense_bias: Optional[bool] = None, config: Optional[ModelConfig] = None, q_scaling: float = 1.0, + qk_norm_type: QkNormType = QkNormType.none, attention_chunk_size: Optional[int] = None, ): """ @@ -130,6 +141,7 @@ def __init__( dtype (torch.dtype): The data type. dense_bias (Optional[bool]): Whether to use bias in the output projection layer. config (Optional[ModelConfig]): The model configuration. + qk_norm_type (QkNormType): The type of QK normalization. q_scaling (float): The scaling factor for the qk_scale. The definition is $O = softmax(QK^T * qk_scale) * V, qk_scale = 1 / (sqrt(head_dim) * q_scaling)$. The default value is 1.0. attention_chunk_size (Optional[int]): See [Chunked Attention] below. """ @@ -156,6 +168,7 @@ def __init__( self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = max_position_embeddings self.pos_embd_params = pos_embd_params + self.qk_norm_type = qk_norm_type self.dense_bias = dense_bias self.q_scaling = q_scaling @@ -258,7 +271,8 @@ def __init__( self.rope_fusion = False # If rope_fusion is not specified, enable if the attention backend supports it. if self.rope_fusion is None: - self.rope_fusion = attn_cls.support_fused_rope() + self.rope_fusion = attn_cls.support_fused_rope( + ) and qk_norm_type != QkNormType.post_rope self.rotary_emb = None if not self.rope_fusion and self.pos_embd_params is not None: @@ -430,7 +444,15 @@ def forward( output = None q, k, v = qkv, None, None - q, k, v = self.apply_rope(q, k, v, position_ids) + if self.qk_norm_type == QkNormType.pre_rope: + q, k, v = self.split_qkv(q, k, v) + q, k = self.apply_qk_norm(q, k) + if not self.rope_fusion and position_ids is not None: + q, k, v = self.split_qkv(q, k, v) + q, k = self.rotary_emb(position_ids, [q, k]) + if self.qk_norm_type == QkNormType.post_rope: + q, k = self.apply_qk_norm(q, k) + #q, k, v = self.apply_rope(q, k, v, position_ids) q, k, v = self.convert_qkv(q, k, v) # Currently only TRTLLM and FLASHINFER are torch compile compatible backends. @@ -499,6 +521,11 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], q, k = self.rotary_emb(position_ids, [q, k]) return q, k, v + def apply_qk_norm(self, q, k): + raise NotImplementedError( + f"QK norm is not implemented for {self.__class__.__name__}." + "Please override the `apply_qk_norm` method in the subclass.") + @torch.library.custom_op("trtllm::mla_custom_op_inplace", mutates_args=("output", )) diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index c0f0482674e..a68333b8fb1 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -5,9 +5,8 @@ def is_nemotron_hybrid(config): def is_mla(config): - if hasattr(config, "kv_lora_rank"): - assert hasattr( - config, "qk_rope_head_dim" - ), "both of kv_lora_rank and qk_rope_head_dim are required." + if (hasattr(config, "kv_lora_rank") and config.kv_lora_rank is not None + and hasattr(config, "qk_rope_head_dim") + and config.qk_rope_head_dim is not None): return True return False diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 59c42d32ab4..20db397cbef 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -4734,6 +4734,15 @@ def create_sinusoidal_positions_for_attention_plugin( inv_freq = 1.0 / (theta**(np.arange(0, dim, 2) / dim)).astype(dtype) inv_freq = RopeEmbeddingUtils.apply_llama3_scaling( inv_freq, rope_scaling_config) + elif scale_type == RotaryScalingType.dynamic: + # Make sure scaling_alpha exists in rope_scaling + # Ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct-FP8/blob/main/modeling_hunyuan.py#L346 + assert rope_scaling_config[ + "alpha"] is not None, "rope_scaling_config.alpha must be provided." + scaling_alpha = rope_scaling_config["alpha"] + adjusted_base = theta * (scaling_alpha**(dim / (dim - 2))) + inv_freq = 1.0 / (adjusted_base**( + np.arange(0, dim, 2, dtype=dtype) / dim)).astype(dtype) else: inv_freq = scale / (theta **(np.arange(0, dim, 2) / dim)).astype(dtype) diff --git a/tensorrt_llm/llmapi/tokenizer.py b/tensorrt_llm/llmapi/tokenizer.py index c006169be7b..7e13643fb82 100644 --- a/tensorrt_llm/llmapi/tokenizer.py +++ b/tensorrt_llm/llmapi/tokenizer.py @@ -57,6 +57,11 @@ def decode(self, token_ids: List[int], *args, **kwargs) -> str: def batch_encode_plus(self, texts: List[str], *args, **kwargs) -> dict: return self.tokenizer.batch_encode_plus(texts, *args, **kwargs) + def get_chat_template(self, + chat_template: Optional[str] = None, + tools: Optional[List[Dict]] = None) -> str: + return self.tokenizer.get_chat_template(chat_template, tools) + def apply_chat_template( self, conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]], *args, @@ -353,5 +358,8 @@ def load_hf_tokenizer(model_dir: str, use_fast=use_fast, **kwargs) - except Exception: + except Exception as e: + logger.warning( + f"Failed to load hf tokenizer from {model_dir}, encounter error: {e}" + ) return None