[JAX] Add end-to-end execution support in colocated Python API #25154
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through
PyLoadedExecutable
. This capability is not yet used for McJAX, but is testedwith a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.
Key highlights:
Colocated Python is compiled into
PyLoadedExeutable
and uses the JAX C++dispatch path.
CustomCallProgram
for a colocated Python compilation nows includesspecialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.
out_specs_fn
now receivesjax.ShapeDTypeStruct
s instead of concrete values.Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.
Several bugs have been fixed (e.g., correctly using
{}
for kwargs).