Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed issues per the review
Signed-off-by: Oleg Goncharov <[email protected]>
  • Loading branch information
Oleg-Goncharov committed Oct 31, 2025
commit 1df9fed76ecb8c02be8b9744e4f8c681d569af78
18 changes: 10 additions & 8 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,16 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu

// Optimized BWD/FWD SwiGLU MXFP8 Rowwise kernels for BF16/FP16 inputs
if constexpr (!std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((!IS_BWD && (ActOP == &silu<fp32, fp32>)) ||
(IS_BWD && (ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>))) {
if (((gated_input.dtype() == DType::kFloat16) || (gated_input.dtype() == DType::kBFloat16)) &&
(scaling_type == ScalingType::ROWWISE)) {
quantize_gated_rowwise<IS_BWD, ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
stream);
return;
}
const bool is_fwd_swiglu = !IS_BWD && (ActOP == &silu<fp32, fp32>);
const bool is_bwd_swiglu = IS_BWD && (ActOP == &silu<fp32, fp32>) &&
(DActOP == &dsilu<fp32, fp32>);
const bool is_supported_data_type = (gated_input.dtype() == DType::kFloat16) ||
(gated_input.dtype() == DType::kBFloat16);
const bool is_supported_scaling_type = scaling_type == ScalingType::ROWWISE;
if (is_supported_data_type && is_supported_scaling_type && (is_fwd_swiglu || is_bwd_swiglu)) {
quantize_gated_rowwise<IS_BWD, ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
stream);
return;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ __device__ __forceinline__ void compute_bwd_gated_activation(
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
const float x = act_elt;
float x = act_elt;
float act_x;
float dact_x;
if constexpr (IS_CLAMPED_SWIGLU) {
const float x = min(act_elt, p.limit);
x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
Expand Down
Loading