CUBE SUGAR CONTAINER

技術系のこと書きます。

Python: Pandas で np.float16 はサポートされていない

まったく知らなかったんだけど、Pandas はカラムの型として NumPy の float16 (16 ビット浮動小数点型) をサポートしていない。 これは、以下の Issue で説明されている。 どうやら、プラットフォームによっては float16 を利用できないため対応が難しいらしい 1

github.com

サポートされていなくても、カラムの型としては指定できる。 そして、なんとなく動いているようにも見えてしまうので知らないとハマる。 メモリを節約するために、高い精度が必要ないカラムには指定したくなる場面もあるだろうから。

今回は、どんな場面でこの問題に気づいたのか述べる。 使った環境は次のとおり。

$ sw_vers
ProductName:    macOS
ProductVersion: 12.6
BuildVersion:   21G115
$ uname -srm
Darwin 21.6.0 arm64
$ python -V
Python 3.10.6
$ pip list | grep pandas
pandas          1.5.0

もくじ

下準備

あらかじめ Pandas をインストールしておく。

$ pip install pandas

そして、Python のインタプリタを起動しておく。

$ python

サンプルのデータとして次のようなデータフレームを用意する。 月ごとのフルーツの値段が縦持ち (Long Data) で表されているイメージ。

>>> import pandas as pd
>>> data = {
...     "yyyymm": [
...         "2021-09",
...         "2021-09",
...         "2021-09",
...         "2021-09",
...         "2021-10",
...         "2021-10",
...         "2021-10",
...     ],
...     "name": [
...         "apple",
...         "banana",
...         "cherry",
...         "dates",
...         "apple",
...         "banana",
...         "cherry",
...     ],
...     "price": [
...         100,
...         120,
...         200,
...         180,
...         110,
...         130,
...         210,
...     ]
... }
>>> df = pd.DataFrame(data)
>>> df.set_index(["yyyymm", "name"], inplace=True)
>>> df
                price
yyyymm  name         
2021-09 apple     100
        banana    120
        cherry    200
        dates     180
2021-10 apple     110
        banana    130
        cherry    210

問題を再現する

先ほどのデータを横持ち (Wide Data) にするために unstack() する。 現状では price カラムの型は int64 なので、この操作は成功する。

>>> df.dtypes
price    int64
dtype: object
>>> df.unstack()
         price                     
name     apple banana cherry  dates
yyyymm                             
2021-09  100.0  120.0  200.0  180.0
2021-10  110.0  130.0  210.0    NaN

ただし、2021-10dates の値が存在しないため NaN が挿入される。 NaN は整数型には無いので、自動的に型がキャストされている。

>>> df.unstack().dtypes
       name  
price  apple     float64
       banana    float64
       cherry    float64
       dates     float64
dtype: object

では、次にカラムの型を float16 にするとどうなるだろうか。

>>> import numpy as np
>>> df = df.astype({
...     "price": np.float16,
... })

先ほどと同じように unstack() してみよう。 すると、今度は例外になってしまった。

>>> df.dtypes
price    float16
dtype: object
>>> df.unstack()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/pandas/core/frame.py", line 9102, in unstack
    result = unstack(self, level, fill_value)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/pandas/core/reshape/reshape.py", line 477, in unstack
    return _unstack_frame(obj, level, fill_value=fill_value)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/pandas/core/reshape/reshape.py", line 506, in _unstack_frame
    return unstacker.get_result(
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/pandas/core/reshape/reshape.py", line 216, in get_result
    values, _ = self.get_new_values(values, fill_value)
  File "/Users/amedama/.virtualenvs/py310/lib/python3.10/site-packages/pandas/core/reshape/reshape.py", line 284, in get_new_values
    libreshape.unstack(
  File "pandas/_libs/reshape.pyx", line 21, in pandas._libs.reshape.__pyx_fused_cpdef
TypeError: No matching signature found

このように、カラムの型が float16 だと失敗する操作があることがわかる。

ちなみに、当たり前だけどカラムの型を float32 にした場合は上手くいく。 これは、サポートされている型なので。

>>> df = df.astype({
...     "price": np.float32,
... })
>>> df.dtypes
price    float32
dtype: object
>>> df.unstack()
         price                     
name     apple banana cherry  dates
yyyymm                             
2021-09  100.0  120.0  200.0  180.0
2021-10  110.0  130.0  210.0    NaN

また、欠損値がない場合にも上手くいく。 これは、どうやら問題は欠損値を埋める処理に起因しているようなので。

>>> df = df.astype({
...     "price": np.float16,
... })
>>> df.drop(index=("2021-09", "dates"))
                price
yyyymm  name         
2021-09 apple   100.0
        banana  120.0
        cherry  200.0
2021-10 apple   110.0
        banana  130.0
        cherry  210.0
>>> df.drop(index=("2021-09", "dates")).unstack()
         price              
name     apple banana cherry
yyyymm                      
2021-09  100.0  120.0  200.0
2021-10  110.0  130.0  210.0

このように、なんとなく動いているようにも見えてしまうので注意が必要になる。

まとめ

Pandas は NumPy の float16 をサポートしていないので、カラムの型に使うのはやめよう。


  1. 英語だと「モグラたたき」のことを「a rabbit hole」と表現するんだね