Skip to content

Commit

Permalink
neural_tangents: migrate from deprecated jax.linear_util to jax.exten…
Browse files Browse the repository at this point in the history
…d.linear_util

PiperOrigin-RevId: 577220260
  • Loading branch information
Jake VanderPlas authored and romanngg committed Nov 21, 2023
1 parent 109796e commit 8d948dc
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@
from jax import jvp
from jax import lax
from jax import linear_transpose
from jax import linear_util as lu
from jax import vjp
from jax import vmap
from jax.core import Jaxpr
Expand Down Expand Up @@ -147,6 +146,12 @@
from .utils.typing import VMapAxes
from .utils.typing import VMapAxisTriple

try:
# jax >=0.4.16
from jax.extend import linear_util as lu
except ImportError:
from jax import linear_util as lu


# LINEARIZATION AND TAYLOR EXPANSION

Expand Down

0 comments on commit 8d948dc

Please sign in to comment.