Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, SamplerType
from tensorrt_llm.llmapi.llm_args import (PeftCacheConfig, SamplerType,
SpeculativeConfig)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
Expand Down Expand Up @@ -670,66 +671,83 @@ def create_py_executor_instance(
peft_cache_config=peft_cache_config)


def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
*, max_seq_len: int, enable_mixed_sampler: bool):
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
max_draft_len = (0 if executor_config.speculative_config is None else
executor_config.speculative_config.max_draft_len)
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
enable_mixed_sampler: bool, max_batch_size: int,
speculative_config: SpeculativeConfig,
max_beam_width: int):
max_num_sequences = max_batch_size * mapping.pp_size
max_draft_len = (0 if speculative_config is None else
speculative_config.max_draft_len)
return TorchSampler.Args(
max_seq_len=max_seq_len,
max_draft_len=max_draft_len,
max_num_sequences=max_num_sequences,
max_beam_width=executor_config.max_beam_width,
max_beam_width=max_beam_width,
enable_mixed_sampler=enable_mixed_sampler,
)


def instantiate_sampler(engine: PyTorchModelEngine,
executor_config: ExecutorConfig,
pytorch_backend_config: PyTorchConfig,
mapping: Mapping):
pytorch_backend_config: PyTorchConfig, mapping: Mapping,
max_batch_size: int, max_beam_width: int,
max_seq_len: int, mm_encoder_only: bool,
speculative_config: SpeculativeConfig,
decoding_config: trtllm.DecodingConfig,
kv_cache_config: trtllm.KvCacheConfig):
sampler_args = create_torch_sampler_args(
executor_config,
mapping,
max_seq_len=engine.max_seq_len,
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler)
decoding_mode = get_decoding_mode(executor_config)
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
max_batch_size=max_batch_size,
speculative_config=speculative_config,
max_beam_width=max_beam_width)
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
max_beam_width=max_beam_width)
if mapping.cp_config.get('cp_type') == CpType.STAR:
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
return TorchSampler(sampler_args)
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
):
return get_spec_decoder(sampler_args, engine.spec_config)

if executor_config.mm_encoder_only:
if mm_encoder_only:
# NOTE: handle model outputs specially for mm encoder executor/engine
return EarlyStopWithMMResult()
if pytorch_backend_config.sampler_type == SamplerType.TRTLLMSampler or (
pytorch_backend_config.sampler_type == SamplerType.auto
and decoding_mode.isBeamSearch()):
logger.debug(f"DecodingMode: {decoding_mode.name}")
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
mapping, decoding_mode,
pytorch_backend_config.disable_overlap_scheduler)
return TRTLLMSampler(engine.model,
engine.dtype,
mapping,
decoding_mode,
pytorch_backend_config.disable_overlap_scheduler,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
decoding_config=decoding_config,
kv_cache_config=kv_cache_config)
if not engine.model.model_config.is_generation:
# NOTE: choose sampler based on model type
return EarlyStopSampler()
return TorchSampler(sampler_args)


def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:
def get_decoding_mode(
decoding_config: trtllm.DecodingConfig,
max_beam_width: int,
) -> DecodingMode:
'''This implementation is based off trtGptModelInflightBatching.cpp getDecodingMode().'''

if executor_config.decoding_config and executor_config.decoding_config.decoding_mode and not executor_config.decoding_config.decoding_mode.isAuto(
if decoding_config and decoding_config.decoding_mode and not decoding_config.decoding_mode.isAuto(
):
decoding_mode = executor_config.decoding_config.decoding_mode
elif executor_config.max_beam_width == 1:
decoding_mode = decoding_config.decoding_mode
elif max_beam_width == 1:
decoding_mode = DecodingMode.TopKTopP()
else:
decoding_mode = DecodingMode.BeamSearch()

# Override decoding mode when beam width is one
if executor_config.max_beam_width == 1 and decoding_mode.isBeamSearch():
if max_beam_width == 1 and decoding_mode.isBeamSearch():
logger.warning(
"Beam width is set to 1, but decoding mode is BeamSearch. Overwriting decoding mode to TopKTopP."
)
Expand Down
28 changes: 26 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,19 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
return mapping


def update_sampler_max_seq_len(max_seq_len, sampler):
# Originally, TRTLLMSampler is constructed with executor_config, but
# _create_kv_cache_manager (via build_managers) may later overwrite executor_config.max_seq_len.
# Because TRTLLMSampler.sample_async still needs the updated limit and executor_config is
# deprecated inside TRTLLMSampler, keep TRTLLMSampler.max_seq_len updated with
# with executor_config.max_seq_len.
from .sampler import TRTLLMSampler

if isinstance(sampler, TRTLLMSampler):
assert hasattr(sampler, "max_seq_len")
sampler.max_seq_len = max_seq_len


def create_py_executor(
llm_args: TorchLlmArgs,
checkpoint_dir: str = None,
Expand Down Expand Up @@ -415,8 +428,17 @@ def drafting_loop_wrapper(model):
)

with mem_monitor.observe_creation_stage(_ExecutorCreationStage.SAMPLER):
sampler = instantiate_sampler(model_engine, executor_config,
pytorch_backend_config, mapping)
sampler = instantiate_sampler(
model_engine,
pytorch_backend_config,
mapping,
max_batch_size=executor_config.max_batch_size,
max_beam_width=executor_config.max_beam_width,
max_seq_len=executor_config.max_seq_len,
mm_encoder_only=executor_config.mm_encoder_only,
speculative_config=executor_config.speculative_config,
decoding_config=executor_config.decoding_config,
kv_cache_config=executor_config.kv_cache_config)
logger.info(f"Using Sampler: {type(sampler).__name__}")

if kv_connector_config is not None:
Expand Down Expand Up @@ -482,6 +504,7 @@ def drafting_loop_wrapper(model):
_ExecutorCreationStage.INIT_KV_CACHE
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
kv_cache_creator.build_managers(resources, estimating_kv_cache)
update_sampler_max_seq_len(executor_config.max_seq_len, sampler)

# Resource managers for speculative decoding
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
Expand Down Expand Up @@ -545,6 +568,7 @@ def drafting_loop_wrapper(model):
# the original value before creating the final KV cache.
executor_config.max_seq_len = max_seq_len
kv_cache_creator.build_managers(resources, False)
update_sampler_max_seq_len(executor_config.max_seq_len, sampler)

for eng in [model_engine, draft_model_engine]:
if eng is None:
Expand Down
36 changes: 21 additions & 15 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
WorldConfig, make_sampling_config)
from tensorrt_llm.bindings.executor import (DecodingConfig, DecodingMode,
ExecutorConfig, FinishReason)
FinishReason, KvCacheConfig)
from tensorrt_llm.bindings.internal.algorithms import CreateNewDecoderRequests
from tensorrt_llm.bindings.internal.batch_manager import (
DecoderInputBuffers, add_new_tokens_to_requests, make_decoding_batch_input)
Expand Down Expand Up @@ -786,12 +786,16 @@ def is_generation_model(self) -> bool:

def __init__(
self,
executor_config: ExecutorConfig,
model,
model_dtype,
mapping: Mapping,
decoding_mode: DecodingMode,
disable_overlap_scheduler: bool,
max_seq_len: int,
max_batch_size: int,
max_beam_width: int,
decoding_config: Optional[DecodingConfig] = None,
kv_cache_config: Optional[KvCacheConfig] = None,
):

vocab_size = model.config.vocab_size
Expand All @@ -802,14 +806,15 @@ def __init__(
self.model_datatype = torch_dtype_to_binding(model_dtype)
self.logits_datatype = DataType.FLOAT
self.decoding_mode = decoding_mode
self.executor_config = executor_config
self.decoding_config = self.executor_config.decoding_config if self.executor_config.decoding_config else DecodingConfig(
self.decoding_config = decoding_config if decoding_config else DecodingConfig(
decoding_mode)
max_attn_window = self.executor_config.kv_cache_config.max_attention_window
max_attn_window = kv_cache_config.max_attention_window
self.max_seq_len = max_seq_len
self.max_attention_window = max(
max_attn_window
) if max_attn_window is not None else executor_config.max_seq_len
self.max_num_sequences = mapping.pp_size * self.executor_config.max_batch_size
max_attn_window) if max_attn_window is not None else max_seq_len
self.max_batch_size = max_batch_size
self.max_beam_width = max_beam_width
self.max_num_sequences = mapping.pp_size * max_batch_size
self.max_seq_idle_microseconds = 180 * 1000 * 1000
self.is_trt_overlap = not disable_overlap_scheduler
self.num_micro_batches = mapping.pp_size if mapping.pp_size > 1 else (
Expand Down Expand Up @@ -838,14 +843,14 @@ def _initialize_store(self):
"buffer_manager":
buffer_manager,
"decoder_input_buffers": [
DecoderInputBuffers(self.executor_config.max_batch_size,
DecoderInputBuffers(self.max_batch_size,
self.MAX_DECODING_TOKENS, buffer_manager)
for _ in range(self.num_micro_batches)
],
"sequence_lengths_host":
torch.empty((
self.max_num_sequences,
self.executor_config.max_beam_width,
self.max_beam_width,
),
dtype=torch.int),
"decoder_state":
Expand All @@ -855,10 +860,10 @@ def _initialize_store(self):

self.store["decoder_state"].setup(
max_num_sequences=self.max_num_sequences,
max_beam_width=self.executor_config.max_beam_width,
max_beam_width=self.max_beam_width,
max_attention_window=self.max_attention_window,
sink_token_length=0,
max_sequence_length=self.executor_config.max_seq_len,
max_sequence_length=self.max_seq_len,
dtype=self.logits_datatype,
model_config=self.model_config,
world_config=self.world_config,
Expand All @@ -871,10 +876,11 @@ def _instantiate_algorithms(self):
self.algs.decoder.setup(
mode=self.decoding_mode,
max_num_sequences=self.max_num_sequences,
max_beam_width=self.executor_config.max_beam_width,
max_beam_width=self.max_beam_width,
dtype=self.logits_datatype,
model_config=self.model_config,
world_config=self.world_config)
world_config=self.world_config,
)
self.algs.create_new_decoder_requests = CreateNewDecoderRequests(
speculative_decoding_fast_logits=False,
is_leader_in_orch_mode=False,
Expand All @@ -890,7 +896,7 @@ def setup_sampler_step(self, requests):
requests.context_requests, self.logits_datatype,
self.store["decoder_input_buffers"][self.micro_batch_idx],
self.store["decoder_state"], self.store["cuda_stream"],
self.algs.decoder.decoder_stream, self.executor_config.max_seq_len,
self.algs.decoder.decoder_stream, self.max_seq_len,
self.beam_width(requests.context_requests))

local_batch_size = len(batch_slots)
Expand Down