Multiple CPUs (MPI)#
Warning
Experimental feature: This MPI backend is experimental and CPU-only - it should not be used with GPUs. For GPU-based distributed computing, use the standard JAX distributed setup described in the multi-node multi-GPU guide. Please let us know if this works for your use case.
The MPI backend can be used as an alternative to GLOO for multi-process CPU communication across single or multiple nodes, potentially offering better performance than the default CPU collective implementations. This is useful when you need to run distributed NetKet calculations on CPU-only clusters.
Use the MPI backend when:
You need distributed computing across CPU-only nodes (single or multiple nodes)
You want MPI-level performance for CPU communication
GPUs are not available or not desired
Installation#
You need to have MPI installed on your system. Note: OpenMPI 5 is not supported, so on macOS use mpich instead.
For detailed installation instructions, see the mpibackend4jax repository.
# Using uv (recommended)
uv add git+https://github.com/mpi4jax/mpibackend4jax
# Or using pip
pip install git+https://github.com/mpi4jax/mpibackend4jax
Detailed Installation Instructions
If you need help with MPI installation:
On Ubuntu/Debian:
sudo apt-get install libmpich-dev mpich
On macOS:
brew install mpich
Then install mpibackend4jax as shown above.
Usage with NetKet#
# Import mpibackend4jax BEFORE importing JAX
import mpibackend4jax as _mpi4jax # noqa: F401
import jax
import netket as nk
# Initialize distributed computing
jax.distributed.initialize()
# Verify setup
print(f"[{jax.process_index()}/{jax.process_count()}] MPI setup complete", flush=True)
print(f"[{jax.process_index()}/{jax.process_count()}] Devices: {jax.local_devices()}", flush=True)
# Define quantum system
L = 20
g = nk.graph.Hypercube(length=L, n_dim=1, pbc=True)
hi = nk.hilbert.Spin(s=1/2, N=g.n_nodes)
# Create Hamiltonian
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
# Define neural network model
model = nk.models.RBM(alpha=1)
# Create variational state
vs = nk.vqs.MCState(
sampler=nk.sampler.MetropolisLocal(hi),
model=model,
n_samples=1024
)
# Set up optimizer and driver
opt = nk.optimizer.Sgd(learning_rate=0.01)
gs = nk.driver.VMC(ha, opt, variational_state=vs)
# Run optimization
if jax.process_index() == 0:
print("Starting MPI-based VMC optimization...", flush=True)
gs.run(n_iter=100, out='mpi_result')
if jax.process_index() == 0:
print("Optimization complete!", flush=True)
Running with MPI#
Launch your script using the MPI launcher:
# Single node with 4 processes
mpirun -n 4 python your_netket_script.py
# Multiple nodes (example with 2 nodes, 4 processes per node)
mpirun -np 8 --hostfile hostfile python your_netket_script.py