Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,47 @@ def top_k_top_p_sampling_batch(
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)

# compute cumulative probability distribution of each sample
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
probs_sorted = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(probs_sorted, dim=-1)

# get the location of top_p
# NB: Currently selecting the smallest index with cumulative_probs > top_p.
# NB: Currently selecting the smallest index with cumulative_probs >= top_p.
# Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True).
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0

# set the logits to -inf for token indices outside top_p
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
mask_to_remove = cumulative_probs >= top_p # at least one 'True' per row
last_index_to_keep = torch.searchsorted(
mask_to_remove.to(torch.int8, non_blocking=True),
torch.ones((1,), dtype=torch.int8, device=mask_to_remove.device).expand(
(mask_to_remove.size(0), 1)
),
right=False,
out_int32=True,
)
mask_to_remove.scatter_(
1,
last_index_to_keep,
torch.zeros((1,), dtype=torch.bool, device=mask_to_remove.device).expand_as(
last_index_to_keep
),
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))

# compute probability distribution
softmax = torch.softmax(logits, dim=-1)
# mask not selected probs
probs_sorted.masked_fill_(mask_to_remove, 0.0)
probs = torch.empty_like(probs_sorted)
probs.scatter_(1, sorted_indices, probs_sorted)
probs /= cumulative_probs[ # renormalize probs
torch.arange(
cumulative_probs.size(0), dtype=torch.int32, device=cumulative_probs.device
), # needed for advanced indexing
last_index_to_keep.squeeze(-1),
].unsqueeze(-1)
del logits # do not use, inconsistent with probs
else:
# compute probability distribution
probs = torch.softmax(logits, dim=-1)

# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)
return next_tokens, softmax
next_tokens = torch.multinomial(probs, num_samples=1, generator=generator).squeeze(-1)
return next_tokens, probs


def greedy_search_sampling_batch(
Expand Down
5 changes: 5 additions & 0 deletions tests/unittest/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ def device_sleep(
def assert_no_cuda_sync(
sync_timeout_s: float = 5, ) -> Generator[None, None, None]:
"""Check that the function does not stream synchronize."""
if int(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)):
print("CUDA_LAUNCH_BLOCKING set, skipping 'assert_no_cuda_sync'")
yield None
return

# NB: This implementation only assumes that the CUDA operations performed
# in the guarded scope use the currently selected CUDA stream. This
# should also cover custom Torch ops as well as non-Torch kernels.
Expand Down