@@ -692,13 +692,17 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
692692
693693 for (int i = 0 ; i < nb ; i ++ ) {
694694 float amax = 0.0f ; // absolute max
695+ float max = 0.0f ;
695696
696697 for (int l = 0 ; l < QK4_0 ; l ++ ) {
697698 const float v = x [i * QK4_0 + l ];
698- amax = MAX (amax , fabsf (v ));
699+ if (amax < fabsf (v )) {
700+ amax = fabsf (v );
701+ max = v ;
702+ }
699703 }
700704
701- const float d = amax / (( 1 << 3 ) - 1 ) ;
705+ const float d = max / -8 ;
702706 const float id = d ? 1.0f /d : 0.0f ;
703707
704708 y [i ].d = d ;
@@ -707,8 +711,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
707711 const float v0 = x [i * QK4_0 + l + 0 ]* id ;
708712 const float v1 = x [i * QK4_0 + l + 1 ]* id ;
709713
710- const uint8_t vi0 = ( int8_t )roundf (v0 ) + 8 ;
711- const uint8_t vi1 = ( int8_t )roundf (v1 ) + 8 ;
714+ const uint8_t vi0 = MIN ( 15 , ( int8_t )roundf (v0 ) + 8 ) ;
715+ const uint8_t vi1 = MIN ( 15 , ( int8_t )roundf (v1 ) + 8 ) ;
712716
713717 assert (vi0 < 16 );
714718 assert (vi1 < 16 );
@@ -728,28 +732,42 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
728732
729733#if defined(__POWER9_VECTOR__ )
730734 const vector float v85 = vec_splats (8.5f );
735+ const vector signed int v15 = vec_splats (15 );
731736 for (int i = 0 ; i < nb ; i ++ ) {
732- float amax = 0.0f ; // absolute max
737+ float max = 0.0f ;
738+ float min = 0.0f ;
733739
734740 vector float srcv [8 ];
735- vector float asrcv [8 ];
736- vector float amaxv [8 ];
741+ vector float maxv [8 ];
742+ vector float minv [8 ];
737743
738744 for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = * (vector float * )(x + i * 32 + 4 * l );
739- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vec_abs (srcv [l ]);
740-
741- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
742- //for (int l = 0; l < 2; l++) amaxv[4*l] = vec_max(amaxv[4*l], amaxv[4*l+2]);
743- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [2 ]);
744- amaxv [4 ] = vec_max (amaxv [4 ], amaxv [6 ]);
745- //for (int l = 0; l < 1; l++) amaxv[8*l] = vec_max(amaxv[8*l], amaxv[8*l+4]);
746- amaxv [0 ] = vec_max (amaxv [0 ], amaxv [4 ]);
747-
748- amax = MAX (
749- MAX (vec_extract (amaxv [0 ], 0 ), vec_extract (amaxv [0 ], 1 )),
750- MAX (vec_extract (amaxv [0 ], 2 ), vec_extract (amaxv [0 ], 3 )));
751-
752- const float d = amax / ((1 << 3 ) - 1 );
745+ //for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
746+
747+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vec_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
748+ //for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
749+ maxv [0 ] = vec_max (maxv [0 ], maxv [2 ]);
750+ maxv [4 ] = vec_max (maxv [4 ], maxv [6 ]);
751+ //for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
752+ maxv [0 ] = vec_max (maxv [0 ], maxv [4 ]);
753+
754+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vec_min (asrcv [2 * l ], asrcv [2 * l + 1 ]);
755+ //for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
756+ minv [0 ] = vec_min (minv [0 ], minv [2 ]);
757+ minv [4 ] = vec_min (minv [4 ], minv [6 ]);
758+ //for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
759+ minv [0 ] = vec_min (minv [0 ], minv [4 ]);
760+
761+
762+ max = MAX (
763+ MAX (vec_extract (maxv [0 ], 0 ), vec_extract (maxv [0 ], 1 )),
764+ MAX (vec_extract (maxv [0 ], 2 ), vec_extract (maxv [0 ], 3 )));
765+ min = MIN (
766+ MIN (vec_extract (minv [0 ], 0 ), vec_extract (minv [0 ], 1 )),
767+ MIN (vec_extract (minv [0 ], 2 ), vec_extract (minv [0 ], 3 )));
768+
769+ const float magnitude = max >= fabsf (min ) ? max : min ;
770+ const float d = magnitude / -8 ;
753771 const float id = d ? 1.0 /d : 0.0 ;
754772
755773 y [i ].d = d ;
@@ -759,27 +777,33 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
759777 for (int l = 0 ; l < 8 ; l ++ ) {
760778 const vector float vf = vec_madd (srcv [l ], vid , v85 );
761779 const vector signed int vi = vec_signed (vf );
780+ const vector signed int vc = vec_min (vi , v15 );
762781
763- pb [2 * l + 0 ] = vec_extract (vi , 0 ) | (vec_extract (vi , 1 ) << 4 );
764- pb [2 * l + 1 ] = vec_extract (vi , 2 ) | (vec_extract (vi , 3 ) << 4 );
782+ pb [2 * l + 0 ] = vec_extract (vc , 0 ) | (vec_extract (vc , 1 ) << 4 );
783+ pb [2 * l + 1 ] = vec_extract (vc , 2 ) | (vec_extract (vc , 3 ) << 4 );
765784 }
766785 }
767786#elif __ARM_NEON
768787 for (int i = 0 ; i < nb ; i ++ ) {
769788 float32x4_t srcv [8 ];
770- float32x4_t asrcv [8 ];
771- float32x4_t amaxv [8 ];
789+ float32x4_t maxv [8 ];
790+ float32x4_t minv [8 ];
772791
773792 for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = vld1q_f32 (x + i * 32 + 4 * l );
774- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = vabsq_f32 (srcv [l ]);
775793
776- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = vmaxq_f32 (asrcv [2 * l ], asrcv [2 * l + 1 ]);
777- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = vmaxq_f32 (amaxv [4 * l ], amaxv [4 * l + 2 ]);
778- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = vmaxq_f32 (amaxv [8 * l ], amaxv [8 * l + 4 ]);
794+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = vmaxq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
795+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = vmaxq_f32 (maxv [4 * l ], maxv [4 * l + 2 ]);
796+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = vmaxq_f32 (maxv [8 * l ], maxv [8 * l + 4 ]);
779797
780- const float amax = vmaxvq_f32 (amaxv [0 ]);
798+ for (int l = 0 ; l < 4 ; l ++ ) minv [2 * l ] = vminq_f32 (srcv [2 * l ], srcv [2 * l + 1 ]);
799+ for (int l = 0 ; l < 2 ; l ++ ) minv [4 * l ] = vminq_f32 (minv [4 * l ], minv [4 * l + 2 ]);
800+ for (int l = 0 ; l < 1 ; l ++ ) minv [8 * l ] = vminq_f32 (minv [8 * l ], minv [8 * l + 4 ]);
781801
782- const float d = amax / ((1 << 3 ) - 1 );
802+ const float max = vmaxvq_f32 (maxv [0 ]);
803+ const float min = vminvq_f32 (minv [0 ]);
804+
805+ const float magnitude = max >= fabsf (min ) ? max : min ;
806+ const float d = magnitude / -8 ;
783807 const float id = d ? 1.0f /d : 0.0f ;
784808
785809 y [i ].d = d ;
@@ -788,9 +812,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
788812 const float32x4_t v = vmulq_n_f32 (srcv [l ], id );
789813 const float32x4_t vf = vaddq_f32 (v , vdupq_n_f32 (8.5f ));
790814 const int32x4_t vi = vcvtq_s32_f32 (vf );
815+ const int32x4_t vc = vminq_s32 (vi , vdupq_n_s32 (15 ));
791816
792- y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vi , 0 ) | (vgetq_lane_s32 (vi , 1 ) << 4 );
793- y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vi , 2 ) | (vgetq_lane_s32 (vi , 3 ) << 4 );
817+ y [i ].qs [2 * l + 0 ] = vgetq_lane_s32 (vc , 0 ) | (vgetq_lane_s32 (vc , 1 ) << 4 );
818+ y [i ].qs [2 * l + 1 ] = vgetq_lane_s32 (vc , 2 ) | (vgetq_lane_s32 (vc , 3 ) << 4 );
794819 }
795820 }
796821#elif defined(__AVX2__ )
@@ -802,22 +827,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
802827 __m256 v3 = _mm256_loadu_ps ( x + 24 );
803828 x += 32 ;
804829
805- // Compute max(abs(e)) for the block
806- const __m256 signBit = _mm256_set1_ps ( -0.0f );
807- __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
808- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
809- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
810- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
830+ // Compute max for the block
831+ __m256 max = _mm256_max_ps ( v0 , v1 );
832+ __m256 maxTmp = _mm256_max_ps ( v2 , v3 );
833+ max = _mm256_max_ps ( max , maxTmp );
811834
812- __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
835+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( max , 1 ), _mm256_castps256_ps128 ( max ) );
813836 max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
814837 max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
815838 const float maxScalar = _mm_cvtss_f32 ( max4 );
816839
840+ // Compute min for the block
841+ __m256 min = _mm256_min_ps ( v0 , v1 );
842+ __m256 minTmp = _mm256_min_ps ( v2 , v3 );
843+ min = _mm256_min_ps ( min , minTmp );
844+
845+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( min , 1 ), _mm256_castps256_ps128 ( min ) );
846+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
847+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
848+ const float minScalar = _mm_cvtss_f32 ( min4 );
849+
817850 // Quantize these floats
818- const float d = maxScalar / 7.0f ;
851+ const float magnitude = maxScalar >= fabsf (minScalar ) ? maxScalar : minScalar ;
852+ const float d = magnitude / -8.0f ;
819853 y [i ].d = d ;
820- const float id = ( maxScalar != 0.0f ) ? 7 .0f / maxScalar : 0.0f ;
854+ const float id = ( magnitude != 0.0f ) ? -8 .0f / magnitude : 0.0f ;
821855 const __m256 mul = _mm256_set1_ps ( id );
822856
823857 // Apply the multiplier
@@ -850,9 +884,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
850884 const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
851885 i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
852886
853- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
887+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
854888 const __m256i off = _mm256_set1_epi8 ( 8 );
855889 i0 = _mm256_add_epi8 ( i0 , off );
890+ const __m256i maxNibble = _mm256_set1_epi8 ( 15 );
891+ i0 = _mm256_min_epi8 ( i0 , maxNibble );
856892
857893 // Compress the vector into 4 bit/value, and store
858894 __m128i res = packNibbles ( i0 );
@@ -867,22 +903,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
867903 __m256 v3 = _mm256_loadu_ps ( x + 24 );
868904 x += 32 ;
869905
870- // Compute max(abs(e)) for the block
871- const __m256 signBit = _mm256_set1_ps ( -0.0f );
872- __m256 maxAbs = _mm256_andnot_ps ( signBit , v0 );
873- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v1 ) );
874- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v2 ) );
875- maxAbs = _mm256_max_ps ( maxAbs , _mm256_andnot_ps ( signBit , v3 ) );
906+ // Compute max for the block
907+ __m256 max = _mm256_max_ps ( v0 , v1 );
908+ __m256 maxTmp = _mm256_max_ps ( v2 , v3 );
909+ max = _mm256_max_ps ( max , maxTmp );
876910
877- __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( maxAbs , 1 ), _mm256_castps256_ps128 ( maxAbs ) );
911+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( max , 1 ), _mm256_castps256_ps128 ( max ) );
878912 max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
879913 max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
880914 const float maxScalar = _mm_cvtss_f32 ( max4 );
881915
916+ // Compute min for the block
917+ __m256 min = _mm256_min_ps ( v0 , v1 );
918+ __m256 minTmp = _mm256_min_ps ( v2 , v3 );
919+ min = _mm256_min_ps ( min , minTmp );
920+
921+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( min , 1 ), _mm256_castps256_ps128 ( min ) );
922+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
923+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
924+ const float minScalar = _mm_cvtss_f32 ( min4 );
925+
882926 // Quantize these floats
883- const float d = maxScalar / 7.0f ;
927+ const float magnitude = maxScalar >= fabsf (minScalar ) ? maxScalar : minScalar ;
928+ const float d = magnitude / -8.0f ;
884929 y [i ].d = d ;
885- const float id = ( maxScalar != 0.0f ) ? 7 .0f / maxScalar : 0.0f ;
930+ const float id = ( magnitude != 0.0f ) ? -8 .0f / magnitude : 0.0f ;
886931 const __m256 mul = _mm256_set1_ps ( id );
887932
888933 // Apply the multiplier
@@ -923,35 +968,46 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
923968 ni0 = _mm_packs_epi16 ( ni0 , ni2 );
924969 ni4 = _mm_packs_epi16 ( ni4 , ni6 );
925970
926- // Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
927- const __m128i off = _mm_set1_epi8 ( 8 );
971+ // Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
972+ const __m128i off = _mm_set1_epi8 ( 8 );
928973 ni0 = _mm_add_epi8 ( ni0 , off );
929974 ni4 = _mm_add_epi8 ( ni4 , off );
975+ const __m128i maxNibble = _mm_set1_epi8 ( 15 );
976+ ni0 = _mm_min_epi8 ( ni0 , maxNibble );
977+ ni4 = _mm_min_epi8 ( ni4 , maxNibble );
930978
931979 // Compress the vector into 4 bit/value, and store
932980 __m128i res = packNibbles ( ni0 , ni4 );
933981 _mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
934982 }
935983#elif defined(__wasm_simd128__ )
936984 for (int i = 0 ; i < nb ; i ++ ) {
937- float amax = 0.0f ; // absolute max
985+ float max = 0.0f ;
986+ float min = 0.0f ;
938987
939988 v128_t srcv [8 ];
940- v128_t asrcv [8 ];
941- v128_t amaxv [8 ];
989+ v128_t maxv [8 ];
990+ v128_t minv [8 ];
942991
943992 for (int l = 0 ; l < 8 ; l ++ ) srcv [l ] = wasm_v128_load (x + i * 32 + 4 * l );
944- for (int l = 0 ; l < 8 ; l ++ ) asrcv [l ] = wasm_f32x4_abs (srcv [l ]);
945993
946- for (int l = 0 ; l < 4 ; l ++ ) amaxv [2 * l ] = wasm_f32x4_max (asrcv [2 * l ], asrcv [2 * l + 1 ]);
947- for (int l = 0 ; l < 2 ; l ++ ) amaxv [4 * l ] = wasm_f32x4_max (amaxv [4 * l ], amaxv [4 * l + 2 ]);
948- for (int l = 0 ; l < 1 ; l ++ ) amaxv [8 * l ] = wasm_f32x4_max (amaxv [8 * l ], amaxv [8 * l + 4 ]);
994+ for (int l = 0 ; l < 4 ; l ++ ) maxv [2 * l ] = wasm_f32x4_max (srcv [2 * l ], srcv [2 * l + 1 ]);
995+ for (int l = 0 ; l < 2 ; l ++ ) maxv [4 * l ] = wasm_f32x4_max (maxv [4 * l ], maxv [4 * l + 2 ]);
996+ for (int l = 0 ; l < 1 ; l ++ ) maxv [8 * l ] = wasm_f32x4_max (maxv [8 * l ], maxv [8 * l + 4 ]);
949997
950- amax = MAX (
951- MAX ( wasm_f32x4_extract_lane ( amaxv [ 0 ], 0 ), wasm_f32x4_extract_lane ( amaxv [ 0 ], 1 )),
952- MAX ( wasm_f32x4_extract_lane ( amaxv [ 0 ], 2 ), wasm_f32x4_extract_lane ( amaxv [ 0 ], 3 )) );
998+ for ( int l = 0 ; l < 4 ; l ++ ) minv [ 2 * l ] = wasm_f32x4_min ( srcv [ 2 * l ], srcv [ 2 * l + 1 ]);
999+ for ( int l = 0 ; l < 2 ; l ++ ) minv [ 4 * l ] = wasm_f32x4_min ( minv [ 4 * l ], minv [ 4 * l + 2 ]);
1000+ for ( int l = 0 ; l < 1 ; l ++ ) minv [ 8 * l ] = wasm_f32x4_min ( minv [ 8 * l ], minv [ 8 * l + 4 ] );
9531001
954- const float d = amax / ((1 << 3 ) - 1 );
1002+ max = MAX (
1003+ MAX (wasm_f32x4_extract_lane (maxv [0 ], 0 ), wasm_f32x4_extract_lane (maxv [0 ], 1 )),
1004+ MAX (wasm_f32x4_extract_lane (maxv [0 ], 2 ), wasm_f32x4_extract_lane (maxv [0 ], 3 )));
1005+ min = MIN (
1006+ MIN (wasm_f32x4_extract_lane (minv [0 ], 0 ), wasm_f32x4_extract_lane (minv [0 ], 1 )),
1007+ MIN (wasm_f32x4_extract_lane (minv [0 ], 2 ), wasm_f32x4_extract_lane (minv [0 ], 3 )));
1008+
1009+ const float magnitude = max >= fabsf (min ) ? max : min ;
1010+ const float d = magnitude / -8 ;
9551011 const float id = d ? 1.0 /d : 0.0 ;
9561012
9571013 y [i ].d = d ;
@@ -960,9 +1016,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
9601016 const v128_t v = wasm_f32x4_mul (srcv [l ], wasm_f32x4_splat (id ));
9611017 const v128_t vf = wasm_f32x4_add (v , wasm_f32x4_splat (8.5f ));
9621018 const v128_t vi = wasm_i32x4_trunc_sat_f32x4 (vf );
1019+ const v128_t vc = wasm_i32x4_min_u (vi , wasm_i32x4_splat (15 ));
9631020
964- y [i ].qs [2 * l + 0 ] = wasm_i32x4_extract_lane (vi , 0 ) | (wasm_i32x4_extract_lane (vi , 1 ) << 4 );
965- y [i ].qs [2 * l + 1 ] = wasm_i32x4_extract_lane (vi , 2 ) | (wasm_i32x4_extract_lane (vi , 3 ) << 4 );
1021+ y [i ].qs [2 * l + 0 ] = wasm_i32x4_extract_lane (vc , 0 ) | (wasm_i32x4_extract_lane (vc , 1 ) << 4 );
1022+ y [i ].qs [2 * l + 1 ] = wasm_i32x4_extract_lane (vc , 2 ) | (wasm_i32x4_extract_lane (vc , 3 ) << 4 );
9661023 }
9671024 }
9681025#else
@@ -1143,13 +1200,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
11431200
11441201 for (int i = 0 ; i < nb ; i ++ ) {
11451202 float amax = 0.0f ; // absolute max
1203+ float max = 0.0f ;
11461204
11471205 for (int l = 0 ; l < QK4_2 ; l ++ ) {
11481206 const float v = x [i * QK4_2 + l ];
1149- amax = MAX (amax , fabsf (v ));
1207+ if (amax < fabsf (v )) {
1208+ amax = fabsf (v );
1209+ max = v ;
1210+ }
11501211 }
11511212
1152- const float d = amax / (( 1 << 3 ) - 1 ) ;
1213+ const float d = max / -8 ;
11531214
11541215 const float id = d ? 1.0f /d : 0.0f ;
11551216
@@ -1159,8 +1220,8 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
11591220 const float v0 = x [i * QK4_2 + l + 0 ]* id ;
11601221 const float v1 = x [i * QK4_2 + l + 1 ]* id ;
11611222
1162- const uint8_t vi0 = ( uint8_t )(v0 + 8.5f );
1163- const uint8_t vi1 = ( uint8_t )(v1 + 8.5f );
1223+ const uint8_t vi0 = MIN ( 15 , ( uint8_t )(v0 + 8.5f ) );
1224+ const uint8_t vi1 = MIN ( 15 , ( uint8_t )(v1 + 8.5f ) );
11641225
11651226 assert (vi0 < 16 );
11661227 assert (vi1 < 16 );
@@ -1254,9 +1315,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
12541315
12551316 block_q4_2 * restrict y = vy ;
12561317
1257- //quantize_row_q4_2_reference(x, y, k);
1258- // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
1259- quantize_row_q4_2_rmse (x , y , k );
1318+ quantize_row_q4_2_reference (x , y , k );
12601319}
12611320
12621321static void quantize_row_q4_3_reference (const float * restrict x , block_q4_3 * restrict y , int k ) {
@@ -1807,7 +1866,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
18071866 [GGML_TYPE_Q4_2 ] = {
18081867 .dequantize_row_q = dequantize_row_q4_2 ,
18091868 .quantize_row_q = quantize_row_q4_2 ,
1810- .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_2_rmse , // quantize_row_q4_2_reference,
1869+ .quantize_row_q_reference = (quantize_row_q_t ) quantize_row_q4_2_reference ,
18111870 .quantize_row_q_dot = quantize_row_q8_0 ,
18121871 .vec_dot_q = ggml_vec_dot_q4_2_q8_0 ,
18131872 },
@@ -12144,8 +12203,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
1214412203 for (int j = 0 ; j < n ; j += k ) {
1214512204 block_q4_2 * restrict y = (block_q4_2 * )dst + j /QK4_2 ;
1214612205
12147- //quantize_row_q4_2_reference(src + j, y, k);
12148- quantize_row_q4_2_rmse (src + j , y , k );
12206+ quantize_row_q4_2_reference (src + j , y , k );
1214912207
1215012208 for (int i = 0 ; i < nb ; i ++ ) {
1215112209 for (int l = 0 ; l < QK4_2 ; l += 2 ) {
0 commit comments