Source code for netket.jax._utils_random

# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import jax

from netket.utils import random_seed, mpi
from netket.utils.mpi import MPI_jax_comm
from netket.utils.types import PRNGKeyT, SeedT

[docs] def PRNGKey( seed: Optional[SeedT] = None, *, root: int = 0, comm=MPI_jax_comm ) -> PRNGKeyT: """ Initialises a PRNGKey using an optional starting seed. The same seed will be distributed to all processes. """ if seed is None: key = jax.random.PRNGKey(random_seed()) elif isinstance(seed, int): key = jax.random.PRNGKey(seed) else: key = seed key = jax.tree_util.tree_map( lambda k: mpi.mpi_bcast_jax(k, root=root, comm=comm)[0], key ) return key
[docs] class PRNGSeq: """ A sequence of PRNG keys generated based on an initial key. """ def __init__(self, base_key: Optional[SeedT] = None): if base_key is None: base_key = PRNGKey() elif isinstance(base_key, int): base_key = PRNGKey(base_key) self._current = base_key def __iter__(self): return self def __next__(self): self._current = jax.random.split(self._current, num=1)[0] return self._current
[docs] def next(self): return self.__next__()
[docs] def take(self, num: int): """ Returns an array of `num` PRNG keys and advances the iterator accordingly. """ keys = jax.random.split(self._current, num=num + 1) self._current = keys[-1] return keys[:-1]
[docs] def mpi_split(key, *, root=0, comm=MPI_jax_comm) -> PRNGKeyT: """ Split a key across MPI nodes in the communicator. Only the input key on the root process matters. Arguments: key: The key to split. Only considered the one on the root process. root: (default=0) The root rank from which to take the input key. comm: (default=MPI.COMM_WORLD) The MPI communicator. Returns: A PRNGKey depending on rank number and key. """ # Maybe add error/warning if in_key is not the same # on all MPI nodes? keys = jax.random.split(key, mpi.n_nodes) keys = jax.tree_util.tree_map(lambda k: mpi.mpi_bcast_jax(k, root=root)[0], keys) return keys[mpi.rank]
def batch_choice(key, a, p): """ Batched version of `jax.random.choice`. Attributes: key: a PRNGKey used as the random key. a: 1D array. Random samples are generated from its elements. p: 2D array of shape `(batch_size, a.size)`. Each slice `p[i, :]` is the probabilities associated with entries in `a` to generate a sample at the index `i` of the output. Can be unnormalized. Returns: The generated samples as an 1D array of shape `(batch_size,)`. """ p_cumsum = p.cumsum(axis=1) r = p_cumsum[:, -1:] * jax.random.uniform(key, shape=(p.shape[0], 1)) indices = (r > p_cumsum).sum(axis=1) out = a[indices] return out