Skip to content

Conversation

@meenchen
Copy link
Contributor

@meenchen meenchen commented Oct 9, 2025

What does this PR do?

Type of change: ?

Overview:

This PR and NVIDIA/TensorRT-LLM#8698 enable NVFP4 AWQ deployment for TRT-LLM. Specifically, this PR fuses pre_quant_scale in following two cases:

  • For MLP, pre_quant_scale of gate_proj layer is fused into up_proj's weight, so we don't need an extra handle in downstream fused moe kernels.
  • For attention, we will try to fuse the pre_quant_scale of o_proj to v_proj if their dimensions match, which means we will skip fusion for MQA/GQA models.

Usage

# Add a code snippet demonstrating how to use this

Testing

unit test, e2e test for Qwen3 dense and moe models.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 9, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 9, 2025

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch weimingc/fuse_pqs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 9, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.43%. Comparing base (c02de17) to head (f55baad).
⚠️ Report is 2 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #421   +/-   ##
=======================================
  Coverage   74.43%   74.43%           
=======================================
  Files         182      182           
  Lines       18234    18234           
=======================================
  Hits        13572    13572           
  Misses       4662     4662           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@meenchen meenchen self-assigned this Oct 14, 2025
@meenchen meenchen requested a review from cjluo-nv October 27, 2025 19:29
@meenchen meenchen changed the title Pattern-based fusion for pre_quant_scale Fusing pre_quant_scale for NVFP4 AWQ Nov 3, 2025
@meenchen meenchen changed the title Fusing pre_quant_scale for NVFP4 AWQ [OMNIML-2932] Fusing pre_quant_scale for NVFP4 AWQ Nov 3, 2025
@meenchen meenchen marked this pull request as ready for review November 3, 2025 23:39
@meenchen meenchen requested a review from a team as a code owner November 3, 2025 23:39
@@ -0,0 +1,193 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that test will still pass.

from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
fuse_prequant_to_linear,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use_prequant_to_linear and fuse_prequant_layernorm be combined or they are mutual exclusive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are quite different. use_prequant_to_linear is rule-based fusion and doesn't need graph tracing.

layernorm_module.weight = torch.nn.Parameter(
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
if hasattr(layernorm_module, "bias"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to handle bias now (not before) because of some new model support or it's nvfp4 awq related?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is just for future proof

mtq.NVFP4_AWQ_LITE_CFG,
],
)
def test_pattern_fuse_prequant_moe(quant_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also cover a test case for BMM style MoE like in llama4 or gpt-oss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation does not work for BMM style Moe, but we can add the support later.

@meenchen meenchen requested a review from a team as a code owner November 8, 2025 00:23
.expand(num_kv_heads, n_rep, kv_head_dim)
.reshape(-1)
)
# Update o_proj's pre_quant_scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this update is regards to update o_proj's PQS so we can just take the first head and apply to v right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this updates the o_proj's PQS, so input channels of o_proj associated with the same query group (output channel) of v have the same prequant scale.

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this.

Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
Signed-off-by: weimingc <[email protected]>
@meenchen meenchen enabled auto-merge (squash) November 19, 2025 19:18
@meenchen meenchen merged commit 1d0ee04 into main Nov 19, 2025
27 checks passed
@meenchen meenchen deleted the weimingc/fuse_pqs branch November 19, 2025 21:16
yeyu-nvidia pushed a commit that referenced this pull request Dec 8, 2025
## What does this PR do?

**Type of change:** ? <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->

**Overview:**

This PR and NVIDIA/TensorRT-LLM#8698 enable
NVFP4 AWQ deployment for TRT-LLM. Specifically, this PR fuses
pre_quant_scale in following two cases:
* For MLP, pre_quant_scale of gate_proj layer is fused into up_proj's
weight, so we don't need an extra handle in downstream fused moe
kernels.
* For attention, we will try to fuse the pre_quant_scale of o_proj to
v_proj if their dimensions match, which means we will skip fusion for
MQA/GQA models.

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->
unit test, e2e test for Qwen3 dense and moe models.

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

---------

Signed-off-by: weimingc <[email protected]>
soodoshll pushed a commit to soodoshll/TensorRT-Model-Optimizer that referenced this pull request Dec 8, 2025
## What does this PR do?

**Type of change:** ? <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->

**Overview:** 

This PR and NVIDIA/TensorRT-LLM#8698 enable
NVFP4 AWQ deployment for TRT-LLM. Specifically, this PR fuses
pre_quant_scale in following two cases:
* For MLP, pre_quant_scale of gate_proj layer is fused into up_proj's
weight, so we don't need an extra handle in downstream fused moe
kernels.
* For attention, we will try to fuse the pre_quant_scale of o_proj to
v_proj if their dimensions match, which means we will skip fusion for
MQA/GQA models.

## Usage
<!-- You can potentially add a usage example below. -->

```python
# Add a code snippet demonstrating how to use this
```

## Testing
<!-- Mention how have you tested your change if applicable. -->
unit test, e2e test for Qwen3 dense and moe models.

## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->

- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->

## Additional Information
<!-- E.g. related issue. -->

---------

Signed-off-by: weimingc <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants