Skip to content

MJX: add differentiable distance calculation#3138

Open
aarushk09 wants to merge 1 commit intogoogle-deepmind:mainfrom
aarushk09:mjx-differentiable-distance
Open

MJX: add differentiable distance calculation#3138
aarushk09 wants to merge 1 commit intogoogle-deepmind:mainfrom
aarushk09:mjx-differentiable-distance

Conversation

@aarushk09
Copy link

Summary

Replace hard-coded dist=1 sentinel values in MJX collision functions with actual signed distance values. This preserves the gradient chain through JAX even when geoms are not colliding, enabling differentiable distance computation.

Changes

All changes are in mjx/mujoco/mjx/_src/collision_convex.py (17 insertions, 15 deletions):

  • _sphere_convex: Remove jp.where(has_separating_axis, 1.0, ...) — the expression d - sphere.size[0] is already correct for both cases
  • _capsule_convex: Use actual face/edge distances instead of -1 sentinels
  • _create_contact_manifold: Use jp.abs(penetration) for non-contact points instead of jp.ones_like
  • _box_box_impl / _sat_gaussmap: Use jp.abs(dist) for inactive edge contact slots and face-separating cases
  • plane_convex / hfield_*: Use jp.abs(support).max() / jp.max(dist) for non-unique contacts

Testing

  • All 37 existing collision_driver_test.py tests pass
  • Verified that jax.grad produces finite, non-zero gradients through contact.dist for non-colliding sphere-box pair

Fixes #3131

@google-cla
Copy link

google-cla bot commented Mar 1, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Replace hard-coded dist=1 sentinel values in collision functions with
actual signed distance values. This preserves the gradient chain through
JAX even when geoms are not colliding, enabling differentiable distance
computation for optimization and learning applications.

Fixes google-deepmind#3131
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MJX: add differentiable distance calculation

1 participant