Skip to content
127 changes: 127 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/moe_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/providers/common.h"
#include "core/framework/tensor_shape.h"
#include "core/util/shape_checker.h"

namespace onnxruntime {
namespace contrib {

enum class MoEParallelType {
None = 0,
EP = 1,
TP = 2,
EPAndTP = 3,
};

struct MoEParameters {
MoEParameters() {}
explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
int64_t num_rows;
int64_t num_experts;
int64_t local_num_experts;
int64_t hidden_size;
int64_t inter_size;

MoEParallelType parallel_type;
int64_t tensor_shards{1};
};

namespace moe_helper {

template <typename Tensor>
Status CheckInputs(MoEParameters& parameters,
const Tensor* input, // required
const Tensor* router_probs, // required
const Tensor* fc1_experts_weights, // required
const Tensor* fc1_experts_bias, // optional
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc2_experts_weights, // required
const Tensor* fc2_experts_bias, // optional
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc3_experts_weights, // optional
const Tensor* fc3_experts_bias, // optional
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
const int pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
const bool is_fused_swiglu) {
ASSERT_TENSOR_2D_OR_3D(input);
ASSERT_TENSOR_3D(fc1_experts_weights);
ASSERT_TENSOR_3D(fc2_experts_weights);
ASSERT_TENSOR_2D(router_probs);
ASSERT_TENSOR_2D(router_probs);

const auto& input_dims = input->Shape().GetDims();
const auto& router_probs_dims = router_probs->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();

int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size;
const bool legacy_shape = hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size;

// Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one.
const int fc1_inter_size = is_fused_swiglu ? 2 * inter_size : inter_size;

if (legacy_shape) {
// legacy shape does not match the memory layout. This is for backward compatible
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
} else {
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size);
}

CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts);

CHECK_TENSOR_SHAPE(fc1_experts_bias, num_experts, fc1_inter_size);
CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size);
CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size);

CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size);
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size);
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size);

if (fc3_experts_weights == nullptr) {
ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr);
} else { // fc3 exists
ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales
}

parameters.num_rows = num_rows;
parameters.num_experts = num_experts;
parameters.local_num_experts = local_num_experts;
parameters.hidden_size = hidden_size;
parameters.inter_size = inter_size;
if (num_experts == local_num_experts) {
if (parameters.tensor_shards == 1) {
parameters.parallel_type = MoEParallelType::None;
} else {
parameters.parallel_type = MoEParallelType::TP;
}
} else if (num_experts > local_num_experts) {
if (parameters.tensor_shards == 1) {
parameters.parallel_type = MoEParallelType::EP;
} else {
parameters.parallel_type = MoEParallelType::EPAndTP;
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_experts must be greater than or equal to local_num_experts, got ", num_experts,
" and ", local_num_experts);
}

return Status::OK();
}

} // namespace moe_helper
} // namespace contrib
} // namespace onnxruntime
12 changes: 8 additions & 4 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,14 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

MoEParameters moe_params(tensor_shards_);
MoEQuantType quant_type = MoEQuantType::None;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));
MoEParameters moe_params;
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
moe_params, input, router_probs,
fc1_experts_weights, fc1_experts_bias_optional, nullptr,
fc2_experts_weights, fc2_experts_bias_optional, nullptr,
fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr,
1, // no quantization so pack size is 1
activation_type_ == ort_fastertransformer::ActivationType::SwiGLU));

ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

MoEParameters moe_params;
MoEQuantType quant_type = MoEQuantType::None;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
moe_params, input, router_probs,
fc1_experts_weights, fc1_experts_bias_optional, nullptr,
fc2_experts_weights, fc2_experts_bias_optional, nullptr,
fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr,
1, // no quantization so pack size is 1
activation_type_ == ort_fastertransformer::ActivationType::SwiGLU));

using CudaT = typename OrtToCudaType<T>::type;
auto stream = context->GetComputeStream();
Expand Down
200 changes: 1 addition & 199 deletions onnxruntime/contrib_ops/cuda/moe/moe_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,211 +7,13 @@
#include "core/framework/tensor_shape.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h"
#include "contrib_ops/cpu/quantization/moe_helper.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

enum class MoEParallelType {
None = 0,
EP = 1,
TP = 2,
EPAndTP = 3,
};

enum class MoEQuantType {
None = 0,
UINT4 = 1,
UINT8 = 2,
};

struct MoEParameters {
MoEParameters() {}
explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
int64_t num_rows;
int64_t num_experts;
int64_t local_num_experts;
int64_t hidden_size;
int64_t inter_size;

MoEParallelType parallel_type;
int64_t tensor_shards{1};
};

class MoEBase {
public:
Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input,
const Tensor* router_probs, const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional,
const Tensor* fc3_experts_bias_optional) const {
const auto& input_dims = input->Shape().GetDims();
const auto& router_probs_dims = router_probs->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();

int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = fc2_experts_weights_dims[1];

if (fc1_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ",
fc1_experts_weights_dims.size());
}
if (fc2_experts_weights_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ",
fc2_experts_weights_dims.size());
}
if (fc1_experts_weights_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_weights_dims[1] must be equal to hidden_size, got ",
fc1_experts_weights_dims[1], " and ", hidden_size);
}
if (fc2_experts_weights_dims[1] != inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_weights_dims[1] must be equal to inter_size, got ",
fc2_experts_weights_dims[1], " and ", inter_size);
}

const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1;
const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1;
if (fc1_experts_weights_dims[2] != act * inter_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_weights_dims[2] is ",
fc1_experts_weights_dims[2], " expected ", act * inter_size / coe);
}
if (fc2_experts_weights_dims[2] != hidden_size / coe) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_weights_dims[2] is ",
fc2_experts_weights_dims[2], " expected ", hidden_size / coe);
}

if (router_probs_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ",
router_probs_dims.size());
}
if (router_probs_dims[0] != num_rows) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ",
router_probs_dims[0], " and ", num_rows);
}
if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) {
const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims();
const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims();
if (fc1_experts_bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ",
fc1_experts_bias_dims.size());
}
if (fc2_experts_bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ",
fc2_experts_bias_dims.size());
}
if (fc1_experts_bias_dims[0] != local_num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_bias_dims[0] must be equal to local_num_experts, got ",
fc1_experts_bias_dims[0], " and ", local_num_experts);
}
if (fc2_experts_bias_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0],
" and ", num_experts);
}
if (fc1_experts_bias_dims[1] != act * inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc1_experts_bias_dims[1] is ", fc1_experts_bias_dims[1],
", expected ", act * inter_size);
}
if (fc2_experts_bias_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1],
" and ", hidden_size);
}
}

if (fc3_experts_weights_optional != nullptr &&
fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ",
fc3_experts_weights_optional->Shape(), " and ", TensorShape(fc1_experts_weights_dims));
}

if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr &&
fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT, "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ",
fc3_experts_bias_optional->Shape(), " and ", fc1_experts_bias_optional->Shape());
}

parameters.num_rows = num_rows;
parameters.num_experts = num_experts;
parameters.local_num_experts = local_num_experts;
parameters.hidden_size = hidden_size;
parameters.inter_size = inter_size;
if (num_experts == local_num_experts) {
if (parameters.tensor_shards == 1) {
parameters.parallel_type = MoEParallelType::None;
} else {
parameters.parallel_type = MoEParallelType::TP;
}
} else if (num_experts > local_num_experts) {
if (parameters.tensor_shards == 1) {
parameters.parallel_type = MoEParallelType::EP;
} else {
parameters.parallel_type = MoEParallelType::EPAndTP;
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_experts must be greater than or equal to local_num_experts, got ", num_experts,
" and ", local_num_experts);
}

return Status::OK();
}

Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales,
const Tensor* fc3_experts_scales, int64_t num_experts, int64_t hidden_size,
int64_t inter_size) const {
const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims();
const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims();

if (fc1_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ",
fc1_experts_scales->Shape().GetDims().size());
}
if (fc1_experts_scales_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ",
fc1_experts_scales_dims[0], " and ", num_experts);
}

// The activation type affects the output dimension of the first FC layer.
const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1;
if (fc1_experts_scales_dims[1] != act * inter_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to act * inter_size, got ",
fc1_experts_scales_dims[1], " and ", act * inter_size);
}

if (fc2_experts_scales_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ",
fc2_experts_scales->Shape().GetDims().size());
}
if (fc2_experts_scales_dims[0] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ",
fc2_experts_scales_dims[0], " and ", num_experts);
}
if (fc2_experts_scales_dims[1] != hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ",
fc2_experts_scales_dims[1], " and ", hidden_size);
}
if (fc3_experts_scales != nullptr && fc1_experts_scales_dims != fc3_experts_scales->Shape().GetDims()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"fc3_experts_scales must be equal to fc1_experts_scales, got ",
fc3_experts_scales->Shape(), " and ", TensorShape(fc1_experts_scales_dims));
}

return Status::OK();
}

protected:
MoEBase(const OpKernelInfo& op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());
Expand Down
Loading
Loading