Skip to content

xarray_jax does not support jax.jit().lower #98

@csubich

Description

@csubich

The JAX API now includes more detailed control over the compilation process with jax.stages, but the xarray_jax wrapper here in graphcast does not seem to support jax.jit().lower:

import graphcast.xarray_jax as xarray_jax
import jax.numpy as jnp
import jax

def ident(a): # Trivial test function
    return a

# Sample variables
foo = jnp.ones(3)
foo_xr = xarray_jax.DataArray(foo)

print(jax.jit(ident)(foo)) # Works
# [1. 1. 1.]

print(jax.jit(ident)(foo_xr)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0

jax.jit(ident).lower(foo) # Works
# <jax._src.stages.Lowered at 0x151bb5e04830>

jax.jit(ident).lower(foo_xr) # Fails
Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 1
----> 1 jax.jit(ident).lower(foo_xr) # Fails

    [... skipping hidden 5 frame]

File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:668, in _unflatten_variable(aux, children)
    666 dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
    667 if dims_change_fn: dims = dims_change_fn(dims)
--> 668 return Variable(dims=dims, data=children[0])

File /fs/site5/eccc/mrd/rpnatm/csu001/ppp5/graphcast_dev/graphcast/xarray_jax.py:113, in Variable(dims, data, **kwargs)
    111 def Variable(dims, data, **kwargs) -> xarray.Variable:  # pylint:disable=invalid-name
    112   """Like xarray.Variable, but can wrap JAX arrays."""
--> 113   return xarray.Variable(dims, wrap(data), **kwargs)

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/core/variable.py:365, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    338 def __init__(
    339     self,
    340     dims,
   (...)
    344     fastpath=False,
    345 ):
    346     """
    347     Parameters
    348     ----------
   (...)
    363         unrecognized encoding items.
    364     """
--> 365     super().__init__(
    366         dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs
    367     )
    369     self._encoding = None
    370     if encoding is not None:

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:253, in NamedArray.__init__(self, dims, data, attrs)
    246 def __init__(
    247     self,
    248     dims: _DimsLike,
    249     data: duckarray[Any, _DType_co],
    250     attrs: _AttrsLike = None,
    251 ):
    252     self._data = data
--> 253     self._dims = self._parse_dimensions(dims)
    254     self._attrs = dict(attrs) if attrs else None

File ~/data/ppp5/conda_env/gforecast_test/lib/python3.11/site-packages/xarray/namedarray/core.py:481, in NamedArray._parse_dimensions(self, dims)
    479 dims = (dims,) if isinstance(dims, str) else tuple(dims)
    480 if len(dims) != self.ndim:
--> 481     raise ValueError(
    482         f"dimensions {dims} must have the same length as the "
    483         f"number of data dimensions, ndim={self.ndim}"
    484     )
    485 if len(set(dims)) < len(dims):
    486     repeated_dims = set([d for d in dims if dims.count(d) > 1])

ValueError: dimensions ('dim_0',) must have the same length as the number of data dimensions, ndim=0

If the xarray is created inside a JITted function, things seem to work:

def make_xr(a):
    return xarray_jax.DataArray(a)

def compose(a):
    return (ident(make_xr(a)))

print(jax.jit(compose).lower(foo).compile()(foo)) # Works
# <xarray.DataArray (dim_0: 3)>
# xarray_jax.JaxArrayWrapper(Array([1., 1., 1.], dtype=float32))
# Dimensions without coordinates: dim_0

I'm not yet sure if exploding xarray arguments into a more pytree-friendly version only to recreate them inside a wrapper is a generic solution, or if doing so with graphcast would just reveal an error further in.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions