Skip to content

Commit

Permalink
Bump min JAX version to 0.4.16 + minor no-op linter fixes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589663637
  • Loading branch information
romanngg committed Dec 11, 2023
1 parent ad3d524 commit ad47437
Show file tree
Hide file tree
Showing 16 changed files with 136 additions and 51 deletions.
4 changes: 4 additions & 0 deletions neural_tangents/_src/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,20 @@
import warnings

import jax

from jax import device_put
from jax import devices
from jax import jit
from jax import pmap
from jax import random

import jax.numpy as jnp

from jax.tree_util import tree_all
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten

import numpy as np

from .utils import utils
Expand Down
13 changes: 7 additions & 6 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,22 @@
from jax import linear_transpose
from jax import vjp
from jax import vmap

from jax.core import Jaxpr
from jax.core import JaxprEqn
from jax.core import Literal
from jax.core import ShapedArray
from jax.core import Value
from jax.core import Var

from jax.extend import linear_util as lu

from jax.interpreters import ad
from jax.interpreters.ad import UndefinedPrimal
from jax.interpreters.ad import Zero

import jax.numpy as jnp

from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_reduce
Expand All @@ -134,6 +140,7 @@
from jax.tree_util import tree_unflatten
from jax.util import safe_map as map
from jax.util import safe_zip as zip

import numpy as np

from .utils import rules
Expand All @@ -146,12 +153,6 @@
from .utils.typing import VMapAxes
from .utils.typing import VMapAxisTriple

try:
# jax >=0.4.16
from jax.extend import linear_util as lu
except ImportError:
from jax import linear_util as lu


# LINEARIZATION AND TAYLOR EXPANSION

Expand Down
23 changes: 19 additions & 4 deletions neural_tangents/_src/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,34 @@
kernel function is JITted internally.
"""


from functools import partial
import operator
from typing import Generator, Iterable, Optional, Union

from .batching import batch
from .empirical import empirical_kernel_fn, NtkImplementation, DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_S_RULES, _DEFAULT_NTK_J_RULES
import jax
from jax import random
import jax.numpy as jnp
from jax.tree_util import tree_map

from .batching import batch

from .empirical import _DEFAULT_NTK_FWD
from .empirical import _DEFAULT_NTK_J_RULES
from .empirical import _DEFAULT_NTK_S_RULES
from .empirical import DEFAULT_NTK_IMPLEMENTATION
from .empirical import empirical_kernel_fn
from .empirical import NtkImplementation

from .utils import utils
from .utils.typing import ApplyFn, Axes, EmpiricalGetKernelFn, Get, InitFn, MonteCarloKernelFn, NTTree, PyTree, VMapAxes
from .utils.typing import ApplyFn
from .utils.typing import Axes
from .utils.typing import EmpiricalGetKernelFn
from .utils.typing import Get
from .utils.typing import InitFn
from .utils.typing import MonteCarloKernelFn
from .utils.typing import NTTree
from .utils.typing import PyTree
from .utils.typing import VMapAxes


def _sample_once_kernel_fn(
Expand Down
14 changes: 9 additions & 5 deletions neural_tangents/_src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,25 @@
note that closed-form kernels currently only support a single `channel_axis`).
"""


import collections
from functools import lru_cache
from typing import Callable, Generator, Iterable, NamedTuple, Optional, Any, Union, Protocol
from typing import Any, Callable, Generator, Iterable, NamedTuple, Optional, Protocol, Union

import jax
from jax import grad
from jax.experimental import ode
import jax.numpy as jnp
import jax.scipy as jsp
from jax.tree_util import tree_all, tree_map
from jax.tree_util import tree_all
from jax.tree_util import tree_map
import numpy as np
import scipy as sp
from .utils import dataclasses, utils
from .utils.typing import Axes, Get, KernelFn

from .utils import dataclasses
from .utils import utils
from .utils.typing import Axes
from .utils.typing import Get
from .utils.typing import KernelFn


PyTree = Any
Expand Down
9 changes: 6 additions & 3 deletions neural_tangents/_src/stax/branching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
several branches into one.
"""


import functools
from typing import Callable, Iterable, Optional, Sequence
import warnings

from jax import numpy as jnp
import jax.example_libraries.stax as ostax
from .requirements import layer, supports_masking

from ..utils.kernel import Kernel
from ..utils.typing import InternalLayer, InternalLayerMasked, Kernels
from ..utils.typing import InternalLayer
from ..utils.typing import InternalLayerMasked
from ..utils.typing import Kernels
from .requirements import layer
from .requirements import supports_masking


@layer
Expand Down
16 changes: 13 additions & 3 deletions neural_tangents/_src/stax/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,21 @@

import frozendict
import jax
from jax import random, lax
from jax import lax
from jax import random
import jax.example_libraries.stax as ostax
from .requirements import Diagonal, get_req, layer, requires

from ..utils.kernel import Kernel
from ..utils.typing import InternalLayer, Layer, LayerKernelFn, NTTree, NTTrees, Shapes
from ..utils.typing import InternalLayer
from ..utils.typing import Layer
from ..utils.typing import LayerKernelFn
from ..utils.typing import NTTree
from ..utils.typing import NTTrees
from ..utils.typing import Shapes
from .requirements import Diagonal
from .requirements import get_req
from .requirements import layer
from .requirements import requires


@layer
Expand Down
15 changes: 12 additions & 3 deletions neural_tangents/_src/stax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,24 @@
import warnings

import jax
from jax import custom_jvp, grad, vmap
from jax import custom_jvp
from jax import grad
from jax import numpy as jnp
from jax import vmap
from jax.scipy.special import erf
import numpy as np
from .requirements import Diagonal, get_diagonal, get_diagonal_outer_prods, layer, requires, supports_masking
import scipy as sp

from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import InternalLayer, LayerKernelFn
from ..utils.typing import InternalLayer
from ..utils.typing import LayerKernelFn
from .requirements import Diagonal
from .requirements import get_diagonal
from .requirements import get_diagonal_outer_prods
from .requirements import layer
from .requirements import requires
from .requirements import supports_masking


@layer
Expand Down
18 changes: 15 additions & 3 deletions neural_tangents/_src/stax/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,30 @@
import warnings

import jax
from jax import eval_shape
from jax import lax
from jax import numpy as jnp
from jax import ops
from jax import random
from jax import ShapeDtypeStruct, eval_shape, vmap
from jax import ShapeDtypeStruct
from jax import vmap
from jax.core import ShapedArray
import jax.example_libraries.stax as ostax
import numpy as np
from .requirements import Bool, Diagonal, get_diagonal_outer_prods, layer, mean_and_var, requires, supports_masking

from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import Axes, InternalLayer, InternalLayerMasked, PyTree
from ..utils.typing import Axes
from ..utils.typing import InternalLayer
from ..utils.typing import InternalLayerMasked
from ..utils.typing import PyTree
from .requirements import Bool
from .requirements import Diagonal
from .requirements import get_diagonal_outer_prods
from .requirements import layer
from .requirements import mean_and_var
from .requirements import requires
from .requirements import supports_masking


# Enums
Expand Down
23 changes: 17 additions & 6 deletions neural_tangents/_src/stax/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,34 @@

"""Requirement management for :obj:`~neural_tangents.stax` layers."""

import dataclasses
import enum
from typing import Callable, Optional, Sequence, Union
import warnings

import frozendict
import jax
from jax import eval_shape
from jax import lax
from jax import numpy as jnp
from jax import eval_shape
from jax.core import ShapedArray
from jax.tree_util import tree_map, tree_all
from ..utils import utils
import dataclasses
from jax.tree_util import tree_all
from jax.tree_util import tree_map
import numpy as np

from ..utils import dataclasses as nt_dataclasses
from ..utils import utils
from ..utils.kernel import Kernel
from ..utils.typing import AnalyticKernelFn, Axes, Get, InitFn, ApplyFn, InternalLayer, Layer, LayerKernelFn, NTTree, PyTree
import numpy as np
from ..utils.typing import AnalyticKernelFn
from ..utils.typing import ApplyFn
from ..utils.typing import Axes
from ..utils.typing import Get
from ..utils.typing import InitFn
from ..utils.typing import InternalLayer
from ..utils.typing import Layer
from ..utils.typing import LayerKernelFn
from ..utils.typing import NTTree
from ..utils.typing import PyTree


# Public decorators
Expand Down
16 changes: 12 additions & 4 deletions neural_tangents/_src/utils/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@

"""Structured derivatives rules."""

from .dataclasses import dataclass, field
import functools
from typing import Callable, Optional, Any, Union
from typing import Any, Callable, Optional, Union

from . import utils
import jax
from jax import lax
from jax.core import JaxprEqn, ShapedArray, Primitive, Jaxpr, Var, AbstractValue, Literal
from jax.core import AbstractValue
from jax.core import Jaxpr
from jax.core import JaxprEqn
from jax.core import Literal
from jax.core import Primitive
from jax.core import ShapedArray
from jax.core import Var
from jax.interpreters import ad
import jax.numpy as jnp
import numpy as np

from . import utils
from .dataclasses import dataclass
from .dataclasses import field


# pytype: disable=wrong-keyword-args

Expand Down
2 changes: 1 addition & 1 deletion neural_tangents/_src/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Common Type Definitions."""

from typing import Any, Generator, Optional, Sequence, TYPE_CHECKING, TypeVar, Union, Protocol
from typing import Any, Generator, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, Union

import jax
import jax.numpy as jnp
Expand Down
3 changes: 2 additions & 1 deletion neural_tangents/_src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from jax import core
from jax import random
import jax.numpy as jnp
from jax.tree_util import tree_all, tree_map
from jax.tree_util import tree_all
from jax.tree_util import tree_map
import numpy as np


Expand Down
11 changes: 9 additions & 2 deletions neural_tangents/experimental/empirical_tf/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,15 @@
import warnings

from jax.experimental import jax2tf
from neural_tangents._src.empirical import NtkImplementation, empirical_ntk_fn, DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_J_RULES, _DEFAULT_NTK_S_RULES
from neural_tangents._src.utils.typing import Axes, PyTree, VMapAxes
from neural_tangents._src.empirical import _DEFAULT_NTK_FWD
from neural_tangents._src.empirical import _DEFAULT_NTK_J_RULES
from neural_tangents._src.empirical import _DEFAULT_NTK_S_RULES
from neural_tangents._src.empirical import DEFAULT_NTK_IMPLEMENTATION
from neural_tangents._src.empirical import empirical_ntk_fn
from neural_tangents._src.empirical import NtkImplementation
from neural_tangents._src.utils.typing import Axes
from neural_tangents._src.utils.typing import PyTree
from neural_tangents._src.utils.typing import VMapAxes
import tensorflow as tf
import tf2jax

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


INSTALL_REQUIRES = [
'jax>=0.4.14',
'jax>=0.4.16',
'frozendict>=2.3.8',
'tensorflow>=2.15.0',
'tf2jax>=0.3.5',
Expand Down
Loading

0 comments on commit ad47437

Please sign in to comment.