Skip to content

Commit bea117c

Browse files
chang-lWong4j
authored andcommitted
[TRTLLM-7410][feat] Support hashing and KV cache reuse for videos (NVIDIA#7360)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
1 parent 86b034c commit bea117c

File tree

6 files changed

+323
-84
lines changed

6 files changed

+323
-84
lines changed

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from tensorrt_llm.inputs.multimodal import MultimodalParams
1616

17-
from ...inputs import (ExtraProcessedInputs, InputProcessor,
18-
MultimodalPlaceholderMetadata,
17+
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
18+
InputProcessor, MultimodalPlaceholderMetadata,
1919
MultimodalPlaceholderPlacement, TextPrompt,
2020
register_input_processor)
2121
from ...llmapi.utils import download_hf_model
@@ -32,7 +32,7 @@
3232
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
3333

3434

35-
class LlavaNextInputProcessor(InputProcessor):
35+
class LlavaNextInputProcessor(BaseMultimodalInputProcessor, InputProcessor):
3636

3737
def __init__(self,
3838
model_path: str,
@@ -56,17 +56,6 @@ def __init__(self,
5656
self.vocab_size = model_config.vocab_size
5757
self.config = model_config.vision_config
5858

59-
def get_num_tokens_per_image(
60-
self,
61-
*,
62-
image_width: int,
63-
image_height: int,
64-
) -> int:
65-
image_size = (image_height, image_width)
66-
num_image_tokens = self.processor._get_num_multimodal_tokens(
67-
[image_size])["num_image_tokens"][0]
68-
return num_image_tokens
69-
7059
def _postprocess(
7160
self, input_ids: torch.Tensor, mm_features: Union[torch.Tensor,
7261
List[torch.Tensor]]

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from transformers import (AutoProcessor, AutoTokenizer, PretrainedConfig,
88
PreTrainedModel, Qwen2_5_VLForConditionalGeneration,
99
Qwen2VLForConditionalGeneration)
10-
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
1110

1211
from tensorrt_llm.inputs.multimodal import MultimodalParams
1312

1413
from ..._utils import nvtx_range_debug
1514
from ...functional import RopeEmbeddingUtils, RotaryScalingType
16-
from ...inputs import (ExtraProcessedInputs, InputProcessor,
17-
MultimodalPlaceholderMetadata,
15+
from ...inputs import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
16+
InputProcessor, MultimodalPlaceholderMetadata,
1817
MultimodalPlaceholderPlacement, TextPrompt,
1918
register_input_processor)
2019
from ...logger import logger
@@ -29,7 +28,7 @@
2928
DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
3029

3130

32-
class Qwen2VLInputProcessorBase(InputProcessor):
31+
class Qwen2VLInputProcessorBase(BaseMultimodalInputProcessor, InputProcessor):
3332

3433
def __init__(self,
3534
model_path: str,
@@ -45,6 +44,9 @@ def __init__(self,
4544
trust_remote_code=trust_remote_code)
4645

4746
self.tllm_multimodal_token_id = self.model_config.vocab_size + 1
47+
# temporal patch size for video frames
48+
self.temporal_patch_size = getattr(model_config.vision_config,
49+
'temporal_patch_size', 1)
4850

4951
@classmethod
5052
def get_rope_index(
@@ -220,38 +222,6 @@ def get_rope_index(
220222
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
221223
return position_ids, mrope_position_deltas
222224

223-
def get_num_tokens_per_image(
224-
self,
225-
*,
226-
image_width: int,
227-
image_height: int,
228-
num_frames: int = 1,
229-
do_resize: bool = True,
230-
):
231-
patch_size = self.model_config.vision_config.patch_size
232-
merge_size = self.model_config.vision_config.spatial_merge_size
233-
temporal_patch_size = self.model_config.vision_config.temporal_patch_size
234-
if do_resize:
235-
resized_height, resized_width = smart_resize(
236-
height=image_height,
237-
width=image_width,
238-
factor=patch_size * merge_size,
239-
min_pixels=self.processor.image_processor.min_pixels,
240-
max_pixels=self.processor.image_processor.max_pixels,
241-
)
242-
image_width, image_height = resized_width, resized_height
243-
244-
padded_num_frames = num_frames + num_frames % temporal_patch_size
245-
246-
grid_t = max(padded_num_frames // temporal_patch_size, 1)
247-
grid_h = image_height // patch_size
248-
grid_w = image_width // patch_size
249-
250-
num_patches = grid_t * grid_h * grid_w
251-
num_vision_tokens = num_patches // (merge_size**2)
252-
253-
return num_vision_tokens
254-
255225
def _preprocess(self, text: dict[str, any], mm_data: dict[str, any],
256226
mm_processor_kwargs: Dict[str, Any]):
257227
images = mm_data.get("image")

tensorrt_llm/inputs/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
22
from .multimodal import MultimodalInput
3-
from .registry import (ExtraProcessedInputs, InputProcessor,
4-
MultimodalPlaceholderMetadata,
3+
from .registry import (BaseMultimodalInputProcessor, ExtraProcessedInputs,
4+
InputProcessor, MultimodalPlaceholderMetadata,
55
MultimodalPlaceholderPlacement, create_input_processor,
66
create_input_processor_with_hash,
77
register_input_processor)
@@ -27,6 +27,7 @@
2727
"create_input_processor_with_hash",
2828
"register_input_processor",
2929
"ExtraProcessedInputs",
30+
"BaseMultimodalInputProcessor",
3031
"MultimodalPlaceholderMetadata",
3132
"MultimodalPlaceholderPlacement",
3233
"ConversationMessage",

tensorrt_llm/inputs/multimodal.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from blake3 import blake3
1010
from torchvision.transforms import ToPILImage
1111

12+
from tensorrt_llm.logger import logger
13+
1214
# Default hasher
1315
default_hasher = blake3
1416

@@ -435,13 +437,22 @@ def apply_mm_hashes(mm_data: Dict[str, Any],
435437
"""Apply hashing to multimodal data items."""
436438

437439
def _hash_image(image):
438-
# only support single modality w/ PIL.Image.Image for now
439440
# TODO: possible hash collision w/ this simplified version (vllm/PR/17378)
440441
hasher = hash_lib()
441442
if isinstance(image, torch.Tensor):
442-
# TODO: Device tensor hashing is an open issue. Limited hashing to CPU for now.
443-
image = image.cpu()
444-
hasher.update(serialize_item(image))
443+
# Ensure tensor is on CPU and contiguous for consistent hashing
444+
image = image.detach().cpu().contiguous()
445+
hasher.update(serialize_item(image))
446+
elif isinstance(image, list):
447+
# Hash each frame with a separator to avoid collisions between [A,B] and [AB]
448+
for frame in image:
449+
hasher.update(b"<frame>")
450+
if isinstance(frame, torch.Tensor):
451+
frame = frame.detach().cpu().contiguous()
452+
hasher.update(serialize_item(frame))
453+
else:
454+
hasher.update(serialize_item(image))
455+
445456
return hasher.hexdigest()
446457

447458
mm_items = {
@@ -483,54 +494,71 @@ def find_mm_token_lengths(mm_data: Dict[str, Any],
483494
num_mm_tokens = {}
484495

485496
for modality, items in mm_items.items():
486-
if modality != "image":
487-
#TODO: support other modalities
488-
raise ValueError(
489-
f"Unsupported modality: {modality}. Only 'image' modality is currently supported for hashing."
490-
)
491-
if not hasattr(input_processor, "get_num_tokens_per_image"):
492-
#TODO: backward compatibility for models that don't yet have get_num_tokens_per_image implemented
493-
#TODO: only support qwen2_vl for now
497+
if not hasattr(input_processor, f"get_num_tokens_per_{modality}"):
494498
raise AttributeError(
495-
f"Input processor {type(input_processor).__name__} does not have 'get_num_tokens_per_image' method required for multimodal hashing."
499+
f"Input processor {type(input_processor).__name__} does not have 'get_num_tokens_per_{modality}' method required for multimodal hashing."
496500
)
497501

498502
modality_token_lengths = []
499503
for item in items:
500-
if isinstance(item, torch.Tensor):
501-
item = ToPILImage()(item)
502-
num_tokens = input_processor.get_num_tokens_per_image(
503-
image_width=item.width,
504-
image_height=item.height,
505-
)
506-
modality_token_lengths.append(num_tokens)
504+
if modality == "image":
505+
if isinstance(item, torch.Tensor):
506+
item = ToPILImage()(item)
507+
num_tokens = input_processor.get_num_tokens_per_image(
508+
image_width=item.width,
509+
image_height=item.height,
510+
)
511+
modality_token_lengths.append(num_tokens)
512+
elif modality == "video":
513+
assert isinstance(item, list), "Video must be a list of frames"
514+
if isinstance(item[0], torch.Tensor):
515+
item = [ToPILImage()(frame) for frame in item]
516+
num_tokens = input_processor.get_num_tokens_per_video(
517+
video_width=item[0].width,
518+
video_height=item[0].height,
519+
num_frames=len(item),
520+
)
521+
modality_token_lengths.append(num_tokens)
522+
else:
523+
# TODO: add audio support if needed
524+
raise ValueError(f"Unsupported modality: {modality}")
507525

508526
num_mm_tokens[modality] = modality_token_lengths
509527

510-
return num_mm_tokens['image'] # flatten all mm instances to a single list
528+
return num_mm_tokens # flatten all mm instances to a single list
511529

512530

513-
def find_mm_token_positions(input_ids: Union[torch.Tensor, List[int],
514-
np.ndarray],
515-
num_mm_tokens: List[int],
516-
vocab_size: int,
517-
mm_token_ids: torch.Tensor = None) -> List[int]:
531+
def find_mm_token_positions(
532+
input_ids: Union[torch.Tensor, List[int], np.ndarray],
533+
num_mm_tokens: List[int],
534+
vocab_size: Optional[int] = None,
535+
mm_token_ids: Optional[torch.Tensor] = None) -> List[int]:
518536
"""Get multimodal token positions using IDs > vocab_size and known lengths.
519537
520538
This function finds multimodal tokens (with IDs > vocab_size) and uses the
521539
provided lengths in num_mm_tokens to identify where each chunk starts.
522540
This works even when there are no gaps between different image sequences
523541
(e.g., when all images use the same token IDs).
542+
Note at least one of vocab_size or mm_token_ids must be provided. If mm_token_ids is provided, vocab_size is ignored.
524543
525544
Args:
526545
input_ids: Token sequence (tensor, list, or numpy array)
527546
num_mm_tokens: List of lengths for each multimodal token chunk
528547
vocab_size: Size of the model's vocabulary
529-
mm_token_ids (optional): possible token ids for multimodal tokens
548+
mm_token_ids: Possible token ids for multimodal tokens
530549
531550
Returns:
532551
List of starting positions for each multimodal token chunk
533552
"""
553+
if mm_token_ids is None and vocab_size is None:
554+
raise ValueError(
555+
"Provide either mm_token_ids or vocab_size to find multimodal token positions"
556+
)
557+
if mm_token_ids is not None and vocab_size is not None:
558+
logger.warning(
559+
"Both mm_token_ids and vocab_size are provided, using mm_token_ids and ignoring vocab_size"
560+
)
561+
534562
# Convert input_ids to tensor if needed
535563
if not isinstance(input_ids, torch.Tensor):
536564
if isinstance(input_ids, list):
@@ -542,6 +570,9 @@ def find_mm_token_positions(input_ids: Union[torch.Tensor, List[int],
542570
if mm_token_ids is None:
543571
mm_mask = input_ids >= vocab_size
544572
else:
573+
if mm_token_ids.ndim != 1:
574+
raise ValueError("mm_token_ids must be a 1D tensor")
575+
mm_token_ids = torch.unique(mm_token_ids)
545576
mm_mask = torch.isin(input_ids, mm_token_ids)
546577

547578
# If no multimodal tokens found, return empty list

0 commit comments

Comments
 (0)