diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 35e64afe4c2..95428061279 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -171,27 +171,47 @@ def top_k_top_p_sampling_batch( sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) # compute cumulative probability distribution of each sample - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + probs_sorted = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(probs_sorted, dim=-1) # get the location of top_p - # NB: Currently selecting the smallest index with cumulative_probs > top_p. + # NB: Currently selecting the smallest index with cumulative_probs >= top_p. # Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True). - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() - sorted_indices_to_remove[:, 0] = 0 - - # set the logits to -inf for token indices outside top_p - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove + mask_to_remove = cumulative_probs >= top_p # at least one 'True' per row + last_index_to_keep = torch.searchsorted( + mask_to_remove.to(torch.int8, non_blocking=True), + torch.ones((1,), dtype=torch.int8, device=mask_to_remove.device).expand( + (mask_to_remove.size(0), 1) + ), + right=False, + out_int32=True, + ) + mask_to_remove.scatter_( + 1, + last_index_to_keep, + torch.zeros((1,), dtype=torch.bool, device=mask_to_remove.device).expand_as( + last_index_to_keep + ), ) - logits = logits.masked_fill(indices_to_remove, float("-inf")) - # compute probability distribution - softmax = torch.softmax(logits, dim=-1) + # mask not selected probs + probs_sorted.masked_fill_(mask_to_remove, 0.0) + probs = torch.empty_like(probs_sorted) + probs.scatter_(1, sorted_indices, probs_sorted) + probs /= cumulative_probs[ # renormalize probs + torch.arange( + cumulative_probs.size(0), dtype=torch.int32, device=cumulative_probs.device + ), # needed for advanced indexing + last_index_to_keep.squeeze(-1), + ].unsqueeze(-1) + del logits # do not use, inconsistent with probs + else: + # compute probability distribution + probs = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) - return next_tokens, softmax + next_tokens = torch.multinomial(probs, num_samples=1, generator=generator).squeeze(-1) + return next_tokens, probs def greedy_search_sampling_batch( diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index a8475927f95..2c731e9110b 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -503,6 +503,11 @@ def device_sleep( def assert_no_cuda_sync( sync_timeout_s: float = 5, ) -> Generator[None, None, None]: """Check that the function does not stream synchronize.""" + if int(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)): + print("CUDA_LAUNCH_BLOCKING set, skipping 'assert_no_cuda_sync'") + yield None + return + # NB: This implementation only assumes that the CUDA operations performed # in the guarded scope use the currently selected CUDA stream. This # should also cover custom Torch ops as well as non-Torch kernels.