Skip to content

Natten ShardTensor and ND support#1519

Open
pzharrington wants to merge 8 commits intoNVIDIA:mainfrom
pzharrington:natten-migration
Open

Natten ShardTensor and ND support#1519
pzharrington wants to merge 8 commits intoNVIDIA:mainfrom
pzharrington:natten-migration

Conversation

@pzharrington
Copy link
Collaborator

@pzharrington pzharrington commented Mar 19, 2026

PhysicsNeMo Pull Request

Description

Extends support for 1D and 3D natten in addition to 2D, and exposes them in nn.functional. The lightweight wrappers in nn.functional are used to more seamlessly support ShardTensor usage via the register_function_handler pattern (no more needing to import partial_na2d from the natten patches file).

This PR also pulls out the natten dependency from the nn-extras group into its own dedicated optional dependency group; uses the wheels they release against pinned torch+CUDA combos which is much faster than building natten each install.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington
Copy link
Collaborator Author

@greptileai

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 19, 2026

Greptile Summary

This PR extends natten (neighborhood attention) support from 2D-only to 1D, 2D, and 3D, exposes all three variants through physicsnemo.nn.functional, and replaces the wrapt-based monkey-patching approach with PyTorch's native __torch_function__ dispatch via ShardTensor.register_function_handler. It also moves the natten dependency into its own dedicated natten-cu12/natten-cu13 optional extras backed by pre-built wheels.

Key changes:

  • New physicsnemo/nn/functional/natten.py provides clean na1d, na2d, na3d wrappers that trigger __torch_function__ dispatch for tensor subclasses.
  • natten_patches.py drops wrapt entirely; the single _natten_wrapper function is registered for all three ops, with the correct type() is torch.Tensor (strict) check ordered before the isinstance(..., ShardTensor) check to prevent infinite recursion on mixed-type inputs.
  • Natten2DSelfAttention.forward is significantly simplified — the manual ShardTensor branch is removed since automatic dispatch now handles it transparently.
  • Comprehensive unit and distributed tests are added covering forward, backward, and dispatch for all three dimensionalities.

Minor findings:

  • A stale "na2d" inline comment remains in _partial_natten (line 199 of natten_patches.py).
  • _raw_func_map[func]() will produce an unhelpful bare KeyError if func is not one of the three registered functions.
  • The removal of the eager natten availability guard in Natten2DSelfAttention.__init__ defers the "natten not installed" error to forward time instead of construction time.

Important Files Changed

Filename Overview
physicsnemo/nn/functional/natten.py New file providing na1d, na2d, na3d thin wrappers with __torch_function__ dispatch; logic is clean, docstrings are thorough, and the handle_torch_function call correctly places kernel_size positionally and dilation as a keyword so the wrapper always has args[3] = kernel_size.
physicsnemo/domain_parallel/shard_utils/natten_patches.py Replaces the wrapt-based na2d_wrapper with a single _natten_wrapper registered for all three ops; the type() is torch.Tensor vs isinstance(..., ShardTensor) ordering correctly prevents the infinite-recursion bug from the previous implementation. Minor stale comment ("na2d") and an unhelpful KeyError path remain.
physicsnemo/models/dit/layers.py Drops the manual ShardTensor branch in Natten2DSelfAttention.forward and the eager natten availability check in __init__; both simplifications are correct but the removed __init__ guard means natten-missing errors are now deferred to forward time instead of construction time.
test/nn/functional/test_natten.py New unit test file covering shape, backward, __torch_function__ dispatch, and full-window SDPA equivalence for all three flavors; well-structured with class-level @requires_module("natten") guards.
test/domain_parallel/ops/test_natten.py New distributed test covering forward and backward passes for sharded 1D/2D/3D neighborhood attention across all spatial-shard dimensions; comprehensive parametrization.
pyproject.toml Moves natten out of nn-extras into dedicated natten-cu12/natten-cu13 extras using pre-built wheel indexes; conflict rules correctly prevent mixing CUDA variants.

Comments Outside Diff (3)

  1. physicsnemo/domain_parallel/shard_utils/natten_patches.py, line 199 (link)

    Stale "na2d" in comment

    The function was renamed from partial_na2d to _partial_natten and now handles 1D, 2D, and 3D operations, but the inline comment still says "na2d":

  2. physicsnemo/domain_parallel/shard_utils/natten_patches.py, line 259-263 (link)

    KeyError on unexpected func gives a poor error message

    _raw_func_map[func]() will raise a bare KeyError if func is not one of na1d, na2d, or na3d (e.g. if a future caller registers _natten_wrapper for a new function but forgets to add it to _raw_func_map). A guard with an explicit message would make the failure much easier to diagnose:

        elif all(isinstance(_t, ShardTensor) for _t in (q, k, v)):
            if func not in _raw_func_map:
                raise UndeterminedShardingError(
                    f"No raw natten function registered for {func!r}. "
                    "Add an entry to `_raw_func_map`."
                )
            raw_func = _raw_func_map[func]()
            return _partial_natten(
                q, k, v, kernel_size, dilation, base_func=raw_func, **natten_kwargs
            )
  3. physicsnemo/models/dit/layers.py, line 402 (link)

    Natten availability check deferred from __init__ to forward time

    The previous code raised an ImportError eagerly in __init__ when natten was not installed. After this change, the error is deferred to the first call of forward (when _na2d_func reaches _natten.functional.na2d via OptionalImport). The OptionalImport error message is informative, but this is a mild UX regression — users constructing Natten2DSelfAttention in environments without natten installed will now only learn about the missing dependency at forward time instead of at construction time.

    Consider adding a lightweight check back in __init__:

    from physicsnemo.core.version_check import OptionalImport
    _natten_check = OptionalImport("natten")
    
    class Natten2DSelfAttention(AttentionModuleBase):
        def __init__(self, ...):
            super().__init__()
            if not _natten_check.available:
                raise ImportError(
                    "natten is required for Natten2DSelfAttention. "
                    'Install it with: pip install "nvidia-physicsnemo[cu12,natten-cu12]"'
                )
            ...

Last reviewed commit: "Merge branch 'main' ..."

@pzharrington
Copy link
Collaborator Author

@greptileai

@peterdsharpe
Copy link
Collaborator

Hey @pzharrington - would it be possible to streamline the natten install process as part of this PR (or in a sibling PR, if you prefer)? As natten becomes used more in nn and other core parts of PhysicsNeMo, having a smoother install process would be very helpful - particularly for end-users who are not using natten functionality but need to install natten as a transitive dependency anyway via our layered dependency stack. I know that install burden (e.g., building packages like transformer_engine, torch_geometric, etc.) was a major motivation for the refactor, so I want to be sure that we keep things as easy for end-users to uv sync or pip install nvidia-physicsnemo as possible.

It looks like natten provides pre-built wheels at https://whl.natten.org, which could be embedded into pyproject.toml. What do you think?

@pzharrington
Copy link
Collaborator Author

@greptileai

Copy link
Collaborator

@coreyjadams coreyjadams left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Approving, I don't want to hold this up! I think this works, I did make a suggestion that could make it easier to maintain long term, but I will leave it to you. I appreciate the work to make pyproject.toml managable. Such a pain with the installs...

I also like the addition of na1d and na3d in shard tensor (and in general!). We might advertise, in the docs, though not needed in this PR, that you could now use neighborhood attention layers and modules freely as building blocks with domain parallelism. Perhaps an update to the examples is in order! Let's do that next, though, not here, don't want to block you :)

@pzharrington pzharrington marked this pull request as ready for review March 20, 2026 18:21
@pzharrington
Copy link
Collaborator Author

/blossom-ci

@pzharrington
Copy link
Collaborator Author

We might advertise, in the docs, though not needed in this PR, that you could now use neighborhood attention layers and modules freely as building blocks with domain parallelism.

@coreyjadams yup this PR is just to get the ShardTensor and functional interfaces in place, next step is to move some of the attention layer modules into nn out of models.dit so they can be reused. In that we can shore up the documentation

Copy link
Collaborator

@loliverhennigh loliverhennigh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The natten is not in the FunctionSpec format but I can clean this up in a clean up PR I have for the functional unit tests.

@pzharrington
Copy link
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants