diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 8d3ee666b4e..55bcac82bde 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1,6 +1,6 @@ import copy import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from PIL.Image import Image @@ -978,7 +978,10 @@ def __init__(self, self.tokenizer = tokenizer self.vocab_size = model_config.text_config.vocab_size self.image_token_index = model_config.image_token_index - + self.fake_image_token = self.processor.fake_image_token + self.image_token = self.processor.img_patch_token + self.image_token_start_index = self.model_config.boi_token_index + self.image_token_end_index = self.model_config.eoi_token_index self.encoder = nn.ModuleDict({ "vision_model": Llama4VisionModel(model_config.vision_config), @@ -987,6 +990,134 @@ def __init__(self, }).cuda() load_sharded_checkpoint(self.encoder, model_path, strict=False) + def attach_multimodal_embeddings( + self, inputs: TextPrompt, multimodal_embedding: Dict[str, + List[Dict[str, + Any]]], + sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Attach pre-processed multimodal embeddings into text token stream for Llama4 model. + + This method skips vision processing and works with externally provided embeddings. + It replaces/expands image placeholders in the text with appropriate tokens and prepares + the embeddings for model forward pass. + + Args: + inputs: Text prompt containing image placeholders + multimodal_embedding: Dictionary containing pre-processed image embedding data with special token information. + Consider adding metadata fields (e.g., model_type, model_name, version) for validation. + Returns: + Tuple of (token_ids, extra_processed_inputs) where: + - token_ids: List of processed token IDs with image placeholders + - extra_processed_inputs: Optional dictionary containing multimodal embeddings + """ + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + if not isinstance(multimodal_embedding, dict): + raise ValueError("multimodal_embedding must be a dictionary") + + if 'image' not in multimodal_embedding: + raise ValueError( + "Only image modality is supported for external multimodal embedding" + ) + + mm_embedding_info = multimodal_embedding['image'] + if not mm_embedding_info or not isinstance(mm_embedding_info[0], dict): + raise ValueError( + "Llama4 image embedding must contain special token information") + + # Extract embedding components + try: + mm_embeddings = [ + mm_embedding['mm_embeddings'] + for mm_embedding in mm_embedding_info + ] + mm_embedding_special_tokens = [ + mm_embedding['image_special_tokens'] + for mm_embedding in mm_embedding_info + ] + mm_embedding_special_offsets = [ + mm_embedding['image_special_token_offsets'] + for mm_embedding in mm_embedding_info + ] + except KeyError as e: + raise ValueError( + f"Missing required key in multimodal embedding: {e}") + + # Validate embedding dimensions + model_hidden_size = self.model_config.text_config.hidden_size + for i, embedding in enumerate(mm_embeddings): + if embedding.shape[-1] != model_hidden_size: + raise ValueError( + f"Multimodal embedding {i} hidden size {embedding.shape[-1]} " + f"must match model hidden size {model_hidden_size}") + + # Count image placeholders (number of images) in the prompt + total_placeholders = text_prompt.count(self.fake_image_token) + if total_placeholders == 0: + raise ValueError( + "No image placeholders found in the prompt, but multimodal embedding was provided" + ) + + if total_placeholders != len(mm_embeddings): + raise ValueError( + f"Number of image placeholders ({total_placeholders}) " + f"does not match number of embeddings ({len(mm_embeddings)})") + + # Process prompt with image embeddings + prompt_splits = text_prompt.split(self.fake_image_token) + new_prompt_parts = [] + + for local_image_index, split_part in enumerate(prompt_splits): + new_prompt_parts.append(split_part) + + if local_image_index < total_placeholders: + # Calculate total tokens for this image + num_tokens = len(mm_embeddings[local_image_index]) + len( + mm_embedding_special_tokens[local_image_index]) + + # Create image token sequence + image_tokens = [self.image_token] * num_tokens + + # Replace special tokens with actual decoded tokens + for offset, token_id in zip( + mm_embedding_special_offsets[local_image_index], + mm_embedding_special_tokens[local_image_index]): + if offset < 0 or offset >= len(image_tokens): + raise ValueError( + f"Image special token offset {offset} is out of range with the total image tokens length {len(image_tokens)}" + ) + if offset < len(image_tokens): + image_tokens[offset] = self.tokenizer.decode([token_id]) + + # Join tokens without spaces + image_str = "".join(image_tokens) + new_prompt_parts.append(image_str) + + # Combine all parts and tokenize + processed_text = "".join(new_prompt_parts) + kwargs = {} + if sampling_params.truncate_prompt_tokens is not None: + kwargs = dict(truncation=True, + max_length=sampling_params.truncate_prompt_tokens) + text_inputs = self.tokenizer( + processed_text, + return_tensors="pt", + add_special_tokens=sampling_params.add_special_tokens, + **kwargs) + token_ids = text_inputs.input_ids.squeeze() + + # Replace image token indices with out-of-vocabulary tokens + token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 + # Concatenate all multimodal embeddings + multimodal_data = {} + multimodal_data["multimodal_embedding"] = torch.cat(mm_embeddings, + dim=0) + return token_ids.tolist(), {"multimodal_data": multimodal_data} + @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index b85077f0d6a..c8abead3b33 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -1,6 +1,6 @@ import copy import os -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -195,6 +195,48 @@ def _postprocess(self, input_ids, mm_features): mm_features = mm_features.view(-1, mm_features.shape[-1]) return fused_input_ids, mm_features + def attach_multimodal_embeddings( + self, inputs: TextPrompt, + multimodal_embedding: Dict[str, List[torch.Tensor]], + sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + """ + Attach pre-processed multimodal embeddings into text token stream for LlavaNext model. + + This method skips vision processing and works with externally provided embeddings. + It replaces/expands image placeholders in the text with appropriate tokens and prepares + the embeddings for model forward pass. + + Args: + inputs: Text prompt containing image placeholders + multimodal_embedding: Dictionary containing pre-processed image embedding data + Returns: + Tuple of (token_ids, extra_processed_inputs) where: + - token_ids: List of processed token IDs with image placeholders + - extra_processed_inputs: Optional dictionary containing multimodal embeddings + """ + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + if not isinstance(multimodal_embedding, dict): + raise ValueError("multimodal_embedding must be a dictionary") + + if 'image' not in multimodal_embedding: + raise ValueError( + "Only image modality is supported for external multimodal embedding" + ) + + input_ids = self.tokenizer( + text_prompt, return_tensors="pt").input_ids[0].to(self.device) + mm_features = torch.stack(multimodal_embedding['image']) + fused_input_ids, mm_features = self._postprocess(input_ids, mm_features) + multimodal_data = {} + multimodal_data["multimodal_embedding"] = mm_features + return fused_input_ids.to(torch.int32).tolist(), { + "multimodal_data": multimodal_data + } + @torch.inference_mode() def __call__( self, inputs: TextPrompt, sampling_params: SamplingParams diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 6ebd7adc03d..577325bfcf8 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -500,6 +500,17 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: + # Convert back to tensor, as opposite to `to_handle` in `llm.generate_async` + # for values with non-selected keys, it's no-op + request.multimodal_params.to_tensor( + "multimodal_data", key="multimodal_embedding") + embedding = request.multimodal_params.multimodal_data.get( + "multimodal_embedding") + if embedding is not None and embedding.is_cuda: + # make sure the embedding resides on the local device + request.multimodal_params.multimodal_data[ + "multimodal_embedding"] = embedding.to("cuda") + executor_request.py_multimodal_data = request.multimodal_params.multimodal_data if self._is_pytorch_backend and request.sampling_params.logits_processor: diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index 19d55ae7744..93689065879 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -234,6 +234,152 @@ def _to_device( f"MultimodalParams: Unsupported element '{element}' to move to device. " f"Supported elements: 'multimodal_data', 'multimodal_input'") + def to_handle(self, element: str, key: Optional[str] = None) -> None: + """Convert multimodal data to tensor handle. + + Converts torch.Tensor objects to SharedTensorContainer handles (serializable dictionaries) + for efficient IPC. This function is a in-place operation. + + Args: + element: Element to convert ("multimodal_data" or "multimodal_input") + key: Specific key to convert. If None, converts all tensor values in multimodal_data. + Defaults to None. + + Example: + # Convert all tensors in multimodal_data to handles + params.to_handle("multimodal_data", key=None) + + # Convert only multimodal_embedding section tensors to handles + params.to_handle("multimodal_data", key="multimodal_embedding") + """ + # Lazy import to avoid circular dependency + from tensorrt_llm._torch.shared_tensor import SharedTensorContainer + + def _to_tensor_handle(data): + for k, v in data.items(): + if isinstance(v, torch.Tensor): + # Convert tensor to handle + handle = SharedTensorContainer.from_tensor(v).dump_to_dict() + data[k] = handle + elif isinstance(v, dict): + _to_tensor_handle(v) + elif isinstance(v, list): + for i, item in enumerate(v): + if isinstance(item, torch.Tensor): + handle = SharedTensorContainer.from_tensor( + item).dump_to_dict() + v[i] = handle + + if element == "multimodal_data": + if self.multimodal_data is None: + return + if key is None: + _to_tensor_handle(self.multimodal_data) + else: + if key not in self.multimodal_data: + return # no-op if key not found + + value = self.multimodal_data[key] + if isinstance(value, torch.Tensor): + handle = SharedTensorContainer.from_tensor( + value).dump_to_dict() + self.multimodal_data[key] = handle + elif isinstance(value, dict): + _to_tensor_handle(value) + else: + raise ValueError( + f"Unsupported value type for multimodal_data: {type(value)}" + ) + elif element == "multimodal_input": + # No-op for multimodal_input + return + else: + raise ValueError( + f"Unsupported element '{element}' to convert to handle.") + + def to_tensor(self, element: str, key: Optional[str] = None) -> None: + """Convert multimodal tensor handles back to tensors. This is the dual operation to to_handle. + + Converts SharedTensorContainer handles (serializable dictionaries) back to torch.Tensor objects + for local computation. This function performs in-place modifications to the multimodal_data. + + Args: + element: Element to convert ("multimodal_data" or "multimodal_input") + key: Specific key to convert. If None, converts all tensor handles in multimodal_data. + Defaults to None. + + Example: + # Convert all handles back to tensors + params.to_tensor("multimodal_data", key=None) + + # Convert only multimodal_embedding section handles back to tensors + params.to_tensor("multimodal_data", key="multimodal_embedding") + """ + # Lazy import to avoid circular dependency + from tensorrt_llm._torch.shared_tensor import SharedTensorContainer + + def _to_tensor(data): + for k, v in data.items(): + if isinstance(v, dict) and 'method_key' in v: + # This is a tensor handle (dict with method_key) + try: + tensor = SharedTensorContainer.from_dict( + v).get_local_view() + data[k] = tensor + except Exception as e: + raise ValueError( + f"Failed to convert handle to tensor for key '{k}': {e}" + ) + elif isinstance(v, dict): + _to_tensor(v) + elif isinstance(v, list): + for i, item in enumerate(v): + if isinstance(item, dict) and 'method_key' in item: + try: + tensor = SharedTensorContainer.from_dict( + item).get_local_view() + v[i] = tensor + except Exception as e: + raise ValueError( + f"Failed to convert handle to tensor in list at index {i}: {e}" + ) + + if element == "multimodal_data": + if self.multimodal_data is None: + return + + if key is None: + _to_tensor(self.multimodal_data) + else: + if key not in self.multimodal_data: + return # no-op if key not found + + value = self.multimodal_data[key] + if isinstance( + value, dict + ) and 'method_key' in value: # This is a tensor handle + try: + tensor = SharedTensorContainer.from_dict( + value).get_local_view() + self.multimodal_data[key] = tensor + except Exception as e: + raise ValueError( + f"Failed to convert handle to tensor for key '{key}': {e}" + ) + elif isinstance(value, dict): + _to_tensor(value) + else: + raise ValueError( + f"Unsupported value type for multimodal_data: {type(value)}" + ) + + elif element == "multimodal_input": + # No-op for multimodal_input + return + else: + raise ValueError( + f"Unsupported element '{element}' to convert to tensor.") + def strip_for_context(self) -> None: """Strip multimodal data for context processing. diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index a6b984b330e..a4711b9079e 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -322,7 +322,7 @@ class ConversationMessage(TypedDict): """Type definition for conversation message structure.""" role: str content: List[dict[str, Any]] - media: List[MultimodalData] + media: List[MultimodalData] | List[torch.Tensor] | List[Dict[str, Any]] # @classmethod # def fromSample(cls, sample: dict[str, str]) -> "ConversationMessage": @@ -477,24 +477,44 @@ def default_multimodal_input_loader( model_type: str, modality: str, prompts: List[str], - media: Union[List[str], List[List[str]]], + media: Optional[Union[List[str], List[List[str]]]] = None, image_data_format: str = "pt", num_frames: int = 8, + mm_embeddings: Optional[Union[List[torch.Tensor], + List[List[torch.Tensor]]]] = None, device: str = "cpu") -> List[dict[str, Union[str, torch.Tensor]]]: - def convert_to_conversation_message(prompt: str, media: Union[str, - List[str]], - modality: str) -> ConversationMessage: + def convert_to_conversation_message( + prompt: str, + media: Union[Any, List[Any]], + modality: str, + is_embedding: bool = False, + ) -> ConversationMessage: if isinstance(media, str): media = [media] if modality in ["image", "multiple_image"]: - mm_data = [ - MultimodalData(modality="image", - data=load_image(i, - format=image_data_format, - device=device)) for i in media - ] + if is_embedding: + # each mm_embedding corresponds to each image placeholder + if not isinstance(media, list): + media = [media] + + mm_data = [{ + 'modality': modality, + 'mm_embedding_info': mm + } for mm in media] + else: + mm_data = [ + MultimodalData(modality=modality, + data=load_image(i, + format=image_data_format, + device=device)) + for i in media + ] elif modality == "video": + if is_embedding: + raise ValueError( + "External embedding is not supported for video modality yet." + ) mm_data = [ MultimodalData(modality=modality, data=load_video(i, @@ -503,11 +523,19 @@ def convert_to_conversation_message(prompt: str, media: Union[str, device=device)) for i in media ] elif modality == "audio": + if is_embedding: + raise ValueError( + "External embedding is not supported for audio modality yet." + ) mm_data = [ MultimodalData(modality=modality, data=load_audio(i, device=device)) for i in media ] elif modality == "image_audio": + if is_embedding: + raise ValueError( + "External embedding is not supported for image_audio modality yet." + ) # Use different load_xxx functions to match the modality. mm_data = [] for m in media: @@ -543,12 +571,18 @@ def convert_to_conversation_message(prompt: str, media: Union[str, raise ValueError(f"Unknown modality: {modality}") return ConversationMessage(role="user", content=prompt, media=mm_data) - if len(media) > len(prompts) and len(prompts) == 1: + assert media is not None or mm_embeddings is not None, "Either media or mm_embeddings must be provided." + assert media is None or mm_embeddings is None, "Either media or mm_embeddings must be provided, not both." + media_or_embeddings = media if media is not None else mm_embeddings + is_embedding = mm_embeddings is not None + + if len(media_or_embeddings) > len(prompts) and len(prompts) == 1: # 1 prompt + N media assert not isinstance( - media[0], list) # media cannot be a list of lists in this case - media = [media] - assert len(media) == len(prompts) + media_or_embeddings[0], + list) # media cannot be a list of lists in this case + media_or_embeddings = [media_or_embeddings] + assert len(media_or_embeddings) == len(prompts) if tokenizer is None and model_type not in HF_CHAT_TEMPLATE_EXCEPTIONS: tokenizer = ModelLoader.load_hf_tokenizer(model_dir, use_fast=True) @@ -560,11 +594,20 @@ def convert_to_conversation_message(prompt: str, media: Union[str, trust_remote_code=True) inputs = [] - for prompt, media in zip(prompts, media): - conv = convert_to_conversation_message(prompt, media, modality) + for prompt_idx, (prompt, + media) in enumerate(zip(prompts, media_or_embeddings)): + conv = convert_to_conversation_message(prompt, media, modality, + is_embedding) mm_data_tracker = MultimodalDataTracker(model_type) for mdata in conv["media"]: - mm_data_tracker.add_data(mdata["modality"], mdata["data"]) + # Check if mdata is a MultimodalData + if isinstance(mdata, + dict) and "modality" in mdata and "data" in mdata: + mm_data_tracker.add_data(mdata["modality"], mdata["data"]) + else: + # Add embeddings to the tracker for placeholder handling + mm_data_tracker.add_data(mdata["modality"], + mdata["mm_embedding_info"]) mm_placeholder_counts = mm_data_tracker.placeholder_counts() prompt = conv["content"] if mm_placeholder_counts: @@ -579,7 +622,12 @@ def convert_to_conversation_message(prompt: str, media: Union[str, mm_placeholder_counts=mm_placeholder_counts) input = {"prompt": prompt} if mm_placeholder_counts: - input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync() + if mm_embeddings is not None: + input[ + "multi_modal_embeddings"] = mm_data_tracker.retrieve_all_sync( + ) + else: + input["multi_modal_data"] = mm_data_tracker.retrieve_all_sync() inputs.append(input) return inputs diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 0d1a1e80201..64a9b725730 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -342,10 +342,10 @@ def generate_async( inputs = prompt_inputs(inputs) - if not inputs.get("prompt") and inputs.get( - "prompt_token_ids") and inputs.get( - "multi_modal_data") and not isinstance( - self.input_processor, DefaultInputProcessor): + if not inputs.get("prompt") and inputs.get("prompt_token_ids") and ( + inputs.get("multi_modal_data") + or inputs.get("multi_modal_embeddings")) and not isinstance( + self.input_processor, DefaultInputProcessor): # VLMs need to process/tokenize the prompt in their own way prompt = self.tokenizer.decode(inputs['prompt_token_ids']) inputs = TextPrompt( @@ -378,6 +378,10 @@ def generate_async( with nvtx_range_debug("input_processor_with_hash"): prompt_token_ids, extra_processed_inputs = input_processor_with_hash( inputs, sampling_params) + elif 'multi_modal_embeddings' in inputs: + mm_embedding_info = inputs['multi_modal_embeddings'] + prompt_token_ids, extra_processed_inputs = self.input_processor.attach_multimodal_embeddings( + inputs, mm_embedding_info, sampling_params) else: with nvtx_range_debug("input_processor"): prompt_token_ids, extra_processed_inputs = self.input_processor( @@ -391,6 +395,10 @@ def generate_async( 'multimodal_input'), multimodal_data=extra_processed_inputs.get( 'multimodal_data')) + # Convert to shared tensor handle to reduce IPC overhead + # for values with non-selected keys, it's no-op + multimodal_params.to_handle("multimodal_data", + key="multimodal_embedding") # Only pass it if it has content if not multimodal_params.has_content(): multimodal_params = None diff --git a/tests/unittest/_torch/multimodal/test_share_multiparams.py b/tests/unittest/_torch/multimodal/test_share_multiparams.py new file mode 100644 index 00000000000..d4ce40f6332 --- /dev/null +++ b/tests/unittest/_torch/multimodal/test_share_multiparams.py @@ -0,0 +1,94 @@ +import unittest + +import torch + +from tensorrt_llm.inputs.multimodal import MultimodalInput, MultimodalParams + + +class TestMultimodalParamsHandleConversion(unittest.TestCase): + """Test cases for to_handle and to_tensor methods in MultimodalParams.""" + + def setUp(self): + """Set up test fixtures.""" + # Create sample cpu tensors for testing (shared cuda tensor using cudaIPC only works between processes) + self.mm_embedding = torch.randn(3, 4, 5) + self.mrope_config = { + "mrope_rotary_cos_sin": torch.randn(2, 3), + "mrope_position_deltas": torch.randn(5), + } + self.image = { + "pixel_values": torch.randn(1, 3, 224, 224), + "image_height": [224], + "image_width": [224], + } + # Create sample multimodal data structure + self.sample_multimodal_data = { + "multimodal_embedding": self.mm_embedding, + "mrope_config": self.mrope_config, + "image": self.image, + } + + def test_to_handle_none_multimodal_data(self): + """Test to_handle with None multimodal_data.""" + params = MultimodalParams() + params.multimodal_data = None + + params.to_handle("multimodal_data") + self.assertIsNone(params.multimodal_data) + params.multimodal_data = {} + params.to_handle("multimodal_data") + self.assertEqual(params.multimodal_data, {}) + + params = MultimodalParams() + multimodal_input = MultimodalInput( + multimodal_hashes=[[1, 2, 3, 4, 5, 6, 7, 8]] * 2, + multimodal_positions=[0, 10], + multimodal_lengths=[2, 2]) + params.multimodal_input = multimodal_input + params.to_handle("multimodal_input") + self.assertEqual(params.multimodal_input, multimodal_input) + + def test_to_tensor_basic_handle(self): + """Test converting a basic handle back to tensor.""" + params = MultimodalParams() + params.multimodal_data = {"multimodal_embedding": self.mm_embedding} + + # Convert to handle + params.to_handle("multimodal_data", key="multimodal_embedding") + # Convert back to tensor + params.to_tensor("multimodal_data", key="multimodal_embedding") + + result = params.multimodal_data["multimodal_embedding"] + self.assertIsInstance(result, torch.Tensor) + self.assertTrue(torch.allclose(result, self.mm_embedding)) + + def test_to_tensor_all_handles(self): + """Test that to_handle followed by to_tensor preserves data integrity.""" + params = MultimodalParams() + params.multimodal_data = self.sample_multimodal_data.copy() + + params.to_handle("multimodal_data", key=None) + params.to_tensor("multimodal_data", key=None) + + self.assertTrue( + torch.allclose(params.multimodal_data["multimodal_embedding"], + self.mm_embedding)) + self.assertTrue( + torch.allclose( + params.multimodal_data["mrope_config"]["mrope_rotary_cos_sin"], + self.mrope_config["mrope_rotary_cos_sin"])) + self.assertTrue( + torch.allclose( + params.multimodal_data["mrope_config"]["mrope_position_deltas"], + self.mrope_config["mrope_position_deltas"])) + self.assertTrue( + torch.allclose(params.multimodal_data["image"]["pixel_values"], + self.image["pixel_values"])) + self.assertEqual(params.multimodal_data["image"]["image_height"], + self.image["image_height"]) + self.assertEqual(params.multimodal_data["image"]["image_width"], + self.image["image_width"]) + + +if __name__ == "__main__": + unittest.main()