How to Install Deep Learning Frameworks (Pytorch and Jax)

Sources:

  1. Pytorch GPU Setup Guide

For CUDA installation, see my How to Install and Setup CUDA post.

How to Install Deep Learning Frameworks (PyTorch and JAX)

PyTorch (CUDA build)

The following command often installs a CPU-only build of PyTorch:

1
conda install torch

If you want GPU support, install a PyTorch build compiled against a CUDA version that is compatible with your NVIDIA driver.

Choose the correct CUDA build

Go to the official PyTorch install page and select the command for your platform and package manager:

1
https://pytorch.org/get-started/locally/

A common source of confusion is the CUDA version. In practice, what matters is:

  • the CUDA version supported by your NVIDIA driver;
  • the CUDA version expected by the PyTorch binary you install.

To inspect the driver-side CUDA capability, run:

1
nvidia-smi

You can also use:

1
watch -n 0.1 nvidia-smi

if you want a live view, but a single nvidia-smi is usually enough for installation.

Check whether PyTorch can see your GPU

1
2
3
4
5
6
7
import torch

if torch.cuda.is_available():
print(torch.cuda.device_count(), "GPU(s) available.")
print("Current GPU:", torch.cuda.get_device_name(0))
else:
print("No GPU available. Falling back to CPU.")

If PyTorch still cannot access the GPU, try forcing a CUDA allocation to get a more explicit error message:

1
torch.zeros(1).cuda()

This often reveals whether the problem is caused by a driver mismatch, a missing CUDA-enabled build, or an environment issue.

Useful CUDA queries in PyTorch

1
2
3
torch.cuda.current_device()     # index of the current GPU
torch.cuda.get_device_name(0) # device name
torch.cuda.device_count() # number of visible GPUs

JAX (CUDA build)

Official installation guide:

1
https://docs.jax.dev/en/latest/installation.html

JAX installation is more fragile than PyTorch installation. In practice, version compatibility matters at three levels:

  1. your codebase may depend on a specific jax / jaxlib version (Jax API changes frequently and some APIs used by your codebase may be deprecated in later versions);
  2. that jaxlib version may require a specific CUDA stack;
  3. some CUDA-related compatibility constraints are effectively minor-version-sensitive.

So do not assume that “latest JAX” is the correct choice for an existing project.

Before installation: clean the environment

Remove existing JAX / CUDA-related Python packages from the current environment first.

Check what is installed:

1
pip list | grep -E 'jax|cuda|cudnn|nvidia'

If the output is non-empty, remove them:

1
pip uninstall -y $(pip list | grep -E 'jax|cuda|cudnn|nvidia' | awk '{print $1}')

Avoid LD_LIBRARY_PATH interference

Recent JAX documentation recommends avoiding a pre-set LD_LIBRARY_PATH during installation and runtime, because it can cause XLA to pick up unexpected libraries.

1
2
3
4
echo $LD_LIBRARY_PATH

# If it is non-empty:
unset LD_LIBRARY_PATH

Be careful with PTX / CUDA package mismatches

A frequent failure mode is that one of the pip-installed CUDA components is too new for the driver stack on the machine.

In particular, watch out for packages such as:

1
nvidia-cuda-nvcc-cu12

This package provides the PTX compiler. If its version is too new relative to the driver-supported CUDA stack, JAX/XLA compilation may fail.

For example, something like this is usually reasonable:

1
2
driver CUDA version: 12.2
PTX compiler version: 12.1.105

Whereas this kind of combination can fail:

1
2
driver CUDA version: 12.2
PTX compiler version: 12.9.86

The general rule is simple: do not mix an older driver stack with much newer pip-installed CUDA compiler components.

Check whether JAX runs correctly

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
python - <<'PY'
import jax
import jax.numpy as jnp

print("jax", jax.__version__)
print("jaxlib", jax.lib.__version__)
print("devices", jax.devices())

@jax.jit
def f(x):
return (x @ x).sum()

x = jnp.ones((1024, 1024), dtype=jnp.float32)
print(f(x).block_until_ready())
PY

If this script prints a GPU device and finishes successfully, the installation is likely usable.


Practical advice

  • For PyTorch, installation is usually straightforward once you choose the correct CUDA build.
  • For JAX, always treat the environment as version-sensitive.
  • For old projects, first determine the exact jax / jaxlib versions expected by the codebase, then work backward to a compatible CUDA setup.
  • If installation fails, inspect both Python packages and the driver/toolchain boundary; many JAX failures are version skew problems rather than code problems.
  • If you want to install Pytorch and Jax together in one environment, you MUST install CPU version of one of them since the current releases of PyTorch and JAX have incompatible CUDA version dependencies. (->Source)

Common Problems

Error:

1
NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.

This happens often after the update of the kernel.

Solution:

  1. Remove current cuda driver:

    1
    2
    3
    4
    5
    # First of all, we need to remove all of the previous dependencies
    sudo apt-get remove --purge '^nvidia-.*'
    sudo apt-get purge nvidia-*
    sudo apt-get update
    sudo apt-get autoremove
  2. Reinstall CUDA.


Error:

1
Could not load library libcudnn_cnn_infer.so.8. Error: libcuda.so: cannot open shared object file: No such file or directory

Solution:

First, check my system to see if libcudnn_cnn_infer.so.8 exists:

  1. Use ldconfig -p | grep libcudnn_cnn, didn't return anything.

  2. Use ldconfig -p | grep libcuda, returned:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    libcudart.so.12 (libc6,x86-64) => /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12
    libcudart.so.12 (libc6,x86-64) => /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudart.so.12
    libcudart.so.12 (libc6,x86-64) => /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudart.so.12
    libcudart.so (libc6,x86-64) => /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so
    libcudart.so (libc6,x86-64) => /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudart.so
    libcudart.so (libc6,x86-64) => /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudart.so
    libcudadebugger.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcudadebugger.so.1
    libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
    libcuda.so.1 (libc6) => /lib/i386-linux-gnu/libcuda.so.1
    libcuda.so (libc6) => /lib/i386-linux-gnu/libcuda.so

    Still didn't have libcudnn_cnn_infer.so.8.

Make sure that cuda and cuda toolkit are installed. See the output of

1
nvidia-smi

and

1
nvcc --version

According to this reply, we just need to:

1
2
3
4
5
conda activate <your env>
mkdir -p $CONDA_PREFIX/etc/conda/activate.d
echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh

1
/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)

Source:

It seems that the libtinfo shared library shipped with conda does not provide its version information. So it's a problem on their end.

I was able to workaround the problem by using another shared library of the same libtinfo.so with the same version as the one shipped with conda and that contains the version information. For instance in my case:

1
2
rm ${CONDA_PREFIX}/lib/libtinfo*
ln -s /lib/x86_64-linux-gnu/libtinfo.so.6 ${CONDA_PREFIX}/lib/libtinfo.so.6

1
AttributeError: 'NoneType' object has no attribute 'glGetError'

You need to:

1
2
3
pip install --upgrade pyrender
conda install libgcc
apt-get install python3-opengl

https://github.com/50ButtonsEach/fliclib-linux-dist/issues/44

1
strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX_3.4.

https://askubuntu.com/questions/1418016/glibcxx-3-4-30-not-found-in-conda-environment

First, see

1
strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX

Then see

1
strings /usr/lib/x86_64-linux-gnu/libstdc++.so.6 | grep GLIBCXX_3.4.

and

1
strings /home/lyk/miniconda3/envs/recall2imagine/bin/../lib/libstdc++.so.6 | grep GLIBCXX_3.4.

Use sym link

1
ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /home/lyk/miniconda3/envs/recall2imagine/bin/../lib/libstdc++.so.6

1
2
2024-04-11 06:16:01.747343: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-04-11 06:16:01.747442: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 8877244416 bytes free, 51033931776 bytes total.

Looks like the GPU is out of mem.


1
2
3
4
5
6
7
libGL error: failed to load driver: swrast
X Error of failed request: BadValue (integer parameter out of range for operation)
Major opcode of failed request: 152 (GLX)
Minor opcode of failed request: 3 (X_GLXCreateContext)
Value in failed request: 0x0
Serial number of failed request: 890
Current serial number in output stream: 891

Solution:

Actually you need to reboot and wait for some time and it will recovered. I don't know why...

One solution is: (-->Source)

From the latter link:

According to online information, there is a problem with the libstdc++.so file in Anaconda (I use this commercial python distribution). It cannot be associated with the driver of the system, so we removed it and used the libstdc++ that comes with Linux. so creates a soft link there.

To solve this problem, run this in bash:

1
2
3
4
5
6
cd miniconda3/{your env}/lib
mkdir backup # Create a new folder to keep the original libstdc++
mv libstd* backup # Put all libstdc++ files into the folder, including soft links
cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ./ # Copy the c++ dynamic link library of the system here
ln -s libstdc++.so.6 libstdc++.so
ln -s libstdc++.so.6 libstdc++.so.6.0.19

where $USER should be your own username.


If you have

1
2
xcb_connection_has_error() returned true
xcb_connection_has_error() returned true

This won't affect your training. But you can use xvfb-run to avoid it.