Source code for netket.nn.utils

# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Callable, Optional

import jax
from jax import numpy as jnp
import numpy as np
from math import prod

from netket import jax as nkjax
from netket.utils import get_afun_if_module, mpi
from netket.utils.types import Array, PyTree
from netket.hilbert import DiscreteHilbert

from netket.utils import config
from netket.jax.sharding import (

from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core import unfreeze

def split_array_mpi(array: Array) -> Array:
    Splits the first dimension of the input array among mpi processes.
    Works like `mpi.scatter`, but assumes that the input array is available and
    identical on all ranks.
    !!! Warn
         The output is a numpy array.
    !!! Warn
         This should not be used with sharding (netket.netket_experimental_sharding=True)
         array: A nd-array

        A numpy array, of potentially different state on every mpi rank.

    if mpi.n_nodes > 1:
        n_states = array.shape[0]
        states_n = np.arange(n_states)

        # divide the hilbert space in chunks for each node
        states_per_rank = np.array_split(states_n, mpi.n_nodes)

        return array[states_per_rank[mpi.rank]]
        return array

[docs]def to_array( hilbert: DiscreteHilbert, apply_fun: Callable[[PyTree, Array], Array], variables: PyTree, *, normalize: bool = True, allgather: bool = True, chunk_size: Optional[int] = None, ) -> Array: """ Computes `apply_fun(variables, states)` on all states of `hilbert` and returns the results as a vector. Args: normalize: If True, the vector is normalized to have L2-norm 1. allgather: When running with MPI: If True, the final wave function is stored in full at all MPI ranks. When running with netket_experimental_sharding=True: If allgather=True, the final wave function is a fully replicated array If allgather=False, the final wave function is a sharded array, padded with zeros to the next multiple of the number of devices chunk_size: Optional integer to specify the largest chunks of samples that the model will be evaluated upon. By default it is `None`, and when specified samples are split into chunks of at most `chunk_size`. Returns: """ if not hilbert.is_indexable: raise RuntimeError("The hilbert space is not indexable") apply_fun = get_afun_if_module(apply_fun) if config.netket_experimental_sharding: # for now assume no mpi (no hybrid) x = hilbert.all_states() xs, mask = distribute_to_devices_along_axis(x, pad=True, pad_value=x[0]) n_states = xs.shape[0] elif mpi.n_nodes == 1: xs = hilbert.all_states() mask = None n_states = xs.shape[0] else: # mpi4jax does not have (yet) allgatherv so we need to be creative # could be made easier if we update mpi4jax n_states = hilbert.n_states n_states_padded = int(np.ceil(n_states / mpi.n_nodes)) * mpi.n_nodes states_n = np.arange(n_states) fake_states_n = np.arange(n_states_padded - n_states) # divide the hilbert space in chunks for each node states_per_rank = np.split( np.concatenate([states_n, fake_states_n]), mpi.n_nodes ) xs = hilbert.numbers_to_states(states_per_rank[mpi.rank]) mask = None psi = _to_array_rank( apply_fun, variables, xs, n_states, normalize, allgather, chunk_size, mask, ) if allgather and config.netket_experimental_sharding: # for simplicity we gather here outside of jit # alternatively we could use a sharding constraint in _to_array_rank psi = gather(psi) # make it a local numpy array, so that we can operate with e.g. # a sparse scipy array on it and jax thinks its replicated next time we pass it to jit psi = np.asarray(extract_replicated(psi)) psi = psi[: hilbert.n_states] return psi
@partial(jax.jit, static_argnums=(0, 3, 4, 5, 6)) def _to_array_rank( apply_fun, variables, σ_rank, n_states, normalize, allgather, chunk_size, mask=None, ): """ Computes apply_fun(variables, σ_rank) and gathers all results across all ranks. The input σ_rank should be a slice of all states in the hilbert space of equal length across all ranks because mpi4jax does not support allgatherv (yet). Args: n_states: total number of elements in the hilbert space. """ if chunk_size is not None: apply_fun = nkjax.apply_chunked( apply_fun, in_axes=(None, 0), chunk_size=chunk_size ) # number of 'fake' states, in the last rank. n_fake_states = σ_rank.shape[0] * mpi.n_nodes - n_states log_psi_local = apply_fun(variables, σ_rank) # last rank, get rid of fake elements if mpi.rank == mpi.n_nodes - 1 and n_fake_states > 0: log_psi_local =[-n_fake_states:].set(-jnp.inf) if normalize: # subtract logmax for better numerical stability logmax, _ = mpi.mpi_max_jax(log_psi_local.real.max()) log_psi_local -= logmax psi_local = jnp.exp(log_psi_local) if mask is not None: # when running under netket_experimental_sharding, # we pad the Hilbert space with extra fake entries, # which in here we mask out to 0 psi_local = psi_local * mask if normalize: # compute normalization norm2 = jnp.linalg.norm(psi_local) ** 2 norm2, _ = mpi.mpi_sum_jax(norm2) psi_local /= jnp.sqrt(norm2) if allgather: psi, _ = mpi.mpi_allgather_jax(psi_local) else: psi = psi_local psi = psi.reshape(-1) # remove fake states psi = psi[0:n_states] return psi
[docs]def to_matrix( hilbert: DiscreteHilbert, machine: Callable[[PyTree, Array], Array], params: PyTree, *, normalize: bool = True, chunk_size: Optional[int] = None, ) -> Array: if not hilbert.is_indexable: raise RuntimeError("The hilbert space is not indexable") psi = to_array(hilbert, machine, params, normalize=False, chunk_size=chunk_size) L = hilbert.physical.n_states rho = psi.reshape((L, L)) if normalize: trace = jnp.trace(rho) rho /= trace return rho
# TODO: Deprecate: remove def update_dense_symm(params, names=("dense_symm", "Dense")): """Updates DenseSymm kernels in pre-PR#1030 parameter pytrees to the new 3D convention. Args: params: a parameter pytree names: layer names search for, default: those used in RBMSymm and GCNN* """ params = unfreeze(params) # just in case, doesn't break with a plain dict def fix_one_kernel(args): path, array = args if ( len(path) > 1 and path[-2] in names and path[-1] == "kernel" and array.ndim == 2 ): array = jnp.expand_dims(array, 1) return (path, array) return unflatten_dict(dict(map(fix_one_kernel, flatten_dict(params).items()))) def _get_output_idx( shape: tuple[int, ...], max_bits: Optional[int] = None ) -> tuple[tuple[int, ...], int]: bits_per_local_occupation = tuple(np.ceil(np.log2(shape)).astype(int)) if max_bits is None: max_bits = max(bits_per_local_occupation) output_idx = [] offset = 0 for b in bits_per_local_occupation: output_idx.extend([i + offset for i in range(b)][::-1]) offset += max_bits output_idx = tuple(output_idx) return output_idx, max_bits def _separate_binary_indices( shape: tuple[int, ...] ) -> tuple[tuple[int, ...], tuple[int, ...]]: binary_indices = tuple([i for i in range(len(shape)) if shape[i] == 2]) non_binary_indices = tuple([i for i in range(len(shape)) if shape[i] != 2]) return binary_indices, non_binary_indices
[docs]@partial(jax.jit, static_argnames=("hilbert", "max_bits")) def binary_encoding( hilbert: DiscreteHilbert, x: Array, *, max_bits: Optional[int] = None, ) -> Array: """ Encodes the array `x` into a set of binary-encoded variables described by the shape of a Hilbert space. The i-th element of x will be encoded in {code}`ceil(log2(shape[i]))` bits. Args: hilbert: Hilbert space of the samples that are to be encoded. x: The array to encode. max_bits: The maximum number of bits to use for each element of `x`. """ x = hilbert.states_to_local_indices(x) shape = tuple(hilbert.shape) jax.core.concrete_or_error(None, shape, "Shape must be known statically") output_idx, max_bits = _get_output_idx(shape, max_bits) binarised_states = jnp.zeros( ( *x.shape, max_bits, ), dtype=x.dtype, ) binary_indices, non_binary_indices = _separate_binary_indices(shape) for i in non_binary_indices: substates = x[..., i].astype(int)[..., jnp.newaxis] binarised_states = ([..., i, :] .set( substates & 2 ** jnp.arange(binarised_states.shape[-1], dtype=int) != 0 ) .astype(x.dtype) ) for i in binary_indices: binarised_states =[..., i, 0].set(x[..., i]) return binarised_states.reshape( *binarised_states.shape[:-2], prod(binarised_states.shape[-2:]) )[..., output_idx]
[docs]def states_to_numbers(hilbert: DiscreteHilbert, σ: Array) -> Array: """ Converts the configuration σ to a 64-bit integer denoting its index in the full Hilbert space. This function calls `hilbert.states_to_numbers` as a JAX pure callback and can thus be used within `jax.jit`. Args: hilbert: The Hilbert space σ: A single or a batch of configurations Returns: a single integer or a batch of integer indices. """ if not hilbert.is_indexable: raise ValueError( f"Hilbert space {hilbert} is too large to be indexed or " f"cannot be indexed at all." ) # calls back into python return jax.pure_callback( hilbert.states_to_numbers, jax.ShapeDtypeStruct(σ.shape[:-1], jnp.int64), σ, vectorized=True, )