diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 48f4b7db691..655c2f5526b 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #include #include @@ -3484,3 +3485,24 @@ void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) { break; } } + +void ggml_cann_op_add_mul_fused(ggml_backend_cann_context & ctx, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + ggml_tensor * weight = mul_tensor->src[0]; + ggml_tensor * input = mul_tensor->src[1]; + ggml_tensor * add = add_tensor->src[1]; + + acl_tensor_ptr acl_input_tensor = ggml_cann_create_tensor(input, input->ne, input->nb, 2, ACL_FORMAT_ND); + int64_t transpose_ne[] = { weight->ne[1], weight->ne[0], weight->ne[2], weight->ne[3]}; + size_t transpose_nb[] = { weight->nb[1], weight->nb[0], weight->nb[2], weight->nb[3]}; + + acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, 2, ACL_FORMAT_ND); + + acl_tensor_ptr acl_add = ggml_cann_create_tensor(add, add->ne, add->nb, 2, ACL_FORMAT_ND); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(add_tensor, add_tensor->ne, add_tensor->nb, 2, ACL_FORMAT_ND); + + float value = 1.0f; + acl_scalar_ptr alpha = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); + acl_scalar_ptr beta = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Addmm, acl_add.get(), acl_input_tensor.get(), acl_weight_tensor.get(), beta.get(), alpha.get(), acl_dst.get(), 0); +} \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 1ebbc769c71..3ceeae9f581 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1145,3 +1145,15 @@ void ggml_cann_op_unary_gated(std::function ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) { + ggml_tensor * mul_node = cgraph->nodes[node_idx]; + ggml_tensor * add_node = cgraph->nodes[node_idx + 1]; + + if (mul_node->src[0]->ne[2] != 1 || mul_node->src[0]->ne[3] != 1 || + mul_node->src[1]->ne[2] != 1 || mul_node->src[1]->ne[3] != 1 || + add_node->src[1]->ne[2] != 1 || add_node->src[1]->ne[3] != 1) { + return false; + } + + if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || + add_node->src[0]->ne[1] != add_node->src[1]->ne[1]) { + return false; + } + + if (add_node->src[0] != mul_node) { + return false; + } + + // if (strstr(mul_node->src[0]->name, "attn_q.weight") == NULL) { + // return false; + // } + return true; + } + + return false; +} + /** * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @@ -2255,9 +2289,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. + static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or("")); if (!use_cann_graph || cann_graph_update_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + + if (opt_fusion) { + if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { + ggml_cann_op_add_mul_fused(*cann_ctx, node, cgraph->nodes[i + 1]); + i++; + continue; + } + } if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9645d0b3909..4de3707cea5 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3721,6 +3721,71 @@ struct test_mul_mat_id : public test_case { } }; +// GGML_OP_MUL_MAT + GGML_OP_ADD (fused operation) +struct test_mul_mat_add : public test_case { + const ggml_type type_a; + const ggml_type type_b; + const int64_t m; + const int64_t n; + const int64_t k; + const std::array bs; // dims 3 and 4 + const std::array nr; // repeat in dims 3 and 4 + const bool broadcast_bias; // whether to broadcast bias (single value per output channel) + + std::string vars() override { + return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, broadcast_bias); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "MUL_MAT_ADD"; + } + + double max_nmse_err() override { + return 5e-4; + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1] + m * n * bs[0] * nr[0] * bs[1] * nr[1]; + } + + test_mul_mat_add(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + int64_t m = 32, int64_t n = 32, int64_t k = 32, + std::array bs = {10, 10}, + std::array nr = {2, 2}, + bool broadcast_bias = false) + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), broadcast_bias(broadcast_bias) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + + ggml_tensor * a = ::ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + ggml_set_name(a, "a"); + + ggml_tensor * b = ::ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); + ggml_set_name(b, "b"); + + + ggml_tensor * mul_mat_result = ggml_mul_mat(ctx, a, b); + ggml_set_name(mul_mat_result, "mul_mat_result"); + + + ggml_tensor * bias; + if (broadcast_bias) { + bias = ::ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, n, bs[0]*nr[0], bs[1]*nr[1]); + } else { + bias = ::ggml_new_tensor_4d(ctx, GGML_TYPE_F32, m, 1, bs[0]*nr[0], bs[1]*nr[1]); + } + ggml_set_name(bias, "bias"); + ggml_tensor * out = ggml_add(ctx, mul_mat_result, bias); + ggml_set_name(out, "out"); + + return out; + } + + bool run_whole_graph() override { return true; } +}; + // GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL struct test_mul_mat_id_fusion : public test_case { const ggml_type type_a; @@ -7434,6 +7499,19 @@ static std::vector> make_test_cases_eval() { } } + for (int m : {128, 896}) { + // For pattern testing + for (int n : {1, 14, 20, 21}) { + for (int k : {128,896}) { + for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat_add(type_a, type_b, m, n, k, {1, 1}, {1, 1}, false)); + } + } + } + } +} + // add_id for (ggml_type type_a : {GGML_TYPE_F32}) { for (ggml_type type_b : {GGML_TYPE_F32}) {