Skip to content
Draft
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
407b408
fix test failure
chraac Nov 27, 2025
4ddb8a4
fix: correct scaling calculations in rope_cache_init
chraac Nov 27, 2025
cfca78b
wip
chraac Nov 27, 2025
e9a02fd
wip
chraac Nov 28, 2025
e324bb0
fix: optimize element copying in rope_hex_f32 using memcpy
chraac Nov 28, 2025
0121291
fix: optimize loop boundaries in rope_hex_f32 for better performance
chraac Nov 28, 2025
010039a
rename
chraac Nov 28, 2025
a6ef41f
wip
chraac Nov 28, 2025
0376146
Merge branch 'master' into dev-fix-rope
chraac Nov 28, 2025
8abecfa
Merge tag 'b7207' into dev-fix-rope
chraac Nov 30, 2025
b567413
feat: add profiling macros for performance measurement in operations
chraac Nov 30, 2025
7c8f101
refactor: replace manual timing with profiling macros in matmul opera…
chraac Dec 3, 2025
3a70465
Merge branch 'master' into dev-fix-rope
chraac Dec 4, 2025
3b0cef4
Revert "refactor: replace manual timing with profiling macros in matm…
chraac Dec 5, 2025
121e656
Revert "feat: add profiling macros for performance measurement in ope…
chraac Dec 5, 2025
401fd3e
refactor: optimize vector operations in vec_dot_q4x4x2_q8x4x2_rx2 fun…
chraac Dec 5, 2025
cf491f2
wip
chraac Dec 5, 2025
3a01d82
feat: enhance vec_dot_q4x4x2_q8x4x2_rx2 function with optimized data …
chraac Dec 7, 2025
87ad8b2
Merge branch 'master' into dev-mulmat-opt
chraac Dec 8, 2025
421d031
feat: add hvx_vec_load_d_and_mpy function for optimized data loading …
chraac Dec 8, 2025
bd43860
wip
chraac Dec 8, 2025
b197464
feat: add hvx_vec_load_d_and_mpy_r2x2 function for optimized vector l…
chraac Dec 8, 2025
309d782
feat: optimize vec_dot functions with improved data handling and loading
chraac Dec 8, 2025
dbe9309
wip
chraac Dec 9, 2025
00d5fb3
feat: add build information and update vector loading functions for o…
chraac Dec 9, 2025
b54ff18
revert rope changes
chraac Dec 10, 2025
f757245
Merge tag 'b7345' into dev-mulmat-opt
chraac Dec 10, 2025
09c4899
fix: revert HVX_Vector back to HVX_UVector
chraac Dec 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 133 additions & 59 deletions ggml/src/ggml-hexagon/htp/matmul-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,50 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y,
return hvx_vec_rmpy_x8_n(x, y, 1024);
}

static inline HVX_Vector_x2 hvx_vec_load_and_mul_d_rx2(const uint8_t * restrict r0_x_d,
const uint8_t * restrict r1_x_d,
const uint8_t * restrict y_d,
const HVX_Vector rd_mask) {
HVX_Vector vy_d = *(const HVX_UVector *) y_d;
HVX_Vector r0_d = *(const HVX_UVector *) r0_x_d;
HVX_Vector r1_d = *(const HVX_UVector *) r1_x_d;

vy_d = Q6_Vh_vshuff_Vh(vy_d);
HVX_Vector r01_d = Q6_V_vmux_QVV(rd_mask, r0_d, r1_d);

vy_d = Q6_Vh_vshuffe_VhVh(vy_d, vy_d);
r01_d = Q6_Vh_vshuff_Vh(r01_d);

HVX_VectorPair r01_dd = Q6_Wqf32_vmpy_VhfVhf(r01_d, vy_d);

HVX_Vector_x2 r;
r.v[0] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r01_dd));
r.v[1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r01_dd));
return r;
}

static inline HVX_Vector_x4 hvx_vec_load_and_mul_d_r2x2(const uint8_t * restrict r0_x_d,
const uint8_t * restrict r1_x_d,
const uint8_t * restrict y_d) {
HVX_Vector vy_d = *(const HVX_UVector *) y_d;
HVX_Vector r0_d = *(const HVX_UVector *) r0_x_d;
HVX_Vector r1_d = *(const HVX_UVector *) r1_x_d;

vy_d = Q6_Vh_vshuff_Vh(vy_d);
r0_d = Q6_Vh_vshuff_Vh(r0_d);
r1_d = Q6_Vh_vshuff_Vh(r1_d);

HVX_VectorPair r0_dd = Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d);
HVX_VectorPair r1_dd = Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d);

HVX_Vector_x4 r;
r.v[0] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r0_dd));
r.v[1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r0_dd));
r.v[2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(r1_dd));
r.v[3] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(r1_dd));
return r;
}

static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % 32 == 0); // min sub-block size
assert((unsigned long) vx % 128 == 0);
Expand Down Expand Up @@ -393,11 +437,11 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,

const uint32_t qk = QK_Q4_0x4x2 * 4;

const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t x_dblk_size = 8 * 4 * sizeof(uint16_t); // 32x __fp16
const uint32_t x_qblk_size = qk / 2; // int4
const uint32_t x_qrow_size = n / 2; // int4 (not padded)

const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_dblk_size = 8 * 4 * sizeof(uint16_t); // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)

Expand All @@ -422,26 +466,65 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
const uint32_t nloe = n % qk; // num leftover elemements

uint32_t i = 0;
for (; i < nb; i++) {
for (; i + 1 < nb; i += 2) {
HVX_Vector r00_ia;
HVX_Vector r10_ia;
{
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);

r00_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
r10_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
}

HVX_Vector r01_ia;
HVX_Vector r11_ia;
{
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + (i + 1) * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + (i + 1) * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + (i + 1) * x_qblk_size);

r01_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
r11_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
}

HVX_Vector_x4 r_dd =
hvx_vec_load_and_mul_d_r2x2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size, y_d + i * y_dblk_size);
Copy link
Contributor Author

@chraac chraac Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimized the scale multiplication step. The previous implementation only processed 32xf16 elements (half the vector width). This change enables 64xf16 multiplication to fully utilize the HVX vector capacity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting garbled output for all models.
Also, ultimately we end up with the INT32 accumulator for each block (32 elements).
In order to multiply it with the FP16 scale we need to convert both (accumulator and scale) into FP32 (QF32). This means that we still need to do the same number of multiplies and use the same number of HVX registers either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting garbled output for all models.

Reverted scale loading to handle unaligned scales, as alignment cannot be ensured for all tensor shapes.

Thought this resolves the garbled output issues. Tested on:

  • llama3-1b: log
  • qwen3-1.7b: log

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, ultimately we end up with the INT32 accumulator for each block (32 elements).
In order to multiply it with the FP16 scale we need to convert both (accumulator and scale) into FP32 (QF32).

  • Regarding the scales utilization: The original source uses 2 Q6_Wqf32_vmpy_VhfVhf instructions for 2 rows but ignores the upper half. This PR aims to fully utilize the results of both multiplications.

  • As for the accumulator width: For Q4_0, an INT32 accumulator is likely excessive. Since src0 (4-bit) * src1 (8-bit) fits in 12 bits, accumulating 32 elements only requires 17 bits total. A 32-bit accumulator is far larger than what is strictly required.


HVX_Vector r00_fa = Q6_Vqf32_vmpy_VsfVsf(r00_ia, r_dd.v[0]);
HVX_Vector r01_fa = Q6_Vqf32_vmpy_VsfVsf(r01_ia, r_dd.v[1]);

HVX_Vector r10_fa = Q6_Vqf32_vmpy_VsfVsf(r10_ia, r_dd.v[2]);
HVX_Vector r11_fa = Q6_Vqf32_vmpy_VsfVsf(r11_ia, r_dd.v[3]);

r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r00_fa);
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r10_fa);

r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r01_fa);
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r11_fa);
}

const HVX_VectorPred rd_mask = Q6_Q_vsetq_R(VLEN / 2);
r1_x_d -= VLEN / 2; // make sure r1 at the high half of the vector

if (i < nb) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);

HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));

HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size,
y_d + i * y_dblk_size, rd_mask);

HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));

HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r_dd.v[0]);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r_dd.v[1]);

r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
i++;
}

// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
Expand All @@ -453,17 +536,13 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));

HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));

HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size,
y_d + i * y_dblk_size, rd_mask);

// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
HVX_Vector r0_dd = Q6_V_vand_QV(bmask, r_dd.v[0]);
HVX_Vector r1_dd = Q6_V_vand_QV(bmask, r_dd.v[1]);

HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
Expand All @@ -473,8 +552,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
}

// Convert into fp32 and reduce
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);

hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
Expand Down Expand Up @@ -594,6 +673,9 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
const uint32_t nb = n / qk; // num full blocks
int32_t nloe = n % qk; // num leftover elemements (must be signed)

const HVX_VectorPred rd_mask = Q6_Q_vsetq_R(VLEN / 2);
r1_x_d -= VLEN / 2; // make sure r1 at the high half of the vector

uint32_t i = 0;
for (; i < nb; i++) {
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
Expand All @@ -603,15 +685,11 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));

HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size,
y_d + i * y_dblk_size, rd_mask);

HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));

HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r_dd.v[0]);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r_dd.v[1]);

r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
Expand All @@ -626,17 +704,13 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));

HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));

HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
HVX_Vector_x2 r_dd = hvx_vec_load_and_mul_d_rx2(r0_x_d + i * x_dblk_size, r1_x_d + i * x_dblk_size,
y_d + i * y_dblk_size, rd_mask);

// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
HVX_Vector r0_dd = Q6_V_vand_QV(bmask, r_dd.v[0]);
HVX_Vector r1_dd = Q6_V_vand_QV(bmask, r_dd.v[1]);

HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
Expand All @@ -646,8 +720,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
}

// Convert into fp32 and reduce
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);

hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
Expand Down Expand Up @@ -684,8 +758,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
// Compute combined scale (fp32).
// Apply scale to acc and accumulate into the row sum (qf32).

const uint32_t nb = n / qk; // num full blocks
int32_t nloe = n % qk; // num leftover elemements (must be signed)
const uint32_t nb = n / qk; // num full blocks
int32_t nloe = n % qk; // num leftover elemements (must be signed)

const HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
const HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
const HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);

uint32_t i = 0;
for (; i < nb; i++) {
Expand All @@ -698,19 +776,16 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);

// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
vy_d = Q6_Vsf_equals_Vqf32(vy_d);
vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
vy_d = Q6_Vsf_equals_Vqf32(vy_d);

// Convert rX_d scales from e8m0 to fp32
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
// Left shift with zero fill to create FP32
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
r0_d = Q6_V_vdelta_VV(r0_d, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
r0_d = Q6_V_vdelta_VV(r0_d, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);

HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));

Expand Down Expand Up @@ -738,11 +813,9 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
// Left shift with zero fill to create FP32
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
r0_d = Q6_V_vdelta_VV(r0_d, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
r0_d = Q6_V_vdelta_VV(r0_d, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);

HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));

Expand Down Expand Up @@ -888,8 +961,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
}

// Convert into fp32 and reduce
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);

hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
Expand Down Expand Up @@ -917,14 +990,15 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri

// for some reason we need volatile here so that the compiler doesn't try anything funky
volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
const HVX_Vector kOne = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16

uint32_t i = 0;

for (i = 0; i < nv0; i++) {
HVX_VectorPair yp = vy[i];

HVX_Vector x = vx[i];
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), kOne); // mul by 1.0

HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
Expand All @@ -937,7 +1011,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
HVX_VectorPair yp = vy[i];

HVX_Vector x = vx[i];
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), kOne); // mul by 1.0

if (nv1 >= 32) {
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
Expand Down Expand Up @@ -2199,7 +2273,7 @@ int op_matmul_id(struct htp_ops_context * octx) {

assert(i02 >= 0 && i02 < n_as);

MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping){ id, iid1 };
matrix_row_counts[i02] += 1;
}
}
Expand Down
Loading