Skip to content

Conversation

@tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Aug 1, 2025

Weight Shape Update

Make sure the shape reflects actual memory layout. The weight is stored in column major.

Add support for SwiGLU activation attributes

Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by introducing a few new attributes. For reference, see the Triton kernel implementation.

New Attributes for SwiGLU

  • swiglu_fusion:

    • 0: Not fused — two separate GEMMs (FC1 and FC3).
    • 1: Fused GEMMs using interleaved format (g and l are interleaved per row).
    • 2: Fused GEMMs using non-interleaved (concatenated) format.
  • swiglu_limit: Clamp threshold applied to g and l.

  • activation_alpha: Scalar multiplier applied to g before sigmoid.

  • activation_beta: Added to l before the final output computation.


SwiGLU Activation Function

The SwiGLU function is defined as:

g = xW + b
l = xV + c
G = min(g, limit)
L = max(min(l, limit), -limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
  • x: Input
  • W, V: Weight matrices
  • b, c: Bias vectors
  • alpha, beta, limit: Float constants

Fusion Behavior

  • When swiglu_fusion = 0:

    • Two GEMMs are computed independently.
    • FC1 → computes g, FC3 → computes l.
  • When swiglu_fusion = 1:

    • g and l are computed in a single fused GEMM (FC1).
    • Output is interleaved per row as: gate, linear, gate, linear, ....
  • When swiglu_fusion = 2:

    • g and l are computed in a single GEMM (FC1).
    • Output is concatenated per row: [g | l].

Implement swiglu_limit for CUDA

Update CUDA kernel to use default swiglu limit.
Update test_moe_cuda.py to have same logic in reference implementation.

Remaining Works

The main purpose of this PR is to update spec instead of implementing them.
Note that MoE/qMoE ops and tests still use hard-coded parameters and will be changed later to read from those attributes.

Column-wise symmetric quantization is used for qMoE. We will add more quantization details when we add support of block-wise quantization soon.

@tianleiwu tianleiwu marked this pull request as draft August 1, 2025 00:21
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@tianleiwu tianleiwu marked this pull request as ready for review August 2, 2025 07:14
@tianleiwu tianleiwu merged commit 562760a into main Aug 2, 2025
92 checks passed
@tianleiwu tianleiwu deleted the tlwu/moe_spec branch August 2, 2025 23:31
sophies927 pushed a commit that referenced this pull request Aug 3, 2025
### Weight Shape Update
Make sure the shape reflects actual memory layout. The weight is stored
in column major.

### Add support for SwiGLU activation attributes
Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by
introducing a few new attributes. For reference, see the [Triton kernel
implementation](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py).


#### New Attributes for SwiGLU

* **`swiglu_fusion`**:

  * `0`: Not fused — two separate GEMMs (FC1 and FC3).
* `1`: Fused GEMMs using **interleaved** format (g and l are interleaved
per row).
  * `2`: Fused GEMMs using **non-interleaved** (concatenated) format.

* **`swiglu_limit`**: Clamp threshold applied to `g` and `l`.

* **`activation_alpha`**: Scalar multiplier applied to `g` before
sigmoid.

* **`activation_beta`**: Added to `l` before the final output
computation.

---

### SwiGLU Activation Function

The SwiGLU function is defined as:

```
g = xW + b
l = xV + c
G = min(g, limit)
L = max(min(l, limit), -limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
```

* `x`: Input
* `W`, `V`: Weight matrices
* `b`, `c`: Bias vectors
* `alpha`, `beta`, `limit`: Float constants

---

### Fusion Behavior

* When `swiglu_fusion = 0`:

  * Two GEMMs are computed independently.
  * FC1 → computes `g`, FC3 → computes `l`.

* When `swiglu_fusion = 1`:

  * `g` and `l` are computed in a **single fused GEMM** (FC1).
* Output is **interleaved** per row as: `gate, linear, gate, linear,
...`.

* When `swiglu_fusion = 2`:

  * `g` and `l` are computed in a single GEMM (FC1).
  * Output is **concatenated** per row: `[g | l]`.

### Implement swiglu_limit for CUDA
Update CUDA kernel to use default swiglu limit.
Update test_moe_cuda.py to have same logic in reference implementation.

### Remaining Works
The main purpose of this PR is to update spec instead of implementing
them.
Note that MoE/qMoE ops and tests still use hard-coded parameters and
will be changed later to read from those attributes.

Column-wise symmetric quantization is used for qMoE. We will add more
quantization details when we add support of block-wise quantization
soon.
apsonawane pushed a commit that referenced this pull request Aug 3, 2025
Make sure the shape reflects actual memory layout. The weight is stored
in column major.

Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by
introducing a few new attributes. For reference, see the [Triton kernel
implementation](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py).

* **`swiglu_fusion`**:

  * `0`: Not fused — two separate GEMMs (FC1 and FC3).
* `1`: Fused GEMMs using **interleaved** format (g and l are interleaved
per row).
  * `2`: Fused GEMMs using **non-interleaved** (concatenated) format.

* **`swiglu_limit`**: Clamp threshold applied to `g` and `l`.

* **`activation_alpha`**: Scalar multiplier applied to `g` before
sigmoid.

* **`activation_beta`**: Added to `l` before the final output
computation.

---

The SwiGLU function is defined as:

```
g = xW + b
l = xV + c
G = min(g, limit)
L = max(min(l, limit), -limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
```

* `x`: Input
* `W`, `V`: Weight matrices
* `b`, `c`: Bias vectors
* `alpha`, `beta`, `limit`: Float constants

---

* When `swiglu_fusion = 0`:

  * Two GEMMs are computed independently.
  * FC1 → computes `g`, FC3 → computes `l`.

* When `swiglu_fusion = 1`:

  * `g` and `l` are computed in a **single fused GEMM** (FC1).
* Output is **interleaved** per row as: `gate, linear, gate, linear,
...`.

* When `swiglu_fusion = 2`:

  * `g` and `l` are computed in a single GEMM (FC1).
  * Output is **concatenated** per row: `[g | l]`.

Update CUDA kernel to use default swiglu limit.
Update test_moe_cuda.py to have same logic in reference implementation.

The main purpose of this PR is to update spec instead of implementing
them.
Note that MoE/qMoE ops and tests still use hard-coded parameters and
will be changed later to read from those attributes.

Column-wise symmetric quantization is used for qMoE. We will add more
quantization details when we add support of block-wise quantization
soon.
sanketkaleoss pushed a commit to sanketkaleoss/onnxruntime that referenced this pull request Aug 11, 2025
### Weight Shape Update
Make sure the shape reflects actual memory layout. The weight is stored
in column major.

### Add support for SwiGLU activation attributes
Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by
introducing a few new attributes. For reference, see the [Triton kernel
implementation](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py).


#### New Attributes for SwiGLU

* **`swiglu_fusion`**:

  * `0`: Not fused — two separate GEMMs (FC1 and FC3).
* `1`: Fused GEMMs using **interleaved** format (g and l are interleaved
per row).
  * `2`: Fused GEMMs using **non-interleaved** (concatenated) format.

* **`swiglu_limit`**: Clamp threshold applied to `g` and `l`.

* **`activation_alpha`**: Scalar multiplier applied to `g` before
sigmoid.

* **`activation_beta`**: Added to `l` before the final output
computation.

---

### SwiGLU Activation Function

The SwiGLU function is defined as:

```
g = xW + b
l = xV + c
G = min(g, limit)
L = max(min(l, limit), -limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
```

* `x`: Input
* `W`, `V`: Weight matrices
* `b`, `c`: Bias vectors
* `alpha`, `beta`, `limit`: Float constants

---

### Fusion Behavior

* When `swiglu_fusion = 0`:

  * Two GEMMs are computed independently.
  * FC1 → computes `g`, FC3 → computes `l`.

* When `swiglu_fusion = 1`:

  * `g` and `l` are computed in a **single fused GEMM** (FC1).
* Output is **interleaved** per row as: `gate, linear, gate, linear,
...`.

* When `swiglu_fusion = 2`:

  * `g` and `l` are computed in a single GEMM (FC1).
  * Output is **concatenated** per row: `[g | l]`.

### Implement swiglu_limit for CUDA
Update CUDA kernel to use default swiglu limit.
Update test_moe_cuda.py to have same logic in reference implementation.

### Remaining Works
The main purpose of this PR is to update spec instead of implementing
them.
Note that MoE/qMoE ops and tests still use hard-coded parameters and
will be changed later to read from those attributes.

Column-wise symmetric quantization is used for qMoE. We will add more
quantization details when we add support of block-wise quantization
soon.
gedoensmax pushed a commit to gedoensmax/onnxruntime that referenced this pull request Sep 2, 2025
### Weight Shape Update
Make sure the shape reflects actual memory layout. The weight is stored
in column major.

### Add support for SwiGLU activation attributes
Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by
introducing a few new attributes. For reference, see the [Triton kernel
implementation](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py).


#### New Attributes for SwiGLU

* **`swiglu_fusion`**:

  * `0`: Not fused — two separate GEMMs (FC1 and FC3).
* `1`: Fused GEMMs using **interleaved** format (g and l are interleaved
per row).
  * `2`: Fused GEMMs using **non-interleaved** (concatenated) format.

* **`swiglu_limit`**: Clamp threshold applied to `g` and `l`.

* **`activation_alpha`**: Scalar multiplier applied to `g` before
sigmoid.

* **`activation_beta`**: Added to `l` before the final output
computation.

---

### SwiGLU Activation Function

The SwiGLU function is defined as:

```
g = xW + b
l = xV + c
G = min(g, limit)
L = max(min(l, limit), -limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
```

* `x`: Input
* `W`, `V`: Weight matrices
* `b`, `c`: Bias vectors
* `alpha`, `beta`, `limit`: Float constants

---

### Fusion Behavior

* When `swiglu_fusion = 0`:

  * Two GEMMs are computed independently.
  * FC1 → computes `g`, FC3 → computes `l`.

* When `swiglu_fusion = 1`:

  * `g` and `l` are computed in a **single fused GEMM** (FC1).
* Output is **interleaved** per row as: `gate, linear, gate, linear,
...`.

* When `swiglu_fusion = 2`:

  * `g` and `l` are computed in a single GEMM (FC1).
  * Output is **concatenated** per row: `[g | l]`.

### Implement swiglu_limit for CUDA
Update CUDA kernel to use default swiglu limit.
Update test_moe_cuda.py to have same logic in reference implementation.

### Remaining Works
The main purpose of this PR is to update spec instead of implementing
them.
Note that MoE/qMoE ops and tests still use hard-coded parameters and
will be changed later to read from those attributes.

Column-wise symmetric quantization is used for qMoE. We will add more
quantization details when we add support of block-wise quantization
soon.
tianleiwu added a commit that referenced this pull request Sep 4, 2025
### Weight Shape Update
Make sure the shape reflects actual memory layout. The weight is stored
in column major.

### Add support for SwiGLU activation attributes
Add spec for the new activation type SwiGLU (Swish-Gated Linear Unit) by
introducing a few new attributes. For reference, see the [Triton kernel
implementation](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/swiglu.py).


#### New Attributes for SwiGLU

* **`swiglu_fusion`**:

  * `0`: Not fused — two separate GEMMs (FC1 and FC3).
* `1`: Fused GEMMs using **interleaved** format (g and l are interleaved
per row).
  * `2`: Fused GEMMs using **non-interleaved** (concatenated) format.

* **`swiglu_limit`**: Clamp threshold applied to `g` and `l`.

* **`activation_alpha`**: Scalar multiplier applied to `g` before
sigmoid.

* **`activation_beta`**: Added to `l` before the final output
computation.

---

### SwiGLU Activation Function

The SwiGLU function is defined as:

```
g = xW + b
l = xV + c
G = min(g, limit)
L = max(min(l, limit), -limit)
swiglu = G * sigmoid(alpha * G) * (L + beta)
```

* `x`: Input
* `W`, `V`: Weight matrices
* `b`, `c`: Bias vectors
* `alpha`, `beta`, `limit`: Float constants

---

### Fusion Behavior

* When `swiglu_fusion = 0`:

  * Two GEMMs are computed independently.
  * FC1 → computes `g`, FC3 → computes `l`.

* When `swiglu_fusion = 1`:

  * `g` and `l` are computed in a **single fused GEMM** (FC1).
* Output is **interleaved** per row as: `gate, linear, gate, linear,
...`.

* When `swiglu_fusion = 2`:

  * `g` and `l` are computed in a single GEMM (FC1).
  * Output is **concatenated** per row: `[g | l]`.

### Implement swiglu_limit for CUDA
Update CUDA kernel to use default swiglu limit.
Update test_moe_cuda.py to have same logic in reference implementation.

### Remaining Works
The main purpose of this PR is to update spec instead of implementing
them.
Note that MoE/qMoE ops and tests still use hard-coded parameters and
will be changed later to read from those attributes.

Column-wise symmetric quantization is used for qMoE. We will add more
quantization details when we add support of block-wise quantization
soon.
@tianleiwu tianleiwu added cherry-picked Cherry-picked for a cherrypicks branch and removed release:1.23.0 labels Sep 4, 2025
jywu-msft pushed a commit that referenced this pull request Sep 5, 2025
### Description
Cherry-pick the following PRs:
#25943
#25937 
#25917
#25909
#25898
#25897
#25888
#25881
#25830
#25619
#25575
#25572
#25558
#25530
#25474
#25455
#25110

Also two dependent PRs for qMoE cpu: 
#25877
#25822

---------

Co-authored-by: xiaomsft <[email protected]>
Co-authored-by: Xiaoyan Hu <[email protected]>
Co-authored-by: Akshay Sonawane <[email protected]>
Co-authored-by: Kunal Vaishnavi <[email protected]>
Co-authored-by: Pradeep Sakhamoori <[email protected]>
Co-authored-by: mingyue <[email protected]>
Co-authored-by: Maximilian Müller <[email protected]>
Co-authored-by: Adrian Lizarraga <[email protected]>
Co-authored-by: Dmitri Smirnov <[email protected]>
Co-authored-by: Emmanuel <[email protected]>
Co-authored-by: Emmanuel Assumang <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: praneshgo <[email protected]>
Co-authored-by: Hariharan Seshadri <[email protected]>
Co-authored-by: Jing Fang <[email protected]>
Co-authored-by: Ishwar Raut <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-picked Cherry-picked for a cherrypicks branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants