Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rearrange Q4_1 quantization to work for multipart models. (Fix #152)
  • Loading branch information
blackhole89 committed Mar 16, 2023
commit a2e9d4951bbbdd180e8dbc60ce7b1bbfcf5a423f
67 changes: 37 additions & 30 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0);

const int nb = k / QK;
const size_t bs = 2*sizeof(float) + QK/2;

float * restrict pm = (float *) (y);
float * restrict pd = (float *) (pm + nb);
uint8_t * restrict pb = (uint8_t *) (pd + nb);
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float));
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));

uint8_t pp[QK/2];

Expand All @@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;

pm[i] = min;
pd[i] = d;
*(float *)pm = min;
*(float *)pd = d;
pm += bs;
pd += bs;

for (int l = 0; l < QK; l += 2) {
const float v0 = (x[i*QK + l + 0] - min)*id;
Expand All @@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
pp[l/2] = vi0 | (vi1 << 4);
}

memcpy(pb + i*QK/2, pp, sizeof(pp));
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
}

Expand Down Expand Up @@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
assert(k % QK == 0);

const int nb = k / QK;
const size_t bs = 2*sizeof(float) + QK/2;

const float * restrict pm = (const float *) (x);
const float * restrict pd = (const float *) (pm + nb);
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));

for (int i = 0; i < nb; i++) {
const float m = pm[i];
const float d = pd[i];
const float d = *(const float *) (pd + i*bs);
const float m = *(const float *) (pm + i*bs);

const uint8_t * restrict pp = pb + i*QK/2;
const uint8_t * restrict pp = pb + i*bs;

for (int l = 0; l < QK; l += 2) {
const uint8_t vi = pp[l/2];
Expand Down Expand Up @@ -1584,14 +1589,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
const int nb = n / QK;

const float * restrict pm0 = (const float *) x;
const float * restrict pm1 = (const float *) y;
const size_t bs = 2*sizeof(float) + QK/2;

const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);

const float * restrict pd0 = (const float *) (pm0 + nb);
const float * restrict pd1 = (const float *) (pm1 + nb);
const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));

const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));

float sumf = 0.0;

Expand All @@ -1604,14 +1611,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void

// Main loop
for (int i = 0; i < nb; ++i) {
const float * m0 = (const float *) (pm0 + i);
const float * m1 = (const float *) (pm1 + i);
const float * m0 = (const float *) (pm0 + i*bs);
const float * m1 = (const float *) (pm1 + i*bs);

const float * d0 = (const float *) (pd0 + i);
const float * d1 = (const float *) (pd1 + i);
const float * d0 = (const float *) (pd0 + i*bs);
const float * d1 = (const float *) (pd1 + i*bs);

const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;

const __m256 d0v = _mm256_broadcast_ss( d0 );
const __m256 d1v = _mm256_broadcast_ss( d1 );
Expand Down Expand Up @@ -1677,14 +1684,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
#else
// scalar
for (int i = 0; i < nb; i++) {
const float m0 = pm0[i];
const float m1 = pm1[i];
const float * m0 = (const float *) (pm0 + i*bs);
const float * m1 = (const float *) (pm1 + i*bs);

const float d0 = pd0[i];
const float d1 = pd1[i];
const float * d0 = (const float *) (pd0 + i*bs);
const float * d1 = (const float *) (pd1 + i*bs);

const uint8_t * restrict p0 = pb0 + i*QK/2;
const uint8_t * restrict p1 = pb1 + i*QK/2;
const uint8_t * restrict p0 = pb0 + i*bs;
const uint8_t * restrict p1 = pb1 + i*bs;

for (int j = 0; j < QK/2; j++) {
const uint8_t v0 = p0[j];
Expand Down
20 changes: 12 additions & 8 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t

size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
const int nb = k / qk;
const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*qk/2);
const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2);
const size_t row_size = nb*bs;

assert(k % qk == 0);

Expand All @@ -498,10 +499,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t

char * pdst = (char *) dst;

for (int j = 0; j < n; j += k) {
float * pm = (float *) (pdst + (j/k)*row_size);
float * pd = (float *) (pm + nb);
uint8_t * pb = (uint8_t *) (pd + nb);
for (int j = 0; j < n; j += k) {
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));

//printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);

Expand All @@ -519,8 +520,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;

pm[i] = min;
pd[i] = d;
*(float *) pd = d;
*(float *) pm = min;
pd += bs;
pm += bs;

for (int l = 0; l < qk; l += 2) {
const float v0 = (src[j + i*qk + l + 0] - min)*id;
Expand All @@ -538,7 +541,8 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
pp[l/2] = vi0 | (vi1 << 4);
}

memcpy(pb + i*qk/2, pp, pp_size);
memcpy(pb, pp, pp_size);
pb += bs;
}
}
}
Expand Down