Skip to content

Conversation

@pathorn
Copy link
Contributor

@pathorn pathorn commented Jun 30, 2025

[feat] Implement pytorch sampler for MTP

Description

Previously, speculative decoding was always greedy and does not support sampling parameters.

This PR implements temperature, top-p, top-k and min-p sampling parameters in python when using MTP speculative decoding (for DeepSeek).

This also adds a log-probs return value from the pytorch sampler (Intended to be used together with #5620)

Future work: support this pytorch sampler for non-MTP models. Figure out why it is faster than the C++ sampler.

Test Coverage

None

@pathorn pathorn requested review from a team as code owners June 30, 2025 22:30
@pathorn pathorn requested review from HuiGao-NV and yuxianq June 30, 2025 22:30
@juney-nvidia juney-nvidia requested review from dcampora, lfr-0531, mikeiovine and netanel-haber and removed request for HuiGao-NV and yuxianq June 30, 2025 22:36
@juney-nvidia juney-nvidia added Community want to contribute PRs initiated from Community Community Engagement help/insights needed from community labels Jun 30, 2025
lora_config: Optional[LoraConfig] = None,
is_draft_model: bool = False,
):
torch.manual_seed(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this line, TRTLLM would malfunction because each GPU would produce a different answer.

This was necessary to prevent divergent calculations on different GPUs with different seeds.

There's a larger problem that currently the user cannot pass in a seed for individual requests. Ideally we would not depend on the global torch seed, but for some reason this is difficult within PyTorch since its cuda graph seems to make some assumptions that break the ability to use a custom RNG.

# Strict acceptance
else:
if self.is_thop:
if False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is happening here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We did not implement the C++ MTP sampling kernel for thop, so this is a hack to use the pytorch implementation of the sampler. I'll revert this line for now, but I'm not sure if the code added by this PR will work properly without forcing this case.

Signed-off-by: Patrick Reiter Horn <[email protected]>
@pathorn pathorn force-pushed the python-mtp-sampler branch from bee8f0b to d4a928c Compare July 1, 2025 23:51

import psutil
import safetensors
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line seems duplicate with line 24.

@QiJune
Copy link
Collaborator

QiJune commented Jul 2, 2025

@pathorn Thanks for you contribution. Could you please try the code format tool? https://github.com/NVIDIA/TensorRT-LLM/blob/main/CONTRIBUTING.md#coding-style

@jhaotingc
Copy link
Collaborator

Hi, thanks @pathorn for contribution, @QiJune, @netanel-haber for reviewing, @nvxuanyuc and I will help getting this merged these weeks.

next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1)
return next_tokens, softmax

def flashinfer_sample(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashinfer_sample is not invoked anywhere. can be deleted.

# generator = torch.Generator(device="cuda")
# generator.manual_seed(0)
# next_tokens = flashinfer_sample(adjusted_logits, top_k, top_p, generator)
# logits = apply_top_k_top_p(logits, top_k, top_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleanup commented out code blocks (here and in other files).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community Engagement help/insights needed from community Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants