diff --git a/lib/bindings/python/rust/llm/block_manager/vllm.rs b/lib/bindings/python/rust/llm/block_manager/vllm.rs index 72aec067a2..52c0d1df34 100644 --- a/lib/bindings/python/rust/llm/block_manager/vllm.rs +++ b/lib/bindings/python/rust/llm/block_manager/vllm.rs @@ -136,7 +136,7 @@ impl KvbmCacheManager { }; let disk_blocks = if let Some(disk) = self.block_manager().disk() { - disk.match_sequence_hashes_blocking(&sequence_hashes) + disk.match_sequence_hashes_blocking(&sequence_hashes[host_blocks.len()..]) .map_err(to_pyerr)? } else { vec![] diff --git a/lib/bindings/python/src/dynamo/llm/vllm_integration/kv_cache_manager.py b/lib/bindings/python/src/dynamo/llm/vllm_integration/kv_cache_manager.py index ca147aea68..d735942e13 100644 --- a/lib/bindings/python/src/dynamo/llm/vllm_integration/kv_cache_manager.py +++ b/lib/bindings/python/src/dynamo/llm/vllm_integration/kv_cache_manager.py @@ -117,19 +117,22 @@ def get_offloaded_computed_blocks( sequence_hashes = self._create_slot(request) - host_owned_blocks, disk_owned_blocks = self.cache_manager.get_offloaded_computed_blocks(sequence_hashes) + remaining_sequence_hashes = sequence_hashes[num_computed_tokens // self.block_size:] + + host_owned_blocks, disk_owned_blocks = self.cache_manager.get_offloaded_computed_blocks(remaining_sequence_hashes) host_block_count = host_owned_blocks.block_count() disk_block_count = disk_owned_blocks.block_count() num_host_computed_tokens = host_block_count * self.block_size num_disk_computed_tokens = disk_block_count * self.block_size - num_external_hit_tokens = max(num_disk_computed_tokens, num_host_computed_tokens) + num_external_hit_tokens = num_host_computed_tokens + num_disk_computed_tokens + + need_to_allocate = num_external_hit_tokens - need_to_allocate = num_external_hit_tokens - num_computed_tokens # In a full-prompt-hit case, we need to recompute the last token - if num_external_hit_tokens == request.num_tokens: + if num_computed_tokens + num_external_hit_tokens == request.num_tokens: need_to_allocate -= 1 # TODO: add stats for offloaded computed tokens