Parallelization#
NetKet normally only uses the jax default device jax.local_devices()[0]
to perform calculations, and ignores the others. This means that if you want to fully exploit your many CPU cores or several GPUs, you must resort to one of two parallelization strategies: MPI or Sharding.
MPI: Explicit parallelization by distributing the markov chains and samples across multiples nodes/devices. This is achieved by using MPI (with mpi4jax). When using MPI, netket/jax will only use the jax default device
jax.local_devices()[0]
on every rank, and you must ensure that this corresponds to different devices (either cores or GPUs).Sharding: native collective communication built into jax is jax’s preferred mode of distributing calculations, and is discussed in Jax Distributed Computation tutorial. This mode can be used both on a single node with many GPUs or many nodes with many GPUs.
Note
What should you use?
Getting MPI up and running on a SLURM HPC cluster can be complicated, and sharding is incredibly easy to setup: install jax and you are done!
However, sharding works well only for GPUs, and CPU support is an afterthought that performs terribly. Generally speaking, if you want to parallelize over many CPUs, you should use MPI, but if you want to use GPUs you should stick to sharding. We mainly only use CPU-based sharding for locally testing that our script will run before sending it to the cluster, but we never use it in production.
Sharding code is also much simpler to write and maintain for us, so in the future it will be the preferred mode. Be careful that some operators based on Numba do not work with sharding, but they can all be converted to a version that works well with it.
NetKet is written such that code that runs with sharding will also work with MPI, and vice-versa. The main thing you should be careful is when you save files to do so only on the master rank.
Chef’s suggestion:
Default |
MPI |
Sharding |
Sharding + distributed |
|
---|---|---|---|---|
1 CPU / 1 GPU |
✔️ |
|||
1 Node: MultiCPU |
✔️ |
🐢 |
||
1 Node: MultiGPU |
🤯 |
✔️ |
||
Distributed: CPU |
✔️ |
|||
Distributed: GPU |
🤯 |
✔️ |
Legend:
✔️ Recommended method
🐢 Sharding is slow
🤯 Hard to setup
MPI (mpi4jax)#
Requires that mpi4py
and mpi4jax
are installed, please refer to Installation#MPI.
Warning
Citing mpi4jax mpi4jax is developed by some researchers. If you use it, you should cite the relevant publication. See Citing NetKet.
When using netket
it is crucial to run Python with the same implementation and version of MPI that the mpi4py
module is compiled against.
If you encounter issues, you can check whether your MPI environment is set up properly by running:
$ mpirun -np 2 python3 -m netket.tools.check_mpi
mpi4py_available : True
mpi4jax_available : True
available_cpus (rank 0) : 12
n_nodes : 1
mpi4py | MPI version : (3, 1)
mpi4py | MPI library_version : Open MPI v4.1.0, package: Open MPI brew@BigSur Distribution, ident: 4.1.0, repo rev: v4.1.0, Dec 18, 2020
This should print some basic information about the MPI installation and, in particular, pick up the correct n_nodes
.
If you get the same output multiple times, each with n_nodes : 1
, this is a clear sign that your MPI setup is broken.
The tool above also reports the number of (logical) CPUs that might be subscribed by Jax on every independent MPI rank during linear algebra operations.
Be mindfull that Jax, in general, is like an invasive plant and tends to use all resources that he can access, and
the environment variables above might not prevent it from making use of the available_cpus
.
On Mac it is not possible to control this number.
On Linux it can be controlled using taskset
or --bind-to core
when using mpirun
.
Note
In the Cluster section of the documentation you can find some example setup instructions of MPI+NetKet on some clusters. Those setups are intended for GPUs, and the CPU setting is much simpler as it does not need to include CUDA.