Skip to content
Prev Previous commit
Next Next commit
refacotring
  • Loading branch information
tianleiwu committed Aug 2, 2025
commit 37abf5d5754ee001e78d7f619b6ac5b35c744c31
24 changes: 13 additions & 11 deletions onnxruntime/contrib_ops/cpu/quantization/moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ enum class MoEParallelType {
};

struct MoEParameters {
MoEParameters() {}
explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {}
int64_t num_rows;
int64_t num_experts;
int64_t local_num_experts;
int64_t hidden_size;
int64_t inter_size;

MoEParallelType parallel_type;
MoEParameters() = default;

explicit MoEParameters(int64_t tensor_shards)
: tensor_shards(tensor_shards) {}

int64_t num_rows{0};
int64_t num_experts{0};
int64_t local_num_experts{0};
int64_t hidden_size{0};
int64_t inter_size{0};

MoEParallelType parallel_type{MoEParallelType::None};
int64_t tensor_shards{1};
};

namespace moe_helper {

template <typename Tensor>
Expand Down Expand Up @@ -94,7 +96,7 @@ Status CheckInputs(MoEParameters& parameters,

if (fc3_experts_weights == nullptr) {
ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr);
} else { // fc3 exists
} else {
ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales
}

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

MoEParameters moe_params(tensor_shards_);
MoEParameters moe_params;
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
moe_params, input, router_probs,
fc1_experts_weights, fc1_experts_bias_optional, nullptr,
Expand Down