Skip to content

Commit d15cf56

Browse files
committed
ggml-quants : use k_sort with IQ2_S, IQ2_XS, and IQ2_XXS
* ggml-quants : use a better distance function in k_sort iq2 neighbour search In practice it seems like the previously-used formula works quite well.
1 parent d13319b commit d15cf56

File tree

1 file changed

+73
-70
lines changed

1 file changed

+73
-70
lines changed

ggml/src/ggml-quants.c

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4324,7 +4324,7 @@ void iq2xs_init_impl(enum ggml_type type) {
43244324
};
43254325

43264326
static const int8_t kvalues_iq2[3] = { 0x08, 0x19, 0x2b };
4327-
static const int8_t kvalues_iq1[3] = { -1, 0, 1 };
4327+
static const int8_t kvalues_iq1[3] = { -8, 0, 8 };
43284328

43294329
// alternatively, this could be 0xAAAA = 43690, but that would be much bigger unnecessarily.
43304330
const int kmap_size = pow3[8]; // 3**8 = 6561
@@ -4391,7 +4391,7 @@ void iq2xs_init_impl(enum ggml_type type) {
43914391
const int8_t * pg = (const int8_t *)(grid + j);
43924392
int32_t d2 = 0;
43934393
for (int k = 0; k < 8; ++k) {
4394-
const int32_t d = p[k] - pg[k];
4394+
const int32_t d = pg[k] - p[k];
43954395
d2 += d * d;
43964396
}
43974397
if (d2 < min_d2) {
@@ -4424,7 +4424,7 @@ void iq2xs_init_impl(enum ggml_type type) {
44244424
const int8_t * pg = (const int8_t *)(grid + j);
44254425
int32_t d2 = 0;
44264426
for (int k = 0; k < 8; ++k) {
4427-
const int32_t d = p[k] - pg[k];
4427+
const int32_t d = pg[k] - p[k];
44284428
d2 += d * d;
44294429
}
44304430
if (d2 < min_d2) {
@@ -4438,6 +4438,7 @@ void iq2xs_init_impl(enum ggml_type type) {
44384438
}
44394439
}
44404440
GGML_ASSERT(kmap[i] < 0);
4441+
// reserve -1 for when there is no neighbour
44414442
kmap[i] = -(offset) - 2;
44424443
offset += min_count + 1;
44434444
GGML_ASSERT(min_count == kneighbour_counts[i]);
@@ -4470,7 +4471,8 @@ static int iq2_find_relative_neighbour(const struct k_sort * k_sort,
44704471
const float * GGML_RESTRICT weight,
44714472
const int8_t * GGML_RESTRICT L,
44724473
float * GGML_RESTRICT sumqx,
4473-
float * GGML_RESTRICT sumq2) {
4474+
float * GGML_RESTRICT sumq2,
4475+
int grid_offset) {
44744476
const int pow3[9] = { 1, 3, 9, 27, 81, 243, 729, 2187, 6561 };
44754477
int index = 0;
44764478
int8_t p[8];
@@ -4484,30 +4486,38 @@ static int iq2_find_relative_neighbour(const struct k_sort * k_sort,
44844486

44854487
float sumqx_new = 0.0f;
44864488
float sumq2_new = 0.0f;
4487-
float best = -1.0f;
4488-
float best_denom = 1.0f;
4489+
float best_d2 = FLT_MAX;
44894490

44904491
if (grid_index < -1) {
44914492
const uint16_t * neighbours = kneighbours - (grid_index + 2);
44924493
const int num_neighbours = neighbours[0];
44934494

4495+
float prev_sumqx = 0.0f;
4496+
float prev_sumq2 = 0.0f;
4497+
float waux[8];
4498+
for (int k = 0; k < 8; ++k) {
4499+
prev_sumqx += weight[k] * (xval[k] * p[k]);
4500+
prev_sumq2 += weight[k] * (p[k] * p[k]);
4501+
waux[k] = sqrtf(weight[k]);
4502+
}
4503+
const float prev_scale = prev_sumq2 > 0.0f ? prev_sumqx / prev_sumq2 : 0.0f;
4504+
44944505
for (int i = 1; i <= num_neighbours; ++i) {
44954506
const int8_t * pg = (const int8_t *)(grid + neighbours[i]);
44964507
float this_sumqx = 0.0f;
44974508
float this_sumq2 = 0.0f;
4509+
float d2 = 0.0f;
44984510
for (int k = 0; k < 8; ++k) {
4499-
const float odd = pg[k] + p[k];
4500-
const float step = pg[k] - p[k];
4511+
const float odd = (grid_offset + pg[k]) + p[k];
4512+
const float step = (grid_offset + pg[k]) - p[k];
4513+
const float diff = prev_scale * (grid_offset + pg[k]) - xval[k];
45014514
this_sumqx += weight[k] * (xval[k] * step);
45024515
this_sumq2 += weight[k] * (odd * step);
4516+
d2 += waux[k] * diff * diff;
45034517
}
45044518

4505-
const float total_sumqx = this_sumqx + (*sumqx);
4506-
const float total_sumq2 = this_sumq2 + (*sumq2);
4507-
const float current = total_sumqx * total_sumq2;
4508-
if (total_sumq2 > 0.0f && current * best_denom > best * total_sumq2) {
4509-
best = current;
4510-
best_denom = total_sumq2;
4519+
if (d2 < best_d2) {
4520+
best_d2 = d2;
45114521
sumqx_new = this_sumqx;
45124522
sumq2_new = this_sumq2;
45134523
grid_index = neighbours[i];
@@ -4589,7 +4599,7 @@ static float make_iq2_quants(int n, struct k_sort * k_sort, const uint64_t * gri
45894599
best_sumqx = sumqx;
45904600
best_sumq2 = sumq2;
45914601
} else {
4592-
best = -1.0f;
4602+
best = 0.0f;
45934603
best_sumqx = 0.0f;
45944604
best_sumq2 = 1.0f;
45954605
}
@@ -4604,27 +4614,20 @@ static float make_iq2_quants(int n, struct k_sort * k_sort, const uint64_t * gri
46044614
sumq2 += w * (odd * step);
46054615
Laux[ii] = k_i;
46064616

4617+
const int grid_index = iq2_find_relative_neighbour(k_sort, grid, kmap, kneighbours, xval + 8*g_i, weight + 8*g_i, Laux + 8*g_i, sumqx_aux + g_i, sumq2_aux + g_i, 0);
4618+
4619+
if (grid_index == grid_idx_aux[g_i]) { continue; }
4620+
if (grid_index < 0) { break; }
4621+
4622+
grid_idx_aux[g_i] = grid_index;
4623+
46074624
// avoid subtraction numerical instabilities by having relative sumqx and sumq2 per grid index
46084625
float sumqx_cur = sumqx;
46094626
float sumq2_cur = sumq2;
4610-
sumqx_aux[g_i] = 0.0f;
4611-
sumq2_aux[g_i] = 0.0f;
46124627
for (int j = 0; j < n_idx; ++j) {
46134628
sumqx_cur += sumqx_aux[j];
46144629
sumq2_cur += sumq2_aux[j];
46154630
}
4616-
sumqx_aux[g_i] = sumqx_cur;
4617-
sumq2_aux[g_i] = sumq2_cur;
4618-
4619-
const int grid_index = iq2_find_relative_neighbour(k_sort, grid, kmap, kneighbours, xval + 8*g_i, weight + 8*g_i, Laux + 8*g_i, sumqx_aux + g_i, sumq2_aux + g_i);
4620-
4621-
sumqx_cur += sumqx_aux[g_i];
4622-
sumq2_cur += sumq2_aux[g_i];
4623-
4624-
if (grid_index == grid_idx_aux[g_i]) { continue; }
4625-
if (grid_index < 0) { break; }
4626-
4627-
grid_idx_aux[g_i] = grid_index;
46284631

46294632
const float current = sumqx_cur * sumqx_cur;
46304633
if (sumq2_cur > 0.0f && current * best_sumq2 > best * sumq2_cur) {
@@ -4655,12 +4658,18 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
46554658
GGML_ASSERT(n%QK_K == 0);
46564659

46574660
const int8_t k_values_iq2xxs[3] = { 0x08, 0x19, 0x2b };
4661+
// the quantized scales are in {0.125, 0.375, 0.625, ... }
4662+
// which are the odd numbers divided by 8
4663+
const int8_t k_values_iq2xxs_s[16] = { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 };
46584664

46594665
const int64_t nbl = n/QK_K;
46604666

46614667
block_iq2_xxs * y = vy;
46624668

46634669
float scales[QK_K/32];
4670+
float sw[QK_K/32];
4671+
int8_t Ls[QK_K/32];
4672+
int8_t Lsaux[QK_K/32];
46644673
float weight[32];
46654674
float xval[32];
46664675
int8_t Laux[32];
@@ -4672,29 +4681,37 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
46724681
uint32_t q2[2*(QK_K/32)];
46734682
struct k_sort k_sort;
46744683
uint8_t buf[K_SORT_BUF_SIZE_NL(32, 3, 2)];
4684+
struct k_sort k_sort_s;
4685+
uint8_t buf_s[K_SORT_BUF_SIZE_NL(QK_K/32, 16, 15)];
46754686

46764687
k_sort_init(&k_sort, 32, 3, k_values_iq2xxs, buf);
4688+
k_sort_init(&k_sort_s, QK_K/32, 16, k_values_iq2xxs_s, buf_s);
46774689

46784690
for (int ibl = 0; ibl < nbl; ++ibl) {
46794691

46804692
y[ibl].d = GGML_FP32_TO_FP16(0.f);
46814693
memset(q2, 0, QK_K/4);
46824694

4683-
float max_scale = 0;
4684-
46854695
const float * xbl = x + QK_K*ibl;
46864696
float sumx2 = 0;
46874697
for (int i = 0; i < QK_K; ++i) {
46884698
sumx2 += xbl[i]*xbl[i];
46894699
}
4690-
float sigma2 = sumx2/QK_K;
4700+
float sigma2 = 2*sumx2/QK_K;
46914701

46924702
for (int ib = 0; ib < QK_K/32; ++ib) {
46934703
const float * xb = xbl + 32*ib;
46944704
const float * qw = quant_weights + QK_K*ibl + 32*ib;
46954705
for (int i = 0; i < 32; ++i) {
46964706
weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
46974707
}
4708+
{
4709+
float sumw = 0.0f;
4710+
for (int i = 0; i < 32; ++i) {
4711+
sumw += weight[i];
4712+
}
4713+
sw[ib] = sumw;
4714+
}
46984715
for (int k = 0; k < 4; ++k) {
46994716
int nflip = 0;
47004717
uint8_t s = 0;
@@ -4734,21 +4751,12 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
47344751
}
47354752
GGML_ASSERT(scale >= 0);
47364753
scales[ib] = scale;
4737-
max_scale = MAX(max_scale, scale);
47384754
}
47394755

4740-
if (!max_scale) {
4741-
memset(y[ibl].qs, 0, QK_K/4);
4742-
continue;
4743-
}
4744-
4745-
// TODO: use make_qkxs_quants here
4746-
float d = max_scale/31;
4747-
y[ibl].d = GGML_FP32_TO_FP16(d);
4748-
float id = 1/d;
4756+
const float d = make_qkxs_nl_quants(QK_K/32, scales, sw, Ls, Lsaux, &k_sort_s, false, true);
4757+
y[ibl].d = GGML_FP32_TO_FP16(d * 8.0f);
47494758
for (int ib = 0; ib < QK_K/32; ++ib) {
4750-
int l = nearest_int(0.5f*(id*scales[ib]-1));
4751-
l = MAX(0, MIN(15, l));
4759+
const uint8_t l = Ls[ib];
47524760
q2[2*ib+1] |= ((uint32_t)l << 28);
47534761
}
47544762
memcpy(y[ibl].qs, q2, QK_K/4);
@@ -4770,12 +4778,18 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
47704778
GGML_ASSERT(n%QK_K == 0);
47714779

47724780
const int8_t k_values_iq2xs[3] = { 0x08, 0x19, 0x2b };
4781+
// the quantized scales are in {0.125, 0.375, 0.625, ... }
4782+
// which are the odd numbers divided by 8
4783+
const int8_t k_values_iq2xs_s[16] = { 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 };
47734784

47744785
const int64_t nbl = n/QK_K;
47754786

47764787
block_iq2_xs * y = vy;
47774788

47784789
float scales[QK_K/16];
4790+
float sw[QK_K/16];
4791+
int8_t Ls[QK_K/16];
4792+
int8_t Lsaux[QK_K/16];
47794793
float weight[16];
47804794
float xval[16];
47814795
int8_t Laux[16];
@@ -4787,30 +4801,38 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
47874801
uint16_t q2[2*(QK_K/16)];
47884802
struct k_sort k_sort;
47894803
uint8_t buf[K_SORT_BUF_SIZE_NL(16, 3, 2)];
4804+
struct k_sort k_sort_s;
4805+
uint8_t buf_s[K_SORT_BUF_SIZE_NL(QK_K/16, 16, 15)];
47904806

47914807
k_sort_init(&k_sort, 16, 3, k_values_iq2xs, buf);
4808+
k_sort_init(&k_sort_s, QK_K/16, 16, k_values_iq2xs_s, buf_s);
47924809

47934810
for (int ibl = 0; ibl < nbl; ++ibl) {
47944811

47954812
y[ibl].d = GGML_FP32_TO_FP16(0.f);
47964813
memset(q2, 0, QK_K/4);
47974814
memset(y[ibl].scales, 0, QK_K/32);
47984815

4799-
float max_scale = 0;
4800-
48014816
const float * xbl = x + QK_K*ibl;
48024817
float sumx2 = 0;
48034818
for (int i = 0; i < QK_K; ++i) {
48044819
sumx2 += xbl[i]*xbl[i];
48054820
}
4806-
float sigma2 = sumx2/QK_K;
4821+
float sigma2 = 2*sumx2/QK_K;
48074822

48084823
for (int ib = 0; ib < QK_K/16; ++ib) {
48094824
const float * xb = xbl + 16*ib;
48104825
const float * qw = quant_weights + QK_K*ibl + 16*ib;
48114826
for (int i = 0; i < 16; ++i) {
48124827
weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
48134828
}
4829+
{
4830+
float sumw = 0.0f;
4831+
for (int i = 0; i < 16; ++i) {
4832+
sumw += weight[i];
4833+
}
4834+
sw[ib] = sumw;
4835+
}
48144836
for (int k = 0; k < 2; ++k) {
48154837
int nflip = 0;
48164838
uint8_t s = 0;
@@ -4845,23 +4867,14 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
48454867
const int grid_index = grid_idx[k];
48464868
q2[2*ib+k] = grid_index | (block_signs[k] << 9);
48474869
}
4848-
GGML_ASSERT(scale >= 0);
4870+
GGML_ASSERT(scale >= 0.0f);
48494871
scales[ib] = scale;
4850-
max_scale = MAX(max_scale, scale);
4851-
}
4852-
4853-
if (!max_scale) {
4854-
memset(y[ibl].qs, 0, QK_K/4);
4855-
continue;
48564872
}
48574873

4858-
// TODO: maybe use make_qkxs_quants here?
4859-
float d = max_scale/31;
4860-
y[ibl].d = GGML_FP32_TO_FP16(d);
4861-
float id = 1/d;
4874+
const float d = make_qkxs_nl_quants(QK_K/16, scales, sw, Ls, Lsaux, &k_sort_s, false, true);
4875+
y[ibl].d = GGML_FP32_TO_FP16(d * 8.0f);
48624876
for (int ib = 0; ib < QK_K/16; ++ib) {
4863-
int l = nearest_int(0.5f*(id*scales[ib]-1));
4864-
l = MAX(0, MIN(15, l));
4877+
const uint8_t l = Ls[ib];
48654878
if (ib % 2 == 0) {
48664879
y[ibl].scales[ib / 2] = l;
48674880
} else {
@@ -6289,8 +6302,6 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R
62896302
memset(&y[ibl], 0, sizeof(block_iq2_s));
62906303
y[ibl].d = GGML_FP32_TO_FP16(0.f);
62916304

6292-
float max_scale = 0;
6293-
62946305
const float * xbl = x + QK_K*ibl;
62956306
float sumx2 = 0;
62966307
for (int i = 0; i < QK_K; ++i) {
@@ -6341,19 +6352,11 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R
63416352

63426353
GGML_ASSERT(scale >= 0);
63436354
scales[ib] = scale;
6344-
max_scale = MAX(max_scale, scale);
6345-
}
6346-
6347-
if (!max_scale) {
6348-
continue;
63496355
}
63506356

63516357
const float d = make_qkxs_nl_quants(QK_K/16, scales, sw, Ls, Lsaux, &k_sort_s, false, true);
63526358
y[ibl].d = GGML_FP32_TO_FP16(d * 8.0f);
6353-
// float id = 1/d;
63546359
for (int ib = 0; ib < QK_K/16; ++ib) {
6355-
// int l = nearest_int(0.5f*(8*id*scales[ib]-1));
6356-
// l = MAX(0, MIN(15, l));
63576360
const uint8_t l = Ls[ib];
63586361
if (ib % 2 == 0) {
63596362
y[ibl].scales[ib / 2] = l;

0 commit comments

Comments
 (0)