diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e5b302310fc..857e6d06ecf 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -296,9 +296,6 @@ def __init__(self, self.event_loop = self._executor_loop_pp else: self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap - if not disable_overlap_scheduler and model_engine.max_beam_width > 1: - raise NotImplementedError( - "Overlap scheduler is not supported for beam search.") if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index cd2c1ded390..f6f4a7420dd 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -473,10 +473,12 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors): finish_reasons: torch.Tensor sequence_lengths: torch.Tensor cum_log_probs: torch.Tensor | None = None + gathered_ids: torch.Tensor | None = None @dataclass(kw_only=True) class SampleStateTRTLLM(SampleState): + finalize_events: dict[str, CudaEvent] host: SampleStateTensorsHostTRTLLM @@ -672,6 +674,24 @@ def sample_async(self, scheduled_requests: ScheduledRequests, self.store["decoder_state"], self.store["decoding_input"][self.micro_batch_idx]) + finalize_events = {} + gathered_ids = None + if beam_width > 1: + finished_sum_device = self.store["decoder_state"].finished_sum + + for request in scheduled_requests.all_requests(): + if request.is_context_init_state: + continue + if finished_sum_device[request.seq_slot] == beam_width: + finalize_events[ + request.request_id] = self._finalize_request( + request, False) + elif request.streaming: + finalize_events[ + request.request_id] = self._finalize_request( + request, True) + gathered_ids = self.store["decoder_state"].gathered_ids.to( + 'cpu', non_blocking=True) new_output_tokens = self.store["decoder_state"].all_new_tokens.to( 'cpu', non_blocking=True) finished_sum = self.store["decoder_state"].finished_sum.to( @@ -698,7 +718,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests, finish_reasons=finish_reasons, sequence_lengths=sequence_lengths, log_probs=log_probs, - cum_log_probs=cum_log_probs) + cum_log_probs=cum_log_probs, + gathered_ids=gathered_ids) sampler_event = torch.cuda.Event() sampler_event.record() @@ -709,7 +730,8 @@ def sample_async(self, scheduled_requests: ScheduledRequests, return SampleStateTRTLLM(scheduled_requests=scheduled_requests, device=device, host=host, - sampler_event=sampler_event) + sampler_event=sampler_event, + finalize_events=finalize_events) @torch.inference_mode() def update_requests(self, state: SampleStateTRTLLM): @@ -797,7 +819,7 @@ def update_requests_multiple_beams_or_drafting(self, ) if state.host.cum_log_probs is not None else None log_probs_host = state.host.log_probs.tolist( ) if state.host.log_probs is not None else None - finalize_events = {} + finalize_events = state.finalize_events reqs = [ r for r in state.scheduled_requests.context_requests @@ -865,19 +887,9 @@ def update_requests_multiple_beams_or_drafting(self, if finished_sum_host[seq_slot] == beam_width: request.state = LlmRequestState.GENERATION_COMPLETE - if beam_width > 1: - finalize_events[ - request.request_id] = self._finalize_request( - request, False) - elif request.streaming and beam_width > 1: - finalize_events[request.request_id] = self._finalize_request( - request, True) - # post process all requests if necessary - if beam_width > 1: - for request in reqs: - if request.request_id in finalize_events: - self._post_process_request( - request, finalize_events[request.request_id]) + for request in reqs: + if request.request_id in finalize_events: + self._post_process_request(request, state) def _finalize_request(self, request: LlmRequest, streaming: bool): """ Finalizes the request. This is necessary for beam search. """ @@ -888,7 +900,7 @@ def _finalize_request(self, request: LlmRequest, streaming: bool): return event def _post_process_request(self, request: LlmRequest, - finalize_event: CudaEvent): + state: SampleStateTRTLLM): """ Post Process the request. Updates the sequence according to the beam search results. request: LlmRequest which shall be post processed finalize_event: CudaEvent to wait for the finalize step to finish @@ -896,17 +908,16 @@ def _post_process_request(self, request: LlmRequest, seq_slot = request.py_seq_slot beam_width = request.sampling_config.beam_width # synchronize on the finalize event before continuing the post processing. - finalize_event.synchronize() + # should be unnecessary, as already wait for the sampler event in update_requests + state.finalize_events[request.request_id].synchronize() # Get these values again, as they might have changed during the finalize step - output_ids_host = self.store["decoder_state"].gathered_ids.to('cpu') - sequence_lengths_host = self.store["decoder_state"].sequence_lengths.to( - 'cpu') + output_ids_host = state.host.gathered_ids + sequence_lengths_host = state.host.sequence_lengths if request.py_return_log_probs: - log_probs_host = self.store["decoder_state"].log_probs.to('cpu') - cum_log_probs_host = self.store["decoder_state"].cum_log_probs.to( - 'cpu') + log_probs_host = state.host.log_probs + cum_log_probs_host = state.host.cum_log_probs generated_tokens = [[0]] * beam_width log_probs = [[] for _ in range(beam_width)] diff --git a/tests/unittest/_torch/test_beam_search.py b/tests/unittest/_torch/test_beam_search.py index b5562ee9c22..25107924c2e 100644 --- a/tests/unittest/_torch/test_beam_search.py +++ b/tests/unittest/_torch/test_beam_search.py @@ -51,6 +51,24 @@ def llm(fixed_params, input_prompts): ) +@pytest.fixture(scope="module") +def llm_overlap(fixed_params, input_prompts): + return LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), + kv_cache_config=KvCacheConfig(max_tokens=10000), + max_batch_size=fixed_params["max_beam_width"] * len( + input_prompts + ), # use small batch size to prevent large buffers from possibly hiding wrong data accesses. + max_seq_len=32, + enable_trtllm_sampler=True, + max_beam_width=fixed_params["max_beam_width"], + disable_overlap_scheduler=False, + #TODO: remove this once we have a proper fix for CUDA graph in beam search + cuda_graph_config=None, + ) + + @force_ampere # Save H100 resource @pytest.mark.parametrize("return_log_probs", [True, False]) @pytest.mark.parametrize("gather_generation_logits", [True, False]) @@ -105,3 +123,57 @@ def test_beam_search_output_shapes(gather_context_logits: bool, assert similar( beam.text, expected_outputs[input_prompts[output_idx]][beam_idx]) + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("return_log_probs", [True, False]) +@pytest.mark.parametrize("gather_generation_logits", [True, False]) +@pytest.mark.parametrize("gather_context_logits", [True, False]) +@pytest.mark.parametrize("num_output_beams", [1, 2]) +@pytest.mark.parametrize("num_prompts", [1, 2]) +@pytest.mark.threadleak(enabled=False) +def test_beam_search_output_shapes_overlap( + gather_context_logits: bool, gather_generation_logits: bool, + return_log_probs: bool, num_output_beams: int, num_prompts: int, + llm_overlap, fixed_params, input_prompts, expected_outputs): + if return_log_probs and num_prompts > 1: + pytest.skip( + "Beam search currently does not support return_log_probs with multiple prompts" + ) + sampling_params = SamplingParams( + max_tokens=fixed_params["max_tokens"], + n=num_output_beams, + best_of=fixed_params["max_beam_width"], + use_beam_search=True, + return_context_logits=gather_context_logits, + return_generation_logits=gather_generation_logits, + logprobs=return_log_probs, + ) + outputs = llm_overlap.generate(input_prompts[:num_prompts], + sampling_params=sampling_params) + assert len(outputs) == num_prompts + for output_idx, output in enumerate(outputs): + if gather_context_logits: + assert output.context_logits is not None + assert len( + output.prompt_token_ids) == output.context_logits.shape[0] + else: + assert output.context_logits is None + assert len(output.outputs) == num_output_beams + for beam_idx, beam in enumerate(output.outputs): + if gather_generation_logits: + gen_logits = beam.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == sampling_params.max_tokens + else: + assert beam.generation_logits is None + + if return_log_probs: + assert len(beam.logprobs) == sampling_params.max_tokens + else: + assert len(beam.logprobs) == 0 + # Check output similarity + assert similar( + beam.text, + expected_outputs[input_prompts[output_idx]][beam_idx])