Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions docs/source/1.0/examples/kvcacheconfig.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# How to Change KV Cache Behavior

Set KV cache behavior by providing the optional ```kv_cache_config argument``` when you create the LLM engine. Consider the quickstart example found in ```examples/pytorch/quickstart.py```:

```python
from tensorrt_llm import LLM, SamplingParams


def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=32)

llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0')
outputs = llm.generate(prompts, sampling_params)

for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
main()
```

This example runs with default KV cache properties. The default value for ```python free_gpu_memory_fraction``` is 0.9, which means TensorRT-LLM tries to allocate 90% of free GPU memory (after loading weights) for KV cache. Depending on your use case, this allocation can be too aggressive. You can reduce this value to 0.7 by adding the following lines to the quickstart example:

```python
from tensorrt_llm.llmapi import KvCacheConfig
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0', kv_cache_config=kv_cache_config)
```

You can also set properties after you create ```KvCacheConfig```. For example:

```python
kv_cache_config = KvCacheConfig()
kv_cache_config.enable_block_reuse = False
llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0', kv_cache_config=kv_cache_config)
```

This code disables block reuse for the quick start example.
68 changes: 68 additions & 0 deletions docs/source/1.0/examples/kvcacheretentionconfig.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# How to Change Block Priorities

You can change block priority by providing the optional ```kv_cache_retention_config``` argument when you submit a request to the LLM engine. Consider the quick start example found in ```examples/pytorch/quickstart.py```:

```python
from tensorrt_llm import LLM, SamplingParams


def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=32)

llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0')
outputs = llm.generate(prompts, sampling_params)

for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
main()
```

The blocks from the prompts are stored for reuse with the default priority of 35 on a scale from 1 to 100, where 100 is highest priority and 1 is lowest priority. Assume you know that the first four tokens of each prompt represent a system prompt that should be stored with high priority (100). You can achieve this by providing a KV cache retention config object when you submit the prompts for generation:

```python
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheRetentionConfig


def main():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=32)

llm = LLM(model='TinyLlama/TinyLlama-1.1B-Chat-v1.0')

# Set priority for first 4 prompt tokens to 100. All other tokens set to default (35) priority.
# This policy never lapses.
tokenRangeRetentionConfig = KvCacheRetentionConfig.TokenRangeRetentionConfig(0, 4, 100, None)
kv_cache_retention_config = KvCacheRetentionConfig(
token_range_retention_configs=[tokenRangeRetentionConfig],
decode_retention_priority=35, # Set generated tokens to default priority
decode_duration_ms=None)
outputs = llm.generate(prompts, sampling_params, kv_cache_retention_config=kv_cache_retention_config)

for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
main()
```

This example uses a single ```kv_cache_retention_config``` object for all the prompts. You can also provide a list that must have the same length as the list of prompts.
12 changes: 8 additions & 4 deletions docs/source/1.0/features/kvcache.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ One caveat in the current code is that only leaf blocks can be evicted (leaves a

### Retention Policy

Blocks are assigned priority in line with the [retention policy](llm-api/reference.html#tensorrt_llm.llmapi.KvCacheRetentionConfig) of the request. The retention policy is a list of [TokenRangeRetentionConfig](llm-api/reference.html#tensorrt_llm.llmapi.KvCacheRetentionConfig.KvCacheRetentionConfig) objects, each specifying priority for a given range of tokens, such as "assign priority X to tokens 10 through 61". You can also assign a duration in milliseconds for this to remain in effect, priority will revert to the default after a period of ```duration_ms``` has elapsed from the first time the block was made available for reuse. TokenRangeRetentionConfig only applies to input (prompt) tokens. The property ```decode_retention_policy``` specifies what priority to assign to blocks with generated (decoded) tokens and ```decode_duration_ms``` specifies how long this should remain in effect, after which priority will revert to the default. Default priority is 35. Any property that expects a duration can be set to None, which indicates retention policy never expires.
Blocks are assigned priority in line with the [retention policy](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.KvCacheRetentionConfig) of the request. Blocks with lower priority scores will be freed preferentially to blocks with higher priority. The retention policy is a list of [TokenRangeRetentionConfig](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.KvCacheRetentionConfig.TokenRangeRetentionConfig) objects, each specifying priority for a given range of tokens, such as "assign priority X to tokens 10 through 61". You can also assign a duration in milliseconds for this to remain in effect. Priority reverts to the default of 35 after a period of ```duration_ms``` has elapsed from the first time the block was made available for reuse. TokenRangeRetentionConfig only applies to input (prompt) tokens. The property ```decode_retention_policy``` specifies what priority to assign to blocks with generated (decoded) tokens and ```decode_duration_ms``` specifies how long this should remain in effect. Priority reverts to the default after expiration. Any property that expects a duration can be set to None. This indicates that particular part of the retention policy never expires.

Not in use: ```transfer_mode``` is a debug option and should not be used.

See [this example](../examples/kvcacheretentionconfig.md) for an example of how to change block priorities of specific requests by altering their retention policy.

### Speculative Decoding

Reuse across requests is only supported for one model MTP, all other [speculative decoding](speculative-decoding.md) algorithms must disable block reuse.
Reuse across requests is supported by all speculative decoding models. Please see [speculative decoding](speculative-decoding.md) for more details.

## Limited Attention Window Size

Expand All @@ -40,7 +42,9 @@ TensorRT-LLM takes advantage of grouped query attention in order to save memory.

## Controlling KV Cache Behavior

Many of the features in the KV cache system are optional or have user defined properties that alter how they work. Users can control KV cache features through class [KVCacheConfig](llm-api/reference.html#tensorrt_llm.llmapi.KvCacheConfig). The remainder of this section describes how to change the most important behaviors of KV cache system.
Many of the features in the KV cache system are optional or have user defined properties that alter how they work. Users can control KV cache features through class [KVCacheConfig](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.KvCacheConfig). The remainder of this section describes how to change the most important behaviors of the KV cache system.

See [this example](../examples/kvcacheconfig.md) for an example of how to use KvCacheConfig to control KV cache behavior.

### Datatype

Expand All @@ -52,7 +56,7 @@ Property ```free_gpu_memory_fraction``` is a ratio > 0 and < 1 that specifies ho

### Enable/Disable Cross Request Reuse

Block reuse across requests is enabled by default, but can be disabled by setting ```enable_block_reuse``` to False. Note that block reuse requires the model engine to support paged context attention, which is enabled by default, but can be enabled or disabled explicitly when the model engine is built: ```trtllm-build --use_paged_context_fmha enable```.
Block reuse across requests is enabled by default, but can be disabled by setting ```enable_block_reuse``` to False.

### Enable Offloading to Host Memory

Expand Down
10 changes: 1 addition & 9 deletions docs/source/1.0/features/speculative-decoding.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ The two model implementation supports the following speculative decoding algorit

For all speculation algorithms, when speculation is enabled, a single sequence of draft tokens with length `max_draft_len` is created for every request. There is currently no way to dynamically disable speculation, thus speed ups are only observable at low batch sizes.

All two-model based speculation implementations have the following additional constraints:
* KV cache reuse must be disabled (this occurs implicitly).
* Overlap scheduling must be disabled.

### Draft/Target

Draft/target is the simplest form of speculative decoding. In this approach, an arbitrary draft model is used to produce draft tokens. It is important to make sure that the draft and target models were trained with the same tokenizer, else the acceptance rate is extremely low and performance is regressed.
Expand Down Expand Up @@ -99,8 +95,6 @@ MTP is currently only supported by Deepseek. MTP can be tuned with the following
* `relaxed_topk`: The top K tokens are sampled from the target model's logits to create the initial candidate set for relaxed decoding.
* `relaxed_delta`: Used to further filter the top K candidate set for relaxed decoding. We remove tokens `t` for which `log(P(top 1 token)) - log(P(t)) > relaxed_delta`.

Unlike the other speculation algorithms, MTP supports the overlap scheduler and KV cache reuse.

```python
from tensorrt_llm.llm_api import MTPDecodingConfig

Expand Down Expand Up @@ -190,9 +184,7 @@ draft tokens to be attached to the `py_draft_tokens` field of request that specu

## Two Model Speculative Decoding Architecture

Note that there are currently a few limitations on the two model implementation:
* KV cache reuse must be disabled.
* Overlap scheduling must be disabled.
Two-model based speculation implementations do not support overlap scheduler. It will be disabled automatically.

In this approach, there are two new steps to the `PyExecutor`'s `_executor_loop`.
* `_prepare_draft_requests`
Expand Down