Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/torch/features/feature_combination_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
| KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | |
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
| Guided Decoding | Yes | Yes | Yes | Yes | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
34 changes: 26 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/guided_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self,
guided_decoding_config, vocab_size_padded)
else:
raise ValueError(
f"invalid guided decoding backend: {self.guided_decoding_backend}"
f"Invalid guided decoding backend: {self.guided_decoding_backend}"
)
logger.info(
f"Guided decoder initialized with backend: {self.guided_decoding_backend}"
Expand Down Expand Up @@ -71,15 +71,15 @@ def __init__(self,
def bitmask_size(self) -> int:
return math.ceil(self.vocab_size_padded / 32)

def _is_matcher_init(self, llm_req: LlmRequest) -> bool:
def _require_matcher_init(self, llm_req: LlmRequest) -> bool:
if llm_req.guided_decoding_params is None:
return False
if llm_req.py_is_draft:
return False
# The request is in the last chunk of a context forward step.
return llm_req.is_context_init_state and llm_req.is_last_context_chunk

def _is_matcher_in_progress(self, llm_req: LlmRequest) -> bool:
def _require_matcher_advance(self, llm_req: LlmRequest) -> bool:
if llm_req.guided_decoding_params is None:
return False
if llm_req.py_is_draft:
Expand All @@ -102,12 +102,17 @@ def build(self, scheduled_requests: ScheduledRequests) -> None:
self.num_advanced_tokens[slot] = 0
self.num_guided_tokens[slot] = 0

if self._is_matcher_init(llm_req):
matcher_init: bool = self._require_matcher_init(llm_req)
matcher_advance: bool = self._require_matcher_advance(llm_req)
if not (matcher_init or matcher_advance):
continue

if matcher_init:
matcher = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)
self.grammar_matchers[slot] = matcher

elif self._is_matcher_in_progress(llm_req):
if matcher_advance:
matcher = self.grammar_matchers[slot]
# The last new token must be acceptable unless the matcher is terminated in a drafting loop.
if llm_req.py_is_draft and (matcher.is_terminated()
Expand All @@ -127,9 +132,6 @@ def build(self, scheduled_requests: ScheduledRequests) -> None:
f"Request {llm_req.py_request_id} failed to accept last new token: {last_new_token}."
)

else:
continue

self.num_advanced_tokens[slot] += 1
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.bitmask_host[slot], 0)
Expand Down Expand Up @@ -244,3 +246,19 @@ def rollback_draft_tokens(self,
# Reset the drafting states.
self.num_advanced_draft_tokens[slot] = 0
self.is_draft_terminated[slot] = False

@nvtx_range("GuidedDecoder.init_disagg_gen_requests")
def init_disagg_gen_requests(self,
scheduled_requests: ScheduledRequests) -> None:
"""Initialize the grammar matchers for disagg gen requests.
"""
for llm_req in scheduled_requests.generation_requests:
if llm_req.guided_decoding_params is None:
continue
assert not llm_req.py_is_draft
slot: int = llm_req.py_seq_slot
if llm_req.context_phase_params is not None and llm_req.py_decoding_iter == 1:
# The request is in the first generation forward step at the disagg gen instance.
self.grammar_matchers[
slot] = self.grammar_matcher_factory.create(
llm_req.guided_decoding_params)
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ def __init__(
self.py_return_generation_logits = return_generation_logits
self.py_return_logits_device_memory = return_logits_device_memory
self.py_is_draft = is_draft
# The request's sequence slot ID, an index between 0 (inclusive) and max_batch_size (exclusive).
self.py_seq_slot = seq_slot
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
self.py_target_seq_slot = target_seq_slot

# TODO: remove this when use DynamicDecodeOp in pytorch flow.
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,9 @@ def _executor_loop_pp(self):
if self._need_return_logits(scheduled_batch):
logits_host = batch_outputs["logits"].to(
"cpu", non_blocking=True)
if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
self._execute_guided_decoder(
scheduled_batch, batch_outputs['logits'])

Expand Down Expand Up @@ -931,6 +934,10 @@ def _executor_loop(self):
self._handle_first_token_response(scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)

if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
if self.drafter is not None and self.use_spec_decode:
if self.guided_decoder is not None:
self.guided_decoder.rollback_rejected_tokens(
Expand Down Expand Up @@ -1055,6 +1062,9 @@ def _executor_loop_overlap(self):
if self.previous_batch is not None:
self._update_requests(self.previous_batch.sample_state)

if self.kv_cache_transceiver and self.guided_decoder:
self.guided_decoder.init_disagg_gen_requests(
scheduled_batch)
self._execute_guided_decoder(scheduled_batch,
batch_outputs['logits'])

Expand Down
141 changes: 126 additions & 15 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Please take a look at the existing test_llm_api_pytorch.py file for reference.
import concurrent
import contextlib
import json
import os
import tempfile
import time
Expand All @@ -19,12 +20,13 @@
from tensorrt_llm.executor.result import GenerationResultBase
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs
from tensorrt_llm.llmapi.tokenizer import load_hf_tokenizer

from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
skip_pre_hopper)
from ..trt_test_alternative import popen
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
get_accuracy_task)
from .accuracy_core import (GSM8K, MMLU, JsonModeEval,
LlmapiAccuracyTestHarness, get_accuracy_task)


class Result(GenerationResultBase):
Expand All @@ -43,7 +45,7 @@ def result(self):
return self


DuckLLM = namedtuple('DuckLLM', ['args', 'generate_async'])
DuckLLM = namedtuple('DuckLLM', ['args', 'tokenizer', 'generate_async'])


class MyThreadPoolExecutor(ThreadPoolExecutor):
Expand Down Expand Up @@ -162,17 +164,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],

def send_request(prompt: str, sampling_params: SamplingParams,
streaming: bool):
response = client.completions.create(
model=model_name,
prompt=prompt,
stream=streaming,
**({
"max_tokens": sampling_params.max_tokens,
"temperature": sampling_params.temperature,
"top_p": sampling_params.top_p,
"stop": sampling_params.stop,
"seed": sampling_params.seed
} if sampling_params else {}))
kwargs = {}
if sampling_params is not None:
kwargs.update(max_tokens=sampling_params.max_tokens,
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
stop=sampling_params.stop,
seed=sampling_params.seed)
if (guided_decoding_params :=
sampling_params.guided_decoding) is not None:
extra_body = {}
if (schema := guided_decoding_params.json) is not None:
extra_body.update(response_format={
"type": "json",
"schema": json.loads(schema)
})
elif guided_decoding_params.json_object:
extra_body.update(
response_format={"type": "json_object"})
else:
# TODO: Support other guided decoding types
raise ValueError(
f"Unsupported guided decoding params: {guided_decoding_params}."
)
kwargs.update(extra_body=extra_body)

response = client.completions.create(model=model_name,
prompt=prompt,
stream=streaming,
**kwargs)
result = Result(id=0,
sampling_params=sampling_params,
outputs=[
Expand All @@ -192,8 +212,10 @@ def generate_async(prompt: str,
thread_pool.futures.append(future)
return future

tokenizer = load_hf_tokenizer(model_name)

try:
yield DuckLLM(args, generate_async)
yield DuckLLM(args, tokenizer, generate_async)
finally:
ctx_server.terminate()
gen_server.terminate()
Expand Down Expand Up @@ -394,6 +416,95 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
ctx_server_config = {
"disable_overlap_scheduler": True,
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding_with_eagle3(self, backend: str, mocker):
mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"})
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 3,
"speculative_model_dir":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": False
}

ctx_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.8,
},
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": {
"free_gpu_memory_fraction": 0.8,
},
"guided_decoding_backend": backend,
"cache_transceiver_config": {
"backend": "default"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(2)
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/test_lists/qa/llm_function_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_4gpus[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_eagle3_tp8[eagle3_one_model=True]
Expand Down Expand Up @@ -520,6 +524,10 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[xgrammar]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1]
Expand Down