-
Notifications
You must be signed in to change notification settings - Fork 2k
[feat] Implement pytorch sampler for MTP #5627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| lora_config: Optional[LoraConfig] = None, | ||
| is_draft_model: bool = False, | ||
| ): | ||
| torch.manual_seed(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this line, TRTLLM would malfunction because each GPU would produce a different answer.
This was necessary to prevent divergent calculations on different GPUs with different seeds.
There's a larger problem that currently the user cannot pass in a seed for individual requests. Ideally we would not depend on the global torch seed, but for some reason this is difficult within PyTorch since its cuda graph seems to make some assumptions that break the ability to use a custom RNG.
| # Strict acceptance | ||
| else: | ||
| if self.is_thop: | ||
| if False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is happening here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We did not implement the C++ MTP sampling kernel for thop, so this is a hack to use the pytorch implementation of the sampler. I'll revert this line for now, but I'm not sure if the code added by this PR will work properly without forcing this case.
Signed-off-by: Patrick Reiter Horn <[email protected]>
bee8f0b to
d4a928c
Compare
|
|
||
| import psutil | ||
| import safetensors | ||
| from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line seems duplicate with line 24.
|
@pathorn Thanks for you contribution. Could you please try the code format tool? https://github.com/NVIDIA/TensorRT-LLM/blob/main/CONTRIBUTING.md#coding-style |
|
Hi, thanks @pathorn for contribution, @QiJune, @netanel-haber for reviewing, @nvxuanyuc and I will help getting this merged these weeks. |
| next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) | ||
| return next_tokens, softmax | ||
|
|
||
| def flashinfer_sample( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flashinfer_sample is not invoked anywhere. can be deleted.
| # generator = torch.Generator(device="cuda") | ||
| # generator.manual_seed(0) | ||
| # next_tokens = flashinfer_sample(adjusted_logits, top_k, top_p, generator) | ||
| # logits = apply_top_k_top_p(logits, top_k, top_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleanup commented out code blocks (here and in other files).
[feat] Implement pytorch sampler for MTP
Description
Previously, speculative decoding was always greedy and does not support sampling parameters.
This PR implements temperature, top-p, top-k and min-p sampling parameters in python when using MTP speculative decoding (for DeepSeek).
This also adds a log-probs return value from the pytorch sampler (Intended to be used together with #5620)
Future work: support this pytorch sampler for non-MTP models. Figure out why it is faster than the C++ sampler.
Test Coverage
None