# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Callable, Optional, Union
from textwrap import dedent
import numpy as np
import jax
from flax import linen as nn
from flax import serialization
from jax import numpy as jnp
from netket.hilbert import AbstractHilbert, ContinuousHilbert
from netket.utils import mpi, wrap_afun
from netket.utils.types import PyTree, DType
from netket.utils.deprecation import warn_deprecation
from netket.utils import struct
from netket.utils.config_flags import config
from netket.jax.sharding import (
extract_replicated,
gather,
distribute_to_devices_along_axis,
device_count,
with_samples_sharding_constraint,
)
from .base import Sampler, SamplerState
from .rules import MetropolisRule
[docs]
class MetropolisSamplerState(SamplerState):
"""
State for a Metropolis sampler.
Contains the current configuration, the RNG state and the (optional)
state of the transition rule.
"""
σ: jnp.ndarray
"""Current batch of configurations in the Markov chain."""
rng: jnp.ndarray
"""State of the random number generator (key, in jax terms)."""
rule_state: Optional[Any]
"""Optional state of the transition rule."""
n_steps_proc: int = struct.field(default_factory=lambda: jnp.zeros((), dtype=int))
"""Number of moves performed along the chains in this process since the last reset."""
n_accepted_proc: jnp.ndarray
"""Number of accepted transitions among the chains in this process since the last reset."""
def __init__(self, σ: jnp.ndarray, rng: jnp.ndarray, rule_state: Optional[Any]):
self.σ = σ
self.rng = rng
self.rule_state = rule_state
self.n_accepted_proc = with_samples_sharding_constraint(
jnp.zeros(σ.shape[0], dtype=int)
)
self.n_steps_proc = jnp.zeros((), dtype=int)
super().__init__()
@property
def acceptance(self) -> float:
"""The fraction of accepted moves across all chains and MPI processes.
The rate is computed since the last reset of the sampler.
Will return None if no sampling has been performed since then.
"""
if self.n_steps == 0:
return None
return self.n_accepted / self.n_steps
@property
def n_steps(self) -> int:
"""Total number of moves performed across all processes since the last reset."""
return self.n_steps_proc * mpi.n_nodes
@property
def n_accepted(self) -> int:
"""Total number of moves accepted across all processes since the last reset."""
# jit sum for gda
res, _ = mpi.mpi_sum_jax(jax.jit(jnp.sum)(self.n_accepted_proc))
return res
def __repr__(self):
if self.n_steps > 0:
acc_string = f"# accepted = {self.n_accepted}/{self.n_steps} ({self.acceptance * 100}%), "
else:
acc_string = ""
return f"{type(self).__name__}({acc_string}rng state={self.rng})"
# serialization when sharded
def serialize_MetropolisSamplerState_sharding(sampler_state):
state_dict = MetropolisSamplerState._to_flax_state_dict(
MetropolisSamplerState._pytree__static_fields, sampler_state
)
for prop in ["σ", "n_accepted_proc"]:
x = state_dict.get(prop, None)
if x is not None and isinstance(x, jax.Array) and len(x.devices()) > 1:
state_dict[prop] = gather(x)
state_dict = extract_replicated(state_dict)
return state_dict
def deserialize_MetropolisSamplerState_sharding(sampler_state, state_dict):
for prop in ["σ", "n_accepted_proc"]:
x = state_dict[prop]
if x is not None:
state_dict[prop] = distribute_to_devices_along_axis(x)
return MetropolisSamplerState._from_flax_state_dict(
MetropolisSamplerState._pytree__static_fields, sampler_state, state_dict
)
if config.netket_experimental_sharding:
# when running on multiple jax processes the σ and n_accepted_proc are not fully addressable
# however, when serializing they need to be so here we register custom handlers which
# gather all the data to every process.
# when deserializing we distribute the samples again to all availale devices
# this way it is enough to serialize on process 0, and we can restart the simulation
# also on a different number of devices, provided the number of samples is still
# divisible by the new number of devices
serialization.register_serialization_state(
MetropolisSamplerState,
serialize_MetropolisSamplerState_sharding,
deserialize_MetropolisSamplerState_sharding,
override=True,
)
def _assert_good_sample_shape(samples, shape, dtype, obj=""):
canonical_dtype = jax.dtypes.canonicalize_dtype(dtype)
if samples.shape != shape or samples.dtype != canonical_dtype:
raise ValueError(
dedent(
f"""
The samples returned by the {obj} have `shape={samples.shape}` and
`dtype={samples.dtype}`, but the sampler requires `shape={shape} and
`dtype={canonical_dtype}` (canonicalized from {dtype}).
If you are using a custom transition rule, check that it returns the
correct shape and dtype.
If you are using a built-in transition rule, there might be a mismatch
between hilbert spaces, or it's a bug in NetKet.
"""
)
)
def _assert_good_log_prob_shape(log_prob, n_chains_per_rank, machine):
if log_prob.shape != (n_chains_per_rank,):
raise ValueError(
dedent(
f"""
The output of the model {machine} has `shape={log_prob.shape}`, but
`shape=({n_chains_per_rank},)` was expected.
This might be because of an hilbert space mismatch or because your
model is ill-configured.
"""
)
)
def _round_n_chains_to_next_multiple(
n_chains, n_chains_per_whatever, n_whatever, whatever_str
):
# small helper function to round the number of chains to the next multiple of [whatever]
# here [whatever] can be e.g. mpi ranks or jax devices
# if n_chains is None and n_chains_per_whatever is None:
# n_chains_per_whatever = default
if n_chains is not None and n_chains_per_whatever is not None:
raise ValueError(
f"Cannot specify both `n_chains` and `n_chains_per_{whatever_str}`"
)
elif n_chains is not None:
n_chains_per_whatever = max(int(np.ceil(n_chains / n_whatever)), 1)
if n_chains_per_whatever * n_whatever != n_chains:
if mpi.rank == 0:
import warnings
warnings.warn(
f"Using {n_chains_per_whatever} chains per {whatever_str} among {n_whatever} {whatever_str}s "
f"(total={n_chains_per_whatever * n_whatever} instead of n_chains={n_chains}). "
f"To directly control the number of chains on every {whatever_str}, specify "
f"`n_chains_per_{whatever_str}` when constructing the sampler. "
f"To silence this warning, either use `n_chains_per_{whatever_str}` or use `n_chains` "
f"that is a multiple of the number of {whatever_str}s",
category=UserWarning,
stacklevel=2,
)
return n_chains_per_whatever * n_whatever
[docs]
class MetropolisSampler(Sampler):
r"""
Metropolis-Hastings sampler for a Hilbert space according to a specific transition rule.
The transition rule is used to generate a proposed state :math:`s^\prime`, starting from the
current state :math:`s`. The move is accepted with probability
.. math::
A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right) ,
where the probability being sampled from is :math:`P(s)=|M(s)|^p`. Here :math:`M(s)` is a
user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`,
and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel.
The dtype of the sampled states can be chosen.
"""
rule: MetropolisRule = None
"""The Metropolis transition rule."""
sweep_size: int = struct.field(pytree_node=False, default=None)
"""Number of sweeps for each step along the chain. Defaults to the number
of sites in the Hilbert space."""
n_chains: int = struct.field(pytree_node=False)
"""Total number of independent chains across all MPI ranks and/or devices."""
reset_chains: bool = struct.field(pytree_node=False, default=False)
"""If True, resets the chain state when `reset` is called on every new sampling."""
[docs]
def __init__(
self,
hilbert: AbstractHilbert,
rule: MetropolisRule,
*,
n_sweeps: int = None,
sweep_size: int = None,
reset_chains: bool = False,
n_chains: Optional[int] = None,
n_chains_per_rank: Optional[int] = None,
machine_pow: int = 2,
dtype: DType = float,
):
"""
Constructs a Metropolis Sampler.
Args:
hilbert: The Hilbert space to sample.
rule: A `MetropolisRule` to generate random transitions from a given state as
well as uniform random states.
n_chains: The total number of independent Markov chains across all MPI ranks.
Either specify this or `n_chains_per_rank`. If MPI is disabled, the two are equivalent;
if MPI is enabled and `n_chains` is specified, then every MPI rank will run
`n_chains/mpi.n_nodes` chains. In general, we recommend specifying `n_chains_per_rank`
as it is more portable.
n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
If netket_experimental_sharding is enabled this is interpreted as the number
of independent chains on every jax device, and the n_chains_per_rank
property of the sampler will return the total number of chains on all devices.
sweep_size: Number of sweeps for each step along the chain.
This is equivalent to subsampling the Markov chain. (Defaults to the number of sites
in the Hilbert space.)
reset_chains: If True, resets the chain state when `reset` is called on every
new sampling (default = False).
machine_pow: The power to which the machine should be exponentiated to generate
the pdf (default = 2).
dtype: The dtype of the states sampled (default = np.float64).
"""
# Validate the inputs
if not isinstance(rule, MetropolisRule):
raise TypeError(
f"The second positional argument, rule, must be a MetropolisRule but "
f"`type(rule)={type(rule)}`."
)
if not isinstance(reset_chains, bool):
raise TypeError("reset_chains must be a boolean.")
if n_sweeps is not None:
warn_deprecation(
"Specifying `n_sweeps` when constructing sampler is deprecated. Please use `sweep_size` instead."
)
if sweep_size is not None:
raise ValueError("Cannot specify both `sweep_size` and `n_sweeps`")
sweep_size = n_sweeps
if sweep_size is None:
sweep_size = hilbert.size
# Default n_chains per rank, if unset
if n_chains is None and n_chains_per_rank is None:
# TODO set it to a few hundred if on GPU?
n_chains_per_rank = 16
n_chains = _round_n_chains_to_next_multiple(
n_chains,
n_chains_per_rank,
device_count(),
"rank",
)
super().__init__(
hilbert=hilbert,
machine_pow=machine_pow,
dtype=dtype,
)
self.n_chains = n_chains
self.reset_chains = reset_chains
self.rule = rule
self.sweep_size = sweep_size
@property
def n_sweeps(self):
warn_deprecation(
"`MetropolisSampler.n_sweeps` is deprecated. Please use `MetropolisSampler.sweep_size` instead."
)
return self.sweep_size
[docs]
def sample_next(
sampler,
machine: Union[Callable, nn.Module],
parameters: PyTree,
state: Optional[SamplerState] = None,
) -> tuple[SamplerState, jnp.ndarray]:
"""
Samples the next state in the Markov chain.
Args:
machine: A Flax module or callable with the forward pass of the log-pdf.
If it is a callable, it should have the signature
:code:`f(parameters, σ) -> jnp.ndarray`.
parameters: The PyTree of parameters of the model.
state: The current state of the sampler. If not specified, then initialize and reset it.
Returns:
state: The new state of the sampler.
σ: The next batch of samples.
Note:
The return order is inverted wrt `sample` because when called inside of
a scan function the first returned argument should be the state.
"""
if state is None:
state = sampler.reset(machine, parameters)
return sampler._sample_next(wrap_afun(machine), parameters, state)
@partial(jax.jit, static_argnums=1)
def _init_state(sampler, machine, parameters, key):
key_state, key_rule = jax.random.split(key)
rule_state = sampler.rule.init_state(sampler, machine, parameters, key_rule)
σ = jnp.zeros((sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype)
σ = with_samples_sharding_constraint(σ)
state = MetropolisSamplerState(σ=σ, rng=key_state, rule_state=rule_state)
# If we don't reset the chain at every sampling iteration, then reset it
# now.
if not sampler.reset_chains:
key_state, rng = jax.jit(jax.random.split)(key_state)
σ = sampler.rule.random_state(sampler, machine, parameters, state, rng)
_assert_good_sample_shape(
σ,
(sampler.n_batches, sampler.hilbert.size),
sampler.dtype,
f"{sampler.rule}.random_state",
)
σ = with_samples_sharding_constraint(σ)
state = state.replace(σ=σ, rng=key_state)
return state
@partial(jax.jit, static_argnums=1)
def _reset(sampler, machine, parameters, state):
rng = state.rng
if sampler.reset_chains:
rng, key = jax.random.split(state.rng)
σ = sampler.rule.random_state(sampler, machine, parameters, state, rng)
_assert_good_sample_shape(
σ,
(sampler.n_batches, sampler.hilbert.size),
sampler.dtype,
f"{sampler.rule}.random_state",
)
σ = with_samples_sharding_constraint(σ)
else:
σ = state.σ
rule_state = sampler.rule.reset(sampler, machine, parameters, state)
return state.replace(
σ=σ,
rng=rng,
rule_state=rule_state,
n_steps_proc=jnp.zeros_like(state.n_steps_proc),
n_accepted_proc=jnp.zeros_like(state.n_accepted_proc),
)
def _sample_next(sampler, machine, parameters, state):
"""
Implementation of `sample_next` for subclasses of `MetropolisSampler`.
If you subclass `MetropolisSampler`, you should override this and not `sample_next`
itself, because `sample_next` contains some common logic.
"""
def loop_body(i, s):
# 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
s["key"], key1, key2 = jax.random.split(s["key"], 3)
σp, log_prob_correction = sampler.rule.transition(
sampler, machine, parameters, state, key1, s["σ"]
)
_assert_good_sample_shape(
σp,
(sampler.n_batches, sampler.hilbert.size),
sampler.dtype,
f"{sampler.rule}.transition",
)
proposal_log_prob = sampler.machine_pow * machine.apply(parameters, σp).real
_assert_good_log_prob_shape(proposal_log_prob, sampler.n_batches, machine)
uniform = jax.random.uniform(key2, shape=(sampler.n_batches,))
if log_prob_correction is not None:
do_accept = uniform < jnp.exp(
proposal_log_prob - s["log_prob"] + log_prob_correction
)
else:
do_accept = uniform < jnp.exp(proposal_log_prob - s["log_prob"])
# do_accept must match ndim of proposal and state (which is 2)
s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"])
s["accepted"] += do_accept
s["log_prob"] = jax.numpy.where(
do_accept.reshape(-1), proposal_log_prob, s["log_prob"]
)
return s
new_rng, rng = jax.random.split(state.rng)
s = {
"key": rng,
"σ": state.σ,
"log_prob": sampler.machine_pow * machine.apply(parameters, state.σ).real,
# for logging
"accepted": state.n_accepted_proc,
}
s = jax.lax.fori_loop(0, sampler.sweep_size, loop_body, s)
new_state = state.replace(
rng=new_rng,
σ=s["σ"],
n_accepted_proc=s["accepted"],
n_steps_proc=state.n_steps_proc + sampler.sweep_size * sampler.n_batches,
)
return new_state, new_state.σ
@partial(jax.jit, static_argnums=(1, 4))
def _sample_chain(sampler, machine, parameters, state, chain_length):
"""
Samples `chain_length` batches of samples along the chains.
Internal method used for jitting calls.
Arguments:
sampler: The Monte Carlo sampler.
machine: A Flax module with the forward pass of the log-pdf.
parameters: The PyTree of parameters of the model.
state: The current state of the sampler.
chain_length: The length of the chains.
Returns:
σ: The next batch of samples.
state: The new state of the sampler
"""
state, samples = jax.lax.scan(
lambda state, _: sampler.sample_next(machine, parameters, state),
state,
xs=None,
length=chain_length,
)
# make it (n_chains, n_samples_per_chain) as expected by netket.stats.statistics
samples = jnp.swapaxes(samples, 0, 1)
return samples, state
def __repr__(sampler):
return (
f"{type(sampler).__name__}("
+ f"\n hilbert = {sampler.hilbert},"
+ f"\n rule = {sampler.rule},"
+ f"\n n_chains = {sampler.n_chains},"
+ f"\n sweep_size = {sampler.sweep_size},"
+ f"\n reset_chains = {sampler.reset_chains},"
+ f"\n machine_power = {sampler.machine_pow},"
+ f"\n dtype = {sampler.dtype}"
+ ")"
)
def __str__(sampler):
return (
f"{type(sampler).__name__}("
+ f"rule = {sampler.rule}, "
+ f"n_chains = {sampler.n_chains}, "
+ f"sweep_size = {sampler.sweep_size}, "
+ f"reset_chains = {sampler.reset_chains}, "
+ f"machine_power = {sampler.machine_pow}, "
+ f"dtype = {sampler.dtype})"
)
[docs]
def MetropolisLocal(hilbert, **kwargs) -> MetropolisSampler:
r"""
Sampler acting on one local degree of freedom.
This sampler acts locally only on one local degree of freedom :math:`s_i`,
and proposes a new state: :math:`s_1 \dots s^\prime_i \dots s_N`,
where :math:`s^\prime_i \neq s_i`.
The transition probability associated to this
sampler can be decomposed into two steps:
1. One of the site indices :math:`i = 1\dots N` is chosen with uniform probability.
2. Among all the possible (:math:`m - 1`) values that :math:`s^\prime_i` can take,
one of them is chosen with uniform probability.
For example, in the case of spin :math:`1/2` particles, :math:`m=2`
and the possible local values are :math:`s_i = -1,+1`.
In this case then :class:`MetropolisLocal` is equivalent to flipping a random spin.
In the case of bosons, with occupation numbers
:math:`s_i = 0, 1, \dots n_{\mathrm{max}}`, :class:`MetropolisLocal`
would pick a random local occupation number uniformly between :math:`0`
and :math:`n_{\mathrm{max}}` except the current :math:`s_i`.
Args:
hilbert: The Hilbert space to sample.
n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
sweep_size: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
This is equivalent to subsampling the Markov chain.
reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False).
machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype: The dtype of the states sampled (default = np.float64).
"""
from .rules import LocalRule
return MetropolisSampler(hilbert, LocalRule(), **kwargs)
[docs]
def MetropolisExchange(
hilbert, *, clusters=None, graph=None, d_max=1, **kwargs
) -> MetropolisSampler:
r"""
This sampler acts locally only on two local degree of freedom :math:`s_i` and :math:`s_j`,
and proposes a new state: :math:`s_1 \dots s^\prime_i \dots s^\prime_j \dots s_N`,
where in general :math:`s^\prime_i \neq s_i` and :math:`s^\prime_j \neq s_j`.
The sites :math:`i` and :math:`j` are also chosen to be within a maximum graph
distance of :math:`d_{\mathrm{max}}`.
The transition probability associated to this sampler can
be decomposed into two steps:
1. A pair of indices :math:`i,j = 1\dots N`, and such that
:math:`\mathrm{dist}(i,j) \leq d_{\mathrm{max}}`,
is chosen with uniform probability.
2. The sites are exchanged, i.e. :math:`s^\prime_i = s_j` and :math:`s^\prime_j = s_i`.
Notice that this sampling method generates random permutations of the quantum
numbers, thus global quantities such as the sum of the local quantum numbers
are conserved during the sampling.
This scheme should be used then only when sampling in a
region where :math:`\sum_i s_i = \mathrm{constant}` is needed,
otherwise the sampling would be strongly not ergodic.
Args:
hilbert: The Hilbert space to sample.
d_max: The maximum graph distance allowed for exchanges.
n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
sweep_size: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
This is equivalent to subsampling the Markov chain.
reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False).
machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype: The dtype of the states sampled (default = np.float64).
Examples:
Sampling from a RBM machine in a 1D lattice of spin 1/2, using
nearest-neighbor exchanges.
>>> import netket as nk
>>>
>>> g=nk.graph.Hypercube(length=10,n_dim=2,pbc=True)
>>> hi=nk.hilbert.Spin(s=0.5, N=g.n_nodes)
>>>
>>> # Construct a MetropolisExchange Sampler
>>> sa = nk.sampler.MetropolisExchange(hi, graph=g)
>>> print(sa)
MetropolisSampler(rule = ExchangeRule(# of clusters: 200), n_chains = 16, sweep_size = 100, reset_chains = False, machine_power = 2, dtype = <class 'float'>)
"""
from .rules import ExchangeRule
rule = ExchangeRule(clusters=clusters, graph=graph, d_max=d_max)
return MetropolisSampler(hilbert, rule, **kwargs)
[docs]
def MetropolisHamiltonian(hilbert, hamiltonian, **kwargs) -> MetropolisSampler:
r"""
Sampling based on the off-diagonal elements of a Hamiltonian (or a generic Operator).
In this case, the transition matrix is taken to be:
.. math::
T( \mathbf{s} \rightarrow \mathbf{s}^\prime) = \frac{1}{\mathcal{N}(\mathbf{s})}\theta(|H_{\mathbf{s},\mathbf{s}^\prime}|),
where :math:`\theta(x)` is the Heaviside step function, and :math:`\mathcal{N}(\mathbf{s})`
is a state-dependent normalization.
The effect of this transition probability is then to connect (with uniform probability)
a given state :math:`\mathbf{s}` to all those states :math:`\mathbf{s}^\prime` for which the Hamiltonian has
finite matrix elements.
Notice that this sampler preserves by construction all the symmetries
of the Hamiltonian. This is in generally not true for the local samplers instead.
This sampler only works on the CPU. To use the Hamiltonian sampler with GPUs,
you should use :class:`netket.sampler.MetropolisHamiltonianNumpy`
Args:
hilbert: The Hilbert space to sample.
hamiltonian: The operator used to perform off-diagonal transition.
n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
sweep_size: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
This is equivalent to subsampling the Markov chain.
reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False).
machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype: The dtype of the states sampled (default = np.float64).
Examples:
Sampling from a RBM machine in a 1D lattice of spin 1/2
>>> import netket as nk
>>>
>>> g=nk.graph.Hypercube(length=10,n_dim=2,pbc=True)
>>> hi=nk.hilbert.Spin(s=0.5, N=g.n_nodes)
>>>
>>> # Transverse-field Ising Hamiltonian
>>> ha = nk.operator.Ising(hilbert=hi, h=1.0, graph=g)
>>>
>>> # Construct a MetropolisHamiltonian Sampler
>>> sa = nk.sampler.MetropolisHamiltonian(hi, hamiltonian=ha)
>>> print(sa)
MetropolisSampler(rule = HamiltonianRuleNumba(operator=Ising(J=1.0, h=1.0; dim=100)), n_chains = 16, sweep_size = 100, reset_chains = False, machine_power = 2, dtype = <class 'float'>)
"""
from .rules import HamiltonianRule
rule = HamiltonianRule(hamiltonian)
return MetropolisSampler(hilbert, rule, **kwargs)
[docs]
def MetropolisGaussian(hilbert, sigma=1.0, **kwargs) -> MetropolisSampler:
"""This sampler acts on all particle positions simultaneously
and proposes a new state according to a Gaussian distribution
with width `sigma`.
Args:
hilbert: The continuous Hilbert space to sample.
sigma: The width of the Gaussian proposal distribution (default = 1.0).
n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
sweep_size: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
This is equivalent to subsampling the Markov chain.
reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False).
machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype: The dtype of the states sampled (default = np.float64).
"""
if not isinstance(hilbert, ContinuousHilbert):
raise ValueError("This sampler only works for Continuous Hilbert spaces.")
from .rules import GaussianRule
rule = GaussianRule(sigma)
return MetropolisSampler(hilbert, rule, **kwargs)
[docs]
def MetropolisAdjustedLangevin(
hilbert, dt=0.001, chunk_size=None, **kwargs
) -> MetropolisSampler:
r"""This sampler acts on all particle positions simultaneously
and takes a Langevin step [1]:
.. math::
x_{t+dt} = x_t + dt \nabla_x \log p(x) \vert_{x=x_t} + \sqrt{2 dt}\eta,
where :math:`\eta` is normal distributed noise :math:`\eta \sim \mathcal{N}(0,1)`.
This sampler only works for continuous Hilbert spaces.
[1]: https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm
Args:
hilbert: The continuous Hilbert space to sample.
dt: Time step size for the Langevin dynamics (noise with variance 2*dt).
chunk_size: Chunk size to compute the gradients of the log probability.
n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
sweep_size: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
This is equivalent to subsampling the Markov chain.
reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False).
machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype: The dtype of the states sampled (default = np.float64).
"""
if not isinstance(hilbert, ContinuousHilbert):
raise ValueError("This sampler only works for Continuous Hilbert spaces.")
from .rules import LangevinRule
rule = LangevinRule(dt=dt, chunk_size=chunk_size)
return MetropolisSampler(hilbert, rule, **kwargs)