Skip to content

Commit bc8e1c0

Browse files
author
Scott Griffiths
committed
Some more refactoring of binary8 and mxfp format code.
This should break the CI due to changes in gfloat.
1 parent 4d40148 commit bc8e1c0

File tree

10 files changed

+294
-224
lines changed

10 files changed

+294
-224
lines changed

bitstring/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@
7171
from .dtypes import DtypeDefinition, dtype_register, Dtype
7272
import types
7373
from typing import List, Tuple, Literal
74+
from .mxfp import decompress_luts as mxfp_decompress_luts
75+
from .fp8 import decompress_luts as binary8_decompress_luts
76+
77+
# Decompress the LUTs for the exotic floating point formats
78+
mxfp_decompress_luts()
79+
binary8_decompress_luts()
7480

7581
# The Options class returns a singleton.
7682
options = Options()

bitstring/bits.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import io
1010
from collections import abc
1111
import functools
12-
from typing import Tuple, Union, List, Iterable, Any, Optional, \
13-
BinaryIO, TextIO, overload, Iterator, Type, TypeVar
12+
from typing import Tuple, Union, List, Iterable, Any, Optional, BinaryIO, TextIO, overload, Iterator, Type, TypeVar
1413
import bitarray
1514
import bitarray.util
1615
import bitstring

bitstring/bitstore_helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def p4binary2bitstore(f: Union[str, float]) -> BitStore:
121121
u = p4binary_fmt.float_to_int8(f)
122122
return int2bitstore(u, 8, False)
123123

124-
125124
def p3binary2bitstore(f: Union[str, float]) -> BitStore:
126125
f = float(f)
127126
u = p3binary_fmt.float_to_int8(f)

bitstring/bitstream.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import bitstring
44
from bitstring.bits import Bits, BitsType
5+
from bitstring.dtypes import Dtype
56
from typing import Union, List, Any, Optional, overload, TypeVar, Tuple
67
import copy
78
import numbers
@@ -354,7 +355,7 @@ def read(self, fmt: Union[int, str, Dtype]) -> Union[int, float, str, Bits, bool
354355
raise bitstring.ReadError(f"Reading off end of bitstring with fmt '{fmt}'. Only {len(self) - p} bits available.")
355356
return val
356357

357-
def readlist(self, fmt: Union[str, List[Union[int, str]]], **kwargs) \
358+
def readlist(self, fmt: Union[str, List[Union[int, str, Dtype]]], **kwargs) \
358359
-> List[Union[int, float, str, Bits, bool, bytes, None]]:
359360
"""Interpret next bits according to format string(s) and return list.
360361

bitstring/fp8.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@
99
import array
1010
import bitarray
1111
from bitstring.luts import binary8_luts_compressed
12+
import math
1213

1314

1415
class Binary8Format:
1516
"""8-bit floating point formats based on draft IEEE binary8"""
1617

1718
def __init__(self, exp_bits: int, bias: int):
18-
# We use look up tables to go from an IEEE float16 to the best float8 representation.
19-
# For startup efficiency they've been precalculated and zipped up
2019
self.exp_bits = exp_bits
2120
self.bias = bias
21+
self.pos_clamp_value = 0b01111111
22+
self.neg_clamp_value = 0b11111111
23+
24+
def __str__(self):
25+
return f"Binary8Format(exp_bits={self.exp_bits}, bias={self.bias})"
2226

2327
def decompress_luts(self):
2428
binary8_to_float_compressed, float16_to_binary8_compressed = binary8_luts_compressed[(self.exp_bits, self.bias)]
@@ -37,18 +41,24 @@ def float_to_int8(self, f: float) -> int:
3741
b = struct.pack('>e', f)
3842
except (OverflowError, struct.error):
3943
# Return the largest representable positive or negative value
40-
return 0b01111111 if f > 0 else 0b11111111
44+
return self.pos_clamp_value if f > 0 else self.neg_clamp_value
4145
f16_int = int.from_bytes(b, byteorder='big')
4246
# Then use this as an index to our large LUT
4347
return self.lut_float16_to_binary8[f16_int]
4448

4549
def createLUT_for_float16_to_binary8(self) -> bytes:
4650
# Used to create the LUT that was compressed and stored for the fp8 code
51+
import gfloat
52+
fi = gfloat.formats.format_info_p3109(8 - self.exp_bits)
4753
fp16_to_fp8 = bytearray(1 << 16)
4854
for i in range(1 << 16):
4955
b = struct.pack('>H', i)
5056
f, = struct.unpack('>e', b)
51-
fp8_i = self.slow_float_to_int8(f)
57+
fp = gfloat.round_float(fi, f)
58+
if math.isnan(fp):
59+
fp8_i = 0b10000000
60+
else:
61+
fp8_i = self.lut_binary8_to_float.index(fp)
5262
fp16_to_fp8[i] = fp8_i
5363
return bytes(fp16_to_fp8)
5464

@@ -76,29 +86,12 @@ def createLUT_for_binary8_to_float(self):
7686
i2f[0b11111111] = float('-inf')
7787
return array.array('f', i2f)
7888

79-
def slow_float_to_int8(self, f: float) -> int:
80-
# Slow, but easier to follow than the faster version. Used only for validation.
81-
if f >= 0:
82-
for i in range(128):
83-
if f < self.lut_binary8_to_float[i]:
84-
return i - 1
85-
# Clip to positive max
86-
return 0b01111111
87-
if f < 0:
88-
if f > self.lut_binary8_to_float[129]:
89-
# Rounding upwards to zero
90-
return 0b00000000 # There's no negative zero so this is a special case
91-
for i in range(130, 256):
92-
if f > self.lut_binary8_to_float[i]:
93-
return i - 1
94-
# Clip to negative max
95-
return 0b11111111
96-
# We only have one nan value
97-
return 0b10000000
9889

9990
# We create the 1.5.2 and 1.4.3 formats.
10091
p4binary_fmt = Binary8Format(exp_bits=4, bias=8)
10192
p3binary_fmt = Binary8Format(exp_bits=5, bias=16)
10293

103-
p4binary_fmt.decompress_luts()
104-
p3binary_fmt.decompress_luts()
94+
95+
def decompress_luts():
96+
p4binary_fmt.decompress_luts()
97+
p3binary_fmt.decompress_luts()

0 commit comments

Comments
 (0)