Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,6 @@ def __init__(self,
self.event_loop = self._executor_loop_pp
else:
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
if not disable_overlap_scheduler and model_engine.max_beam_width > 1:
raise NotImplementedError(
"Overlap scheduler is not supported for beam search.")
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
self.event_loop = trace_func(self.event_loop)

Expand Down
59 changes: 35 additions & 24 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
finish_reasons: torch.Tensor
sequence_lengths: torch.Tensor
cum_log_probs: torch.Tensor | None = None
gathered_ids: torch.Tensor | None = None


@dataclass(kw_only=True)
class SampleStateTRTLLM(SampleState):
finalize_events: dict[str, CudaEvent]
host: SampleStateTensorsHostTRTLLM


Expand Down Expand Up @@ -672,6 +674,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
self.store["decoder_state"],
self.store["decoding_input"][self.micro_batch_idx])

finalize_events = {}
gathered_ids = None
if beam_width > 1:
finished_sum_device = self.store["decoder_state"].finished_sum

for request in scheduled_requests.all_requests():
if request.is_context_init_state:
continue
if finished_sum_device[request.seq_slot] == beam_width:
finalize_events[
request.request_id] = self._finalize_request(
request, False)
elif request.streaming:
finalize_events[
request.request_id] = self._finalize_request(
request, True)
gathered_ids = self.store["decoder_state"].gathered_ids.to(
'cpu', non_blocking=True)
new_output_tokens = self.store["decoder_state"].all_new_tokens.to(
'cpu', non_blocking=True)
finished_sum = self.store["decoder_state"].finished_sum.to(
Expand All @@ -698,7 +718,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
finish_reasons=finish_reasons,
sequence_lengths=sequence_lengths,
log_probs=log_probs,
cum_log_probs=cum_log_probs)
cum_log_probs=cum_log_probs,
gathered_ids=gathered_ids)

sampler_event = torch.cuda.Event()
sampler_event.record()
Expand All @@ -709,7 +730,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
return SampleStateTRTLLM(scheduled_requests=scheduled_requests,
device=device,
host=host,
sampler_event=sampler_event)
sampler_event=sampler_event,
finalize_events=finalize_events)

@torch.inference_mode()
def update_requests(self, state: SampleStateTRTLLM):
Expand Down Expand Up @@ -797,7 +819,7 @@ def update_requests_multiple_beams_or_drafting(self,
) if state.host.cum_log_probs is not None else None
log_probs_host = state.host.log_probs.tolist(
) if state.host.log_probs is not None else None
finalize_events = {}
finalize_events = state.finalize_events

reqs = [
r for r in state.scheduled_requests.context_requests
Expand Down Expand Up @@ -865,19 +887,9 @@ def update_requests_multiple_beams_or_drafting(self,

if finished_sum_host[seq_slot] == beam_width:
request.state = LlmRequestState.GENERATION_COMPLETE
if beam_width > 1:
finalize_events[
request.request_id] = self._finalize_request(
request, False)
elif request.streaming and beam_width > 1:
finalize_events[request.request_id] = self._finalize_request(
request, True)
# post process all requests if necessary
if beam_width > 1:
for request in reqs:
if request.request_id in finalize_events:
self._post_process_request(
request, finalize_events[request.request_id])
for request in reqs:
if request.request_id in finalize_events:
self._post_process_request(request, state)

def _finalize_request(self, request: LlmRequest, streaming: bool):
""" Finalizes the request. This is necessary for beam search. """
Expand All @@ -888,25 +900,24 @@ def _finalize_request(self, request: LlmRequest, streaming: bool):
return event

def _post_process_request(self, request: LlmRequest,
finalize_event: CudaEvent):
state: SampleStateTRTLLM):
""" Post Process the request. Updates the sequence according to the beam search results.
request: LlmRequest which shall be post processed
finalize_event: CudaEvent to wait for the finalize step to finish
"""
seq_slot = request.py_seq_slot
beam_width = request.sampling_config.beam_width
# synchronize on the finalize event before continuing the post processing.
finalize_event.synchronize()
# should be unnecessary, as already wait for the sampler event in update_requests
state.finalize_events[request.request_id].synchronize()

# Get these values again, as they might have changed during the finalize step
output_ids_host = self.store["decoder_state"].gathered_ids.to('cpu')
sequence_lengths_host = self.store["decoder_state"].sequence_lengths.to(
'cpu')
output_ids_host = state.host.gathered_ids
sequence_lengths_host = state.host.sequence_lengths

if request.py_return_log_probs:
log_probs_host = self.store["decoder_state"].log_probs.to('cpu')
cum_log_probs_host = self.store["decoder_state"].cum_log_probs.to(
'cpu')
log_probs_host = state.host.log_probs
cum_log_probs_host = state.host.cum_log_probs

generated_tokens = [[0]] * beam_width
log_probs = [[] for _ in range(beam_width)]
Expand Down
72 changes: 72 additions & 0 deletions tests/unittest/_torch/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def llm(fixed_params, input_prompts):
)


@pytest.fixture(scope="module")
def llm_overlap(fixed_params, input_prompts):
return LLM(
model=os.path.join(llm_models_root(), "llama-models-v2",
"TinyLlama-1.1B-Chat-v1.0"),
kv_cache_config=KvCacheConfig(max_tokens=10000),
max_batch_size=fixed_params["max_beam_width"] * len(
input_prompts
), # use small batch size to prevent large buffers from possibly hiding wrong data accesses.
max_seq_len=32,
enable_trtllm_sampler=True,
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=False,
#TODO: remove this once we have a proper fix for CUDA graph in beam search
cuda_graph_config=None,
)


@force_ampere # Save H100 resource
@pytest.mark.parametrize("return_log_probs", [True, False])
@pytest.mark.parametrize("gather_generation_logits", [True, False])
Expand Down Expand Up @@ -105,3 +123,57 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
assert similar(
beam.text,
expected_outputs[input_prompts[output_idx]][beam_idx])


@force_ampere # Save H100 resource
@pytest.mark.parametrize("return_log_probs", [True, False])
@pytest.mark.parametrize("gather_generation_logits", [True, False])
@pytest.mark.parametrize("gather_context_logits", [True, False])
@pytest.mark.parametrize("num_output_beams", [1, 2])
@pytest.mark.parametrize("num_prompts", [1, 2])
@pytest.mark.threadleak(enabled=False)
def test_beam_search_output_shapes_overlap(
gather_context_logits: bool, gather_generation_logits: bool,
return_log_probs: bool, num_output_beams: int, num_prompts: int,
llm_overlap, fixed_params, input_prompts, expected_outputs):
if return_log_probs and num_prompts > 1:
pytest.skip(
"Beam search currently does not support return_log_probs with multiple prompts"
)
sampling_params = SamplingParams(
max_tokens=fixed_params["max_tokens"],
n=num_output_beams,
best_of=fixed_params["max_beam_width"],
use_beam_search=True,
return_context_logits=gather_context_logits,
return_generation_logits=gather_generation_logits,
logprobs=return_log_probs,
)
outputs = llm_overlap.generate(input_prompts[:num_prompts],
sampling_params=sampling_params)
assert len(outputs) == num_prompts
for output_idx, output in enumerate(outputs):
if gather_context_logits:
assert output.context_logits is not None
assert len(
output.prompt_token_ids) == output.context_logits.shape[0]
else:
assert output.context_logits is None
assert len(output.outputs) == num_output_beams
for beam_idx, beam in enumerate(output.outputs):
if gather_generation_logits:
gen_logits = beam.generation_logits
assert gen_logits is not None
assert gen_logits.ndim == 2
assert gen_logits.shape[0] == sampling_params.max_tokens
else:
assert beam.generation_logits is None

if return_log_probs:
assert len(beam.logprobs) == sampling_params.max_tokens
else:
assert len(beam.logprobs) == 0
# Check output similarity
assert similar(
beam.text,
expected_outputs[input_prompts[output_idx]][beam_idx])