Closed
Description
Description
Hi,
I want to run jax on my GPU, however I get the error No visible GPU devices.
The exact output of running the python file beginning with
from jax import jit, random, config
config.update('jax_enable_x64', True)
import jax
import jax.numpy as jnp
print(jax.devices())
...
is
E external/xla/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
Traceback (most recent call last):
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 879, in backends
backend = _init_backend(platform)
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 970, in _init_backend
backend = registration.factory()
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 668, in factory
return xla_client.make_c_api_client(plugin_name, updated_options, None)
File ".venv/lib/python3.10/site-packages/jaxlib/xla_client.py", line 200, in make_c_api_client
return _xla.get_c_api_client(plugin_name, options, distributed_client)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "src/systems.py", line 5, in <module>
print(jax.devices())
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1082, in devices
return get_backend(backend).devices()
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 1016, in get_backend
return _get_backend_uncached(platform)
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 995, in _get_backend_uncached
bs = backends()
File ".venv/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 895, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
System info (python version, jaxlib version, accelerator, etc.)
OS
Ubuntu 22.04.4 LTS.
Jax version
>> pip list | grep jax
jax 0.4.30
jax-cuda12-pjrt 0.4.30
jax-cuda12-plugin 0.4.30
jaxlib 0.4.30
Nvidia packages
>> pip list | grep nvidia
nvidia-cublas-cu12 12.5.2.13
nvidia-cuda-cupti-cu12 12.5.39
nvidia-cuda-nvcc-cu12 12.5.40
nvidia-cuda-nvrtc-cu12 12.2.140
nvidia-cuda-runtime-cu12 12.5.39
nvidia-cudnn-cu12 9.1.1.17
nvidia-cufft-cu12 11.2.3.18
nvidia-curand-cu12 10.3.3.141
nvidia-cusolver-cu12 11.6.2.40
nvidia-cusparse-cu12 12.4.1.24
nvidia-nccl-cu12 2.22.3
nvidia-nvjitlink-cu12 12.5.40
GPU
>> nvidia-smi
Thu Jun 20 19:20:45 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX 2000 Ada Gene... Off | 00000000:01:00.0 Off | N/A |
| N/A 48C P3 11W / 35W | 15MiB / 8188MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2306 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------------------+
CUDA
>> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Apr_17_19:19:55_PDT_2024
Cuda compilation tools, release 12.5, V12.5.40