netket.jax.mpi_split

Contents

netket.jax.mpi_split#

netket.jax.mpi_split(key, *, root=0, comm=None)[source]#

Split a key across MPI nodes in the communicator. Only the input key on the root process matters.

Parameters:
  • key – The key to split. Only considered the one on the root process.

  • root – (default=0) The root rank from which to take the input key.

  • comm – (default=MPI.COMM_WORLD) The MPI communicator.

Return type:

Any

Returns:

A PRNGKey depending on rank number and key.