How to Install Deep Learning Frameworks (Pytorch and Jax)
Sources:
- Pytorch GPU Setup Guide
For CUDA installation, see my How to Install and Setup CUDA post.
How to Install Deep Learning Frameworks (PyTorch and JAX)
PyTorch (CUDA build)
The following command often installs a CPU-only build of PyTorch:
1 | conda install torch |
If you want GPU support, install a PyTorch build compiled against a CUDA version that is compatible with your NVIDIA driver.
Choose the correct CUDA build
Go to the official PyTorch install page and select the command for your platform and package manager:
1 | https://pytorch.org/get-started/locally/ |
A common source of confusion is the CUDA version. In practice, what matters is:
- the CUDA version supported by your NVIDIA driver;
- the CUDA version expected by the PyTorch binary you install.
To inspect the driver-side CUDA capability, run:
1 | nvidia-smi |
You can also use:
1 | watch -n 0.1 nvidia-smi |
if you want a live view, but a single nvidia-smi is usually enough for installation.
Check whether PyTorch can see your GPU
1 | import torch |
If PyTorch still cannot access the GPU, try forcing a CUDA allocation to get a more explicit error message:
1 | torch.zeros(1).cuda() |
This often reveals whether the problem is caused by a driver mismatch, a missing CUDA-enabled build, or an environment issue.
Useful CUDA queries in PyTorch
1 | torch.cuda.current_device() # index of the current GPU |
JAX (CUDA build)
Official installation guide:
1 | https://docs.jax.dev/en/latest/installation.html |
JAX installation is more fragile than PyTorch installation. In practice, version compatibility matters at three levels:
- your codebase may depend on a specific
jax/jaxlibversion (Jax API changes frequently and some APIs used by your codebase may be deprecated in later versions); - that
jaxlibversion may require a specific CUDA stack; - some CUDA-related compatibility constraints are effectively minor-version-sensitive.
So do not assume that “latest JAX” is the correct choice for an existing project.
Before installation: clean the environment
Remove existing JAX / CUDA-related Python packages from the current environment first.
Check what is installed:
1 | pip list | grep -E 'jax|cuda|cudnn|nvidia' |
If the output is non-empty, remove them:
1 | pip uninstall -y $(pip list | grep -E 'jax|cuda|cudnn|nvidia' | awk '{print $1}') |
Avoid LD_LIBRARY_PATH interference
Recent JAX documentation recommends avoiding a pre-set LD_LIBRARY_PATH during installation and runtime, because it can cause XLA to pick up unexpected libraries.
1 | echo $LD_LIBRARY_PATH |
Be careful with PTX / CUDA package mismatches
A frequent failure mode is that one of the pip-installed CUDA components is too new for the driver stack on the machine.
In particular, watch out for packages such as:
1 | nvidia-cuda-nvcc-cu12 |
This package provides the PTX compiler. If its version is too new relative to the driver-supported CUDA stack, JAX/XLA compilation may fail.
For example, something like this is usually reasonable:
1 | driver CUDA version: 12.2 |
Whereas this kind of combination can fail:
1 | driver CUDA version: 12.2 |
The general rule is simple: do not mix an older driver stack with much newer pip-installed CUDA compiler components.
Check whether JAX runs correctly
1 | python - <<'PY' |
If this script prints a GPU device and finishes successfully, the installation is likely usable.
Practical advice
- For PyTorch, installation is usually straightforward once you choose the correct CUDA build.
- For JAX, always treat the environment as version-sensitive.
- For old projects, first determine the exact
jax/jaxlibversions expected by the codebase, then work backward to a compatible CUDA setup. - If installation fails, inspect both Python packages and the driver/toolchain boundary; many JAX failures are version skew problems rather than code problems.
- If you want to install Pytorch and Jax together in one environment, you MUST install CPU version of one of them since the current releases of PyTorch and JAX have incompatible CUDA version dependencies. (->Source)
Common Problems
Error:
1 | NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running. |
This happens often after the update of the kernel.
Solution:
Remove current cuda driver:
1
2
3
4
5# First of all, we need to remove all of the previous dependencies
sudo apt-get remove --purge '^nvidia-.*'
sudo apt-get purge nvidia-*
sudo apt-get update
sudo apt-get autoremoveReinstall CUDA.
Error:
1 | Could not load library libcudnn_cnn_infer.so.8. Error: libcuda.so: cannot open shared object file: No such file or directory |
Solution:
First, check my system to see if libcudnn_cnn_infer.so.8 exists:
Use
ldconfig -p | grep libcudnn_cnn, didn't return anything.Use
ldconfig -p | grep libcuda, returned:1
2
3
4
5
6
7
8
9
10libcudart.so.12 (libc6,x86-64) => /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12
libcudart.so.12 (libc6,x86-64) => /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudart.so.12
libcudart.so.12 (libc6,x86-64) => /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudart.so.12
libcudart.so (libc6,x86-64) => /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so
libcudart.so (libc6,x86-64) => /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudart.so
libcudart.so (libc6,x86-64) => /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudart.so
libcudadebugger.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcudadebugger.so.1
libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
libcuda.so.1 (libc6) => /lib/i386-linux-gnu/libcuda.so.1
libcuda.so (libc6) => /lib/i386-linux-gnu/libcuda.soStill didn't have
libcudnn_cnn_infer.so.8.
Make sure that cuda and cuda toolkit are installed. See the output of
1 | nvidia-smi |
and
1 | nvcc --version |
According to this reply, we just need to:
1 | conda activate <your env> |
1 | /miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash) |
Source:
It seems that the libtinfo shared library shipped with conda does not provide its version information. So it's a problem on their end.
I was able to workaround the problem by using another shared library of the same libtinfo.so with the same version as the one shipped with conda and that contains the version information. For instance in my case:
1 | rm ${CONDA_PREFIX}/lib/libtinfo* |
1 | AttributeError: 'NoneType' object has no attribute 'glGetError' |
You need to:
1 | pip install --upgrade pyrender |
https://github.com/50ButtonsEach/fliclib-linux-dist/issues/44
1 | strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX_3.4. |
https://askubuntu.com/questions/1418016/glibcxx-3-4-30-not-found-in-conda-environment
First, see
1 | strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX |
Then see
1 | strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX_3.4. |
and
1 | strings /home/lyk/miniconda3/envs/recall2imagine/bin/../lib/libstdc++.so.6 | grep GLIBCXX_3.4. |
Use sym link
1 | ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /home/lyk/miniconda3/envs/recall2imagine/bin/../lib/libstdc++.so.6 |
1 | 2024-04-11 06:16:01.747343: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR |
Looks like the GPU is out of mem.
1 | libGL error: failed to load driver: swrast |
Solution:
Actually you need to reboot and wait for some time and it will recovered. I don't know why...
One solution is: (-->Source)
From the latter link:
According to online information, there is a problem with the libstdc++.so file in Anaconda (I use this commercial python distribution). It cannot be associated with the driver of the system, so we removed it and used the libstdc++ that comes with Linux. so creates a soft link there.
To solve this problem, run this in bash:
1 | cd miniconda3/{your env}/lib |
where $USER should be your own username.
If you have
1 | xcb_connection_has_error() returned true |
This won't affect your training. But you can use xvfb-run to avoid it.