99import array
1010import bitarray
1111from bitstring .luts import binary8_luts_compressed
12+ import math
1213
1314
1415class Binary8Format :
1516 """8-bit floating point formats based on draft IEEE binary8"""
1617
1718 def __init__ (self , exp_bits : int , bias : int ):
18- # We use look up tables to go from an IEEE float16 to the best float8 representation.
19- # For startup efficiency they've been precalculated and zipped up
2019 self .exp_bits = exp_bits
2120 self .bias = bias
21+ self .pos_clamp_value = 0b01111111
22+ self .neg_clamp_value = 0b11111111
23+
24+ def __str__ (self ):
25+ return f"Binary8Format(exp_bits={ self .exp_bits } , bias={ self .bias } )"
2226
2327 def decompress_luts (self ):
2428 binary8_to_float_compressed , float16_to_binary8_compressed = binary8_luts_compressed [(self .exp_bits , self .bias )]
@@ -37,18 +41,24 @@ def float_to_int8(self, f: float) -> int:
3741 b = struct .pack ('>e' , f )
3842 except (OverflowError , struct .error ):
3943 # Return the largest representable positive or negative value
40- return 0b01111111 if f > 0 else 0b11111111
44+ return self . pos_clamp_value if f > 0 else self . neg_clamp_value
4145 f16_int = int .from_bytes (b , byteorder = 'big' )
4246 # Then use this as an index to our large LUT
4347 return self .lut_float16_to_binary8 [f16_int ]
4448
4549 def createLUT_for_float16_to_binary8 (self ) -> bytes :
4650 # Used to create the LUT that was compressed and stored for the fp8 code
51+ import gfloat
52+ fi = gfloat .formats .format_info_p3109 (8 - self .exp_bits )
4753 fp16_to_fp8 = bytearray (1 << 16 )
4854 for i in range (1 << 16 ):
4955 b = struct .pack ('>H' , i )
5056 f , = struct .unpack ('>e' , b )
51- fp8_i = self .slow_float_to_int8 (f )
57+ fp = gfloat .round_float (fi , f )
58+ if math .isnan (fp ):
59+ fp8_i = 0b10000000
60+ else :
61+ fp8_i = self .lut_binary8_to_float .index (fp )
5262 fp16_to_fp8 [i ] = fp8_i
5363 return bytes (fp16_to_fp8 )
5464
@@ -76,29 +86,12 @@ def createLUT_for_binary8_to_float(self):
7686 i2f [0b11111111 ] = float ('-inf' )
7787 return array .array ('f' , i2f )
7888
79- def slow_float_to_int8 (self , f : float ) -> int :
80- # Slow, but easier to follow than the faster version. Used only for validation.
81- if f >= 0 :
82- for i in range (128 ):
83- if f < self .lut_binary8_to_float [i ]:
84- return i - 1
85- # Clip to positive max
86- return 0b01111111
87- if f < 0 :
88- if f > self .lut_binary8_to_float [129 ]:
89- # Rounding upwards to zero
90- return 0b00000000 # There's no negative zero so this is a special case
91- for i in range (130 , 256 ):
92- if f > self .lut_binary8_to_float [i ]:
93- return i - 1
94- # Clip to negative max
95- return 0b11111111
96- # We only have one nan value
97- return 0b10000000
9889
9990# We create the 1.5.2 and 1.4.3 formats.
10091p4binary_fmt = Binary8Format (exp_bits = 4 , bias = 8 )
10192p3binary_fmt = Binary8Format (exp_bits = 5 , bias = 16 )
10293
103- p4binary_fmt .decompress_luts ()
104- p3binary_fmt .decompress_luts ()
94+
95+ def decompress_luts ():
96+ p4binary_fmt .decompress_luts ()
97+ p3binary_fmt .decompress_luts ()
0 commit comments