Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d2f12ac
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
9fe2a2b
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
1f6195c
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
aebd547
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
2b2ab31
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
bcf8c5c
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
c6c3536
k_quants: WIP super-blocks with 64 weights
Jun 21, 2023
5aae4b8
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
41e46ec
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
460dd84
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
3bd9ae7
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
03f30c8
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
cda47a6
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
80c75fe
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
2b2a13c
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
9d27d8d
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
2ff543c
k_quants: WIP super-blocks with 64 weights
Jun 22, 2023
d92c5a9
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
fae24af
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
e1bbcfc
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
167a0bb
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
6081a65
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
ff83e32
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
285eeb1
k_quants: WIP super-blocks with 64 weights
Jun 23, 2023
8b98d01
k_quants: call them _K, not _k, also on Metal
Jun 23, 2023
558a194
k_quants: correctly define QK_K in llama.cpp
Jun 23, 2023
333ffcc
Fixed bug in q4_K quantization added with the 64-block addition
Jun 23, 2023
88412a1
Simplify via lambda
Jun 23, 2023
aeefd4e
k_quants: swicth Q3_K to 4-bit scales when QK_K = 64
Jun 24, 2023
ce19b96
k_quants: switch Q4_K to 4-bit scales when QK_K = 64
Jun 24, 2023
4f61506
k_quants: forgot to add the Metal changes in last commit
Jun 24, 2023
ccf4901
k_quants: change Q5_K to be type 0 when QK_K = 64
Jun 24, 2023
2da3a59
k_quants: AVX2 implementation for new 64-weight Q5_K
Jun 24, 2023
53e81ca
k_quants: 10% faster ARM_NEON Q5_K dot product
Jun 24, 2023
5fd8337
k_quants: fixed issue caused by merging with master
Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
k_quants: WIP super-blocks with 64 weights
Q3_K works on Metal and is slightly faster
than QK_K = 256 (26.6 ms vs 28.3 ms).
  • Loading branch information
Iwan Kawrakow committed Jun 26, 2023
commit ff83e32c6ab4b6500625df347750c896cffe4bb5
88 changes: 77 additions & 11 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -798,10 +798,13 @@ typedef struct {
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
half d; // super-block scale
#if QK_K == 64
int8_t scales[K_SCALE_SIZE];
#else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif
half d; // super-block scale
} block_q3_k;
// 110 bytes / block

#if QK_K == 64
typedef struct {
Expand Down Expand Up @@ -903,6 +906,8 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
assert(k % QK_K == 0);
const int nb = k / QK_K;

#if QK_K == 256

const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;

Expand Down Expand Up @@ -948,8 +953,34 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
}
q += 32;
}
}
#else
for (int i = 0; i < nb; i++) {

const float d_all = (float)(x[i].d);

device const uint8_t * q = x[i].qs;
device const uint8_t * hm = x[i].hmask;

const float d1 = d_all * x[i].scales[0];
const float d2 = d_all * x[i].scales[1];
const float d3 = d_all * x[i].scales[2];
const float d4 = d_all * x[i].scales[3];

for (int l = 0; l < 8; ++l) {
uint8_t h = hm[l];
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
}
y += QK_K;
}
#endif

}

Expand Down Expand Up @@ -1286,6 +1317,8 @@ kernel void kernel_mul_mat_q3_k_f32(
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;

#if QK_K == 256

const int tid = tpitg.y; // expecting 16
const int ip = tid/8; // 0 or 1
const int il = tid/2 - 4*ip; // 0...3
Expand Down Expand Up @@ -1339,6 +1372,39 @@ kernel void kernel_mul_mat_q3_k_f32(

//sum[ith] = sumf;
sum[ith] = sumf1 - 32.f*sumf2;
#else
const int il = 4 * tpitg.x; // 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1
const int in = il%8; // 0, 4, 0, 4

float sumf = 0;

for (int i = tpitg.y; i < nb; i += tptg.y) {

const float d_all = (float)(x[i].d);

device const uint8_t * q = x[i].qs + il;
device const uint8_t * h = x[i].hmask + in;
device const float * y = yy + i * QK_K + il;

const float d1 = d_all * x[i].scales[0];
const float d2 = d_all * x[i].scales[1];
const float d3 = d_all * x[i].scales[2];
const float d4 = d_all * x[i].scales[3];

for (int l = 0; l < 4; ++l) {
const uint8_t hm = h[l] >> im;
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
}

}

sum[ith] = sumf;

#endif

//
// Accumulate the sum from all threads in the threadgroup
Expand Down Expand Up @@ -1371,10 +1437,6 @@ kernel void kernel_mul_mat_q4_k_f32(
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {

const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;

const int nb = ne00/QK_K;

const int64_t r0 = tgpig.x;
Expand All @@ -1390,6 +1452,10 @@ kernel void kernel_mul_mat_q4_k_f32(

#if QK_K == 256

const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;

const int tid = tpitg.y; // 0...16
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
Expand Down Expand Up @@ -1660,10 +1726,10 @@ kernel void kernel_mul_mat_q6_k_f32(

float4 sums = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 4; ++l) {
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & 0x03) << 4)) - 32);
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & 0x0c) << 2)) - 32);
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & 0x30) >> 0)) - 32);
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & 0xc0) >> 2)) - 32);
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
}
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
}
Expand Down