Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ggml : alternative Q4_3 implementation using modified Q8_0
  • Loading branch information
ggerganov committed Apr 22, 2023
commit 5425e06006846a7e81d58d34a407fa8ea7a2149f
87 changes: 53 additions & 34 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
#define QK8_0 32
typedef struct {
float d; // delta
float s; // d * sum(qs[i])
float s0; // d * sum(qs[i]) low
float s1; // d * sum(qs[i]) high
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");


// reference implementation for deterministic creation of model files
Expand Down Expand Up @@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r

y[i].d = d;

int sum = 0;
for (int l = 0; l < QK8_0; ++l) {
const float v = x[i*QK8_0 + l]*id;
y[i].qs[l] = roundf(v);
sum += y[i].qs[l];
int sum0 = 0;
int sum1 = 0;

for (int l = 0; l < QK8_0/2; ++l) {
const float v0 = x[i*QK8_0 + l]*id;
const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id;

y[i].qs[ l] = roundf(v0);
y[i].qs[QK8_0/2 + l] = roundf(v1);

sum0 += y[i].qs[ l];
sum1 += y[i].qs[QK8_0/2 + l];
}
y[i].s = d * sum;

y[i].s0 = d * sum0;
y[i].s1 = d * sum1;
}
}

Expand Down Expand Up @@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int

y[i].d = d;

int32x4_t accv = vdupq_n_s32(0);
int32x4_t accv0 = vdupq_n_s32(0);
int32x4_t accv1 = vdupq_n_s32(0);

for (int l = 0; l < 8; l++) {
// low half
for (int l = 0; l < 4; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const int32x4_t vi = vcvtnq_s32_f32(v);

y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);

accv0 = vaddq_s32(accv0, vi);
}

// high half
for (int l = 4; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const int32x4_t vi = vcvtnq_s32_f32(v);

Expand All @@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);

accv = vaddq_s32(accv, vi);
accv1 = vaddq_s32(accv1, vi);
}
int32_t sum = vaddvq_s32(accv);
y[i].s = d * sum;

const int32_t sum0 = vaddvq_s32(accv0);
const int32_t sum1 = vaddvq_s32(accv1);

y[i].s0 = d * sum0;
y[i].s1 = d * sum1;
}
#elif defined(__AVX2__) || defined(__AVX__)
// TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
Expand Down Expand Up @@ -2395,7 +2425,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];

sum8 += x0->d * y0->s + x1->d * y1->s;
sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1);

const uint8x16_t m4b = vdupq_n_u8(0xf);

Expand Down Expand Up @@ -2562,7 +2592,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];

summs += x0->m * y0->s + x1->m * y1->s;
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);

const uint8x16_t m4b = vdupq_n_u8(0xf);

Expand All @@ -2589,8 +2619,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *

#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
Expand Down Expand Up @@ -2845,6 +2875,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);

float summs = 0.0f;

for (int i = 0; i < nb; i += 2) {
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
Expand All @@ -2854,18 +2886,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = &y[i + 0];
const block_q8_0 * restrict y1 = &y[i + 1];

summs += GGML_FP16_TO_FP32(x0_0->m) * y0->s0 + GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
summs += GGML_FP16_TO_FP32(x1_0->m) * y1->s0 + GGML_FP16_TO_FP32(x1_1->m) * y1->s1;

const uint8x16_t m4b = vdupq_n_u8(0xf);

const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);

const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);

const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));

Expand All @@ -2887,17 +2917,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);

const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));

const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);

#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
Expand Down Expand Up @@ -2926,7 +2945,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
#endif
}

*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
Expand Down