diff --git a/examples/multimodal_disaggregated/README.md b/examples/multimodal_disaggregated/README.md new file mode 100755 index 00000000000..50eba11f7bd --- /dev/null +++ b/examples/multimodal_disaggregated/README.md @@ -0,0 +1,174 @@ +# Multimodal Disaggregated Serving (Experimental) + +This example demonstrates how to set up disaggregated multimodal serving with TensorRT-LLM, where the vision encoder and language model decoder run as separate services for improved scalability and resource utilization. + +## ⚠️ Disclaimer + +**This is a Proof-of-Concept (POC) and early demonstration with several limitations:** +1. **Model Support**: Limited to LLaVA-Next models only +2. **Modality Support**: Image modality only (no video support yet) +3. **Server Configuration**: Only supports 1 encoder server and 1 LLM server (though the LLM server can have multiple workers via tensor parallelism) + +## Overview + +Disaggregated multimodal serving separates the multimodal pipeline into distinct components: + +- **Encoder Server**: Handles vision processing (images), including pre-processing and encoding, using the multimodal encoder +- **LLM Decoder Server**: Processes text generation using the language model +- **Disaggregated Server**: Orchestrates requests between encoder and decoder services + +This architecture enables better resource utilization and scalability by allowing independent scaling of vision and language processing components. + +## Setup Instructions + +### Step 1: Prepare Configuration Files + +Create the required configuration files in your working directory: + +#### LLM API Configuration (`extra-llm-api-config.yml`) +```bash +# Note: Current multimodal implementation does not support KV cache reuse, +# so we disable it for all cases +cat > ./extra-llm-api-config.yml << EOF +kv_cache_config: + enable_block_reuse: false +EOF +``` + +#### Disaggregated Server Configuration (`disagg_config.yaml`) +```bash +cat > ./disagg_config.yaml << EOF +hostname: localhost +port: 8000 +backend: pytorch +multimodal_servers: + num_instances: 1 + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + urls: + - "localhost:8002" +EOF +``` + +### Step 2: Start the Encoder Server + +Launch the multimodal encoder server on GPU 0: + +```bash +mkdir -p Logs/ +CUDA_VISIBLE_DEVICES=0 trtllm-serve encoder llava-hf/llava-v1.6-mistral-7b-hf \ + --host localhost \ + --port 8001 \ + --backend pytorch \ + &> Logs/log_encoder_0 & +``` + +### Step 3: Start the LLM Decoder Server + +Launch the language model decoder server on GPU 1: + +```bash +CUDA_VISIBLE_DEVICES=1 trtllm-serve llava-hf/llava-v1.6-mistral-7b-hf \ + --host localhost \ + --port 8002 \ + --backend pytorch \ + --extra_llm_api_options ./extra-llm-api-config.yml \ + &> Logs/log_pd_tp1 & +``` + +### Step 4: Start the Disaggregated Orchestrator + +Launch the disaggregated server that coordinates between encoder and decoder: + +```bash +trtllm-serve disaggregated_mm -c disagg_config.yaml &> Logs/log_disagg_server & +``` + +## Alternative Setup + +Instead of running Steps 2-4 manually, you can start all services at once using the provided script: + +```bash +./start_disagg_mm.sh +``` + +This script will start the encoder server, LLM decoder server, and disaggregated orchestrator automatically with the same configuration as the manual steps above. + +## Multi-GPU Decoder Configuration + +For larger models and higher throughput, you can run the decoder with tensor parallelism (TP>1) across multiple GPUs: + +```bash +CUDA_VISIBLE_DEVICES=1,2 trtllm-serve llava-hf/llava-v1.6-mistral-7b-hf \ + --host localhost \ + --port 8002 \ + --backend pytorch \ + --tp_size 2 \ + --extra_llm_api_options ./extra-llm-api-config.yml \ + &> Logs/log_pd_tp2 & +``` + +## Testing the Setup + +### Basic Functionality Test + +Test the setup with a multimodal chat completion request: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llava-hf/llava-v1.6-mistral-7b-hf", + "messages":[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe the natural environment in the image." + }, + { + "type":"image_url", + "image_url": { + "url": "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png" + } + } + ] + }], + "max_tokens": 64, + "temperature": 0 + }' +``` + +### Performance Testing + +Use the provided performance testing script for load testing (assuming you've already set up the multimodal disaggregated server): + +#### Prerequisites +```bash +pip install genai_perf +``` + +#### Concurrency Testing +```bash +./test_client_disag_mm.sh --concurrency 1 --port 8000 +``` + +#### Request Rate Testing +```bash +./test_client_disag_mm.sh --request-rate 10 --port 8000 +``` + + +## Roadmap & Future Improvements + +- **Model Support**: Add support for more multimodal models beyond LLaVA-Next +- **Communication**: NIXL integration for transferring multimodal embeddings between servers +- **Scalability**: Enable support for multiple LLM servers and multimodal servers with a routing manager +- **Parallelism**: Enable data parallelism (DP) in multimodal server +- **Configuration**: Test/verify/enable major parallel configurations in LLM decoder server +- **Optimization**: Performance optimization and tuning diff --git a/examples/multimodal_disaggregated/start_disagg_mm.sh b/examples/multimodal_disaggregated/start_disagg_mm.sh new file mode 100644 index 00000000000..a525bad64b4 --- /dev/null +++ b/examples/multimodal_disaggregated/start_disagg_mm.sh @@ -0,0 +1,14 @@ +#!/bin/bash +mkdir -p Logs/ +CUDA_VISIBLE_DEVICES=0 trtllm-serve encoder llava-hf/llava-v1.6-mistral-7b-hf \ + --host localhost \ + --port 8001 \ + --backend pytorch \ + &> Logs/log_encoder_0 & +CUDA_VISIBLE_DEVICES=1 trtllm-serve llava-hf/llava-v1.6-mistral-7b-hf \ + --host localhost \ + --port 8002 \ + --backend pytorch \ + --extra_llm_api_options ./extra-llm-api-config.yml \ + &> Logs/log_pd_tp1 & +trtllm-serve disaggregated_mm -c disagg_config.yaml &> Logs/log_disagg_server & diff --git a/examples/multimodal_disaggregated/test_client_disag_mm.sh b/examples/multimodal_disaggregated/test_client_disag_mm.sh new file mode 100644 index 00000000000..0b4f9a3c162 --- /dev/null +++ b/examples/multimodal_disaggregated/test_client_disag_mm.sh @@ -0,0 +1,174 @@ +#!/bin/bash + +# This script runs genai-perf to profile a multimodal model. +# Supports two modes: concurrency or request_rate + +# --- Command Line Arguments Parsing --- +usage() { + echo "Usage: $0 [--concurrency | --request-rate ] --port " + echo "" + echo "Options:" + echo " --concurrency Run in concurrency mode with specified concurrency level" + echo " --request-rate Run in request rate mode with specified rate (requests/sec)" + echo " --port Server port number (e.g., 8001, 8003)" + echo "" + echo "Examples:" + echo " $0 --concurrency 2 --port 8003" + echo " $0 --request-rate 15 --port 8001" + echo " $0 --concurrency 1 --port 9000" + exit 1 +} + +# Initialize variables +MODE="" +CONCURRENCY="" +REQUEST_RATE="" +PORT="" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --concurrency) + MODE="concurrency" + CONCURRENCY="$2" + shift 2 + ;; + --request-rate) + MODE="request_rate" + REQUEST_RATE="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + -h|--help) + usage + ;; + *) + echo "Unknown option: $1" + usage + ;; + esac +done + +# Validate arguments +if [ -z "$MODE" ]; then + echo "Error: Must specify either --concurrency or --request-rate" + usage +fi + +if [ -z "$PORT" ]; then + echo "Error: Must specify --port" + usage +fi + +# Validate PORT +if ! [[ "${PORT}" =~ ^[0-9]+$ ]] || [ "${PORT}" -lt 1 ] || [ "${PORT}" -gt 65535 ]; then + echo "Error: PORT must be a valid port number (1-65535)" + echo "You provided: '${PORT}'" + exit 1 +fi + +# Validate and set mode-specific values +if [ "${MODE}" = "concurrency" ]; then + if ! [[ "${CONCURRENCY}" =~ ^[0-9]+$ ]] || [ "${CONCURRENCY}" -lt 1 ]; then + echo "Error: CONCURRENCY must be a positive integer" + echo "You provided: '${CONCURRENCY}'" + exit 1 + fi + if [ "${CONCURRENCY}" -gt 1 ]; then + REQUEST_COUNT=$((CONCURRENCY*5)) + else + REQUEST_COUNT=$((CONCURRENCY*50)) + fi + echo "Running in CONCURRENCY mode: CONCURRENCY=${CONCURRENCY}, REQUEST_COUNT=${REQUEST_COUNT}, PORT=${PORT}" +elif [ "${MODE}" = "request_rate" ]; then + if ! [[ "${REQUEST_RATE}" =~ ^[0-9]+$ ]] || [ "${REQUEST_RATE}" -lt 1 ]; then + echo "Error: REQUEST_RATE must be a positive integer" + echo "You provided: '${REQUEST_RATE}'" + exit 1 + fi + REQUEST_COUNT=$((REQUEST_RATE*10)) + echo "Running in REQUEST_RATE mode: REQUEST_RATE=${REQUEST_RATE}, REQUEST_COUNT=${REQUEST_COUNT}, PORT=${PORT}" +fi + +ISL=64 +OSL=64 + +# --- Configuration for genai-perf --- +MODEL_NAME="llava-hf/llava-v1.6-mistral-7b-hf" +TOKENIZER_NAME="llava-hf/llava-v1.6-mistral-7b-hf" +SERVICE_KIND="openai" +ENDPOINT_TYPE="multimodal" +INPUT_FILE="./mm_data_oai.json" +SERVER_URL="localhost:${PORT}" + +# Set append name based on port +if [ "${PORT}" = "8000" ]; then + APPEND_NAME="disagg" +elif [ "${PORT}" = "8002" ]; then + APPEND_NAME="agg" +else + APPEND_NAME="port${PORT}" +fi + +if [ "${MODE}" = "concurrency" ]; then + PROFILE_EXPORT_FILE="ISL_${ISL}_OSL_${OSL}_CONCURRENCY_${CONCURRENCY}_${APPEND_NAME}.json" +else + PROFILE_EXPORT_FILE="ISL_${ISL}_OSL_${OSL}_RATE_${REQUEST_RATE}_${APPEND_NAME}.json" +fi + +RANDOM_SEED=123 +# Set to true if your endpoint supports streaming and you want to test it +ADD_STREAMING_FLAG=true # or true + +# --- Build the genai-perf command --- +CMD="genai-perf profile" +CMD="${CMD} -m \"${MODEL_NAME}\"" +CMD="${CMD} --tokenizer \"${TOKENIZER_NAME}\"" +#CMD="${CMD} --service-kind \"${SERVICE_KIND}\"" +CMD="${CMD} --endpoint-type \"${ENDPOINT_TYPE}\"" +#CMD="${CMD} --input-file \"${INPUT_FILE}\"" +CMD="${CMD} --output-tokens-mean ${OSL}" +#CMD="${CMD} --output-tokens-stddev ${OUTPUT_TOKENS_STDDEV}" +CMD="${CMD} --request-count ${REQUEST_COUNT}" +CMD="${CMD} --profile-export-file \"${PROFILE_EXPORT_FILE}\"" +CMD="${CMD} --url \"${SERVER_URL}\"" +CMD="${CMD} --random-seed ${RANDOM_SEED}" + +# --- Mode-specific flags --- +if [ "${MODE}" = "concurrency" ]; then + CMD="${CMD} --num-prompts ${CONCURRENCY}" + CMD="${CMD} --concurrency ${CONCURRENCY}" + echo "Added concurrency flags: --num-prompts ${CONCURRENCY} --concurrency ${CONCURRENCY}" +elif [ "${MODE}" = "request_rate" ]; then + CMD="${CMD} --request-rate ${REQUEST_RATE}" + echo "Added request rate flag: --request-rate ${REQUEST_RATE}" +fi + +CMD="${CMD} --image-width-mean 512" +CMD="${CMD} --image-width-stddev 0" +CMD="${CMD} --image-height-mean 512" +CMD="${CMD} --image-height-stddev 0" +CMD="${CMD} --image-format png" +CMD="${CMD} --synthetic-input-tokens-mean ${ISL}" +CMD="${CMD} --synthetic-input-tokens-stddev 0" + +if [ "${ADD_STREAMING_FLAG}" = true ] ; then + CMD="${CMD} --streaming" +fi +CMD="${CMD} --extra-inputs \"max_tokens:${OSL}\"" +CMD="${CMD} --extra-inputs \"min_tokens:${OSL}\"" +CMD="${CMD} --extra-inputs \"ignore_eos:true\"" +CMD="${CMD} -- -v" +CMD="${CMD} --max-threads 1" + +# --- Execute the command --- +echo "Executing command:" +echo "${CMD}" +eval "${CMD}" + +# Example usage: +# ./test_client_disag_mm.sh --concurrency 2 --port 8003 +# ./test_client_disag_mm.sh --request-rate 15 --port 8001 diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 9c59d5bee25..eda1d8ff37e 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -44,6 +44,7 @@ def _add_trt_llm_dll_directory(): from .auto_parallel import AutoParallelConfig, auto_parallel from .builder import BuildConfig, Builder, BuilderConfig, build from .disaggregated_params import DisaggregatedParams +from .multimodal_params import MultimodalParams from .functional import Tensor, constant from .llmapi import LLM, LlmArgs from .logger import logger @@ -101,6 +102,7 @@ def _add_trt_llm_dll_directory(): 'SamplingParams', 'DisaggregatedParams', 'KvCacheConfig', + 'MultimodalParams', '__version__', ] diff --git a/tensorrt_llm/_torch/distributed/__init__.py b/tensorrt_llm/_torch/distributed/__init__.py index 82f5a23b614..a14350038b5 100644 --- a/tensorrt_llm/_torch/distributed/__init__.py +++ b/tensorrt_llm/_torch/distributed/__init__.py @@ -1,6 +1,6 @@ from tensorrt_llm.functional import AllReduceFusionOp -from .communicator import Distributed, MPIDist, PPComm, TorchDist +from .communicator import Distributed, MPIDist, PPComm, TorchDist, MMEmbeddingComm from .ops import (AllReduce, AllReduceParams, AllReduceStrategy, MoEAllReduce, allgather, reducescatter, userbuffers_allreduce_finalize) @@ -17,4 +17,5 @@ "PPComm", "MPIDist", "Distributed", + "MMEmbeddingComm", ] diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 83eb7157495..830b5f64be3 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -1,3 +1,4 @@ +import atexit import os from abc import ABC, abstractmethod from typing import Optional @@ -241,3 +242,33 @@ def pp_recv(tensor): def pp_send(tensor): """Send tensors to next pp rank.""" _pp_comm.send(tensor) + + +class MMEmbeddingComm: + # MMEmbeddingComm communication using torch.distributed with nccl backend + # Currnet trtllm nccl communicator only supports p2p communication + def __init__(self, global_mapping: Mapping): + self.mapping = global_mapping + if not dist.is_initialized(): + master_ip = os.getenv("MASTER_ADDR", "localhost") + master_port = os.getenv("MASTER_PORT", "6000") + init_method = f"tcp://{master_ip}:{master_port}" + dist.init_process_group(backend="nccl", + init_method=init_method, + world_size=global_mapping.world_size, + rank=global_mapping.rank) + atexit.register(self._cleanup) + + # Force NCCL initialization and rank population via PyTorch distributed barrier. + # This is necessary for NOW if using pp + tp because our custom nccl allreduce + # op for tp groups can interfere with PyTorch's NCCL initialization when PyTorch + # distributed performs the first comm. op and kick off nccl init. The barrier here + # ensures proper NCCL setup and GPU-procs binding at beginning. + dist.barrier(device_ids=[torch.cuda.current_device()]) + + def _cleanup(self): + if dist.is_initialized(): + dist.destroy_process_group() + + def broadcast(self, tensor: torch.Tensor, root=0): + dist.broadcast(tensor, src=root) \ No newline at end of file diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index ccede150f07..613a40eec67 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -22,7 +22,32 @@ from .modeling_clip import CLIPVisionModel from .modeling_multimodal_utils import fuse_input_embeds from .modeling_utils import ModelConfig, filter_weights, register_auto_model +from .modeling_utils import ModelConfig, register_auto_model +from ...executor.request import MultimodalParams +class CLIPEncoderInfo: + + def __init__(self, hf_config): + self.vision_config = hf_config.vision_config + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + return self.get_patch_grid_length()**2 + 1 + + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: + image_size, patch_size = self.get_image_size(), self.get_patch_size() + assert image_size % patch_size == 0 + return image_size // patch_size class LlavaNextInputProcessor(InputProcessor): @@ -66,6 +91,96 @@ def __init__(self, model_path, model_config, tokenizer): # Use HF multi-modal projector self.mm_projector = hf_mm_projector + self.hf_model_config = hf_model_config + + def image_size_to_num_patches(self, image_size, grid_pinpoints = None, patch_size: int = None): + """ + Calculate the number of patches after the preprocessing for images of any resolution. + + Args: + image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`): + The size of the input image in the format (height, width). ? + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + + Returns: + int: the number of patches + """ + def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + This is done by calculating the effective and wasted resolution for each possible resolution. + + The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. + + Args: + original_size (tuple): + The original size of the image in the format (height, width). + possible_resolutions (list): + A list of possible resolutions in the format [(height1, width1), (height2, width2), ...]. + + Returns: + tuple: The best fit resolution in the format (height, width). + """ + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit + + if grid_pinpoints is None: + grid_pinpoints = self.hf_model_config.image_grid_pinpoints + if patch_size is None: + patch_size = self.hf_model_config.vision_config.image_size + + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError(f"image_size invalid type {type(image_size)} with value {image_size}") + image_size = image_size.tolist() + + best_resolution = select_best_resolution(image_size, grid_pinpoints) + height, width = best_resolution + num_patches = 0 + # consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1 + for i in range(0, height, patch_size): + for j in range(0, width, patch_size): + num_patches += 1 + # add the base patch + num_patches += 1 + return num_patches + + def image_size_to_num_tokens(self, image_size): + num_patches = self.image_size_to_num_patches(image_size) + vision_encoder_info = CLIPEncoderInfo(self.hf_model_config) + base_feature_size = vision_encoder_info.get_num_image_tokens( + image_width=image_size[1], + image_height=image_size[0], + ) + if self.hf_model_config.vision_feature_select_strategy == "default": + base_feature_size = base_feature_size - 1 + return num_patches * base_feature_size @nvtx_range("[Vision] preprocess") def _preprocess(self, images): @@ -191,6 +306,80 @@ def __call__( } + @nvtx_range("[Vision] postprocess") + def _postprocess_ids_only(self, input_ids, total_mm_tokens): + # Define model specific variables here before shared logic + mm_tokens = torch.tensor([self.model_config.image_token_index + ]).to(input_ids.device) + vocab_size = self.model_config.text_config.vocab_size + start_len = end_len = 0 # for llava, need not append start/end token around each image token + # End model specific variables + + ## find mm token positions in input_ids + mm_token_positions = torch.where(torch.isin(input_ids, mm_tokens))[0] + num_medias = len(mm_token_positions) + mm_tokens_per_media = total_mm_tokens // num_medias + assert mm_tokens_per_media > 0, "Number of multimodal tokens per media must be greater than 0" + + # TODO: 1 prompt + N media (N>=1) only one frame per media (image only) + mm_lengths_per_frame = [mm_tokens_per_media] * num_medias + mm_lengths_per_split = [mm_tokens_per_media] * num_medias + mm_total_length = sum(mm_lengths_per_split) + + + ## split input_ids into segments by isolating mm tokens + mm_split_positions = torch.cat( + [mm_token_positions, mm_token_positions + 1]).unique() + input_ids_splits = list(input_ids.tensor_split(mm_split_positions.cpu( + ))) # len(input_ids_splits) = num_segments after mm tokens are isolated + mm_ids_splits = list( + torch.arange(vocab_size, + vocab_size + mm_total_length, + device=input_ids.device).split(mm_lengths_per_split) + ) # len(mm_ids_splits) = num_mm_segments + + for i, mm_ids in enumerate(mm_ids_splits): + mm_ids = mm_ids.reshape(-1, mm_lengths_per_frame[i]) + mm_ids_splits[i] = mm_ids.flatten() + + ## replace mm token ids with the expanded out-of-vocab ids + mm_split_idx = 0 + for i, split in enumerate(input_ids_splits): + if torch.isin(split, mm_tokens).any().item(): + input_ids_splits[i] = mm_ids_splits[mm_split_idx] + mm_split_idx += 1 + assert mm_split_idx == len( + mm_ids_splits), "All mm_ids_splits should be consumed" + + ## concat text & mm input_ids, wrap mm feature in prompt tuning config + fused_input_ids = torch.cat(input_ids_splits).to( + device=input_ids.device) + assert len(fused_input_ids) == len(input_ids) + mm_total_length - num_medias, "Fused input_ids length should match the sum of text and multimodal embedding lengths" + #fused_length = len(input_ids) + mm_total_length + num_frames * ( + # start_len + end_len) - num_medias + + return fused_input_ids + + @torch.inference_mode() + def postprocess( + self, inputs: TextPrompt, sampling_params: SamplingParams, disagg_mm_params: MultimodalParams, + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + text_prompt, mm_data = inputs.get("prompt"), inputs.get( + "multi_modal_data", {}) + assert 'image' in mm_data + model_hidden_size = self.model_config.text_config.hidden_size + + input_ids = self.tokenizer( + text_prompt, return_tensors="pt").input_ids[0].to(self.device) + assert len(disagg_mm_params.embeddings) == 1, "Only one fused multimodal embedding is supported" + mm_handle = disagg_mm_params.embeddings[0] + total_mm_tokens = mm_handle['tensor_size'][0] + hidden_size = mm_handle['tensor_size'][-1] + assert model_hidden_size == hidden_size, "Multimodal embedding hidden size must match model hidden size" + + fused_input_ids = self._postprocess_ids_only(input_ids, total_mm_tokens) + return fused_input_ids.to(torch.int32).tolist() + @register_auto_model("LlavaNextForConditionalGeneration") @register_input_processor(LlavaNextInputProcessor, model_type="llava_next") class LlavaNextModel(PreTrainedModel): diff --git a/tensorrt_llm/_torch/multimodal/__init__.py b/tensorrt_llm/_torch/multimodal/__init__.py new file mode 100644 index 00000000000..993ed24555b --- /dev/null +++ b/tensorrt_llm/_torch/multimodal/__init__.py @@ -0,0 +1,9 @@ +from .mm_utils import _SharedTensorRebuildMethodRegistry, SharedTensorContainer + +# Initialize the registry when the package is imported +_SharedTensorRebuildMethodRegistry.initialize() + +# Export the classes for easy access +__all__ = [ + 'SharedTensorContainer', +] diff --git a/tensorrt_llm/_torch/multimodal/mm_encoder.py b/tensorrt_llm/_torch/multimodal/mm_encoder.py new file mode 100644 index 00000000000..46c00b0c632 --- /dev/null +++ b/tensorrt_llm/_torch/multimodal/mm_encoder.py @@ -0,0 +1,220 @@ +from pathlib import Path +from typing import Any, Optional, Union, List + +from tensorrt_llm.executor import GenerationExecutor +from tensorrt_llm.llmapi.llm import TorchLlmArgs, TrtLlmArgs +from tensorrt_llm.llmapi.utils import exception_handler, get_device_count, print_colored_debug +from tensorrt_llm.llmapi.llm_utils import LlmBuildStats, CachedModelLoader, _ModelRuntimeContext +from tensorrt_llm.logger import logger +from tensorrt_llm.executor.utils import get_spawn_proxy_process_env, create_mpi_comm_session +from tensorrt_llm.llmapi.mpi_session import external_mpi_comm_available, MpiPoolSession +import tempfile +import atexit +import weakref +from tensorrt_llm._utils import nvtx_range_debug +from tensorrt_llm.executor.multimodal import MultimodalRequest +import asyncio +from tensorrt_llm.bindings import executor as tllm + +class MultimodalEncoder: + def __init__(self, + model: Union[str, Path], + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, # TP should never be used for mm-encoder + data_parallel_size: int = 1, # TODO: Placeholder for future use in multimodal encoder server + dtype: str = "auto", + revision: Optional[str] = None, + **kwargs: Any) -> None: + + self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) + + kwargs_dict = dict(kwargs) + kwargs_dict['backend'] = 'pytorch' + try: + # Reuse the LLM arg parser for mm-encoder for now as some configs/args can be shared + # e.g., max_batch_size, parallel_config, mpi_session, etc. + self.args = TorchLlmArgs.from_kwargs( + model=model, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + revision=revision, + **kwargs_dict) + + except Exception as e: + logger.error( + f"Failed to parse the arguments for the mm encoder constructor: {e}") + raise e + + + print_colored_debug(f"Encoder.args.mpi_session: {self.args.mpi_session}\n", + "yellow") + self.mpi_session = self.args.mpi_session + + if self.args.parallel_config.is_multi_gpu: + if get_device_count( + ) < self.args.parallel_config.world_size_per_node: + raise RuntimeError( + f"Only {get_device_count()} GPUs are available, but {self.args.parallel_config.world_size} are required." + ) + + logger.info( + f'start MpiSession with {self.args.parallel_config.world_size} workers' + ) + if not self.mpi_session: + mpi_process_pre_spawned: bool = get_spawn_proxy_process_env() + if not mpi_process_pre_spawned: + print_colored_debug(f"Encoder create MpiPoolSession\n", + "yellow") + self.mpi_session = MpiPoolSession( + n_workers=self.args.parallel_config.world_size) + else: + print_colored_debug(f"Encoder create MpiCommSession\n", + "yellow") + self.mpi_session = create_mpi_comm_session( + self.args.parallel_config.world_size) + + try: + # Due to the Executor can only accept a engine path, we need to save the engine to a directory + self._engine_dir: Optional[Path] = None + self._executor: Optional[GenerationExecutor] = None + if self._on_trt_backend: + self._workspace = tempfile.TemporaryDirectory( + suffix="-mm-encoder-workspace", dir=self.args.workspace) + else: + self._workspace = None + + self._hf_model_dir: Optional[Path] = None + + self.runtime_context: Optional[_ModelRuntimeContext] = None + self.llm_build_stats = LlmBuildStats() + + self._build_model() + + except Exception as e: + if self.mpi_session is not None: + self.mpi_session.shutdown() + raise e + + exception_handler.register(self, 'shutdown') + atexit.register(MultimodalEncoder._shutdown_wrapper, weakref.ref(self)) + + @property + def workspace(self) -> Path: + return Path(self._workspace.name) if self._on_trt_backend else None + + def generate_from_mm_request( + self, + mm_requests: List[MultimodalRequest], + ): + """Generate embeddings for multiple multimodal requests in parallel. + + Args: + mm_requests: List of multimodal requests to process + + Returns: + List of generation results + """ + async def _process_requests(): + # Submit all requests first + futures = [] + for request in mm_requests: + future = await self.generate_async(request) + futures.append(future) + + # Then wait for all results + results = [] + for future in futures: + result = await future.aresult() + results.append(result) + return results + + # Run the async operations in an event loop + return asyncio.run(_process_requests()) + + @nvtx_range_debug("Encoder.generate_async", color="green", category="Encoder") + async def generate_async( + self, + mm_request: MultimodalRequest, + ): + """Generate embeddings for a multimodal request asynchronously. + + Args: + mm_request: The multimodal request containing items to process + + Returns: + A promise that will be resolved with the generation results + """ + # First fetch and load all the data + await mm_request.fetch() + # Then generate the embeddings asynchronously + result = self._executor.generate_multimodal_async( + mm_request, + ) + return result + + def _build_model(self): + model_loader = CachedModelLoader(self.args, + mpi_session=self.mpi_session, + workspace=self.workspace, + llm_build_stats=weakref.proxy( + self.llm_build_stats)) + self._engine_dir, self._hf_model_dir = model_loader() + # update the model_dir to a local dir for the runtime, such as tokenizer loading. + if self._engine_dir is not None: + self.args.model = self._engine_dir + + max_batch_size = self.args.max_batch_size or self.args.build_config.max_batch_size + # In _build_model method: + executor_config = tllm.ExecutorConfig(1) + executor_config.backend = "pytorch" + executor_config.mm_encoder_only = True + executor_config.mapping = self.args.parallel_config.to_mapping() + executor_config.build_config = self.args.build_config + executor_config.hf_model_dir = self._hf_model_dir + executor_config.trt_engine_dir = self._engine_dir + executor_config.max_batch_size = max_batch_size + executor_config.max_num_active_requests = 2048 + + self._executor = self._executor_cls.create( + self._engine_dir, + executor_config=executor_config, + model_world_size=self.args.parallel_config.world_size, + mpi_session=self.mpi_session, + reuse_mpi_comm=external_mpi_comm_available( + self.args.parallel_config.world_size), + is_llm_executor=False) + + @property + def _on_trt_backend(self) -> bool: + return isinstance(self.args, TrtLlmArgs) + + def shutdown(self) -> None: + if hasattr(self, "_executor") and self._executor is not None: + self._executor.shutdown() + self._executor = None + + if hasattr(self, 'mpi_session') and self.mpi_session is not None: + self.mpi_session.shutdown() + self.mpi_session = None + + @staticmethod + def _shutdown_wrapper(self_ref): + # Retrieve the instance if it still exists + instance = self_ref() + if instance is not None: + instance.shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + del exc_value, traceback + self.shutdown() + return False # propagate exceptions + + def __getstate__(self): + raise RuntimeError("Encoder object can not be pickled.") + + def __del__(self): + self.shutdown() \ No newline at end of file diff --git a/tensorrt_llm/_torch/multimodal/mm_utils.py b/tensorrt_llm/_torch/multimodal/mm_utils.py new file mode 100644 index 00000000000..f467124a3b8 --- /dev/null +++ b/tensorrt_llm/_torch/multimodal/mm_utils.py @@ -0,0 +1,224 @@ +import logging +import base64 +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch.multiprocessing.reductions import rebuild_cuda_tensor, reduce_tensor, rebuild_tensor, rebuild_meta_tensor + +logger = logging.getLogger(__name__) + + +class _SharedTensorRebuildMethodRegistry: + """Registry for tensor rebuild methods with fixed keys for common methods. + + This class maintains a mapping of numeric keys to rebuild methods. + Common methods are pre-registered with fixed keys for consistency. + """ + # Fixed keys for common rebuild methods + REBUILD_CUDA = 1 + REBUILD_CPU = 2 + REBUILD_META = 3 + + _registry: Dict[int, Callable] = {} + + @classmethod + def initialize(cls): + """Initialize the registry with common rebuild methods.""" + # Register common methods with fixed keys + cls._registry[cls.REBUILD_CUDA] = rebuild_cuda_tensor + cls._registry[cls.REBUILD_CPU] = rebuild_tensor + cls._registry[cls.REBUILD_META] = rebuild_meta_tensor + + @classmethod + def register(cls, method: Callable) -> int: + """Register a rebuild method and return its key. + + Args: + method: The rebuild method to register + + Returns: + The numeric key assigned to the method + """ + if method == rebuild_cuda_tensor: + return cls.REBUILD_CUDA + if method == rebuild_tensor: + return cls.REBUILD_CPU + if method == rebuild_meta_tensor: + return cls.REBUILD_META + raise NotImplementedError("Other rebuild methods are not supported yet") + + @classmethod + def get_method(cls, key: int) -> Callable: + """Get a rebuild method by its key. + + Args: + key: The numeric key of the method + + Returns: + The registered rebuild method + + Raises: + KeyError: If the key is not found in the registry + """ + if key not in cls._registry: + raise KeyError(f"No rebuild method registered with key {key}") + return cls._registry[key] + + +class SharedTensorContainer: + """A class for sharing tensors between processes. + + This class provides a simple way to share tensors between processes + using pytorch reduce/rebuild methods. + + Note: Whenever you call reduce_tensor, you must call the corresponding rebuild method at consumer process(es), otherwise, the producer process cannot release the memory + to caching allocator as the innner refcount never reach zero. + """ + def __init__(self, method_key: int, tensor_handle: Dict[str, Any]): + self.method_key = method_key + self.tensor_handle = tensor_handle + + @staticmethod + def handle_to_dict(tensor_handle) -> Dict[str, Any]: + """Convert the shared tensor handle to a dictionary that can be serialized. + + This method converts the tensor handle information into a format that can be + safely serialized (e.g., to JSON). It handles binary data by encoding it in base64. + + Returns: + Dictionary containing the serialized tensor information with the following keys: + - method_key: The registry key for the rebuild method + - tensor_size: List of tensor dimensions + - tensor_stride: List of tensor strides + - tensor_offset: Offset in the storage + - dtype: String representation of the tensor's data type + - storage_device: Device where the tensor is stored + - storage_handle: Base64 encoded storage handle + - storage_size_bytes: Size of the storage in bytes + - storage_offset_bytes: Offset in the storage in bytes + - requires_grad: Whether the tensor requires gradients + - ref_counter_handle: Base64 encoded reference counter handle + - ref_counter_offset: Offset in the reference counter + - event_handle: Base64 encoded CUDA event handle + - event_sync_required: Whether CUDA event synchronization is required + + Raises: + KeyError: If required tensor information is missing + ValueError: If tensor information cannot be serialized + """ + try: + # tensor_handle is a tuple returned by reduce_tensor + tensor_info = tensor_handle + # Convert tensor info to a basic dict with only serializable values + serializable_info = { + # tensor_info[0] is the type of the tensor, which is "torch.Tensor" + "tensor_size": list(tensor_info[1]), + "tensor_stride": list(tensor_info[2]), + "tensor_offset": tensor_info[3], + "dtype": str(tensor_info[5]), + "storage_device": tensor_info[6], + "storage_handle": base64.b64encode(tensor_info[7]).decode('utf-8'), + "storage_size_bytes": tensor_info[8], + "storage_offset_bytes": tensor_info[9], + "requires_grad": tensor_info[10], + "ref_counter_handle": base64.b64encode(tensor_info[11]).decode('utf-8'), + "ref_counter_offset": tensor_info[12], + "event_handle": base64.b64encode(tensor_info[13]).decode('utf-8'), + "event_sync_required": tensor_info[14] + } + return serializable_info + except IndexError as e: + raise KeyError(f"Missing required tensor information: {e}") + except Exception as e: + raise ValueError(f"Failed to serialize tensor information: {e}") + + @staticmethod + def dict_to_handle(tensor_info: Dict[str, Any]) -> Tuple: + """Create a tensor handle from a serialized dictionary. + + This method reconstructs a tensor handle from a previously serialized + dictionary. It handles base64 encoded binary data by decoding it back to bytes. + + Args: + tensor_info: Dictionary containing the serialized tensor information + with the same keys as returned by to_dict() + + Returns: + A new SharedTensorContainer instance + + Raises: + KeyError: If required tensor information is missing + ValueError: If tensor information cannot be deserialized + """ + try: + # Decode base64 encoded binary data + storage_handle = base64.b64decode(tensor_info['storage_handle']) + ref_counter_handle = base64.b64decode(tensor_info['ref_counter_handle']) + event_handle = base64.b64decode(tensor_info['event_handle']) + + # Reconstruct the tensor handle + tensor_handle = (torch.Tensor, + tuple(tensor_info['tensor_size']), + tuple(tensor_info['tensor_stride']), + tensor_info['tensor_offset'], + torch.storage.TypedStorage, + eval(tensor_info['dtype']), + tensor_info['storage_device'], + storage_handle, + tensor_info['storage_size_bytes'], + tensor_info['storage_offset_bytes'], + tensor_info['requires_grad'], + ref_counter_handle, + tensor_info['ref_counter_offset'], + event_handle, + tensor_info['event_sync_required']) + + return tensor_handle + except KeyError as e: + raise KeyError(f"Missing required tensor information: {e}") + except Exception as e: + raise ValueError(f"Failed to deserialize tensor information: {e}") + + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> 'SharedTensorContainer': + """Create a SharedTensorContainer from a local tensor. + + Args: + tensor: The tensor to share + + Returns: + SharedTensorContainer instance that can be shared between processes + """ + rebuild_method, tensor_handle = reduce_tensor(tensor) + method_key = _SharedTensorRebuildMethodRegistry.register(rebuild_method) + # hack to make it serializable + tensor_handle = SharedTensorContainer.handle_to_dict(tensor_handle) + return cls(method_key, tensor_handle) + + @classmethod + def from_dict(cls, tensor_info: Dict[str, Any]) -> 'SharedTensorContainer': + """Create a SharedTensorContainer from a serialized dictionary. + """ + method_key = tensor_info['method_key'] + tensor_handle = SharedTensorContainer.dict_to_handle(tensor_info) + return cls(method_key, tensor_handle) + + def get_local_view(self) -> torch.Tensor: + """Convert the shared tensor back to a local tensor. + + Returns: + The reconstructed tensor + """ + rebuild_method = _SharedTensorRebuildMethodRegistry.get_method(self.method_key) + return rebuild_method(*self.tensor_handle) + + def dump_to_dict(self) -> Dict[str, Any]: + """Convert this class instance to a dictionary that can be JSON serialized. + + Returns: + Dictionary containing the serialized tensor information + """ + result = self.tensor_handle.copy() + result["method_key"] = self.method_key + return result diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 3af5849e82a..1572a39b3ab 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -489,6 +489,9 @@ def create_py_executor_instance( cache_transceiver_config = executor_config.cache_transceiver_config kv_cache_transceiver = create_kv_cache_transceiver( mapping, kv_cache_manager, attention_type, cache_transceiver_config) + # Unfortunately, we cannot init this comm lazily only for mm_disagg request + # TODO: Add ncclBcast support in trtllm nccl communicator + model_engine._setup_mm_emb_comm() # Keep diagg mm_emb communicator close to where kvcache_transceiver created return PyExecutor(resource_manager, scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 63ac568f4dd..aec59ef8f49 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -222,6 +222,8 @@ def __init__( **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) + self.py_disagg_mm_params = kwargs.pop("disagg_mm_params", None) + super().__init__( *args, client_id=client_id, @@ -323,6 +325,7 @@ def executor_request_to_llm_request( stop_words_list = convert_wordlist( executor_request.stop_words) if executor_request.stop_words else None + disagg_mm_params = getattr(executor_request, 'disagg_mm_params', None) llm_request = LlmRequest( request_id=req_id, max_new_tokens=executor_request.max_tokens, @@ -375,6 +378,7 @@ def executor_request_to_llm_request( if executor_request.client_id is not None else req_id, priority=0.5, llm_request_type=llm_request_type, - context_phase_params=executor_request.context_phase_params) + context_phase_params=executor_request.context_phase_params, + disagg_mm_params=disagg_mm_params) return llm_request diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index c6cbc47bfbc..b568e528fe7 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -39,7 +39,7 @@ from ..autotuner import AutoTuner, autotune from ..compilation.backend import Backend from ..compilation.utils import set_enable_piecewise_cuda_graph_capture_flag -from ..distributed import MPIDist +from ..distributed import MPIDist, MMEmbeddingComm from ..distributed.communicator import init_pp_comm from ..expert_statistic import ExpertStatistic from ..metadata import KVCacheParams @@ -58,6 +58,8 @@ from .resource_manager import (BaseResourceManager, KVCacheManager, ResourceManager) from .scheduler import ScheduledRequests +from tensorrt_llm._torch.multimodal import SharedTensorContainer +from tensorrt_llm._torch.pyexecutor.multimodal.shared_tensor_handle_pool import get_handle_buffer MAX_UINT64 = (1 << 64) - 1 @@ -314,6 +316,7 @@ def __init__( init_pp_comm(mapping) self.dist = dist ExpertStatistic.create(self.dist.rank) + self.mm_emb_dist = None self.pytorch_backend_config = pytorch_backend_config self.spec_config = spec_config self.is_spec_decode = spec_config is not None @@ -799,6 +802,11 @@ def _set_up_spec_metadata( spec_resource_manager=spec_resource_manager) return self.spec_metadata + def _setup_mm_emb_comm(self): + if self.mm_emb_dist is None: + self.mm_emb_dist = MMEmbeddingComm(self.mapping) + return self.mm_emb_dist + def _get_padded_batch(self, scheduled_requests: ScheduledRequests, kv_cache_manager) -> int: can_run_cuda_graph = scheduled_requests.can_run_cuda_graph @@ -1131,6 +1139,26 @@ def _prepare_tp_inputs( multi_modal_data.append(multimodal_embedding) mrope_rotary_cos_sin = request.mrope_rotary_cos_sin + if request.py_disagg_mm_params is not None and hasattr(request.py_disagg_mm_params, 'embeddings'): + assert multimodal_embedding is None, "multimodal_embedding and disagg_mm_params are not supported at the same time" + assert self.mm_emb_dist is not None, "mm_emb_dist is not initialized" + mm_tensor_handle = request.py_disagg_mm_params.embeddings[0] + tensor_shape = mm_tensor_handle['tensor_size'] + tensor_dtype = mm_tensor_handle['dtype'] + multimodal_embedding = torch.empty(tensor_shape, dtype=eval(tensor_dtype), device='cuda') + if self.mapping.rank == 0: + # Leading rank will rebuild the tensor in local device + shared_tensor = SharedTensorContainer.from_dict(mm_tensor_handle).get_local_view() + # TODO: this is a temp solution to avoid immediate free the handle + tensor_pool = get_handle_buffer() + tensor_pool.add_handle(str(request.py_request_id), shared_tensor) + multimodal_embedding.copy_(shared_tensor) + self.mm_emb_dist.broadcast(multimodal_embedding) + else: + # Other ranks will broadcast the tensor from leading rank + self.mm_emb_dist.broadcast(multimodal_embedding) + multi_modal_data.append(multimodal_embedding) + if mrope_rotary_cos_sin is not None: mrope_config['mrope_rotary_cos_sin'].append( mrope_rotary_cos_sin) diff --git a/tensorrt_llm/_torch/pyexecutor/multimodal/multimodal_executor.py b/tensorrt_llm/_torch/pyexecutor/multimodal/multimodal_executor.py new file mode 100755 index 00000000000..2fc8855e74b --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/multimodal/multimodal_executor.py @@ -0,0 +1,461 @@ +import datetime +import logging +from typing import Dict, List, Optional +import traceback +import torch +import asyncio +from tensorrt_llm._utils import nvtx_range +from ...distributed import Distributed, MPIDist + +from ..py_executor import PyExecutor, _get_from_request_queue +from ..model_engine import ModelEngine, PyTorchModelEngine + +from tensorrt_llm.inputs import create_input_processor +import queue +import threading +from tensorrt_llm.executor.multimodal import MultimodalRequest, MultimodalResponse +from ..llm_request import ExecutorResponse +from tensorrt_llm._torch.multimodal import SharedTensorContainer +from torchvision.transforms import ToTensor +from ..py_executor import RequestQueueItem +logger = logging.getLogger(__name__) + +class MMExecutor(PyExecutor): + + def __init__(self, + resource_manager, + scheduler, + model_engine: ModelEngine, + dist: Distributed, + enable_overlap_scheduler: bool = False, + max_num_active_requests: int = 10, + max_batch_size: int = 8, + start_worker: bool = True): + + self.device_id = torch.cuda.current_device() + self.global_rank = dist.rank + self.request_queue = queue.Queue() + self.next_req_id = 0 + + # related modules + self.resource_manager = resource_manager + self.scheduler = scheduler + self.model_engine = model_engine + self.dist = dist + self.enable_overlap_scheduler = enable_overlap_scheduler + + # enqueue and _fetch_new_requests used data + self.enqueue_lock = threading.Lock() + self.active = True + self.shutdown_event = threading.Event() + + # response used data + self.response_lock = threading.Lock() + self.response_cv = threading.Condition(self.response_lock) + self.responses = {} + self.canceled_req_ids = set() + # _executor_loop private data + self.max_num_active_requests = max_num_active_requests # TODO: remove this should be the same as max_batch_size + self.active_requests = [] + self.is_shutdown = False + self.event_loop = self._executor_loop + self.max_batch_size = max_batch_size + print(f"max_batch_size: {self.max_batch_size}") + print(f"max_num_active_requests: {self.max_num_active_requests}") + + # Start worker if needed + self.worker_started = False + self.worker_lock = threading.Lock() + if start_worker: + self.start_worker() + + def start_worker(self): + self.worker_lock.acquire() + try: + if self.worker_started == False: + self.worker_thread = threading.Thread(target=self.event_loop, + daemon=True) + self.worker_thread.start() + self.worker_started = True + finally: + self.worker_lock.release() + + @nvtx_range("_fetch_new_requests") + def _fetch_new_requests(self): + total_num_active_requests = len(self.active_requests) + total_max_num_active_requests = self.max_num_active_requests + + timeout = None if total_num_active_requests == 0 else datetime.timedelta( + 0) + new_requests = [] + if self.dist.rank == 0: + new_requests = _get_from_request_queue( + self.request_queue, timeout, + total_max_num_active_requests - total_num_active_requests) + + if self.dist.world_size > 1: + new_requests = self.dist.broadcast(new_requests, root=0) + return new_requests + + def _merge_tp_requests(self, new_requests: List[RequestQueueItem]): + for request in new_requests: + if request is None: + return True + for req_item in new_requests: + self.active_requests.append(req_item.request) # type: ignore + return False + + def _executor_loop(self): + """ + Simplified version of the executor loop that handles multimodal requests. + Focuses only on basic functionality without complex features. + """ + torch.cuda.set_device(self.device_id) + got_finish_signal = False + iter_count = 0 + while not got_finish_signal or len(self.active_requests) > 0: + # Get new requests + new_requests = self._fetch_new_requests() + # TODO: support DP across all requests in the batch + got_finish_signal = (new_requests is not None and self._merge_tp_requests(new_requests)) or got_finish_signal + + # Exit if no more work to do + if got_finish_signal and len(self.active_requests) == 0: + break + + # Schedule requests + scheduled_batch, batch_size = self._schedule() + + assert batch_size > 0, ( + "fail to schedule any pending request, " + "probably run out of resource.") + print(f"in executor loop, iter_count: {iter_count}, batch_size: {batch_size}, active_requests: {len(self.active_requests)}") + + self.num_scheduled_requests = batch_size + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch)} requests' + ) + + finished_requests = [] + # Process batch + if batch_size > 0: + # TODO: add resource manager for multimodal executor + # self.resource_manager.prepare_resources(scheduled_batch) # only sequency manager? + + batch_outputs = self._forward(scheduled_batch) + + # TODO: Handle canceled requests for multimodal executor + self._handle_cancelled_requests() + finished_requests = self._handle_responses(scheduled_batch, batch_outputs) + + # Free resources + #self.resource_manager.update_resources(scheduled_batch) + + # TODO: add iter perf stats for multimodal executor + iter_count += 1 + # if self.enable_iter_perf_stats: + + # Cleanup when loop is done + self._executor_loop_cleanup() + + def _schedule(self): + # This is a simple static scheduler that only considers active requests, and max batch size + num_to_schedule = min(len(self.active_requests), self.max_batch_size) + if num_to_schedule == 0: + return [] + + scheduled_requests = self.active_requests[:num_to_schedule] + batch_size = len(scheduled_requests) + return scheduled_requests, batch_size + + @nvtx_range("_forward") + def _forward(self, + scheduled_requests): + @nvtx_range( + f"[Executor] _forward_step: {len(scheduled_requests)} mm reqs" + ) + def forward(scheduled_requests, resource_manager): + return self.model_engine.forward(scheduled_requests, + resource_manager) + + try: + outputs = forward(scheduled_requests, self.resource_manager) + return outputs + except Exception as e: + traceback.print_exc() + error_msg = str(e) + logger.error( + f"Encountered an error in forward function: {error_msg}") + self._handle_errors(error_msg) + return None + + @nvtx_range("_enqueue_responses") + def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]): + logger.debug( + f'before enqueue, rank = {self.dist.rank}, responses = {responses}') + if self.dist.rank == 0: + with self.response_cv: + for req_id, resp in responses.items(): + if isinstance(resp, MultimodalResponse): + if resp.cp_event is not None: + # if we need to cpy embedding to host, we can sync here + resp.cp_event.synchronize() + # We only store/enqueue the handle here + resp.embedding_handle = [SharedTensorContainer.from_tensor(resp.embeddings)] + resp.embeddings = None + resp.cp_event = None + + if req_id in self.responses.keys(): + self.responses[req_id].append(resp) + else: + self.responses.update({req_id: [resp]}) + self.response_cv.notify_all() + + @nvtx_range("_handle_responses") + def _handle_responses(self, scheduled_requests, batch_outputs): + """Handle responses from postprocess_batch_outputs using existing infrastructure.""" + new_responses = {} + new_active_requests = [] + if batch_outputs is None: + for request in scheduled_requests: + if request.has_error() or request not in self.active_requests: + continue + response = request.create_response() + if response: + response.set_final() + new_responses[request.id] = response + + self._enqueue_responses(new_responses) + self.active_requests = [req for req in self.active_requests if req.id not in new_responses] + return scheduled_requests + + mm_embeddings = batch_outputs['mm_embeddings'] + mrope_config = batch_outputs['mrope_config'] + batch_request_offsets = batch_outputs['batch_request_offsets'] + + # Process each request's portion of the fused embeddings + for i, request in enumerate(scheduled_requests): + assert isinstance(request, MultimodalRequest), "request should be a MultimodalRequest" + if request.has_error() or request not in self.active_requests: + continue + start_idx = batch_request_offsets[i] + end_idx = batch_request_offsets[i + 1] + + # Create response for this request + response = request.create_response() + if response: + # Extract this request's portion of embeddings + request_embedding = mm_embeddings[start_idx:end_idx] + + # Attach the fused embedding directly to the response + response.set_embeddings(request_embedding, cp_event=None) + + # Attach mrope config if available + if mrope_config is not None: + response.set_mrope_config(mrope_config) + + response.set_final() + new_responses.update({request.id: response}) + + self._enqueue_responses(new_responses) + self.active_requests = [req for req in self.active_requests if req.id not in new_responses] + return scheduled_requests # finished requests + + def _handle_cancelled_requests(self): + #TODO: properly handle canceled ids in pp case + if self.dist.has_tp: + self.canceled_req_ids = self.dist.broadcast(self.canceled_req_ids, + root=0) + + if len(self.canceled_req_ids) == 0: + return + + cancelled_responses = {} + left_requests = [] + for request in self.active_requests: + req_id = request.id + if req_id in self.canceled_req_ids: + # TODO: As for now, all resources are on-the-fly, so we don't need to free resources here + # but in future, when we add embedding tensor pool, we need to evict and free resources here + # self._terminate_request(request) + cancelled_responses[req_id] = request.create_response() + self.canceled_req_ids.remove(req_id) + else: + left_requests.append(request) + self.active_requests = left_requests + + # When enable attention dp, each rank does not have full copy of requests + # so we need to remove the cancel requests not in the local rank + self.canceled_req_ids.clear() + + # enqueue the cancelled requests' responses as they are not + # active_requests and be discarded in the sampler loop. + self._enqueue_responses(cancelled_responses) + + def shutdown(self): + """ + Signals the server to shutdown. + """ + try: + self.enqueue_lock.acquire() + self.request_queue.put(None) + self.active = False + finally: + self.enqueue_lock.release() + self.shutdown_event.wait() + self.worker_thread.join() + self.worker_started = False + del self.model_engine + + def enqueue_request(self, + request: MultimodalRequest, + query: Optional[List] = None): + try: + self.enqueue_lock.acquire() + assert self.active, "PyExecutor has already been shutdown." + req_id = self.next_req_id + self.request_queue.put(RequestQueueItem(req_id, request)) + self.next_req_id += 1 + finally: + self.enqueue_lock.release() + return req_id + + def enqueue_requests(self, requests: List[MultimodalRequest]): + """ + Enqueue new requests + """ + req_ids = [] + try: + self.enqueue_lock.acquire() + assert self.active, "MMPyExecutor has already been shutdown." + for request in requests: + self.request_queue.put( + RequestQueueItem(self.next_req_id, request)) + req_ids.append(self.next_req_id) + self.next_req_id += 1 + finally: + self.enqueue_lock.release() + return req_ids + + + def _handle_errors(self, error_msg: Optional[str] = None): + error_responses = {} + error_msg = error_msg or "error" + for request in self.active_requests: + req_id = request.id + # use the same error response as the llm executor + error_responses[req_id] = ExecutorResponse( + req_id, error_msg, client_id=request.id) + self.active_requests.clear() + self._enqueue_responses(error_responses) + + def cancel_request(self, id: int): + """ + Cancel the request with provided request id + Args: + id (int): The request id for which to cancel the response + """ + self.canceled_req_ids.add(id) + + def get_latest_kv_cache_events(self): + return [] + + def get_latest_iteration_stats(self): + return [] + + +class MultimodalModelEngine(PyTorchModelEngine): + def __init__( + self, + model_path: str, + pytorch_backend_config = None, + max_batch_size: Optional[int] = 8, + dist: Optional[MPIDist] = None, + ): + self.pytorch_backend_config = pytorch_backend_config + self.dist = dist + self.max_batch_size = max_batch_size + self.model = create_input_processor(model_path, None) + + def _prepare_inputs(self, scheduled_requests): + """Prepare inputs for batch processing. + + Args: + scheduled_requests: List[MultimodalRequest] + + Returns: + Tuple[Dict[str, List[MultimodalItem]], List[int]]: + - Dict mapping modality to ordered list of items + - List of offsets for each request (based on token lengths) + """ + all_mm_items = [] + for request in scheduled_requests: + all_mm_items.extend(request.items) + processed_items = {} # Dict to track processed items by (req_id, item_id) + + # Process items asynchronously + for ready_item in all_mm_items: + # Calculate token length and preprocess + # TODO: Need to converge for all models on this, currently we don't have an uniform way to get the token length + # https://github.com/vllm-project/vllm/blob/54631f826233dbd1c046f9a70e98bc2e25edff1a/vllm/model_executor/models/llava.py#L151 + # ready_item.length = self.model.get_num_image_tokens(image_width=ready_item.data.width, image_height=ready_item.data.height) + # TODO: VLLM output lenght is not correct, need to fix it + image_size = (ready_item.data.height, ready_item.data.width) + # TODO: Add other modalities. We should know length for each item here + ready_item.length = self.model.image_size_to_num_tokens(image_size) + ready_item.data = ToTensor()(ready_item.data) + ready_item.data = self.model._preprocess([ready_item.data])[0] # _preprocess involves H2D transfer + processed_items[(ready_item.req_id, ready_item.id)] = ready_item + + # 3. Reconstruct batch_mm_items in correct order and calculate offsets + batch_mm_items = {} + batch_request_offsets = [] + current_offset = 0 + for request in scheduled_requests: + batch_request_offsets.append(current_offset) + request_offset = 0 + for item in request.items: + # Get the processed item in original order + processed_item = processed_items[(item.req_id, item.id)] + + # Set the start position within this request + processed_item.offset = request_offset + # Update request_offset for next item + request_offset += processed_item.length + + # Add to batch_mm_items maintaining request/item order + if processed_item.modality_type not in batch_mm_items: + batch_mm_items[processed_item.modality_type] = [] + batch_mm_items[processed_item.modality_type].append(processed_item) + + current_offset += request_offset + batch_request_offsets.append(current_offset) + return batch_mm_items, batch_request_offsets + + @torch.inference_mode() + def _model_forward(self, batch_mm_items, batch_request_offsets): + batch_mm_data = { + modality: [item.data for item in items] + for modality, items in batch_mm_items.items() + } + # Shape after torch.cat (total_patches, 3, pix_height, pix_width) - image + batch_image_input = torch.cat(batch_mm_data['image']) + batch_image_features = self.model._process(batch_image_input) if len(batch_mm_data['image']) > 0 else None + assert batch_image_features.shape[0] == sum([item.length for item in batch_mm_items['image']]), "batch_mm_features should have the same length as sum of item.length" + assert batch_image_features.dim() == 2, "batch_mm_features should be a 2D tensor" + + # TODO: add mrope config which seems need input ids of llm request, deferring for now + mrope_config = None + if batch_image_features is None: + return None + + return { + "mm_embeddings": batch_image_features, # maybe extend to dict if other modality added + "mrope_config": mrope_config, + "batch_request_offsets": batch_request_offsets + } + + def forward(self, scheduled_requests, resource_manager = None): + batch_mm_items, batch_request_offsets = self._prepare_inputs(scheduled_requests) + return self._model_forward(batch_mm_items, batch_request_offsets) \ No newline at end of file diff --git a/tensorrt_llm/_torch/pyexecutor/multimodal/multimodal_pyexecutor_creator.py b/tensorrt_llm/_torch/pyexecutor/multimodal/multimodal_pyexecutor_creator.py new file mode 100755 index 00000000000..00f5a148edc --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/multimodal/multimodal_pyexecutor_creator.py @@ -0,0 +1,42 @@ +import copy +import tensorrt_llm +from tensorrt_llm.mapping import Mapping + +from ...distributed import MPIDist + +from .multimodal_executor import MMExecutor, MultimodalModelEngine +from tensorrt_llm.bindings.executor import ExecutorConfig + +def create_multimodal_pyexecutor(executor_config: ExecutorConfig, + checkpoint_dir: str = None): + + if executor_config.mapping is None: + mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), + tp_size=tensorrt_llm.mpi_world_size(), + gpus_per_node=tensorrt_llm.default_gpus_per_node(), + rank=tensorrt_llm.mpi_rank()) + else: + mapping = copy.deepcopy(executor_config.mapping) + mapping.rank = tensorrt_llm.mpi_rank() + + dist = MPIDist(mapping=mapping) + + model_engine = MultimodalModelEngine( + checkpoint_dir, + #mapping=mapping, + max_batch_size=executor_config.max_batch_size, + dist=dist, + ) + resources_manager = {} + scheduler = None + py_executor = MMExecutor(resources_manager, + scheduler, + model_engine=model_engine, + dist=dist, + enable_overlap_scheduler=False, + max_num_active_requests=executor_config.max_num_active_requests, + max_batch_size=executor_config.max_batch_size, + start_worker=False) + + py_executor.start_worker() + return py_executor diff --git a/tensorrt_llm/_torch/pyexecutor/multimodal/shared_tensor_handle_pool.py b/tensorrt_llm/_torch/pyexecutor/multimodal/shared_tensor_handle_pool.py new file mode 100644 index 00000000000..7a987f76829 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/multimodal/shared_tensor_handle_pool.py @@ -0,0 +1,55 @@ +import torch +from collections import OrderedDict +from typing import Any + +class SharedTensorHandleBuffer: + """ This is a container to temporary hold the shared tensor handles recv by consumer process. + Everytime when the consumer is done with accessing the shared tensor from producer, to avoid immediate + calling of release/close the cudaIPC handle (it could introduce severe overheads), we buffered it. For many + scenarios, it can be helpful. + + TODO: In fact, how is the impact of cudaIPC overhead needs to be studied. Ideally, we should manage a shared tensor pool + in the producer; therefore we can avoid such overhead of open/close cudaIPC handles in consumer processes. + + Hopefully, NIXL integration can help address this issue. + """ + + def __init__(self, max_handles: int = 10): + self.active_handles: OrderedDict[Any, torch.Tensor] = OrderedDict() + self.max_handles = max_handles + + def _remove_handle(self, key: str) -> None: + """Internal method to remove a handle without acquiring lock. + + Args: + key: Identifier of the handle to remove + """ + if key in self.active_handles: + del self.active_handles[key] + + + def add_handle(self, key: str, tensor_info: Any) -> None: + """Add a new tensor handle to the pool. + + Args: + key: Unique identifier for the handle + tensor_info: Information about the tensor to be stored + """ + if len(self.active_handles) >= self.max_handles: + oldest_key = next(iter(self.active_handles)) + self._remove_handle(oldest_key) + self.active_handles[key] = tensor_info + +_tensor_pool = None + +def get_handle_buffer(): + """Get or create the global tensor pool instance. + + This function ensures the tensor pool is created only when needed and after + multiprocessing is properly set up. + """ + global _tensor_pool + if _tensor_pool is None: + _tensor_pool = SharedTensorHandleBuffer() + + return _tensor_pool \ No newline at end of file diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 415e92445b6..1f0f6865c5b 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1198,8 +1198,11 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: total_max_num_active_requests - total_num_active_requests) if self.dist.rank == 0: - py_request_objects = self._collect_py_objects_from_requests( + py_logits_post_processors = self._collect_py_objects_from_requests( new_requests, "py_logits_post_processors") + py_disagg_mm_params = self._collect_py_objects_from_requests( + new_requests, "disagg_mm_params") + py_request_objects = tuple(filter(None, [py_logits_post_processors, py_disagg_mm_params])) else: py_request_objects = None @@ -1223,10 +1226,10 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp) and self.dist.rank > 0: - attr_name, req_obj_dict = py_request_objects - self._attach_py_objects_to_requests(new_requests, attr_name, - req_obj_dict) - + for attr_name, req_obj_dict in py_request_objects: + self._attach_py_objects_to_requests(new_requests, attr_name, + req_obj_dict) + if not self.enable_attention_dp: self._update_new_active_requests_queue_latency(new_requests) new_requests = self._merge_requests(new_requests) diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 537c9e3e683..257f3c91a75 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -3,7 +3,7 @@ import signal # Added import import subprocess # nosec B404 import sys -from typing import Any, List, Optional +from typing import Any, List, Optional, Union, Tuple import click import torch @@ -17,15 +17,18 @@ from tensorrt_llm.llmapi import (LLM, BuildConfig, CapacitySchedulerPolicy, DynamicBatchConfig, KvCacheConfig, SchedulerConfig) -from tensorrt_llm.llmapi.disagg_utils import (CtxGenServerConfig, +from tensorrt_llm.llmapi.disagg_utils import (CtxGenServerConfig, MultimodalServerConfig, MetadataServerConfig, ServerRole, parse_disagg_config_file, - parse_metadata_server_config_file) + parse_metadata_server_config_file, + parse_mm_disagg_config_file) from tensorrt_llm.llmapi.llm_utils import update_llm_args_with_extra_dict from tensorrt_llm.llmapi.mpi_session import find_free_port from tensorrt_llm.llmapi.reasoning_parser import ReasoningParserFactory from tensorrt_llm.logger import logger, severity_map -from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer +from tensorrt_llm.serve import OpenAIDisaggServer, OpenAIServer, OpenAIMultiModalDisaggServer +from tensorrt_llm._torch.multimodal.mm_encoder import MultimodalEncoder +from tensorrt_llm.serve.encoder_server import OpenAIEncoderServer # Global variable to store the Popen object of the child process _child_p_global: Optional[subprocess.Popen] = None @@ -150,6 +153,17 @@ def launch_server(host: str, asyncio.run(server(host, port)) +def launch_encoder_server(host: str, port: int, encoder_args: dict): + backend = encoder_args["backend"] + model = encoder_args["model"] + if backend == 'pytorch': + encoder = MultimodalEncoder(**encoder_args) + else: + raise ValueError(f"Unsupported backend: {backend}") + + server = OpenAIEncoderServer(encoder=encoder, model=model) + + asyncio.run(server(host, port)) @click.command("serve") @click.argument("model", type=str) @@ -301,17 +315,90 @@ def serve(model: str, tokenizer: Optional[str], host: str, port: int, launch_server(host, port, llm_args, metadata_server_cfg, server_role) +@click.command("encoder") +@click.argument("model", type=str) +@click.option("--host", + type=str, + default="localhost", + help="Hostname of the server.") +@click.option("--port", type=int, default=8000, help="Port of the server.") +@click.option("--backend", + type=click.Choice(["pytorch"]), + default=None, + help="Set to 'pytorch' for pytorch path. Default is cpp path.") +@click.option('--log_level', + type=click.Choice(severity_map.keys()), + default='info', + help="The logging level.") +@click.option("--max_batch_size", + type=int, + default=BuildConfig.max_batch_size, + help="Maximum number of requests that the engine can schedule.") +@click.option("--gpus_per_node", + type=int, + default=None, + help="Number of GPUs per node. Default to None, and it will be " + "detected automatically.") +@click.option("--trust_remote_code", + is_flag=True, + default=False, + help="Flag for HF transformers.") +@click.option( + "--extra_encoder_options", + type=str, + default=None, + help= + "Path to a YAML file that overwrites the parameters specified by trtllm-serve." +) +def serve_encoder(model: str, host: str, port: int, + log_level: str, backend: str, max_batch_size: int, + gpus_per_node: Optional[int], + trust_remote_code: bool, + extra_encoder_options: Optional[str]): + """Running an OpenAI API compatible server + + MODEL: model name | HF checkpoint path | TensorRT engine path + """ + logger.set_level(log_level) + + llm_args, _ = get_llm_args( + model=model, + backend=backend, + max_batch_size=max_batch_size, + gpus_per_node=gpus_per_node, + trust_remote_code=trust_remote_code) + + encoder_args_extra_dict = {} + if extra_encoder_options is not None: + with open(extra_encoder_options, 'r') as f: + encoder_args_extra_dict = yaml.safe_load(f) + encoder_args = update_llm_args_with_extra_dict(llm_args, encoder_args_extra_dict) + + # TODO: add DP for encoder + assert encoder_args["tensor_parallel_size"] == 1, "TP should be 1 for encoder" + assert encoder_args["pipeline_parallel_size"] == 1, "PP should be 1 for encoder" + assert encoder_args["moe_expert_parallel_size"] is None, "EP should be None for encoder" + launch_encoder_server(host, port, encoder_args) + + def get_ctx_gen_server_urls( - server_configs: List[CtxGenServerConfig]) -> List[str]: + server_configs: List[Union[CtxGenServerConfig, MultimodalServerConfig]]) -> Tuple[List[str], List[str]]: ctx_server_urls = [] gen_server_urls = [] + mm_server_urls = [] for cfg in server_configs: if cfg.type == "ctx": ctx_server_urls.append(f"http://{cfg.hostname}:{cfg.port}") - else: + elif cfg.type == "gen": gen_server_urls.append(f"http://{cfg.hostname}:{cfg.port}") + elif cfg.type == "mm": + mm_server_urls.append(f"http://{cfg.hostname}:{cfg.port}") - return ctx_server_urls, gen_server_urls + if len(mm_server_urls) > 0: + # TODO: enable fully disagg mode (e+p+d) later + return mm_server_urls, gen_server_urls + else: + return ctx_server_urls, gen_server_urls @click.command("disaggregated") @@ -368,6 +455,39 @@ def disaggregated(config_file: Optional[str], asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port)) +@click.command("multimodal_disaggregated") +@click.option("-c", + "--config_file", + type=str, + default=None, + help="Specific option for disaggregated mode.") +@click.option("-t", + "--server_start_timeout", + type=int, + default=180, + help="Server start timeout") +@click.option("-r", + "--request_timeout", + type=int, + default=180, + help="Request timeout") +def multimodal_disaggregated(config_file: Optional[str], server_start_timeout: int, + request_timeout: int): + """Running server in multimodal disaggregated mode""" + disagg_cfg = parse_mm_disagg_config_file(config_file) + + mm_server_urls, gen_server_urls = get_ctx_gen_server_urls( + disagg_cfg.server_configs) + print(f"mm_server_urls: {mm_server_urls}, gen_server_urls: {gen_server_urls}, disaggregated_cfg: {disagg_cfg.server_configs}") + + server = OpenAIMultiModalDisaggServer(gen_servers=gen_server_urls, + mm_servers=mm_server_urls, + req_timeout_secs=request_timeout, + server_start_timeout_secs=server_start_timeout, + ctx_router_config=None, + gen_router_config=None) + asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port)) + def set_cuda_device(): if (os.getenv("OMPI_COMM_WORLD_RANK")): @@ -604,7 +724,9 @@ def resolve_command(self, ctx, args): commands={ "serve": serve, "disaggregated": disaggregated, - "disaggregated_mpi_worker": disaggregated_mpi_worker + "disaggregated_mpi_worker": disaggregated_mpi_worker, + "disaggregated_mm": multimodal_disaggregated, + "encoder": serve_encoder }) if __name__ == "__main__": diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 9a7028cd564..f2c4e19816d 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -20,6 +20,7 @@ from ..bindings import executor as tllm from ..builder import Engine from ..disaggregated_params import DisaggregatedParams +from ..multimodal_params import MultimodalParams from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) @@ -120,7 +121,8 @@ def generate_async( mrope_config: Optional[dict] = None, kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, disaggregated_params: Optional[DisaggregatedParams] = None, - postproc_params: Optional[PostprocParams] = None + postproc_params: Optional[PostprocParams] = None, + disagg_mm_params: Optional[MultimodalParams] = None, ) -> GenerationResult: """Generate output for the given prompt token ids in the asynchronous mode. Asynchronous generation accepts single prompt only. @@ -145,7 +147,16 @@ def generate_async( multimodal_embedding=multimodal_embedding, mrope_config=mrope_config, kv_cache_retention_config=kv_cache_retention_config, - disaggregated_params=disaggregated_params)) + disaggregated_params=disaggregated_params, + disagg_mm_params=disagg_mm_params)) + return result + + def generate_multimodal_async( + self, + mm_request, + ): + result = self.submit_mm( + mm_request) return result def generate( diff --git a/tensorrt_llm/executor/multimodal/__init__.py b/tensorrt_llm/executor/multimodal/__init__.py new file mode 100644 index 00000000000..0d2cab91fd4 --- /dev/null +++ b/tensorrt_llm/executor/multimodal/__init__.py @@ -0,0 +1,8 @@ +from .request import * +from .result import * + +__all__ = [ + "MultimodalRequest", + "MultimodalResponse", + "MultimodalResult", +] \ No newline at end of file diff --git a/tensorrt_llm/executor/multimodal/request.py b/tensorrt_llm/executor/multimodal/request.py new file mode 100644 index 00000000000..c989e4c8b7d --- /dev/null +++ b/tensorrt_llm/executor/multimodal/request.py @@ -0,0 +1,283 @@ +import asyncio +from dataclasses import dataclass, field +from typing import (Any, List, Optional, AsyncIterator) + +import torch + +from typing import cast +from tensorrt_llm.inputs.utils import load_image, async_load_image +from tensorrt_llm.multimodal_params import MultimodalParams + +__all__ = [ + "MultimodalRequest", + "MultimodalResponse", +] + +@dataclass(slots=True) +class MultimodalItem: + # request id for this item + req_id: int + # the id of the mm item within each request + id: int + # the modality type of the item + modality_type: str + # The url of the item + url: str + # The data of the item + data: Optional[Any] = None # Any: can be raw tensor or processed tensor + data_handle: Optional[bytes] = None + # Whether the item has been materialized + materialized: bool = False + + # The # of tokens offset of the item within each request after encoder + offset: int = 0 + # The # of tokens length of the item + length: int = 0 + + # The coroutine for the item + coroutine: Optional[asyncio.Task] = None + error_message: Optional[str] = None + + def __post_init__(self): + if self.data is not None or self.data_handle is not None: + self.materialized = True + + + async def async_prefetch(self): + try: + if self.modality_type == "image": + self.coroutine = async_load_image(self.url, format="pil", device="cpu") + elif self.modality_type == "video": + assert False, "video is not supported yet" + #self.coroutine = async_load_video(self.url) + else: + raise ValueError(f"Unknown modality type: {self.modality_type}") + except Exception as e: + self.materialized = False + self.error_message = str(e) + + async def retrieve(self): + if self.coroutine: + try: + self.data = await self.coroutine + if self.data is not None: # Only set materialized if we got valid data + self.materialized = True + return self.data + except Exception as e: + self.materialized = False + self.error_message = str(e) + return None + else: + return self.data + + def load(self): + try: + if self.modality_type == "image": + self.data = load_image(self.url, format="pil", device="cpu") + self.materialized = True + elif self.modality_type == "video": + assert False, "video is not supported yet" + #self.coroutine = load_video(self.url) + else: + raise ValueError(f"Unknown modality type: {self.modality_type}") + except Exception as e: + self.materialized = False + self.error_message = str(e) + + @staticmethod + async def process_items(items: List['MultimodalItem']) -> AsyncIterator['MultimodalItem']: + """Process a list of items concurrently and yield them as they complete. + + Args: + items: List of MultimodalItems to process + + Yields: + MultimodalItems as they complete loading + """ + # Create tasks for all items + tasks = {} + for item in items: + # Create and store the task + task = asyncio.create_task(item.retrieve()) + tasks[task] = item # Map task to item directly + + # Process results as they arrive + while tasks: + # Wait for any task to complete + done, pending = await asyncio.wait( + tasks.keys(), + return_when=asyncio.FIRST_COMPLETED + ) + + # Process completed tasks + for task in done: + item = tasks.pop(task) # Get the item directly from the task mapping + try: + result = await task + if result is not None: + yield item + else: + # Item failed to load but we still yield it with error state + item.materialized = False + item.error_message = "Failed to load item" + yield item + except Exception as e: + item.materialized = False + item.error_message = str(e) + yield item + + +class MultimodalRequest: + """ + A request class for multimodal encoding. + Handles requests containing URLs for different modalities (images, videos, etc.) + and returns embeddings for each item. + """ + + def __init__(self, items: Optional[List[MultimodalItem]] = None): + self.items = items or [] # type: List[MultimodalItem] + self.id: Optional[int] = None + self.error_message = None + + def set_id(self, id): + self.id = id + for item in self.items: + item.req_id = id + return self + + def has_error(self) -> bool: + # TODO: check if any item AND output better err msg + return any(item.error_message for item in self.items) + + async def prefetch(self): + await asyncio.gather(*[item.async_prefetch() for item in self.items]) + + async def fetch(self): + """Load and fill data for all items in the request. + + This method will: + 1. Initialize loading for all items using prefetch() + 2. Process all items concurrently and wait for them to complete + 3. Update the request state based on the results + + Returns: + bool: True if all items were loaded successfully, False otherwise + """ + await self.prefetch() + async for item in MultimodalItem.process_items(self.items): + item.coroutine = None + item.data_handle = None # TODO: need to remove this + + def load(self): + for item in self.items: + item.load() + + @classmethod + def from_chat_messages(cls, messages) -> "MultimodalRequest": + request = cls() + count = 0 + + for message in messages: + content = message.get("content", []) + # Ignore empty and txt content + if isinstance(content, str) or content is None: + content = [] + # Process each content part + for part in content: + assert isinstance(part, dict) + part_type = part.get("type", None) + if part_type is None or part_type == "text": + continue + + # Handle image_url type + if part_type == "image_url": + url = part.get("image_url", {}).get("url") + if url: + url = cast(str, url) + request.items.append(MultimodalItem(req_id=request.id, id=count, modality_type="image", url=url)) + count += 1 + # TODO: Handle video_url type hasn't been tested yet + elif part_type == "video_url": + url = part.get("video_url", {}).get("url") + if url: + url = cast(str, url) + request.items.append(MultimodalItem(req_id=request.id, id=count, modality_type="video", url=url)) + count += 1 + + return request + + def create_response(self): + """Create a response object and set up IPC communication. + + Returns: + MultimodalResponse: The response object that will be populated with results + """ + num_items = len(self.items) + item_offsets = [item.offset for item in self.items] + item_token_length = [item.length for item in self.items] + # TODO: how to set client id? is it always the same as request id? + response = MultimodalResponse(request_id=self.id, client_id=self.id, num_items=num_items, item_offsets=item_offsets, item_token_length=item_token_length) + return response + +@dataclass(slots=True) +class MultimodalResponse: + """Response for multimodal requests containing embeddings for each item.""" + request_id: int + client_id: Optional[int] = None + num_items: int = 0 + item_offsets: List[int] = field(default_factory=list) + item_token_length: List[int] = field(default_factory=list) + embeddings: Optional[torch.Tensor] = None + embedding_handle: Optional[bytes] = None + mrope_config: Optional[dict] = None + _is_final: bool = False + error_msg: Optional[str] = None + cp_event: Optional[torch.cuda.Event] = None + + def set_embeddings(self, embeddings: torch.Tensor, cp_event: Optional[torch.cuda.Event] = None) -> None: + """Set the embeddings for the response.""" + self.embeddings = embeddings + self.cp_event = cp_event + + # TODO: error handling is missing; hopefully pass the error from mmItem or during processing + def set_error(self, error_msg: str) -> None: + """Set an error message for the response.""" + self.error_msg = error_msg + + def set_mrope_config(self, mrope_config: dict) -> None: + """Set the mrope config for the response.""" + self.mrope_config = mrope_config + + def has_error(self) -> bool: + """Check if the response has an error.""" + return self.error_msg is not None + + def set_final(self) -> None: + """Set the response to final.""" + self._is_final = True + + # TODO: this is a hack to make the result compatible with the proxy/worker architecture + @property + def result(self): + """Return a result object compatible with the proxy/worker architecture.""" + return type('Result', (), { + 'is_final': self._is_final, # Multimodal responses are always final + 'error_msg': self.error_msg + }) + + def get_params(self): + if self.embedding_handle: + # Convert the serialized tensor info to a JSON-serializable format + embeddings = [] + for tensor_info in self.embedding_handle: + embeddings.append(tensor_info.dump_to_dict()) + else: + embeddings = None + + return MultimodalParams( + embeddings=embeddings, + mrope_config=self.mrope_config, + num_items=self.num_items, + item_offsets=self.item_offsets, + item_token_length=self.item_token_length) + diff --git a/tensorrt_llm/executor/multimodal/result.py b/tensorrt_llm/executor/multimodal/result.py new file mode 100644 index 00000000000..b91826e3fc4 --- /dev/null +++ b/tensorrt_llm/executor/multimodal/result.py @@ -0,0 +1,166 @@ +import weakref +from queue import Queue +from typing import (TYPE_CHECKING, Callable, Optional, Union, Any, cast) + +from tensorrt_llm._utils import nvtx_range_debug +from tensorrt_llm.llmapi.tracer import global_tracer +from tensorrt_llm.llmapi.utils import AsyncQueue +from tensorrt_llm.executor.utils import ErrorResponse, has_event_loop + +if TYPE_CHECKING: + from tensorrt_llm.executor import GenerationExecutor +from .request import MultimodalRequest, MultimodalResponse + +__all__ = [ + "MultimodalResult", +] + +class MultimodalResult: + def __init__( + self, + mm_request: MultimodalRequest, + background_error_handler: Optional[Callable] = None, + executor: Optional["GenerationExecutor"] = None, + ) -> None: + self.request_id = mm_request.id # abort_request is using request_id + self._background_error_handler = background_error_handler + self._done = False + self._timeout = 2 + self._executor: Optional[weakref.ReferenceType[ + "GenerationExecutor"]] = weakref.ref(executor) if executor else None + self._aborted = False + self.multimodal_params = None + if has_event_loop(): + self.aqueue = AsyncQueue() + self.queue = self.aqueue.sync_q + else: + self.queue = Queue() + self.aqueue = None + + def set_timeout(self, timeout: float) -> None: + """Set the timeout for getting results.""" + self._timeout = timeout + + def mark_undone(self) -> None: + """Should be called when new prompts are submitted.""" + self._done = False + + @property + def finished(self) -> bool: + return self._done + + async def _aresult_step(self) -> None: + assert self.aqueue is not None, "The asyncio event loop was not present during initialization, so async operations are not available." + response = await self.aqueue.get() + global_tracer().log_instant("result_step.get") + self._handle_response(response) + + async def aresult(self) -> "MultimodalResult": + """Wait for the completion of the request, and return the result. + + Returns: + MultimodalResult: The result object. + """ + while not self._done: + await self._aresult_step() + return self + + def _result_step(self, timeout: Optional[float] = None) -> None: + response = self.queue.get(timeout=timeout) + self._handle_response(response) + + @nvtx_range_debug("handle_response", + color="red", + category="MultimodalResult") + def _handle_response(self, + response: Union[MultimodalResponse, ErrorResponse]) -> None: + if isinstance(response, MultimodalResponse): + if response.has_error(): + if self._background_error_handler is not None and ( + handler := self._background_error_handler()): + handler(response.error_msg) + + response_result = response.result + self._done = response_result.is_final + self.multimodal_params = response.get_params() + + if self._background_error_handler and ( + handler := self._background_error_handler()): + handler() + elif isinstance(response, ErrorResponse): + if self._background_error_handler is not None and ( + handler := self._background_error_handler()): + handler(response.error_msg) + # TODO: we should not need to set done here; but proxy error_queue is always empty (?) + # WAR: we set done to unblock the result.get() + self._done = True + else: + raise ValueError(f"Unknown response type: {response}") + + def result(self, timeout: Optional[float] = None) -> "MultimodalResult": + """Wait for the completion of the request, and return the result. + + Args: + timeout (float, optional): Timeout. Defaults to None. + + Returns: + tensorrt_llm.executor.result.GenerationResult: generation result. + """ + while not self._done: + self._result_step(timeout) + return self + + def abort(self) -> None: + """Abort the multimodal request.""" + assert self._executor is not None, "The executor is not set for this result." + executor = self._executor() + assert executor is not None, "The executor has been garbage collected." + assert self.request_id is not None, "The request ID is not set." + executor.abort_request(self.request_id) + self._aborted = True + + def __await__(self): + """Make the result awaitable.""" + return self.aresult().__await__() + + def __iter__(self): + """Make the result iterable.""" + return self + + def __next__(self): + """Get the next result.""" + if self._done: + raise StopIteration + return self.result() + + def __aiter__(self): + """Make the result async iterable.""" + return self + + async def __anext__(self): + """Get the next result asynchronously.""" + if self._done: + raise StopAsyncIteration + return await self.aresult() + + def _exception(self, timeout: Optional[float] = None) -> Optional[Exception]: + """Get any exception that occurred during processing.""" + try: + self._result_step(timeout) + except Exception as e: + return e + return None + + def _repr_fields(self) -> dict[str, Any]: + """Get fields for string representation.""" + return { + "request_id": self.request_id, + "done": self._done, + } + + def __repr__(self) -> str: + """Get string representation of the result.""" + fields = self._repr_fields() + fields_str = ", ".join(f"{k}={v!r}" for k, v in fields.items()) + return f"{self.__class__.__name__}({fields_str})" + diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 76cb2737c6e..b4c77e9838c 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -19,6 +19,7 @@ from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, print_colored, print_colored_debug) from .executor import GenerationExecutor +from .multimodal import MultimodalResult from .ipc import FusedIpcQueue, IpcQueue from .postproc_worker import PostprocWorkerConfig from .request import CancellingRequest, GenerationRequest @@ -386,6 +387,27 @@ def shutdown(self): # Process the errors in-case error during shutting down the threads self._handle_background_error() + def submit_mm(self, request): + """Submit a multimodal request for processing. + + Args: + request: The multimodal request containing items to process + + Returns: + MultimodalResponse: The response object that will be populated with results + """ + self._start_dispatch_threads() + request.set_id(self._get_next_client_id()) + result = MultimodalResult( + request, + background_error_handler=self._handle_background_error, + executor=self) + self._results[request.id] = result + self.request_queue.put(request) + + self._handle_background_error() + return result + def submit(self, request: GenerationRequest) -> GenerationResult: """ Low-level API to the executor. Return a "future" GenerationResult diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index f58252df1cc..f67b52d88f4 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -9,6 +9,7 @@ from ..llmapi.llm_utils import KvCacheRetentionConfig from ..sampling_params import SamplingParams from .postproc_worker import PostprocParams +from ..multimodal_params import MultimodalParams __all__ = [ "LoRARequest", @@ -85,6 +86,7 @@ def __init__( kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, disaggregated_params: Optional[DisaggregatedParams] = None, postproc_params: Optional[PostprocParams] = None, + disagg_mm_params: Optional[MultimodalParams] = None, ): if isinstance(prompt_token_ids, list): self.prompt_token_ids = prompt_token_ids @@ -108,6 +110,7 @@ def __init__( self.kv_cache_retention_config = kv_cache_retention_config self.id: Optional[int] = None self.disaggregated_params = disaggregated_params + self.disagg_mm_params = disagg_mm_params def set_id(self, id): assert self.id is None, f"Request ID is already set: {self.id}" diff --git a/tensorrt_llm/executor/serialization.py b/tensorrt_llm/executor/serialization.py index a3bd47ea6a1..d1936993189 100644 --- a/tensorrt_llm/executor/serialization.py +++ b/tensorrt_llm/executor/serialization.py @@ -15,6 +15,7 @@ "collections": ["OrderedDict"], "datetime": ["timedelta"], "pathlib": ["PosixPath"], + "PIL.Image": ["Image"], "llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess" ], # only used in tests ### starting import of torch models classes. They are used in test_llm_multi_gpu.py. @@ -70,6 +71,7 @@ "tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"], "tensorrt_llm.builder": ["BuildConfig"], "tensorrt_llm.disaggregated_params": ["DisaggregatedParams"], + "tensorrt_llm.multimodal_params": ["MultimodalParams"], "tensorrt_llm.executor.postproc_worker": [ "PostprocArgs", "PostprocParams", "PostprocWorkerConfig", "PostprocWorker.Input", "PostprocWorker.Output" @@ -78,6 +80,10 @@ "CancellingRequest", "GenerationRequest", "LoRARequest", "PromptAdapterRequest" ], + "tensorrt_llm.executor.multimodal.request": [ + "MultimodalRequest", "MultimodalResponse", "MultimodalResult", "MultimodalItem" + ], + "tensorrt_llm._torch.multimodal.mm_utils": ["SharedTensorContainer"], "tensorrt_llm.executor.result": [ "CompletionOutput", "DetokenizedGenerationResultBase", "GenerationResult", "GenerationResultBase", "IterationResult", diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 5ed6807731d..ffdc092db96 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -36,6 +36,8 @@ PostprocWorkerConfig, postproc_worker_main) from .request import (CancellingRequest, GenerationRequest, LoRARequest, PromptAdapterRequest) +from .multimodal.request import MultimodalRequest, MultimodalResponse +from .multimodal.result import MultimodalResult from .result import (GenerationResult, IterationResult, LogProbsResult, ResponseWrapper, compute_logprobs) from .utils import (ErrorResponse, IntraProcessQueue, RequestError, @@ -81,6 +83,8 @@ def __init__( self._executor_config = executor_config self._is_pytorch_backend = getattr(self._executor_config, "backend", None) == "pytorch" + self._is_mm_encoder_only = getattr(self._executor_config, "mm_encoder_only", + False) if global_mpi_size() > 1: logger.set_rank(self.global_rank) @@ -122,10 +126,17 @@ def _create_engine(): "engine_dir": executor_config.trt_engine_dir, } if executor_config.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["lora_config"] = lora_config + # If mm_encoder_only is True, we use the multimodal pyexecutor that only executes the mm encoder and returns the mm embedding. + if self._is_mm_encoder_only: + from tensorrt_llm._torch.pyexecutor.multimodal.multimodal_pyexecutor_creator import \ + create_multimodal_pyexecutor + create_executor = create_multimodal_pyexecutor + args.pop("engine_dir") # remove engine_dir as mm executor for now does not accept it + else: + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor + create_executor = create_py_executor + args["lora_config"] = lora_config elif executor_config.backend == "_autodeploy": from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ create_autodeploy_executor @@ -471,6 +482,9 @@ def _deduce_max_tokens(request: GenerationRequest, executor_request.py_logits_post_processors = lp if isinstance( lp, list) else [lp] + if request.disagg_mm_params is not None: + executor_request.disagg_mm_params = request.disagg_mm_params + if request.query_token_ids is not None: # pytorch star attention workflow # a workaround to avoid public interface update @@ -482,6 +496,22 @@ def _deduce_max_tokens(request: GenerationRequest, except Exception as e: raise RequestError(str(e)) from e + def _enqueue_mm_request(self, request: MultimodalRequest) -> int: + """Enqueue a multimodal request directly without converting to tllm.Request. + + Args: + request: The multimodal request to enqueue + + Returns: + int: The request ID assigned by the engine + """ + assert request.id is not None + + # For multimodal, we just need to pass the data directly + # No need for complex request type conversion + req_id = self.engine.enqueue_request(request) + return req_id + def submit(self, request: GenerationRequest) -> GenerationResult: """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """ self.start() @@ -516,6 +546,33 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result + @nvtx_range_debug("submit_mm", + color="yellow", + category="worker_submit") + def submit_mm(self, request: MultimodalRequest): + """Submit a multimodal request and return a MultimodalResponse.""" + self.start() + + if self.rank != 0: + raise RuntimeError( + "Only rank 0 can submit requests.\n" + "To fix this, ensure that the llm.generate(...) method is " + "guarded with the `if __name__ == '__main__':` block.") + + client_id = request.id if request.id is not None else self._get_next_client_id() + if request.id is None: + request.set_id(client_id) + + result = MultimodalResult(request) + self._results[client_id] = result + + request_id = self._enqueue_mm_request(request) + self._client_id_to_request_id[client_id] = request_id + + self._handle_background_error() + + return result + def _pop_result(self, client_id: int): self._results.pop(client_id, None) self._client_id_to_request_id.pop(client_id, None) @@ -751,6 +808,13 @@ def notify_proxy_threads_to_quit(): logger.error(f"submit request failed: {e}") worker._await_response_helper.temp_error_responses.put( ErrorResponse(req.id, e, req.id)) + elif isinstance(req, MultimodalRequest): + try: + worker.submit_mm(req) + except RequestError as e: + logger.error(f"submit mm request failed: {e}") + worker._await_response_helper.temp_error_responses.put( + ErrorResponse(req.id, e, req.id)) else: raise ValueError(f"Unknown request type: {type(req)}") @@ -841,7 +905,7 @@ def handle_for_worker(self, responses: List[tllm.Response]) -> None: queue = self.worker.return_queue(response.client_id) logprobs_result = _get_logprobs(self.worker, response, - self.worker._is_pytorch_backend) + self.worker._is_pytorch_backend) if not self.worker._is_mm_encoder_only else None if logprobs_result: response = ResponseWrapper(response, logprobs_result) @@ -882,7 +946,7 @@ def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: response.request_id) else: logprobs_result = _get_logprobs(self.worker, response, - self.worker._is_pytorch_backend) + self.worker._is_pytorch_backend) if not self.worker._is_mm_encoder_only else None if logprobs_result: response = ResponseWrapper(response, logprobs_result) @@ -985,7 +1049,7 @@ def _send_rsp( # Eliminate the finished GenerationRequest instances timely, which may # take considerable memory. - if is_llm_response(response): + if is_llm_response(response) or isinstance(response, MultimodalResponse): if response.has_error() or response.result.is_final: worker._pop_result(response.client_id) elif isinstance(response, ErrorResponse): diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index 2fa3987d6e1..4577eec6fd4 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass, field from enum import Enum -from typing import Any, List, Literal, Optional, Tuple +from typing import Any, List, Literal, Optional, Tuple, Union import yaml from mpi4py.MPI import COMM_WORLD, Comm @@ -29,6 +29,14 @@ class CtxGenServerConfig(): instance_num_ranks: int = 1 other_args: dict = field(default_factory=dict) +@dataclass +class MultimodalServerConfig(): + type: Literal['mm'] + hostname: Optional[str] = None + port: Optional[int] = None + instance_num_ranks: int = 1 + other_args: dict = field(default_factory=dict) + @dataclass class RouterConfig(): @@ -42,6 +50,15 @@ class ConditionalDisaggConfig(): max_local_prefill_length: int = 0 +@dataclass +class MultimodalDisaggServerConfig(): + server_configs: List[Union[MultimodalServerConfig, CtxGenServerConfig]] + hostname: str = "localhost" + port: int = 8000 + # TODO: add router support for multimodal disagg + # mm_router_config: Optional[RouterConfig] = None + # gen_router_config: Optional[RouterConfig] = None + @dataclass class DisaggServerConfig(): server_configs: List[CtxGenServerConfig] @@ -70,6 +87,45 @@ def parse_disagg_config_file(yaml_config_file: str): return disagg_server_config +def parse_mm_disagg_config_file(yaml_config_file: str): + + with open(yaml_config_file, 'r') as file: + + config = yaml.safe_load(file) + + disagg_server_config = extract_mm_disagg_cfg(**config) + + return disagg_server_config + +def extract_mm_disagg_cfg(hostname: str = 'localhost', + port: int = 8000, + multimodal_servers: dict = dict(), + generation_servers: dict = dict(), + **kwargs: Any) -> MultimodalDisaggServerConfig: + + # If parameters are specified outside the context_severs and generation_servers sections, + # make sure they match + # Also inherit the values from the top-level + for key, value in kwargs.items(): + for server_type, servers in [("multimodal_servers", multimodal_servers), + ("generation_servers", generation_servers) + ]: + if key in servers: + if servers[key] != value: + raise ValueError( + f"Parameter {key} is specified both in the top-level and in the {server_type} section, but with different values" + ) + else: + # Inherit the value from the top-level + servers[key] = value + + server_configs = extract_multimodal_cfgs( + type="mm", **multimodal_servers) + extract_ctx_gen_cfgs( + type="gen", **generation_servers) + + # TODO: add router config later for multimodal server + + return MultimodalDisaggServerConfig(server_configs, hostname, port) def extract_disagg_cfg(hostname: str = 'localhost', port: int = 8000, @@ -113,6 +169,47 @@ def extract_disagg_cfg(hostname: str = 'localhost', return config +def extract_multimodal_cfgs(type: Literal['mm'], + num_instances: int = 1, + urls: Optional[List[str]] = None, + **kwargs: Any) -> List[MultimodalServerConfig]: + + hostnames = [] + ports = [] + if urls: + for url in urls: + hostname, port_str = url.split(':') + port = int(port_str) + hostnames.append(hostname) + ports.append(port) + + if len(hostnames) != num_instances: + raise ValueError( + f"Number of hostnames ({len(hostnames)}) should be equal to the number of instances ({num_instances})" + ) + + if len(ports) != num_instances: + raise ValueError( + f"Number of ports ({len(ports)}) should be equal to the number of instances ({num_instances})" + ) + + else: + hostnames = [None] * num_instances + ports = [None] * num_instances + + # Compute the number of ranks per instance for multimodal server + instance_num_ranks = kwargs.get('data_parallel_size', 1) + + cfgs = [] + for hostname, port in zip(hostnames, ports): + cfgs.append( + MultimodalServerConfig(type=type, + hostname=hostname, + port=port, + instance_num_ranks=instance_num_ranks, + other_args=kwargs)) + return cfgs + def extract_ctx_gen_cfgs(type: Literal['ctx', 'gen'], num_instances: int = 1, urls: Optional[List[str]] = None, diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index ac6abe0e67b..5ee4aad3ae8 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -39,7 +39,7 @@ # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import from .utils import (append_docstring, exception_handler, get_device_count, print_colored_debug) - +from ..multimodal_params import MultimodalParams class RequestOutput(DetokenizedGenerationResultBase, GenerationResult): """The output data of a completion request to the LLM. @@ -219,6 +219,8 @@ def generate( KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None, disaggregated_params: Optional[Union[ DisaggregatedParams, Sequence[DisaggregatedParams]]] = None, + disagg_mm_params: Optional[Union[ + MultimodalParams, Sequence[MultimodalParams]]] = None, ) -> Union[RequestOutput, List[RequestOutput]]: """Generate output for the given prompts in the synchronous mode. Synchronous generation accepts either single prompt or batched prompts. @@ -266,6 +268,7 @@ def _item_at(maybe_batched: Union[Any, Sequence[Any]], pos: int) -> Any: kv_cache_retention_config=_item_at(kv_cache_retention_config, i), disaggregated_params=_item_at(disaggregated_params, i), + disagg_mm_params=_item_at(disagg_mm_params, i), streaming=False) futures.append(future) @@ -291,6 +294,7 @@ def generate_async( kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None, disaggregated_params: Optional[DisaggregatedParams] = None, _postproc_params: Optional[PostprocParams] = None, + disagg_mm_params: Optional[MultimodalParams] = None, ) -> RequestOutput: """Generate output for the given prompt in the asynchronous mode. Asynchronous generation accepts single prompt only. @@ -343,10 +347,18 @@ def generate_async( prompt_token_ids = inputs['prompt_token_ids'] prompt = None query_token_ids = inputs.get("query_token_ids", None) + extra_processed_inputs = None elif "prompt" in inputs: - with nvtx_range_debug("input_processor"): - prompt_token_ids, extra_processed_inputs = self.input_processor( - inputs, sampling_params) + + if disagg_mm_params is not None: + with nvtx_range_debug("mm_postprocess_on_llm"): + prompt_token_ids = self.input_processor.postprocess(inputs, sampling_params, disagg_mm_params) + extra_processed_inputs = None + else: + with nvtx_range_debug("input_processor"): + prompt_token_ids, extra_processed_inputs = self.input_processor( + inputs, sampling_params) + prompt = inputs['prompt'] if extra_processed_inputs is not None: query_token_ids = extra_processed_inputs.get('query_token_ids') @@ -377,6 +389,7 @@ def generate_async( kv_cache_retention_config=kv_cache_retention_config, disaggregated_params=disaggregated_params, postproc_params=_postproc_params, + disagg_mm_params=disagg_mm_params, ) return RequestOutput._from_generation_result(result, prompt, diff --git a/tensorrt_llm/multimodal_params.py b/tensorrt_llm/multimodal_params.py new file mode 100644 index 00000000000..59b97289519 --- /dev/null +++ b/tensorrt_llm/multimodal_params.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + + +@dataclass(slots=True, kw_only=True) +class MultimodalParams: + """ + Parameters for multimodal parameters in disaggregated serving as the interface between mm_encoder and llm servers. + + This class holds information needed to reconstruct prompt_token_ids and mm_embedding in LLM servers. + + Args: + embeddings (Optional[Dict[str, Any]]): + Metadata for reconstructing embedding tensors via CUDA IPC. + The tensor data is stored in shared memory or device memory (shared by cudaIPC) for LLM server reconstruction. + + mrope_config (Optional[dict]): + Configuration for multimodal Rotary Position Embedding parameters. + + num_items (Optional[int]): + Number of multimodal items in the batch. Used to reconstruct input_ids. + + item_offsets (Optional[List[int]]): + Offsets for positioning each multimodal item in the sequence. + + item_token_length (Optional[List[int]]): + Token lengths for each multimodal item. + + Note: + As an experimental feature, all fields are currently optional to allow flexibility during development. + In future, we should stabilize the interface by defining a fixed set of required fields. + """ + embeddings: Optional[List[Dict[str, Any]]] = None + mrope_config: Optional[dict] = None + num_items: Optional[int] = 0 + item_offsets: Optional[List[int]] = None + item_token_length: Optional[List[int]] = None \ No newline at end of file diff --git a/tensorrt_llm/serve/__init__.py b/tensorrt_llm/serve/__init__.py index 0f8e6b1b282..1c9ef3f9d8a 100644 --- a/tensorrt_llm/serve/__init__.py +++ b/tensorrt_llm/serve/__init__.py @@ -1,4 +1,5 @@ from .openai_disagg_server import OpenAIDisaggServer +from .openai_disagg_multimodal import OpenAIMultiModalDisaggServer from .openai_server import OpenAIServer __all__ = ['OpenAIServer', 'OpenAIDisaggServer'] diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index fd56bfa161b..66d847ddc08 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -67,7 +67,7 @@ def _parse_chat_message_content_mm_part( def parse_chat_message_content_part( - part: ChatCompletionMessageParam, ) -> Optional[Any]: + part: ChatCompletionMessageParam, skip_loading: bool = False) -> Optional[Any]: """Parse a single part of a chat message.""" if isinstance(part, str): return part @@ -93,8 +93,10 @@ async def load_image_async(): except Exception as e: logger.error(f"Failed to load image: {str(e)}") return None + async def noop_coroutine(): + return str_content - return MultimodalData(modality="image", data=load_image_async()) + return MultimodalData(modality="image", data=load_image_async() if not skip_loading else noop_coroutine()) if part_type == "video_url": str_content = cast(str, content) @@ -105,8 +107,10 @@ async def load_video_async(): except Exception as e: logger.error(f"Failed to load video: {str(e)}") return None + async def noop_coroutine(): + return str_content - return MultimodalData(modality="video", data=load_video_async()) + return MultimodalData(modality="video", data=load_video_async() if not skip_loading else noop_coroutine()) raise NotImplementedError(f"Unknown part type: {part_type}") @@ -114,12 +118,13 @@ async def load_video_async(): def parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionMessageParam], + skip_loading: bool = False, ) -> ConversationMessage: """Parse multiple parts of a chat message.""" text_parts = [] media_parts = [] for part in parts: - parse_res = parse_chat_message_content_part(part) + parse_res = parse_chat_message_content_part(part, skip_loading) if parse_res: if isinstance(parse_res, str): text_parts.append(parse_res) @@ -134,7 +139,7 @@ def parse_chat_message_content_parts( def parse_chat_message_content( - message: ChatCompletionMessageParam, ) -> ConversationMessage: + message: ChatCompletionMessageParam, skip_loading: bool = False) -> ConversationMessage: """Parse the content of a chat message.""" role = message["role"] content = message.get("content") @@ -149,6 +154,7 @@ def parse_chat_message_content( result = parse_chat_message_content_parts( role, content, + skip_loading, ) return result @@ -156,6 +162,7 @@ def parse_chat_message_content( def parse_chat_messages_coroutines( messages: List[ChatCompletionMessageParam], model_config: AutoConfig, + skip_loading: bool = False, ) -> Tuple[List[ConversationMessage], Optional[Coroutine[ Any, Any, Optional[Dict[str, List[Any]]]]]]: """Parse multiple chat messages and return conversation and coroutine.""" @@ -163,7 +170,7 @@ def parse_chat_messages_coroutines( mm_data_tracker = MultimodalDataTracker(model_config.model_type) for msg in messages: - parsed_msg = parse_chat_message_content(msg) + parsed_msg = parse_chat_message_content(msg, skip_loading) conversation.append(parsed_msg) if parsed_msg["media"]: for mdata in parsed_msg["media"]: diff --git a/tensorrt_llm/serve/encoder_server.py b/tensorrt_llm/serve/encoder_server.py new file mode 100755 index 00000000000..c7f160dc947 --- /dev/null +++ b/tensorrt_llm/serve/encoder_server.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +import asyncio +import logging +import signal +from contextlib import asynccontextmanager +from http import HTTPStatus +from typing import Dict, List, Optional, Union +from fastapi import FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse, Response +import uvicorn + +from tensorrt_llm.serve.openai_protocol import ChatCompletionRequest +from tensorrt_llm.executor import CppExecutorError +from tensorrt_llm.version import __version__ as VERSION +from tensorrt_llm._torch.multimodal.mm_encoder import MultimodalEncoder +from tensorrt_llm.executor.multimodal import MultimodalRequest +from tensorrt_llm.multimodal_params import MultimodalParams +from pathlib import Path +from tensorrt_llm.serve.openai_protocol import ModelList, ModelCard +from tensorrt_llm.serve.openai_server import OpenAIServer +from dataclasses import asdict + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# yapf: enale +TIMEOUT_KEEP_ALIVE = 5 # seconds. + + +class OpenAIEncoderServer: + """ + Encoder server that processes image URLs and returns structured embeddings + Compatible with the OpenAI disaggregated server architecture. + """ + + def __init__(self, encoder: MultimodalEncoder, model: str): + """ + Initialize the encoder server. + + Args: + encoder: ImageEncoder instance specialized for encoding images + model: Name or identifier for the encoder model + """ + self.encoder = encoder + self.model = model + + model_dir = Path(model) + if model_dir.exists() and model_dir.is_dir(): + self.model = model_dir.name + else: + self.model = model + + @asynccontextmanager + async def lifespan(app: FastAPI): + # terminate rank0 worker + yield + self.encoder.shutdown() + + self.app = FastAPI(lifespan=lifespan) + + @self.app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + return OpenAIServer.create_error_response(message=str(exc)) + + self.register_routes() + + async def await_disconnected(self, raw_request: Request, promise): + while not await raw_request.is_disconnected(): + await asyncio.sleep(1) + if not promise.finished: + promise.abort() + logger.info( + f"{raw_request.client} is disconnected, abort {promise.request_id}") + + def register_routes(self): + self.app.add_api_route("/health", self.health, methods=["GET"]) + self.app.add_api_route("/version", self.version, methods=["GET"]) + self.app.add_api_route("/v1/models", self.get_model, methods=["GET"]) + self.app.add_api_route("/v1/multimodal_encoder", + self.encode_image, + methods=["POST"]) + + async def health(self) -> Response: + """Health check endpoint.""" + return Response(status_code=200) + + async def version(self) -> JSONResponse: + """Version information endpoint.""" + ver = {"version": VERSION} + return JSONResponse(content=ver) + + async def get_model(self) -> JSONResponse: + model_list = ModelList(data=[ModelCard(id=self.model)]) + return JSONResponse(content=model_list.model_dump()) + + async def encode_image(self, request: ChatCompletionRequest, raw_request: Request) -> Response: + + async def create_mm_embedding_response( + promise) -> MultimodalParams: + await promise.aresult() + return promise.multimodal_params + + try: + mm_request = MultimodalRequest.from_chat_messages(request.messages) + if len(mm_request.items) == 0: + return JSONResponse(content={}) + promise = await self.encoder.generate_async(mm_request) + asyncio.create_task(self.await_disconnected(raw_request, promise)) + response = await create_mm_embedding_response(promise) + if isinstance(response, MultimodalParams): + return JSONResponse(content=asdict(response)) + else: + return JSONResponse(content=response.model_dump()) + + except CppExecutorError: + # If internal executor error is raised, shutdown the server + signal.raise_signal(signal.SIGINT) + except Exception as e: + return OpenAIServer.create_error_response(str(e)) + + async def __call__(self, host: str, port: int): + """Run the server.""" + config = uvicorn.Config(self.app, + host=host, + port=port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + await uvicorn.Server(config).serve() + + diff --git a/tensorrt_llm/serve/openai_disagg_multimodal.py b/tensorrt_llm/serve/openai_disagg_multimodal.py new file mode 100755 index 00000000000..575b150da50 --- /dev/null +++ b/tensorrt_llm/serve/openai_disagg_multimodal.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python +import asyncio +import copy +import json +import logging +import os +import signal +from contextlib import asynccontextmanager +from http import HTTPStatus +from typing import List, Optional, Type, Union, Dict, Any + +import aiohttp +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse, Response, StreamingResponse + +# yapf: disable +from tensorrt_llm.executor import CppExecutorError +from tensorrt_llm.llmapi.disagg_utils import RouterConfig +from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + ErrorResponse) +from tensorrt_llm.multimodal_params import MultimodalParams + +from tensorrt_llm.serve.router import create_router +from tensorrt_llm.version import __version__ as VERSION + +logging.basicConfig(level=logging.INFO) + +# yapf: enale +TIMEOUT_KEEP_ALIVE = 10 # seconds. + +class OpenAIMultiModalDisaggServer: + + def __init__(self, + gen_servers: List[str] = None, + mm_servers: List[str] = None, + req_timeout_secs: int = 180, + server_start_timeout_secs: int = 180, + ctx_router_config: Optional[RouterConfig] = None, + gen_router_config: Optional[RouterConfig] = None): + + self.ctx_servers = None + self.gen_servers = gen_servers + self.mm_servers = mm_servers + assert len(mm_servers) == 1, "Currently only one multimodal server is supported" + # We should remove this restriction pretty soon (also need to modify the broadcast mm_embed logic in model runner) + assert len(gen_servers) == 1, "Currently only one generation server is supported" + + assert os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") != "1", "Multimodal disaggregated mode is not supported in disaggregated_gen benchmark mode" + + # Session will be initialized in lifespan + self.session: Optional[aiohttp.ClientSession] = None + + @asynccontextmanager + async def lifespan(app: FastAPI): + # Create a persistent aiohttp ClientSession + self.session = aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=0, limit_per_host=0, keepalive_timeout=300), + timeout=aiohttp.ClientTimeout(total=req_timeout_secs)) + + logging.info("Waiting for multimodal and LLM decoder servers to be ready") + await self.wait_for_servers_ready(server_start_timeout_secs) + yield + await self.session.close() # Ensure session cleanup + + self.app = FastAPI(lifespan=lifespan) + + @self.app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + return JSONResponse(status_code=400, content={"error": str(exc)}) + + self.register_routes() + + @staticmethod + def create_error_response( + message: str, + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + raise HTTPException(status_code=500, detail=f"Internal server error {message}") + + def register_routes(self): + self.app.add_api_route("/health", self.health, methods=["GET"]) + self.app.add_api_route("/version", self.version, methods=["GET"]) + self.app.add_api_route("/v1/completions", + self.openai_completion, + methods=["POST"]) + self.app.add_api_route("/v1/chat/completions", + self.openai_chat_completion, + methods=["POST"]) + + async def health(self) -> Response: + return Response(status_code=200) + + async def version(self) -> JSONResponse: + ver = {"version": VERSION} + return JSONResponse(content=ver) + + async def merge_streaming_responses(self, ctx_response, + gen_server: str, + gen_req: Union[CompletionRequest, ChatCompletionRequest]): + try: + # First yield the context response if it's not None + if ctx_response is not None: + # Remove the disaggregated params from the context response + data = ctx_response.model_dump() + del data['choices'][0]['disaggregated_params'] + data = json.dumps(data) + yield f"data: {data}\n\n".encode('utf-8') + + # Then yield the generation responses + if isinstance(gen_req, CompletionRequest): + gen_response = await self.send_completion_request(gen_server, gen_req) + elif isinstance(gen_req, ChatCompletionRequest): + gen_response = await self.send_chat_request(gen_server, gen_req) + else: + raise TypeError("Invalid request type: {type(gen_req).__name__}") + + async for chunk in gen_response.body_iterator: + yield chunk + + finally: + #await self.gen_router.finish_request(gen_req) + pass + + async def openai_completion(self, req: CompletionRequest) -> Response: + raise ValueError("Completion mode is not yet supported for multimodal disaggregated server.") + try: + gen_req = copy.deepcopy(req) + if not isinstance(req.prompt, str): + # Check if it's a list and contains integers + if type(req.prompt) is list and len(req.prompt) == 1: + req.prompt = req.prompt[0] + elif not isinstance(req.prompt, list) or not all(isinstance(x, int) for x in req.prompt): + raise ValueError("Disaggregated server currently only supports single string prompt or list of integers in request") + + ctx_response = await self._process_context_server_request(req, "completion") + + return await self._process_generation_server_request(gen_req, ctx_response) + + except Exception as e: + await self._handle_exception(e) + + async def openai_chat_completion(self, req: ChatCompletionRequest) -> Response: + try: + # Step 1: Process multimodal request and get response + mm_req = copy.deepcopy(req) + mm_response = await self._process_multimodal_server_request(mm_req) + + # Step 2: Append multimodal response directly to the original request + if mm_response and 'embeddings' in mm_response: + req.mm_params = MultimodalParams(**mm_response) + + return await self._process_generation_server_request(req) + + except Exception as e: + await self._handle_exception(e) + + async def _handle_exception(self, exception): + if isinstance(exception, CppExecutorError): + logging.exception(exception) + signal.raise_signal(signal.SIGINT) + elif isinstance(exception, HTTPException): + raise exception # Re-raise HTTP exceptions properly + else: + logging.exception(exception) + raise HTTPException(status_code=500, detail=f"Internal server error {str(exception)}") + + async def _process_multimodal_server_request(self, mm_req: ChatCompletionRequest) -> Optional[Dict[str, Any]]: + """ + Process multimodal request and return response from multimodal server. + + Returns: + Optional[Dict[str, Any]]: Response dictionary from multimodal server or None if processing failed + """ + try: + # Disable streaming for multimodal requests + mm_req.stream = False + + # Send request to multimodal server + async with self.session.post( + self.mm_servers[0] + "/v1/multimodal_encoder", + json=mm_req.model_dump(exclude_unset=True) + ) as response: + if not response.ok: + error_msg = f"Multimodal server returned error: {response.status} {response.reason}" + logging.error(error_msg) + raise HTTPException(status_code=response.status, detail=error_msg) + return await response.json() + + except Exception as e: + logging.error(f"Unexpected error in multimodal processing: {str(e)}") + raise HTTPException(status_code=500, detail=f"Multimodal processing failed: {str(e)}") + + async def _process_generation_server_request(self, gen_req, ctx_response=None): + if ctx_response is not None: + choices = ctx_response.choices + if len(choices) > 1: + raise ValueError("Disagg server returned more than one choice. This is currently not supported in disaggregated server.") + if choices[0].disaggregated_params is None: + raise ValueError("Context server did not return disaggregated params") + + # Append disaggregates parameters to generation request + gen_req.disaggregated_params = choices[0].disaggregated_params + gen_req.disaggregated_params.request_type = "generation_only" + else: + if gen_req.disaggregated_params is not None: + # TODO: support E+PD for now; later we can support E+P+D + del gen_req.disaggregated_params + + # Pick a generation server and send request + # gen_server, _ = await self.gen_router.get_next_server(gen_req) + # TODO: support gen_server routing + gen_server = self.gen_servers[0] + + if not gen_req.stream: + try: + if isinstance(gen_req, CompletionRequest): + raise ValueError("Completion mode is not supported in multimodal disaggregated mode.") + elif isinstance(gen_req, ChatCompletionRequest): + gen_response = await self.send_chat_request(gen_server, gen_req) + return gen_response + finally: + # TODO: support gen_router + #await self.gen_router.finish_request(gen_req) + pass + else: + # Return a streaming response that combines both context and generation responses + return StreamingResponse( + self.merge_streaming_responses(ctx_response, gen_server, gen_req), + media_type="text/event-stream" + ) + + + async def __call__(self, host, port): + config = uvicorn.Config(self.app, + host=host, + port=port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + await uvicorn.Server(config).serve() + + async def create_generator(self, url: str, request: Union[CompletionRequest, ChatCompletionRequest], end_point: str): + async with self.session.post(url + end_point, json=request.model_dump(exclude_unset=True)) as response: + content_type = response.headers.get("Content-Type", "") + if "text/event-stream" in content_type: + if not request.stream: + raise ValueError("Received an event-stream although request stream was False") + + try: + async for line in response.content.iter_any(): + if line: + yield line + await asyncio.sleep(0) + except Exception as e: + logging.error(f"Unexpected error in stream: {e}") + raise + + async def create_completion_generator(self, url: str, request: CompletionRequest): + async for chunk in self.create_generator(url, request, "/v1/completions"): + yield chunk + + async def create_chat_generator(self, url: str, request: ChatCompletionRequest): + async for chunk in self.create_generator(url, request, "/v1/chat/completions"): + yield chunk + + async def send_request(self, url: str, + request: Union[CompletionRequest, ChatCompletionRequest], + endpoint: str, + response_type: Type[Union[CompletionResponse, ChatCompletionResponse]], + create_generator: callable) -> Union[CompletionResponse, ChatCompletionResponse, StreamingResponse]: + if request.stream: + response_generator = create_generator(url, request) + return StreamingResponse(content=response_generator, media_type="text/event-stream") + else: + request_json = request.model_dump(exclude_unset=True) + async with self.session.post(url + endpoint, json=request_json) as response: + content_type = response.headers.get("Content-Type", "") + if "text/event-stream" in content_type: + raise ValueError("Received an event-stream although request stream was False") + + response_dict = await response.json() + if not response.ok: + logging.error(f"Request failed with status {response.status}") + logging.error(f"Response body: {response_dict}") + response.raise_for_status() + return response_type(**response_dict) + + async def send_completion_request(self, url: str, request: CompletionRequest) -> Union[CompletionResponse, StreamingResponse]: + return await self.send_request(url, request, "/v1/completions", CompletionResponse, self.create_completion_generator) + + async def send_chat_request(self, url: str, request: ChatCompletionRequest) -> ChatCompletionResponse: + return await self.send_request(url, request, "/v1/chat/completions", ChatCompletionResponse, self.create_chat_generator) + + async def check_server_ready(self, server_url: str) -> bool: + try: + async with self.session.get(server_url+"/health") as response: + return response.status == 200 + except Exception: + return False + + async def wait_for_servers_ready(self, server_start_timeout_secs: int = 180): + async def are_servers_ready(): + context_ready = True + if self.ctx_servers is not None: + context_ready = all([await self.check_server_ready(url) for url in self.ctx_servers]) + generation_ready = all([await self.check_server_ready(url) for url in self.gen_servers]) + return context_ready and generation_ready + + async def check_all_servers_ready(): + while not await are_servers_ready(): + wait_time = 3 + logging.info("Context and generation servers are not ready. Waiting...") + await asyncio.sleep(wait_time) + + try: + await asyncio.wait_for(check_all_servers_ready(), timeout=server_start_timeout_secs) + except asyncio.CancelledError: + raise TimeoutError("Timeout waiting for multimodal and LLM decoder servers to be ready") diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 61111c49717..84dfa511cd5 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -15,7 +15,7 @@ from tensorrt_llm.executor.serialization import register_approved_ipc_class from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams - +from tensorrt_llm.multimodal_params import MultimodalParams class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields & allow to initialize by both alias and field name @@ -508,6 +508,11 @@ class ChatCompletionRequest(OpenAIBaseModel): description=("Parameters for disaggregated serving"), ) + mm_params: Optional[MultimodalParams] = Field( + default=None, + description=("Parameters for multimodal serving"), + ) + # doc: end-chat-completion-extra-params def to_sampling_params(self) -> SamplingParams: diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index dcc71e43bf4..fa19ad78493 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -252,8 +252,9 @@ async def create_chat_response( sampling_params = request.to_sampling_params() postproc_args = ChatPostprocArgs.from_request(request) disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) - - conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(request.messages, self.model_config) + # Skip loading when mm_params is provided, i.e., loading + processing is done by encoder server. + skip_loading = True if request.mm_params is not None else False + conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(request.messages, self.model_config, skip_loading) if request.prompt_token_ids is not None: prompt = request.prompt_token_ids @@ -291,7 +292,8 @@ async def create_chat_response( sampling_params=sampling_params, _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=request.stream, - disaggregated_params=disaggregated_params + disaggregated_params=disaggregated_params, + disagg_mm_params=request.mm_params ) asyncio.create_task(self.await_disconnected(raw_request, promise)) if not self.postproc_worker_enabled: diff --git a/tests/unittest/_torch/multimodal_disagg/test_encoder_llm_disagg.py b/tests/unittest/_torch/multimodal_disagg/test_encoder_llm_disagg.py new file mode 100644 index 00000000000..5b1293f69a9 --- /dev/null +++ b/tests/unittest/_torch/multimodal_disagg/test_encoder_llm_disagg.py @@ -0,0 +1,200 @@ +import os +import pytest +import copy +import json + +from tensorrt_llm.executor.multimodal.request import MultimodalRequest +from tensorrt_llm._torch.multimodal.mm_encoder import MultimodalEncoder +from tensorrt_llm.llmapi.llm import LLM, SamplingParams +from tensorrt_llm.llmapi import KvCacheConfig +from tensorrt_llm.inputs import default_multimodal_input_loader + +example_images = [ + "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png", + "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg", +] + + +@pytest.fixture(scope="function") +def multimodal_model_config(): + """Get multimodal model configuration similar to integration tests""" + # You can extend this to support multiple models or get from environment + model_configs = { + 'llava-v1.6-mistral-7b-hf': { + 'model_name': 'llava-v1.6-mistral-7b-hf', + 'hf_model_dir': 'llava-hf/llava-v1.6-mistral-7b-hf', # HuggingFace model ID + } + } + + return model_configs['llava-v1.6-mistral-7b-hf'] + + +@pytest.mark.parametrize("model_key", [ + "llava-v1.6-mistral-7b-hf", +]) +def test_single_image_chat(model_key, multimodal_model_config): + """Test processing single image using disaggregated encoder + LLM API. + + This test verifies that disaggregated multimodal generation produces identical + results to standard multimodal generation by comparing outputs. + """ + # Get model configuration + if model_key != "llava-v1.6-mistral-7b-hf": + pytest.skip(f"Skipping test for {model_key} - only testing llava-v1.6-mistral-7b-hf for now") + + # Extract model information from config + model_name = multimodal_model_config['model_name'] + encoder_model_dir = multimodal_model_config['hf_model_dir'] + + # Test configuration + max_tokens = 64 + free_gpu_memory_fraction = 0.6 + max_batch_size = 1 + + # Test data - OpenAI chat completion format + prompts = ["Describe the natural environment in the image."] + media = [example_images[0]] + + # Create OpenAI chat messages format + messages_list = [] + for prompt, image_url in zip(prompts, media): + messages = [{ + "role": "user", + "content": [{ + "type": "text", + "text": prompt + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }] + }] + messages_list.append(messages) + + # Sampling configuration + sampling_params = SamplingParams(max_tokens=max_tokens) + kv_cache_config = KvCacheConfig( + enable_block_reuse=False, + free_gpu_memory_fraction=free_gpu_memory_fraction, + ) + + # Step 1: Process multimodal data using disaggregated encoder + encoder = None + llm = None + + try: + encoder = MultimodalEncoder(model=encoder_model_dir, max_batch_size=max_batch_size) + + # Process all messages through the encoder + multimodal_requests = [MultimodalRequest.from_chat_messages(msgs) for msgs in messages_list] + results = encoder.generate_from_mm_request(multimodal_requests) + + # Validate encoder output + assert results is not None, "Encoder returned None results" + assert len(results) == len(prompts), f"Expected {len(prompts)} results, got {len(results)}" + + mm_params_list = [] + for i, result in enumerate(results): + mm_params = result.multimodal_params + assert mm_params is not None, f"Multimodal parameters are None for request {i}" + assert hasattr(mm_params, 'embeddings'), f"Multimodal parameters missing embeddings attribute for request {i}" + assert mm_params.num_items > 0, f"Expected multimodal items > 0 for request {i}, got {mm_params.num_items}" + mm_params_list.append(mm_params) + + # Step 2: Initialize LLM and prepare inputs + llm = LLM( + model=encoder_model_dir, + backend='pytorch', + kv_cache_config=kv_cache_config, + trust_remote_code=True + ) + + # Load model configuration + config_path = os.path.join(llm._hf_model_dir, 'config.json') + assert os.path.exists(config_path), f"Model config not found at {config_path}" + + with open(config_path, 'r') as f: + model_config = json.load(f) + model_type = model_config['model_type'] + + # Prepare multimodal inputs + inputs = default_multimodal_input_loader( + tokenizer=llm.tokenizer, + model_dir=llm._hf_model_dir, + model_type=model_type, + modality="image", + prompts=prompts, + media=media, + image_data_format="pt" + ) + + # Validate inputs structure + assert len(inputs) == len(prompts), f"Expected {len(prompts)} inputs, got {len(inputs)}" + # Step 3: Generate reference output with raw multimodal inputs + outputs_ref = llm.generate(inputs, sampling_params=sampling_params) + + # Validate reference outputs + assert outputs_ref is not None, "Reference generation returned None" + assert len(outputs_ref) == len(prompts), f"Expected {len(prompts)} reference outputs, got {len(outputs_ref)}" + for i, output in enumerate(outputs_ref): + assert len(output.outputs) > 0, f"Reference generation has no output text for input {i}" + + # Step 4: Prepare inputs for disaggregated multimodal generation + inputs_disagg = copy.deepcopy(inputs) + for i, input_data in enumerate(inputs_disagg): + # disaggregated generation doesn't need raw image data, but keep the key + input_data["multi_modal_data"]["image"] = [] + + # Step 5: Generate output using disaggregated multimodal parameters + # Note: For batch processing, we need to match mm_params with inputs + outputs = llm.generate(inputs_disagg, sampling_params=sampling_params, disagg_mm_params=mm_params_list) + + # Validate disaggregated outputs + assert outputs is not None, "Disaggregated generation returned None" + assert len(outputs) == len(prompts), f"Expected {len(prompts)} disaggregated outputs, got {len(outputs)}" + for i, output in enumerate(outputs): + assert len(output.outputs) > 0, f"Disaggregated generation has no output text for input {i}" + + # Step 6: Compare outputs - they should match exactly + assert len(outputs_ref) == len(outputs), f"Number of outputs don't match: {len(outputs_ref)} vs {len(outputs)}" + + for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs)): + # Compare prompts + assert ref_output.prompt == test_output.prompt, \ + f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}" + + # Compare number of generated outputs + assert len(ref_output.outputs) == len(test_output.outputs), \ + f"Number of generated outputs don't match for output {i}: {len(ref_output.outputs)} vs {len(test_output.outputs)}" + + # Compare generated text and other attributes + for j, (ref_gen, test_gen) in enumerate(zip(ref_output.outputs, test_output.outputs)): + assert ref_gen.text == test_gen.text, \ + f"Generated text doesn't match for output {i}, generation {j}:\nReference: {ref_gen.text!r}\nTest: {test_gen.text!r}" + + # Compare token IDs if available + if hasattr(ref_gen, 'token_ids') and hasattr(test_gen, 'token_ids'): + assert ref_gen.token_ids == test_gen.token_ids, \ + f"Token IDs don't match for output {i}, generation {j}" + + # Compare log probabilities if available + if hasattr(ref_gen, 'logprobs') and hasattr(test_gen, 'logprobs'): + assert ref_gen.logprobs == test_gen.logprobs, \ + f"Log probabilities don't match for output {i}, generation {j}" + + # Verify non-empty generation for all outputs + for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs)): + ref_text = ref_output.outputs[0].text.strip() + test_text = test_output.outputs[0].text.strip() + assert len(ref_text) > 0, f"Reference generation produced empty text for input {i}" + assert len(test_text) > 0, f"Disaggregated generation produced empty text for input {i}" + + finally: + # Cleanup resources + if encoder is not None: + del encoder + if llm is not None: + del llm + diff --git a/tests/unittest/_torch/multimodal_disagg/test_share_cuda_tensor.py b/tests/unittest/_torch/multimodal_disagg/test_share_cuda_tensor.py new file mode 100644 index 00000000000..166f821538c --- /dev/null +++ b/tests/unittest/_torch/multimodal_disagg/test_share_cuda_tensor.py @@ -0,0 +1,54 @@ +import unittest +import multiprocessing as mp +import torch + +from tensorrt_llm._torch.multimodal.mm_utils import SharedTensorContainer, _SharedTensorRebuildMethodRegistry + + +class TestShareCudaTensor(unittest.TestCase): + """Test cases for sharing CUDA tensor between processes.""" + + @classmethod + def setUpClass(cls): + """Initialize the registry before running tests.""" + _SharedTensorRebuildMethodRegistry.initialize() + + def setUp(self): + """Set up test fixtures.""" + self.ref_tensor = torch.randn(3, 4, 5) + self.cuda_available = torch.cuda.is_available() + if self.cuda_available: + torch.cuda.set_device(0) + + @staticmethod + def _producer(q, tensor, device=None): + """Producer: create CUDA tensor and share it.""" + try: + if device is not None: + torch.cuda.set_device(device) + tensor = tensor.cuda() + container = SharedTensorContainer.from_tensor(tensor) + q.put(('success', container.dump_to_dict())) + except Exception as e: + q.put(('error', str(e))) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_share_cuda_tensor(self): + """Test tensor sharing between processes.""" + mp.set_start_method('spawn', force=True) + queue = mp.Queue() + + # Producer process + producer = mp.Process(target=self._producer, args=(queue, self.ref_tensor, 0)) + producer.start() + status, data = queue.get(timeout=100) + # Verify + self.assertEqual(status, 'success') + reconstructed = SharedTensorContainer.from_dict(data).get_local_view() + self.assertTrue(torch.allclose(reconstructed.cpu(), self.ref_tensor)) + producer.join() + + +if __name__ == '__main__': + mp.set_start_method('spawn', force=True) + unittest.main() \ No newline at end of file