Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Add end-to-end execution support in colocated Python API #25154

Merged
merged 1 commit into from
Dec 5, 2024

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Nov 27, 2024

[JAX] Add end-to-end execution support in colocated Python API

This change adds a capability to run colocated Python function calls through
PyLoadedExecutable. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

  • Colocated Python is compiled into PyLoadedExeutable and uses the JAX C++
    dispatch path.

  • CustomCallProgram for a colocated Python compilation nows includes
    specialization (input/output specs, devices). This information allows a
    colocated Python backend to transform input/outputs and validate
    PyTree/dtype/shape/sharding.

  • out_specs_fn now receives jax.ShapeDTypeStructs instead of concrete values.

  • Deserialization of devices now prefers the default backend. This improves the
    compatibility with an environment using both multi-platform backend as well as
    the standard "cpu" backend at the same time.

  • Several bugs have been fixed (e.g., correctly using {} for kwargs).

@copybara-service copybara-service bot force-pushed the test_700809795 branch 10 times, most recently from 646c9bc to 74f86e5 Compare December 4, 2024 21:59
@copybara-service copybara-service bot force-pushed the test_700809795 branch 5 times, most recently from a3cd21a to d3a5681 Compare December 5, 2024 18:15
This change adds a capability to run colocated Python function calls through
`PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

* Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++
dispatch path.

* `CustomCallProgram` for a colocated Python compilation nows includes
specialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.

* `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values.

* Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.

* Several bugs have been fixed (e.g., correctly using `{}` for kwargs).

PiperOrigin-RevId: 703172997
@copybara-service copybara-service bot merged commit e20a483 into main Dec 5, 2024
@copybara-service copybara-service bot deleted the test_700809795 branch December 5, 2024 18:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant