Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: Compatibility Issues With Jax and TensorFlow #240

Open
MichaelCato opened this issue Jun 17, 2024 · 6 comments
Open

Question: Compatibility Issues With Jax and TensorFlow #240

MichaelCato opened this issue Jun 17, 2024 · 6 comments

Comments

@MichaelCato
Copy link

I am installing colabfold on my Linux. After updating with ./update_linux.sh, I have incompatibilities between Jax's required ml-dtypes (which requires 0.4.0) and tensorflow's required ml-dtypes (which requires 0.3.2).

If I prioritize Jax by running pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html after running ./update_linux.sh, everything seems to work, but I am not sure if I am losing tensorflow in the run.

Has anyone had similar issues?

@YoshitakaMo
Copy link
Owner

The issue is caused by the latest jax update (0.4.29). Jax and jaxlib 0.4.29 requires ml-dtypes 0.4.0, but tensorflow does 0.3.2, as you pointed out.
Probably, /path/to/your/localcolabfold/colabfold-conda/bin/pip3 install --upgrade "jax[cuda12]"==0.4.28 (instead of pip3 install --upgrade "jax[cuda12]") will solve the issue for the time being.

@jecorn
Copy link

jecorn commented Jun 18, 2024

For posterity and so people can find this via search. The solution by @YoshitakaMo (/path/to/your/localcolabfold/colabfold-conda/bin/pip3 install --upgrade "jax[cuda12]"==0.4.28) fixes the following exceptions:

RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

and

jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51).

and

RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Full error below:

2024-06-18 08:52:05,501 Running colabfold 1.5.5 (1648d2335943f9a483b6a803ebaea3e76162c788)
Traceback (most recent call last):
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1281, in run
    jax.tools.colab_tpu.setup_tpu()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/tools/colab_tpu.py", line 20, in setup_tpu
    raise RuntimeError(
RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 657, in factory
    xla_client.initialize_pjrt_plugin(plugin_name)
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jaxlib/xla_client.py", line 176, in initialize_pjrt_plugin
    _xla.initialize_pjrt_plugin(plugin_name)
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in <module>
    sys.exit(main())
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2046, in main
    run(
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1286, in run
    if jax.local_devices()[0].platform == 'cpu':
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1135, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
    return _get_backend_uncached(platform)
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
    bs = backends()
  File "/home/cornlab/bin/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)```

@MichaelCato
Copy link
Author

The issue is caused by the latest jax update (0.4.29). Jax and jaxlib 0.4.29 requires ml-dtypes 0.4.0, but tensorflow does 0.3.2, as you pointed out. Probably, /path/to/your/localcolabfold/colabfold-conda/bin/pip3 install --upgrade "jax[cuda12]"==0.4.28 (instead of pip3 install --upgrade "jax[cuda12]") will solve the issue for the time being.

This worked! Thank you!

@lealiaxiong
Copy link

Hi! I am having what I believe is the same issue (error traceback below), even though I installed localcolabfold using the most recent install_colabbatch_linux.sh, which I see already has a line specifying "$COLABFOLDDIR/colabfold-conda/bin/pip" install --upgrade "jax[cuda12]"==0.4.28. Any suggestions?

Traceback (most recent call last):
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1281, in run
    jax.tools.colab_tpu.setup_tpu()
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/tools/colab_tpu.py", line 20, in setup_tpu
    raise RuntimeError(
RuntimeError: jax.tools.colab_tpu.setup_tpu() was required for older JAX versions running on older generations of TPUs, and should no longer be used.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 874, in backends
    backend = _init_backend(platform)
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 965, in _init_backend
    backend = registration.factory()
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 663, in factory
    return xla_client.make_c_api_client(plugin_name, updated_options, None)
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jaxlib/xla_client.py", line 199, in make_c_api_client
    return _xla.get_c_api_client(plugin_name, options, distributed_client)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in <module>
    sys.exit(main())
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2046, in main
    run(
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1286, in run
    if jax.local_devices()[0].platform == 'cpu':
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1135, in local_devices
    process_index = get_backend(backend).process_index()
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1011, in get_backend
    return _get_backend_uncached(platform)
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 990, in _get_backend_uncached
    bs = backends()
  File "/opt/localcolabfold/colabfold-conda/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 890, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

@YoshitakaMo
Copy link
Owner

No visible GPU devices.

This means that your PC couldn't detect a GPU. Probably there is an error around the installation of CUDA drivers.
Does nvidia-smi command return the state of GPU correctly? Check your System info (python version, jaxlib version, accelerator, etc., See also jax-ml/jax#21998)

@lealiaxiong
Copy link

@YoshitakaMo thank you for maintaining localcolabfold and for being active in the issues!

For posterity / for people encountering similar issue in the future:
I am putting localcolabfold in a Docker container and using on an EC2 instance; I was able to solve my problem by starting the EC2 instance with the user data script here: https://repost.aws/articles/ARwfQMxiC-QMOgWykD9mco1w/how-do-i-install-nvidia-gpu-driver-cuda-toolkit-and-optionally-nvidia-container-toolkit-on-amazon-linux-2023-al2023

When I was encountering the problem in my above question, from within my Docker container nvidia-smi returned the GPU state correctly and nvcc --version looked fine to me as well...not sure what happened but in the end I resolved my issue by having the fresh instance with fresh CUDA toolkit / container toolkit installation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants