Source code for netket.experimental.sampler.metropolis_pmap

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from typing import Callable, Union
from functools import partial

import jax
from jax import numpy as jnp

import numpy as np

from flax import linen as nn

from netket.utils import mpi, struct
from netket.utils.types import PyTree

from netket.sampler import SamplerState
from netket.sampler import MetropolisSamplerState, MetropolisSampler


class MetropolisSamplerPmapState(MetropolisSamplerState):
    """
    Sampler State for the `MetropolisSamplerPmap` sampler.

    Wraps `MetropolisSamplerState` and overrides a few properties to work
    with the ShardedDeviceArrays.
    """

    @property
    def n_steps(self) -> int:
        return self.n_steps_proc.sum() * mpi.n_nodes

    @property
    def n_accepted(self) -> int:
        """Total number of moves accepted across all processes since the last reset."""
        res, _ = mpi.mpi_sum_jax(self.n_accepted_proc.sum())
        return res

    def _repr_pretty_(self, p, cycle):
        super()._repr_pretty_(p, cycle)


[docs] class MetropolisSamplerPmap(MetropolisSampler): r""" Metropolis-Hastings sampler for an Hilbert space according to a specific transition rule where chains are split among the available devices (`jax.devices()`). This sampler is experimental. It's API might change without warnings in future NetKet releases. To parallelize on CPU, you should set the following environment variable before loading jax/NetKet, XLA_FLAGS="--xla_force_host_platform_device_count=XX", where XX is the number of desired cpu devices. The transition rule is used to generate a proposed state :math:`s^\prime`, starting from the current state :math:`s`. The move is accepted with probability .. math:: A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right) , where the probability being sampled from is :math:`P(s)=|M(s)|^p`. Here :math:`M(s)` is a user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`, and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel. The dtype of the sampled states can be chosen. """ _sampler_device: MetropolisSampler = struct.static_field()
[docs] def __init__(self, *args, n_chains_per_device=None, **kwargs): """ Constructs a Metropolis Sampler. Args: hilbert: The hilbert space to sample rule: A `MetropolisRule` to generate random transitions from a given state as well as uniform random states. sweep_size: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). reset_chains: If False the state configuration is not reset when reset() is called. n_chains: The total number of Markov Chain to be run in parallel on a the available devices. This will be rounded to the nearest multiple of `len(jax.devices())` n_chains_per_device: The number of chains to be run in parallel on one device. Cannot be specified if n_chains is also specified. machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the states sampled (default = np.float32). """ n_devices = len(jax.devices()) n_chains_undef = "n_chains" not in kwargs and "n_chains_per_rank" not in kwargs if not n_chains_undef and n_chains_per_device is not None: raise ValueError( "Cannot specify both n_chains/n_chains_per_rank and n_chains_per_device" ) args, kwargs = super().__pre_init__(*args, **kwargs) if n_chains_per_device is None: n_chains_per_device = int( max(np.ceil(kwargs["n_chains_per_rank"] / n_devices), 1) ) kwargs["n_chains_per_rank"] = n_chains_per_device * n_devices if kwargs["n_chains_per_rank"] != n_chains_per_device * n_devices: warnings.warn( f"Using {n_chains_per_device*n_devices} chains " f"({n_chains_per_device} chains on each of {n_devices} devices).", category=UserWarning, stacklevel=2, ) super().__init__(*args, **kwargs) n_chains_per_device = self.n_chains_per_rank // len(jax.devices()) self._sampler_device = MetropolisSampler( self.hilbert, n_chains_per_rank=n_chains_per_device, rule=self.rule, sweep_size=self.sweep_size, reset_chains=self.reset_chains, machine_pow=self.machine_pow, )
@property def n_chains_per_device(self): return self._sampler_device.n_chains_per_rank def _init_state(self, machine, parameters, key): key = jax.random.split(key, len(jax.devices())) state = _init_state_pmap(self._sampler_device, machine, parameters, key) state = MetropolisSamplerPmapState( σ=state.σ, rng=state.rng, rule_state=state.rule_state, n_steps_proc=state.n_steps_proc, n_accepted_proc=state.n_accepted_proc, ) return state def _reset(self, machine, parameters, state): state = _reset_pmap(self._sampler_device, machine, parameters, state) state = MetropolisSamplerPmapState( σ=state.σ, rng=state.rng, rule_state=state.rule_state, n_steps_proc=state.n_steps_proc, n_accepted_proc=state.n_accepted_proc, ) return state def _sample_next(self, machine, parameters, state): state, conf = _sample_next_pmap( self._sampler_device, machine, parameters, state ) state = MetropolisSamplerPmapState( σ=state.σ, rng=state.rng, rule_state=state.rule_state, n_steps_proc=state.n_steps_proc, n_accepted_proc=state.n_accepted_proc, ) return state, conf.reshape(-1, conf.shape[-1]) def _sample_chain( self, machine: Union[Callable, nn.Module], parameters: PyTree, state: SamplerState, chain_length: int, ) -> tuple[jnp.ndarray, SamplerState]: samples, state = _sample_chain_pmap( self._sampler_device, machine, parameters, state, chain_length ) state = MetropolisSamplerPmapState( σ=state.σ, rng=state.rng, rule_state=state.rule_state, n_steps_proc=state.n_steps_proc, n_accepted_proc=state.n_accepted_proc, ) return samples.reshape((-1,) + samples.shape[-2:]), state
@partial( jax.pmap, in_axes=(None, None, None, 0), out_axes=0, static_broadcasted_argnums=1 ) def _init_state_pmap(sampler, machine, parameters, key): return sampler._init_state(machine, parameters, key) @partial( jax.pmap, in_axes=(None, None, None, 0), out_axes=0, static_broadcasted_argnums=1 ) def _reset_pmap(sampler, machine, parameters, state): return sampler._reset(machine, parameters, state) @partial( jax.pmap, in_axes=(None, None, None, 0), out_axes=(0, 0), static_broadcasted_argnums=1, ) def _sample_next_pmap(sampler, machine, parameters, state): return sampler._sample_next(machine, parameters, state) @partial( jax.pmap, in_axes=(None, None, None, 0, None), out_axes=(0, 0), static_broadcasted_argnums=(1, 4), ) def _sample_chain_pmap(sampler, machine, parameters, state, chain_length): return sampler._sample_chain(machine, parameters, state, chain_length)