Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Another minor improvement
  • Loading branch information
Iwan Kawrakow committed Aug 22, 2023
commit e9f1340c20952713895028c9d76b71abbf999735
28 changes: 17 additions & 11 deletions k_quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
}

static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
uint8_t * restrict Laux) {
uint8_t * restrict Laux, bool use_mad) {
float min = x[0];
float max = x[0];
float sum_x = 0, sum_x2 = 0;
Expand All @@ -288,13 +288,17 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t
float num = sum_x2 * n - sum_x * sum_x * n / (n-1);
float iscale = nmax/(max - min);
float scale = 1/iscale;
float best_mse = 0;
float best_mad = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
L[i] = MAX(0, MIN(nmax, l));
float diff = scale * L[i] + min - x[i];
float w = x[i] * x[i];
best_mse += w * diff * diff;
if (use_mad) {
best_mad += w * fabsf(diff);
} else {
best_mad += w * diff * diff;
}
}
if (num <= 0) {
*the_min = -min;
Expand All @@ -318,17 +322,21 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, uint8_t
this_min = 0;
this_scale = sqrtf(sum_x2 / sum_l2);
}
float mse = 0;
float mad = 0;
for (int i = 0; i < n; ++i) {
float diff = this_scale * Laux[i] + this_min - x[i];
float w = x[i] * x[i];
mse += w * diff * diff;
if (use_mad) {
mad += w * fabsf(diff);
} else {
mad += w * diff * diff;
}
}
if (mse < best_mse) {
if (mad < best_mad) {
for (int i = 0; i < n; ++i) {
L[i] = Laux[i];
}
best_mse = mse;
best_mad = mad;
scale = this_scale;
min = this_min;
}
Expand Down Expand Up @@ -368,7 +376,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
float max_min = 0;
for (int j = 0; j < QK_K/16; ++j) {
//scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5, 0.f);
scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux);
scales[j] = make_qkx2_quants(16, 3, x + 16*j, L + 16*j, &mins[j], Laux, false);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
Expand Down Expand Up @@ -724,7 +732,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
//scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux);
scales[j] = make_qkx2_quants(32, 15, x + 32*j, L + 32*j, &mins[j], Laux, true);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
Expand Down Expand Up @@ -875,7 +883,6 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict

#if QK_K == 256
uint8_t L[QK_K];
//uint8_t Laux[32];
float mins[QK_K/32];
float scales[QK_K/32];
#else
Expand All @@ -891,7 +898,6 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
//scales[j] = make_qkx2_quants(32, 31, x + 32*j, L + 32*j, &mins[j], Laux);
float scale = scales[j];
if (scale > max_scale) {
max_scale = scale;
Expand Down