99from blake3 import blake3
1010from torchvision .transforms import ToPILImage
1111
12+ from tensorrt_llm .logger import logger
13+
1214# Default hasher
1315default_hasher = blake3
1416
@@ -435,13 +437,22 @@ def apply_mm_hashes(mm_data: Dict[str, Any],
435437 """Apply hashing to multimodal data items."""
436438
437439 def _hash_image (image ):
438- # only support single modality w/ PIL.Image.Image for now
439440 # TODO: possible hash collision w/ this simplified version (vllm/PR/17378)
440441 hasher = hash_lib ()
441442 if isinstance (image , torch .Tensor ):
442- # TODO: Device tensor hashing is an open issue. Limited hashing to CPU for now.
443- image = image .cpu ()
444- hasher .update (serialize_item (image ))
443+ # Ensure tensor is on CPU and contiguous for consistent hashing
444+ image = image .detach ().cpu ().contiguous ()
445+ hasher .update (serialize_item (image ))
446+ elif isinstance (image , list ):
447+ # Hash each frame with a separator to avoid collisions between [A,B] and [AB]
448+ for frame in image :
449+ hasher .update (b"<frame>" )
450+ if isinstance (frame , torch .Tensor ):
451+ frame = frame .detach ().cpu ().contiguous ()
452+ hasher .update (serialize_item (frame ))
453+ else :
454+ hasher .update (serialize_item (image ))
455+
445456 return hasher .hexdigest ()
446457
447458 mm_items = {
@@ -483,54 +494,71 @@ def find_mm_token_lengths(mm_data: Dict[str, Any],
483494 num_mm_tokens = {}
484495
485496 for modality , items in mm_items .items ():
486- if modality != "image" :
487- #TODO: support other modalities
488- raise ValueError (
489- f"Unsupported modality: { modality } . Only 'image' modality is currently supported for hashing."
490- )
491- if not hasattr (input_processor , "get_num_tokens_per_image" ):
492- #TODO: backward compatibility for models that don't yet have get_num_tokens_per_image implemented
493- #TODO: only support qwen2_vl for now
497+ if not hasattr (input_processor , f"get_num_tokens_per_{ modality } " ):
494498 raise AttributeError (
495- f"Input processor { type (input_processor ).__name__ } does not have 'get_num_tokens_per_image ' method required for multimodal hashing."
499+ f"Input processor { type (input_processor ).__name__ } does not have 'get_num_tokens_per_ { modality } ' method required for multimodal hashing."
496500 )
497501
498502 modality_token_lengths = []
499503 for item in items :
500- if isinstance (item , torch .Tensor ):
501- item = ToPILImage ()(item )
502- num_tokens = input_processor .get_num_tokens_per_image (
503- image_width = item .width ,
504- image_height = item .height ,
505- )
506- modality_token_lengths .append (num_tokens )
504+ if modality == "image" :
505+ if isinstance (item , torch .Tensor ):
506+ item = ToPILImage ()(item )
507+ num_tokens = input_processor .get_num_tokens_per_image (
508+ image_width = item .width ,
509+ image_height = item .height ,
510+ )
511+ modality_token_lengths .append (num_tokens )
512+ elif modality == "video" :
513+ assert isinstance (item , list ), "Video must be a list of frames"
514+ if isinstance (item [0 ], torch .Tensor ):
515+ item = [ToPILImage ()(frame ) for frame in item ]
516+ num_tokens = input_processor .get_num_tokens_per_video (
517+ video_width = item [0 ].width ,
518+ video_height = item [0 ].height ,
519+ num_frames = len (item ),
520+ )
521+ modality_token_lengths .append (num_tokens )
522+ else :
523+ # TODO: add audio support if needed
524+ raise ValueError (f"Unsupported modality: { modality } " )
507525
508526 num_mm_tokens [modality ] = modality_token_lengths
509527
510- return num_mm_tokens [ 'image' ] # flatten all mm instances to a single list
528+ return num_mm_tokens # flatten all mm instances to a single list
511529
512530
513- def find_mm_token_positions (input_ids : Union [ torch . Tensor , List [ int ],
514- np .ndarray ],
515- num_mm_tokens : List [int ],
516- vocab_size : int ,
517- mm_token_ids : torch .Tensor = None ) -> List [int ]:
531+ def find_mm_token_positions (
532+ input_ids : Union [ torch . Tensor , List [ int ], np .ndarray ],
533+ num_mm_tokens : List [int ],
534+ vocab_size : Optional [ int ] = None ,
535+ mm_token_ids : Optional [ torch .Tensor ] = None ) -> List [int ]:
518536 """Get multimodal token positions using IDs > vocab_size and known lengths.
519537
520538 This function finds multimodal tokens (with IDs > vocab_size) and uses the
521539 provided lengths in num_mm_tokens to identify where each chunk starts.
522540 This works even when there are no gaps between different image sequences
523541 (e.g., when all images use the same token IDs).
542+ Note at least one of vocab_size or mm_token_ids must be provided. If mm_token_ids is provided, vocab_size is ignored.
524543
525544 Args:
526545 input_ids: Token sequence (tensor, list, or numpy array)
527546 num_mm_tokens: List of lengths for each multimodal token chunk
528547 vocab_size: Size of the model's vocabulary
529- mm_token_ids (optional): possible token ids for multimodal tokens
548+ mm_token_ids: Possible token ids for multimodal tokens
530549
531550 Returns:
532551 List of starting positions for each multimodal token chunk
533552 """
553+ if mm_token_ids is None and vocab_size is None :
554+ raise ValueError (
555+ "Provide either mm_token_ids or vocab_size to find multimodal token positions"
556+ )
557+ if mm_token_ids is not None and vocab_size is not None :
558+ logger .warning (
559+ "Both mm_token_ids and vocab_size are provided, using mm_token_ids and ignoring vocab_size"
560+ )
561+
534562 # Convert input_ids to tensor if needed
535563 if not isinstance (input_ids , torch .Tensor ):
536564 if isinstance (input_ids , list ):
@@ -542,6 +570,9 @@ def find_mm_token_positions(input_ids: Union[torch.Tensor, List[int],
542570 if mm_token_ids is None :
543571 mm_mask = input_ids >= vocab_size
544572 else :
573+ if mm_token_ids .ndim != 1 :
574+ raise ValueError ("mm_token_ids must be a 1D tensor" )
575+ mm_token_ids = torch .unique (mm_token_ids )
545576 mm_mask = torch .isin (input_ids , mm_token_ids )
546577
547578 # If no multimodal tokens found, return empty list
0 commit comments