Source code for netket.experimental.sampler.metropolis_pt

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Union
from functools import partial

import numpy as np

import jax
from jax import numpy as jnp

from netket import config
from netket.utils.types import PyTree, PRNGKeyT, Array
from netket.utils import struct, mpi
from netket.jax.sharding import with_samples_sharding_constraint, sharding_decorator

from netket.sampler import MetropolisSamplerState, MetropolisSampler
from netket.sampler.rules import LocalRule, ExchangeRule, HamiltonianRule

# Original C++ Implementation
# https://github.com/netket/netket/blob/1e187ae2b9d2aa3f2e53b09fe743e50763d04c9a/Sources/Sampler/metropolis_hastings_pt.hpp
# https://github.com/netket/netket/blob/1e187ae2b9d2aa3f2e53b09fe743e50763d04c9a/Sources/Sampler/metropolis_hastings_pt.cc
# Python port
# https://github.com/netket/netket/blob/87d469aa8c23f71c4838cf09d7ed7b87ff2ea01f/netket/legacy/sampler/numpy/metropolis_hastings_pt.py


class MetropolisPtSamplerState(MetropolisSamplerState):
    """
    State for the Metropolis Parallel Tempering sampler.

    Contains the usual quantities, as well as statistics about the paralel tempering.
    """

    beta: jnp.ndarray = None
    """The inverse temperatures of the different chains."""

    n_accepted_per_beta: jnp.ndarray = None
    """Total number of moves accepted per beta across all processes since the last reset."""
    beta_0_index: jnp.ndarray = None
    r"""Index of the position of the chain with :math:`\\beta=1`."""
    beta_position: jnp.ndarray = None
    r"""Averaged position of :math:`\\beta=1`."""
    beta_diffusion: jnp.ndarray = None
    """Average variance of the position of :math:`\\beta = 1`."""
    exchange_steps: int = 0
    """Number of exchanges between the different temperatures."""

    def __init__(
        self,
        σ: jnp.ndarray,
        rng: jnp.ndarray,
        rule_state: Optional[Any],
        beta: jnp.ndarray,
    ):
        n_chains, n_replicas = beta.shape

        self.beta = beta
        self.n_accepted_per_beta = jnp.zeros((n_chains, n_replicas), dtype=int)
        self.beta_0_index = jnp.zeros((n_chains,), dtype=int)
        self.beta_position = jnp.zeros((n_chains,), dtype=float)
        self.beta_diffusion = jnp.zeros((n_chains,), dtype=float)
        self.exchange_steps = jnp.zeros((), dtype=int)
        super().__init__(σ, rng, rule_state)
        self.n_accepted_proc = jnp.zeros(
            n_chains, dtype=int
        )  # correct shape is (n_chains,) and not (n_batches,)

    def __repr__(self):
        if self.n_steps > 0:
            acc_string = f"# accepted = {self.n_accepted}/{self.n_steps} ({self.acceptance * 100}%), "
        else:
            acc_string = ""

        text = (
            f"MetropolisPtSamplerState(# replicas = {self.beta.shape[-1]}, "
            + acc_string
            + f"rng state={self.rng}"
        )
        return text

    @property
    def normalized_diffusion(self):
        r"""
        Average variance of the position of :math:`\\beta = 1`.
        In the ideal case, this quantity should be of order ~[0.2, 1.0]
        """
        diffusion = jnp.sqrt(
            self.beta_diffusion / self.exchange_steps / self.beta.shape[-1]
        )
        out, _ = mpi.mpi_mean_jax(diffusion.mean())

        return out

    @property
    def normalized_position(self):
        r"""
        Average position of :math:`\\beta = 1`, normalized and centered around 0.
        """
        position = self.beta_position / float(self.beta.shape[-1] - 1) - 0.5
        out, _ = mpi.mpi_mean_jax(position.mean())

        return out


[docs] class MetropolisPtSampler(MetropolisSampler): """ Metropolis-Hastings with Parallel Tempering sampler. This sampler samples an Hilbert space, producing samples off a specific dtype. The samples are generated according to a transition rule that must be specified. """ n_replicas: int = struct.field(pytree_node=False, default=32) """ The number of replicas evolving with different temperatures for every _physical_ markov chain. The total number of chains evolved is :code:`n_chains * n_replicas`. """ _beta_sorted: jax.Array = None """ An internal cache for the user-specified betas, sorted. """ _beta_distribution: str = struct.field(pytree_node=False, default="linear") """ An internal for the user-specified distribution of betas. """
[docs] def __init__( self, *args, n_replicas: Optional[int] = None, betas: Optional[Union[str, jax.Array]] = "linear", **kwargs, ): r""" ``MetropolisSampler`` is a generic Metropolis-Hastings sampler using a transition rule to perform moves in the Markov Chain. The transition kernel is used to generate a proposed state :math:`s^\prime`, starting from the current state :math:`s`. The move is accepted with probability .. math:: A(s\rightarrow s^\prime) = \mathrm{min}\left (1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right), where the probability being sampled from is :math:`P(s)=β|M(s)|^p`. Here :math:`M(s)` is a user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`, :math:`β` is the temperature of the Markov Chain and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel. Args: hilbert: The hilbert space to sample rule: A `MetropolisRule` to generate random transitions from a given state as well as uniform random states. n_replicas: The number of different temperatures β for the sampling, must be even. (default : 32). betas: (Optional) Distribution or list of values of the temperatures β. For the distribution, possibility between "linear" for a linear distribution and "log" for a logarithmic one. For the explicit list of values, the length must be even and the value β=1 must obligatory be an element of betas, all other temperatures must be in (0,1]. (default : "lin", i.e. linear distribution between (0,1]). n_chains: The number of Markov Chain to be run in parallel on a single process. sweep_size: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). n_chains: The number of batches of the states to sample (default = 8) machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the states sampled (default = np.float32). """ if isinstance(betas, str): betas = betas.lower() if betas not in ["linear", "lin", "logarithmic", "log"]: raise ValueError( f""" To initialize the temperatures with a string, you must choose between "lin" for a linear distribution and "log" for a logarithmic distribution. Instead got "{betas}". """ ) if n_replicas is None: n_replicas = 32 # truncate to last 3 beta_distribution = betas[:3] betas = None elif isinstance(betas, Array) or isinstance(betas, list): # TODO: this defaults to float32/64 depending on what is enabled. # Might need to think about a better default here. betas = jnp.array(betas, dtype=float) if betas.ndim != 1: raise ValueError("betas must have exactly 1 dimension.") if n_replicas is not None: raise ValueError( """ Cannot specify the list of betas and n_replicas at the same time. The number of replicas will be inferred automatically from the length of the vector of inverse temperatures. """ ) # we need beta[0] = 1, so we sort and check that the temperatures are valid betas = jnp.sort(betas, descending=True) if not (jnp.isclose(betas[0], 1) and betas[-1] > 0): raise ValueError( rf"""The values for beta should be in (0,1] and obligatory contain beta=1, instead got [{jnp.min(betas):.2f},{jnp.max(betas):.8f}].""" ) beta_distribution = "custom" n_replicas = betas.shape[-1] else: raise TypeError( "`betas` must be a string or a vector of inverse temperatures." ) # verify the number of replicas if not ( isinstance(n_replicas, int) and n_replicas > 0 and np.mod(n_replicas, 2) == 0 ): raise ValueError( "n_replicas (or the length of `betas`) must be an even integer > 0." ) self.n_replicas = n_replicas self._beta_sorted = betas self._beta_distribution = beta_distribution super().__init__(*args, **kwargs)
@property def sorted_betas(self): """ The sorted values of the temperatures for each _physical_ markov chain. The first value is β = 1 and is the _physical_ temperature. """ if self._beta_sorted is not None: return self._beta_sorted else: if self._beta_distribution == "lin": return ( 1.0 - jnp.arange(self.n_replicas, dtype=jnp.float64) / self.n_replicas ) elif self._beta_distribution == "log": return -jnp.log( jnp.arange(1, self.n_replicas + 1, dtype=jnp.float64) / (self.n_replicas + 1) ) / jnp.log(self.n_replicas + 1) else: raise NotImplementedError(f"distribution: {self._beta_distribution}") def __repr__(sampler): return ( f"{type(sampler).__name__}(" + f"\n hilbert = {sampler.hilbert}," + f"\n rule = {sampler.rule}," + f"\n n_chains = {sampler.n_chains}," + f"\n n_replicas = {sampler.n_replicas}," + f"\n beta_distribution = {sampler._beta_distribution}," + f"\n sweep_size = {sampler.sweep_size}," + f"\n reset_chains = {sampler.reset_chains}," + f"\n machine_power = {sampler.machine_pow}," + f"\n dtype = {sampler.dtype}" + ")" ) @property def n_batches(self) -> int: r""" The batch size of the configuration $\sigma$ used by this sampler on this jax process. If you are not using MPI, this is equal to `n_chains * n_replicas`, but if you are using MPI this is equal to `n_chains_per_rank * n_replicas`. """ if config.netket_experimental_sharding: n_batches = self.n_chains else: n_batches, remainder = divmod(self.n_chains, mpi.n_nodes) if remainder != 0: raise RuntimeError( "The number of chains is not a multiple of the number of mpi ranks" ) return n_batches * self.n_replicas @partial(jax.jit, static_argnums=1) def _init_state( sampler, machine, parameters: PyTree, key: PRNGKeyT ) -> MetropolisPtSamplerState: key_state, key_rule, rng = jax.random.split(key, 3) rule_state = sampler.rule.init_state(sampler, machine, parameters, key_rule) σ = sampler.rule.random_state(sampler, machine, parameters, rule_state, rng) σ = with_samples_sharding_constraint(σ) beta = jnp.tile( sampler.sorted_betas, (sampler.n_batches // sampler.n_replicas, 1) ) return MetropolisPtSamplerState( σ=σ, rng=key_state, rule_state=rule_state, beta=beta, ) @partial(jax.jit, static_argnums=1) def _reset(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): state = super()._reset(machine, parameters, state) return state.replace( n_accepted_per_beta=jnp.zeros_like(state.n_accepted_per_beta), beta_position=jnp.zeros_like(state.beta_position), beta_diffusion=jnp.zeros_like(state.beta_diffusion), exchange_steps=jnp.zeros_like(state.exchange_steps), # beta=beta, # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int64), ) def _sample_next( sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState ): def loop_body(i, s): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s["key"], key1, key2, key3, key4 = jax.random.split(s["key"], 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # # beta = hcb.call( # cbi, # (i, s["beta"]), # result_shape=jax.ShapeDtypeStruct(s["beta"].shape, s["beta"].dtype), # ) beta = s["beta"] ## Usual Metropolis sampling σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s["σ"] ) proposal_log_prob = sampler.machine_pow * machine.apply(parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_batches,)) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1,)) * (proposal_log_prob - s["log_prob"] + log_prob_correction) ) else: do_accept = uniform < jnp.exp( beta.reshape((-1,)) * (proposal_log_prob - s["log_prob"]) ) # do_accept must match ndim of proposal and state (which is 2) s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"]) n_accepted_per_beta = s["n_accepted_per_beta"] + do_accept.reshape( (sampler.n_batches // sampler.n_replicas, sampler.n_replicas) ) s["log_prob"] = jax.numpy.where( do_accept.reshape(-1), proposal_log_prob, s["log_prob"] ) ## exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_batches // sampler.n_replicas,), ) # 0 or 1 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1) ) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = beta_row.at[idxs].set( beta_row[inn], unique_indices=True, indices_are_sorted=True ) proposed_beta = proposed_beta.at[inn].set( beta_row[idxs], unique_indices=True, indices_are_sorted=False ) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): # prob[idxs] = (beta_i - beta_j) log psi(x_i) # prob[inn] = (beta_j - beta_i) log psi(x_j) # so we have to add the log probabilities to get the right acceptance prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - s["beta"]) * s["log_prob"].reshape( (sampler.n_batches // sampler.n_replicas, sampler.n_replicas) ) prob_rescaled = jnp.exp(compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform( key4, shape=( sampler.n_batches // sampler.n_replicas, sampler.n_replicas // 2, ), ) # decide where to swap do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas) ) # concat along last dimension # roll if swap_order is odd do_swap = jax.vmap(jnp.where, in_axes=(0, 0, 0), out_axes=0)( swap_order == 0, do_swap, jnp.roll(do_swap, 1, axis=-1) ) # Do the swap where it has to be done new_beta = jax.numpy.where(do_swap, proposed_beta, beta) s["beta"] = new_beta swap_order = swap_order.reshape(-1) # we use shard_map to avoid the all-gather emitted by the batched jnp.take / indexing beta_0_moved = sharding_decorator(jax.vmap(jnp.take), (True, True))( do_swap, s["beta_0_index"] ) # flag saying if beta_0 should move proposed_beta_0_index = jnp.mod( s["beta_0_index"] + (-2 * jnp.mod(swap_order, 2) + 1) * (-2 * jnp.mod(s["beta_0_index"], 2) + 1), sampler.n_replicas, ) s["beta_0_index"] = jnp.where( beta_0_moved, proposed_beta_0_index, s["beta_0_index"] ) # swap acceptances swapped_n_accepted_per_beta = swap_rows(n_accepted_per_beta, idxs, inn) s["n_accepted_per_beta"] = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed s["exchange_steps"] += 1 delta = s["beta_0_index"] - s["beta_position"] s["beta_position"] = s["beta_position"] + delta / s["exchange_steps"] delta2 = s["beta_0_index"] - s["beta_position"] s["beta_diffusion"] = s["beta_diffusion"] + delta * delta2 return s new_rng, rng = jax.random.split(state.rng) s = { "key": rng, "σ": state.σ, "log_prob": sampler.machine_pow * machine.apply(parameters, state.σ).real, "beta": state.beta, # for logging "beta_0_index": state.beta_0_index, "n_accepted_per_beta": state.n_accepted_per_beta, "beta_position": state.beta_position, "beta_diffusion": state.beta_diffusion, "exchange_steps": state.exchange_steps, } s = jax.lax.fori_loop(0, sampler.sweep_size, loop_body, s) # we use shard_map to avoid the all-gather emitted by the batched jnp.take / indexing n_accepted_proc = sharding_decorator(jax.vmap(jnp.take), (True, True))( s["n_accepted_per_beta"], s["beta_0_index"] ) new_state = state.replace( rng=new_rng, σ=s["σ"], n_steps_proc=state.n_steps_proc + sampler.sweep_size * sampler.n_batches // sampler.n_replicas, beta=s["beta"], beta_0_index=s["beta_0_index"], beta_position=s["beta_position"], beta_diffusion=s["beta_diffusion"], exchange_steps=s["exchange_steps"], n_accepted_per_beta=s["n_accepted_per_beta"], n_accepted_proc=n_accepted_proc, ) σ_flat = new_state.σ σ = σ_flat.reshape((-1, sampler.n_replicas, σ_flat.shape[-1])) # we use shard_map to avoid the all-gather emitted by the batched jnp.take / indexing σ_new = sharding_decorator(partial(jnp.take_along_axis, axis=1), (True, True))( σ, s["beta_0_index"][:, None, None] ) σ_new = jax.lax.collapse(σ_new, 0, 2) # remove dummy replica dim return new_state, σ_new
[docs] def MetropolisLocalPt(hilbert, *args, **kwargs): r""" Sampler acting on one local degree of freedom. This sampler acts locally only on one local degree of freedom :math:`s_i`, and proposes a new state: :math:`s_1 \dots s^\prime_i \dots s_N`, where :math:`s^\prime_i \neq s_i`. The transition probability associated to this sampler can be decomposed into two steps: 1. One of the site indices :math:`i = 1\dots N` is chosen with uniform probability. 2. Among all the possible (:math:`m`) values that :math:`s_i` can take, one of them is chosen with uniform probability. For example, in the case of spin :math:`1/2` particles, :math:`m=2` and the possible local values are :math:`s_i = -1,+1`. In this case then :class:`MetropolisLocal` is equivalent to flipping a random spin. In the case of bosons, with occupation numbers :math:`s_i = 0, 1, \dots n_{\mathrm{max}}`, :class:`MetropolisLocal` would pick a random local occupation number uniformly between :math:`0` and :math:`n_{\mathrm{max}}`. Args: hilbert: The hilbert space to sample n_chains: The number of Markov Chain to be run in parallel on a single process. sweep_size: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). n_chains: The number of batches of the states to sample (default = 8) machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the states sampled (default = np.float32). """ return MetropolisPtSampler(hilbert, LocalRule(), *args, **kwargs)
[docs] def MetropolisExchangePt(hilbert, *args, clusters=None, graph=None, d_max=1, **kwargs): r""" This sampler acts locally only on two local degree of freedom :math:`s_i` and :math:`s_j`, and proposes a new state: :math:`s_1 \dots s^\prime_i \dots s^\prime_j \dots s_N`, where in general :math:`s^\prime_i \neq s_i` and :math:`s^\prime_j \neq s_j`. The sites :math:`i` and :math:`j` are also chosen to be within a maximum graph distance of :math:`d_{\mathrm{max}}`. The transition probability associated to this sampler can be decomposed into two steps: 1. A pair of indices :math:`i,j = 1\dots N`, and such that :math:`\mathrm{dist}(i,j) \leq d_{\mathrm{max}}`, is chosen with uniform probability. 2. The sites are exchanged, i.e. :math:`s^\prime_i = s_j` and :math:`s^\prime_j = s_i`. Notice that this sampling method generates random permutations of the quantum numbers, thus global quantities such as the sum of the local quantum numbers are conserved during the sampling. This scheme should be used then only when sampling in a region where :math:`\sum_i s_i = \mathrm{constant}` is needed, otherwise the sampling would be strongly not ergodic. Args: hilbert: The hilbert space to sample d_max: The maximum graph distance allowed for exchanges. n_chains: The number of Markov Chain to be run in parallel on a single process. sweep_size: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). n_chains: The number of batches of the states to sample (default = 8) machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the states sampled (default = np.float32). Examples: Sampling from a RBM machine in a 1D lattice of spin 1/2, using nearest-neighbours exchanges. >>> import pytest; pytest.skip("EXPERIMENTAL") >>> import netket as nk >>> import netket.sampler.metropolis_pt as mpt >>> >>> g=nk.graph.Hypercube(length=10,n_dim=2,pbc=True) >>> hi=nk.hilbert.Spin(s=0.5, N=g.n_nodes) >>> >>> # Construct a MetropolisExchange Sampler >>> sa = mpt.MetropolisExchangePt(hi, graph=g) >>> print(sa) MetropolisSampler(rule = ExchangeRule(# of clusters: 200), n_chains = 16, machine_power = 2, sweep_size = 100, dtype = <class 'numpy.float64'>) """ rule = ExchangeRule(clusters=clusters, graph=graph, d_max=d_max) return MetropolisPtSampler(hilbert, rule, *args, **kwargs)
def MetropolisHamiltonianPt(hilbert, hamiltonian, *args, **kwargs): r""" Sampling based on the off-diagonal elements of a Hamiltonian (or a generic Operator). In this case, the transition matrix is taken to be: .. math:: T( \mathbf{s} \rightarrow \mathbf{s}^\prime) = \frac{1}{\mathcal{N}(\mathbf{s})}\theta(|H_{\mathbf{s},\mathbf{s}^\prime}|), where :math:`\theta(x)` is the Heaviside step function, and :math:`\mathcal{N}(\mathbf{s})` is a state-dependent normalization. The effect of this transition probability is then to connect (with uniform probability) a given state :math:`\mathbf{s}` to all those states :math:`\mathbf{s}^\prime` for which the Hamiltonian has finite matrix elements. Notice that this sampler preserves by construction all the symmetries of the Hamiltonian. This is in generally not true for the local samplers instead. Args: machine: A machine :math:`\Psi(s)` used for the sampling. The probability distribution being sampled from is :math:`F(\Psi(s))`, where the function :math:`F(X)`, is arbitrary, by default :math:`F(X)=|X|^2`. hamiltonian: The operator used to perform off-diagonal transition. n_chains: The number of Markov Chain to be run in parallel on a single process. sweep_size: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom (n_visible). Examples: Sampling from a RBM machine in a 1D lattice of spin 1/2 >>> import pytest; pytest.skip("EXPERIMENTAL") >>> import netket as nk >>> import netket.sampler.metropolis_pt as mpt >>> >>> g=nk.graph.Hypercube(length=10,n_dim=2,pbc=True) >>> hi=nk.hilbert.Spin(s=0.5, N=g.n_nodes) >>> >>> # Transverse-field Ising Hamiltonian >>> ha = nk.operator.Ising(hilbert=hi, h=1.0, graph=g) >>> >>> # Construct a MetropolisExchange Sampler >>> sa = mpt.MetropolisHamiltonianPt(hi, hamiltonian=ha) >>> print(sa) MetropolisSampler(rule = HamiltonianRule(Ising(J=1.0, h=1.0; dim=100)), n_chains = 16, machine_power = 2, sweep_size = 100, dtype = <class 'numpy.float64'>) """ rule = HamiltonianRule(hamiltonian) return MetropolisPtSampler(hilbert, rule, *args, **kwargs)