Multiple nodes (GPUs)#
Setting up JAX distributed#
Warning
Important: The JAX distributed setup details may change over time. This documentation was written in July 2025.
For multi-node calculations, initialize JAX’s distributed computing at the beginning of your script. The function jax.distributed.initialize() (see documentation) works automatically with SLURM, but if you’re not using SLURM, there are many parameters that need to be specified manually.
Always verify that JAX correctly detects all processes and devices by adding print statements as shown below. This is extremely useful to ensure that your code is being distributed correctly and to debug possible deadlocks in the initialize function. As a general suggestion, running the script below can be useful to verify that everything works correctly.
import jax
# Initialize distributed JAX (works automatically with SLURM)
jax.distributed.initialize()
# Always print this to verify correct setup
print(f"[{jax.process_index()}/{jax.process_count()}] devices:", jax.devices(), flush=True)
print(f"[{jax.process_index()}/{jax.process_count()}] local devices:", jax.local_devices(), flush=True)
print(f"I will be running a calculation among {jax.process_count()} tasks, "
f"using a total of {len(jax.devices())} devices "
f"({len(jax.local_devices())} per slurm task). "
f"If this does not match your expected number of total devices, "
f"something is misconfigured", flush=True)
import netket as nk
# ... rest of your code
Warning
Checking that the setup is correct
A common mistake is that different nodes do not communicate with each other, resulting in your calculation running independently on each node instead of as a single distributed computation.
To verify that your distributed setup is working correctly, check that:
The code above prints
[i/N]for each process, whereigoes from 0 to N-1 and N is the total number of processes across all nodesjax.local_devices()shows the GPU(s) visible to each individual processjax.devices()shows all GPUs across all nodes (total should be N_gpus_per_node × N_nodes)The diagnostic message shows the expected total number of devices for your cluster configuration
SLURM job script example#
When using jax.distributed.initialize() with multiple nodes, JAX assumes there is 1 task per GPU. Therefore, you must set --ntasks-per-node equal to the number of GPUs per node. In the example above, we assume nodes with 4 GPUs.
The --gres=gpu:N option works correctly with JAX’s automatic detection, but --gpus-per-task does not work and should be avoided.
Here’s a typical SLURM script for running NetKet across multiple nodes:
#!/bin/bash
#SBATCH --job-name=netket-distributed
#SBATCH --output=netket_%j.txt
#SBATCH --nodes=2 # Number of nodes
#SBATCH --ntasks-per-node=4 # Tasks per node (must equal GPUs per node for jax.distributed.initialize())
#SBATCH --cpus-per-task=5 # CPU cores per task
#SBATCH --gres=gpu:4 # GPUs per node
#SBATCH --time=02:00:00
module purge
# Force JAX to use GPUs, and fail if he cannot use them.
export JAX_PLATFORM_NAME=gpu
# Launch with srun (SLURM's parallel launcher)
srun uv run python your_netket_script.py
It is possible to use different SLURM configurations, but they must be configured manually. For more details, refer to jax.distributed.initialize() and the JAX clusters folder in the JAX repository.
Note that jax automatically installs a copy of CUDA in the python environment, so you don’t need to install your own or load the relevant modules. This way, CUDA is always updated and you are sure of using the correct version. It is best to avoid loading the cluster-provided cuda versions.
Working with distributed arrays#
When working with distributed computing, special care is needed for I/O operations and array printing.
Safe printing of distributed arrays#
Distributed arrays need to be replicated before printing to avoid errors:
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding
# Example distributed array
mesh = jax.distributed.get_abstract_mesh()
sharded_spec = NamedSharding(mesh, jax.sharding.PartitionSpec('S'))
replicated_spec = NamedSharding(mesh, jax.sharding.PartitionSpec())
distributed_array = jax.device_put(jnp.ones((100, 50)), sharded_spec)
# WRONG: This may cause errors or incomplete output
# print(distributed_array)
# CORRECT: Replicate before printing
replicated_array = jax.lax.with_sharding_constraint(
distributed_array, replicated_spec
)
if jax.process_index() == 0:
print("Array shape:", replicated_array.shape)
print("Array values:", replicated_array)
Safe file I/O#
Note
NetKet’s built-in loggers handle distributed I/O correctly automatically. The pattern below is needed only when writing custom I/O code.
When saving data, ensure only one process writes to avoid conflicts:
import jax
import jax.numpy as jnp
import numpy as np
def save_distributed_array(array, filename):
"""Safely save a distributed array to file"""
# First unshard the data by getting it from devices
array_local = jax.device_get(array)
# Then convert to numpy on all processes
array_np = np.array(array_local)
# Only process 0 writes to file
if jax.process_index() == 0:
np.save(filename, array_np)
print(f"Saved array to {filename}", flush=True)
# Example usage
distributed_result = some_netket_calculation()
save_distributed_array(distributed_result, "results.npy")
Testing distributed code locally#
NetKet provides the djaxrun utility for testing multi-process distributed code locally. This simulates multi-node execution and is strongly recommended for testing scripts before submitting to HPC clusters:
# Test with 2 processes (simulating 2 nodes)
djaxrun --simple -np 2 python3 your_script.py
# Test with 4 processes
djaxrun --simple -np 4 python3 your_script.py
The djaxrun utility works well with CPUs but requires careful GPU management for multi-GPU testing:
# For GPU testing, ensure each process sees only one GPU
import os
import jax
# Set GPU visibility based on process rank
if 'LOCAL_RANK' in os.environ:
local_rank = int(os.environ['LOCAL_RANK'])
os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank)
jax.distributed.initialize()
print(f"Process {jax.process_index()}: {jax.local_devices()}")
import netket as nk
# ... your NetKet code
Troubleshooting#
Common issues#
No output or deadlock: If you don’t see any messages and no prints of local devices,
jax.distributed.initialize()is likely deadlocked because nodes cannot communicate. This happens on some clusters due to security restrictions blocking communication among compute nodes. See the GRPC incompatibility with HTTP proxy wildcards section in 🔪 The Sharp Bits 🔪 for solutions.Initialization hangs: Check that your hostname resolves correctly (
ping $(hostname))GPU not detected: Verify
nvidia-smishows GPUs and CUDA is properly installedSlow communication: Sometimes communication might be very slow. For detailed communication debugging, set
export NCCL_DEBUG=INFO
Complete example#
Here’s a complete example showing distributed NetKet usage:
import jax
import netket as nk
# Initialize distributed computing
jax.distributed.initialize()
# Verify setup
print(f"[{jax.process_index()}/{jax.process_count()}] devices:", jax.devices(), flush=True)
print(f"[{jax.process_index()}/{jax.process_count()}] local devices:", jax.local_devices(), flush=True)
print(f"I will be running a calculation among {jax.process_count()} tasks, "
f"using a total of {len(jax.devices())} devices "
f"({len(jax.local_devices())} per slurm task). "
f"If this does not match your expected number of total devices, "
f"something is misconfigured", flush=True)
# Define your quantum system
L = 20
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
hi = nk.hilbert.Spin(s=1/2, N=g.n_nodes)
# Create Hamiltonian
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
# Define neural network model
model = nk.models.RBM(alpha=1)
# Create variational state
vs = nk.vqs.MCState(
sampler=nk.sampler.MetropolisLocal(hi),
model=model,
n_samples=1024
)
# Set up optimizer and driver
opt = nk.optimizer.Sgd(learning_rate=0.01)
gs = nk.driver.VMC(ha, opt, variational_state=vs)
# Run optimization
if jax.process_index() == 0:
print("Starting VMC optimization...", flush=True)
gs.run(n_iter=300, out='distributed_result')
if jax.process_index() == 0:
print("Optimization complete!", flush=True)
print(f"Final energy: {gs.energy}", flush=True)
This setup will automatically distribute the computation across all available devices and handle the distributed sampling and optimization efficiently.