Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,16 @@ inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_bl

inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
const size_t alignment_requirement = (TMA_GMEM_ALIGNMENT * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}

__device__ __forceinline__ unsigned char *align_smem_ptr_per_TMA_requirements(unsigned char *p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char *>(addr);
}

namespace kernel {

constexpr size_t THREADS_PER_BLOCK = 256;
Expand Down
14 changes: 14 additions & 0 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "./specialized/gated_mxfp8_rowwise_swiglu.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -696,6 +697,19 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
scaling_type = ScalingType::BIDIMENSIONAL;
}

// Optimized BWD/FWD SwiGLU MXFP8 Rowwise kernels for BF16/FP16 inputs
if constexpr (!std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((!IS_BWD && (ActOP == &silu<fp32, fp32>)) ||
(IS_BWD && (ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>))) {
if (((gated_input.dtype() == DType::kFloat16) || (gated_input.dtype() == DType::kBFloat16)) &&
(scaling_type == ScalingType::ROWWISE)) {
quantize_gated_rowwise<IS_BWD, ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
stream);
return;
}
}
}

const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
Expand Down
Loading
Loading