HPC Cluster setup examples#

This guide includes some installation/setup instructions for NetKet using the two distributed computing modes (MPI and/or sharding). The guide assumes that you want to run on GPUs.

The guide is based on our own experience getting NetKet running on France’s Jean-Zay HPC cluster. If you struggle installing it on your own machine, do get in touch on Github.

Setting up MPI with CUDA-aware MPI#

We here assume that you want to use CUDA-aware MPI because that allows for maximal performance. If your cluster does not provide CUDA-aware MPI, the installation is usually simpler as you do not need to load all cuda modules, and can just install mpi4jax without build time isolation.

# Pick your environment name
ENV_NAME=jax_gpu_mpi_amd

module load anaconda-py3 # To create the environment. You might just load python or some equivalent
module load gcc/12.2.0   # Compilers necessary to compile mpi4py/jax. You might have to load something equivalent
                         # prefer recent compilers
# You must load the cuda-aware mpi versions, and all related cuda and cudnn libraries
module load cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda 

# Create the environment
conda create -y --name $ENV_NAME python=3.11 
conda activate $ENV_NAME

# Always update pip
pip install --upgrade pip

# Remove mpi4py and mpi4jax from build cache, to make sure you are building them anew
pip cache remove mpi4py
pip cache remove mpi4jax

# Install jax version that makes use of the local cuda version
pip install --upgrade "jax[cuda12_local]"
pip install --upgrade mpi4py cython
pip install --upgrade --no-build-isolation mpi4jax
pip install --upgrade netket 

Note that the pip cache remove is useful to avoid issues because pip caches previously built versions of mpi4py and mpi4jax. Also, if you have errors always create a brand new environment.

Running simulations with MPI on GPUS#

#SBATCH --job-name=test_mpi
#SBATCH --output=test_mpi_%j.txt
#SBATCH --hint=nomultithread  # Disable Hyperthreading

#SBATCH --ntasks=4
#SBATCH --cpus-per-task=5
#SBATCH --gres=gpu:4          # here you should insert the total number of gpus per node
#SBATCH --ntasks-per-node=4   # maximum number of tasks per node, should match the maximum 
                              # of GPUS per node.
#SBATCH --time=01:00:00

ENV_NAME=jax_gpu_mpi_amd

# Load the same packages you used during installation. In our case that is
module purge
module load gcc/12.2.0 anaconda-py3 
module load cuda/12.2.0 cudnn/9.2-v7.5.1.10 openmpi/4.1.5-cuda

# Load the conda environment or equivalent
conda activate $ENV_NAME

# This is to use fast direct gpu-to-gpu communication
# If you do not have CUDA-aware MPI, set this to 0 instead
export MPI4JAX_USE_CUDA_MPI=1
# This is not strictly needed, simply tells netket to forcefully use MPI
export NETKET_MPI=1
# This automatically assigns only 1 GPU per rank (MPI cannoot use more than 1)
# If you do not use this, you should make sure that every rank sees only 1 GPU.
export NETKET_MPI_AUTODETECT_LOCAL_GPU=1
# Tell Jax that we want to use GPUs. THis is generally not needed but can't hurt
export JAX_PLATFORM_NAME=gpu

srun python yourscript.py

Running simulations with Sharding on GPUS#

To run simulations with sharding on GPUs, just install the following packages in the environemnt:

netket
jax[cuda]

nothing else is needed: no MPI.

#SBATCH --job-name=test_mpi
#SBATCH --output=test_mpi_%j.txt
#SBATCH --hint=nomultithread  # Disable Hyperthreading

#SBATCH --ntasks=4
#SBATCH --cpus-per-task=5
#SBATCH --gres=gpu:4          # here you should insert the total number of gpus per node
#SBATCH --time=01:00:00

module purge

# Load the same packages you used during installation. In our case that is
module load gcc/12.2.0 anaconda-py3 

# Load the conda environment or equivalent
conda activate ENV_NAME

# Tell NetKet to use experimental sharding mode.
export NETKET_EXPERIMENTAL_SHARDING=1
# Tell Jax that we want to use GPUs. THis is generally not needed but can't hurt
export JAX_PLATFORM_NAME=gpu

srun python yourscript.py

And the script is structured as

import jax

jax.distributed.initialize()

print(jax.devices())
print(jax.local_devices())

import netket as nk

Cluster-specific informations#

This is a sparse collection of instructions written for some specific HPC clusters, which might help users getting started elsewhere.