Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -2027,7 +2027,7 @@ class GenericLlmRequest

// Scatter the input tokens to other beam
mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens);
mLastTokens = VecTokens(mSamplingConfig.beamWidth);
mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back());

// Init mUniqueTokens
VecUniqueTokens uniqueTokens{inputTokens.size()};
Expand Down
10 changes: 5 additions & 5 deletions docs/source/torch/features/feature_combination_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
| Disaggregated Serving | Yes | Yes | Yes | --- | | | | | | | | | | |
| Chunked Prefill | Yes | Yes | Yes | Untested | --- | | | | | | | | | |
| MTP | Yes | Yes | Yes | Yes | Untested | --- | | | | | | | | |
| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | |
| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | |
| EAGLE-3(One Model Engine) | Yes | Yes | Yes | No | Yes | No | --- | | | | | | | |
| EAGLE-3(Two Model Engine) | NO | Yes | Yes | No | Yes | No | No | --- | | | | | | |
| Torch Sampler | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | --- | | | | | |
| TLLM C++ Sampler | Yes | Yes | Yes | Yes | Yes | No | No | No | No | --- | | | | |
| KV Cache Reuse | Yes | Yes | Yes | Untested | Yes | Untested | Yes | No | Yes | Yes | --- | | | |
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | Yes | --- |
| Slide Window Attention | Yes | Yes | Yes | Untested | No | Untested | Untested | Untested | Yes | Yes | WIP | --- | | |
| Logits Post Processor | No | Yes | Yes | No | Yes | No | No | No | Yes | Yes | Yes | Yes | --- | |
| Guided Decoding | Yes | Yes | Yes | No | Yes | No | No | Yes | Yes | Yes | Yes | Yes | Yes | --- |
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ einops
flashinfer-python==0.2.5
opencv-python-headless
xgrammar==0.1.21
llguidance==0.7.29
jsonschema
backoff
nvtx
matplotlib # FIXME: this is added to make nvtx happy
meson
ninja
etcd3
blake3
llguidance==0.7.29
soundfile
triton==3.3.1; platform_machine == "x86_64"
tiktoken
Expand Down
52 changes: 45 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/grammar_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,19 @@ class GrammarMatcher(ABC):
def accept_token(self, token_id: int) -> bool:
pass

@abstractmethod
def rollback(self, num_tokens: int) -> None:
pass

@abstractmethod
def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
pass

@abstractmethod
def is_terminated(self) -> bool:
pass


class GrammarMatcherFactory(ABC):

Expand All @@ -39,15 +47,23 @@ def __init__(self, matcher: xgrammar.GrammarMatcher):
def accept_token(self, token_id: int) -> bool:
return self._matcher.accept_token(token_id)

def rollback(self, num_tokens: int) -> None:
self._matcher.rollback(num_tokens)

def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
self._matcher.fill_next_token_bitmask(next_token_bitmask, index)

def is_terminated(self) -> bool:
return self._matcher.is_terminated()


class XGrammarMatcherFactory(GrammarMatcherFactory):

def __init__(self, guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int):
def __init__(self,
guided_decoding_config: GuidedDecodingConfig,
vocab_size_padded: int,
max_num_draft_tokens: int = 0):
super().__init__()
vocab_type = xgrammar.VocabType.RAW
add_prefix_space = False
Expand All @@ -72,6 +88,7 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig,
cache_enabled=True,
cache_limit_bytes=cache_limit_bytes,
)
self.max_num_draft_tokens = max_num_draft_tokens

def create(self,
guided_decoding_params: GuidedDecodingParams) -> XGrammarMatcher:
Expand Down Expand Up @@ -106,27 +123,48 @@ def create(self,
case _:
raise ValueError(f"Unsupported guide type: {guide_type}.")

matcher = xgrammar.GrammarMatcher(compiled_grammar)
matcher = xgrammar.GrammarMatcher(
compiled_grammar, max_rollback_tokens=self.max_num_draft_tokens)
return XGrammarMatcher(matcher)


class LLGuidanceMatcher(GrammarMatcher):

def __init__(self, matcher: llguidance.LLMatcher):
def __init__(self, matcher: llguidance.LLMatcher, eos_token: int):
super().__init__()
self._matcher = matcher
self._eos_token = eos_token
self._is_terminated = False

def accept_token(self, token_id: int) -> bool:
result = self._matcher.consume_token(token_id)
if self._matcher.is_stopped():
# Accept EOS token only if the matcher is stopped.
if token_id == self._eos_token:
self._is_terminated = True
return True
else:
return False

num_accepted = self._matcher.try_consume_tokens([token_id])
self._check_err()
return num_accepted > 0

def rollback(self, num_tokens: int) -> None:
if self._is_terminated:
self._is_terminated = False
num_tokens -= 1
self._matcher.rollback(num_tokens)
self._check_err()
return result

def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor,
index: int) -> None:
llguidance.torch.fill_next_token_bitmask(self._matcher,
next_token_bitmask, index)
self._check_err()

def is_terminated(self) -> bool:
return self._is_terminated

def _check_err(self) -> None:
if self._matcher.is_error():
raise ValueError(
Expand Down Expand Up @@ -181,4 +219,4 @@ def create(
if matcher.is_error():
raise ValueError(f"LLGuidance matcher error: {matcher.get_error()}")

return LLGuidanceMatcher(matcher)
return LLGuidanceMatcher(matcher, self._tokenizer.eos_token)
Loading