Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
q4_0c: disable prefetching on M1
  • Loading branch information
unbounded committed May 4, 2023
commit d53f76760d7b067fd0cef67a994a90e662bdfb50
19 changes: 14 additions & 5 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1154,13 +1154,17 @@ static void quantize_row_q4_0c_reference(const float * restrict x, uint8_t * res
float id[2];
for (int j = 0; j < 2; j++) {
float amax = 0.0f; // absolute max
float max = 0.0f;

for (int l = 0; l < QK4_0; l++) {
const float v = xb[j][l];
amax = MAX(amax, fabsf(v));
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}

d[j] = amax / ((1 << 3) - 1);
d[j] = max / -8;
id[j] = d[j] ? 1.0f/d[j] : 0.0f;
}

Expand All @@ -1169,10 +1173,10 @@ static void quantize_row_q4_0c_reference(const float * restrict x, uint8_t * res

for (int l = 0; l < QK4_0; l++) {
const float v0 = xb[0][l]*id[0];
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);

const float v1 = xb[1][l]*id[1];
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);

assert(vi0 < 16);
assert(vi1 < 16);
Expand Down Expand Up @@ -3126,16 +3130,19 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
float sumf = 0.0;

#if defined(__ARM_NEON)
const int ahead=80;
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

for (int i = 0; i < nb/2; i++) {
// Disable prefetching on M1 for now.
#ifndef __APPLE__
const int ahead=80;
__builtin_prefetch(&xqs[i*QK4_0 + 64*ahead]);
__builtin_prefetch(&yqs[2*i*QK8_0C + 64*ahead]);
__builtin_prefetch(&yqs[2*i*QK8_0C + 64*ahead + 64]);
__builtin_prefetch(&xds[2*i + 64/4*ahead]);
__builtin_prefetch(&yds[2*i + 64/4*ahead]);
#endif

const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
const int dst1 = i + i/2*2 + 2; // 2, 3, 6, 7, 10, 11 ...
Expand Down Expand Up @@ -9738,11 +9745,13 @@ static void ggml_compute_forward_alibi(
ggml_compute_forward_alibi_f32(params, src0, src1, dst);
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_0C:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_0C:
case GGML_TYPE_Q8_1:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
Expand Down