diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index a3a0e1b15c4..72cddacd58d 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -373,6 +373,7 @@ def _group_requests_by_sampling_strategy( requests: Iterable[LlmRequest], *, pin_memory: bool = False) -> dict[Strategy, torch.Tensor]: + # NB: Client code relies on request indices in returned torch.Tensor being sorted. strategy_dict: dict[Strategy, list[int]] = defaultdict(list) for req_index, req in enumerate(requests): strategy_dict[_request_strategy(req)].append(req_index) @@ -1176,12 +1177,20 @@ def _sample_batched_by_strategy( len(speculation_group_indices), dtype=torch.int32) group_logits_cuda_indices = logits_cuda_indexer[group_req_indices] - if group_logits_cuda_indices.numel() != logits_cuda.size(0): + # NB: Assuming that group_req_indices are sorted + group_req_1st_index, group_req_last_index = group_req_indices[ + 0], group_req_indices[-1] + if group_req_last_index - group_req_1st_index + 1 == len( + group_req_indices): + # Avoid data movement if indices are contiguous + group_logits_cuda = logits_cuda[ + req_offsets[group_req_1st_index]:( + req_offsets[group_req_last_index] + + req_num_steps[group_req_last_index])] + else: group_logits_cuda_indices_cuda = group_logits_cuda_indices.to( device=logits_cuda.device, non_blocking=True) group_logits_cuda = logits_cuda[group_logits_cuda_indices_cuda] - else: - group_logits_cuda = logits_cuda # Indexer for accessing tokens in 'group_logits_cuda' (and 'group_next_tokens_cuda') # corresponding to the requests in 'group_req_indices'.