Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
CANN: Support fusion operator that supports mul and add
  • Loading branch information
TianHao324 committed Dec 3, 2025
commit 046d2a00eac53e8fcd88c011774d4c760e368dcf
22 changes: 22 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
#include <aclnnop/aclnn_upsample_nearest_2d.h>
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
#include <aclnnop/aclnn_zero.h>
#include <aclnnop/aclnn_addmm.h>
#include <float.h>

#include <cmath>
Expand Down Expand Up @@ -3484,3 +3485,24 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
break;
}
}

void ggml_cann_op_add_mul_fused(ggml_backend_cann_context & ctx, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) {
ggml_tensor * weight = mul_tensor->src[0];
ggml_tensor * input = mul_tensor->src[1];
ggml_tensor * add = add_tensor->src[1];

acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(input, input->ne, input->nb, 2, ACL_FORMAT_ND);
int64_t transpose_ne[] = { weight->ne[1], weight->ne[0], weight->ne[2], weight->ne[3]};
size_t transpose_nb[] = { weight->nb[1], weight->nb[0], weight->nb[2], weight->nb[3]};

acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, 2, ACL_FORMAT_ND);

acl_tensor_ptr acl_add = ggml_cann_create_tensor(add, add->ne, add->nb, 2, ACL_FORMAT_ND);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(add_tensor, add_tensor->ne, add_tensor->nb, 2, ACL_FORMAT_ND);

float value = 1.0f;
acl_scalar_ptr alpha = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);
acl_scalar_ptr beta = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);

GGML_CANN_CALL_ACLNN_OP(ctx, Addmm, acl_add.get(), acl_input_tensor.get(), acl_weight_tensor.get(), beta.get(), alpha.get(), acl_dst.get(), 0);
}
12 changes: 12 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1145,3 +1145,15 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
*/
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);

/**
* @brief Performs fused MUL + ADD operation using the CANN backend.
*
* This function fuses the MUL and ADD operations into a single kernel call
* for better performance.
*
* @param ctx The context for the CANN backend operations.
* @param mul_tensor The MUL operation node, contains the two input tensors to be calculated.
* @param add_tensor The ADD operation node, contains the tensor to be added.
*/
void ggml_cann_op_add_mul_fused(ggml_backend_cann_context & ctx, ggml_tensor * mul_tensor, ggml_tensor * add_tensor);
43 changes: 43 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2231,6 +2231,40 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
}
#endif // USE_ACL_GRAPH

static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
ggml_tensor * mul_node = cgraph->nodes[node_idx];
ggml_tensor * add_node = cgraph->nodes[node_idx + 1];

if (mul_node->src[0]->ne[2] != 1 || mul_node->src[0]->ne[3] != 1 ||
mul_node->src[1]->ne[2] != 1 || mul_node->src[1]->ne[3] != 1 ||
add_node->src[1]->ne[2] != 1 || add_node->src[1]->ne[3] != 1) {
return false;
}

if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] ||
add_node->src[0]->ne[1] != add_node->src[1]->ne[1]) {
return false;
}

if (add_node->src[0] != mul_node) {
return false;
}

// if (strstr(mul_node->src[0]->name, "attn_q.weight") == NULL) {
// return false;
// }
return true;
}

return false;
}

/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
Expand All @@ -2255,9 +2289,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or(""));
if (!use_cann_graph || cann_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];

if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
ggml_cann_op_add_mul_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
i++;
continue;
}
}

if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
Expand Down
78 changes: 78 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3721,6 +3721,71 @@ struct test_mul_mat_id : public test_case {
}
};

// GGML_OP_MUL_MAT + GGML_OP_ADD (fused operation)
struct test_mul_mat_add : public test_case {
const ggml_type type_a;
const ggml_type type_b;
const int64_t m;
const int64_t n;
const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const bool broadcast_bias; // whether to broadcast bias (single value per output channel)

std::string vars() override {
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, broadcast_bias);
}

std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "MUL_MAT_ADD";
}

double max_nmse_err() override {
return 5e-4;
}

uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1] + m * n * bs[0] * nr[0] * bs[1] * nr[1];
}

test_mul_mat_add(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
bool broadcast_bias = false)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), broadcast_bias(broadcast_bias) {}

ggml_tensor * build_graph(ggml_context * ctx) override {

ggml_tensor * a = ::ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
ggml_set_name(a, "a");

ggml_tensor * b = ::ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_name(b, "b");


ggml_tensor * mul_mat_result = ggml_mul_mat(ctx, a, b);
ggml_set_name(mul_mat_result, "mul_mat_result");


ggml_tensor * bias;
if (broadcast_bias) {
bias = ::ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, n, bs[0]*nr[0], bs[1]*nr[1]);
} else {
bias = ::ggml_new_tensor_4d(ctx, GGML_TYPE_F32, m, 1, bs[0]*nr[0], bs[1]*nr[1]);
}
ggml_set_name(bias, "bias");
ggml_tensor * out = ggml_add(ctx, mul_mat_result, bias);
ggml_set_name(out, "out");

return out;
}

bool run_whole_graph() override { return true; }
};

// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL
struct test_mul_mat_id_fusion : public test_case {
const ggml_type type_a;
Expand Down Expand Up @@ -7434,6 +7499,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}

for (int m : {128, 896}) {
// For pattern testing
for (int n : {1, 14, 20, 21}) {
for (int k : {128,896}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat_add(type_a, type_b, m, n, k, {1, 1}, {1, 1}, false));
}
}
}
}
}

// add_id
for (ggml_type type_a : {GGML_TYPE_F32}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
Expand Down
Loading