netket.jax.mpi_split#
- netket.jax.mpi_split(key, *, root=0, comm=None)[source]#
Split a key across MPI nodes in the communicator. Only the input key on the root process matters.
- Parameters:
key – The key to split. Only considered the one on the root process.
root – (default=0) The root rank from which to take the input key.
comm – (default=MPI.COMM_WORLD) The MPI communicator.
- Return type:
- Returns:
A PRNGKey depending on rank number and key.