netket.sampler.ParallelTemperingSampler#
- class netket.sampler.ParallelTemperingSampler[source]#
Bases:
MetropolisSampler
Metropolis-Hastings with Parallel Tempering sampler.
This sampler samples an Hilbert space, producing samples off a specific dtype. The samples are generated according to a transition rule that must be specified.
The Metropolis Hastings acceptance rule is correted with a temperature.
- Inheritance
- __init__(*args, n_replicas=None, betas='linear', **kwargs)[source]#
ParallelTemperingSampler
is a generic Metropolis-Hastings sampler using a transition rule to perform moves in the Markov Chain. The transition kernel is used to generate a proposed state \(s^\prime\), starting from the current state \(s\). The move is accepted with probability\[A(s\rightarrow s^\prime) = \mathrm{min}\left (1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right),\]where the probability being sampled from is \(P(s)=Ξ²|M(s)|^p\). Here \(M(s)\) is a user-provided function (the machine), \(p\) is also user-provided with default value \(p=2\), \(Ξ²\) is the temperature of the Markov Chain and \(L(s,s^\prime)\) is a suitable correcting factor computed by the transition kernel.
- Parameters:
hilbert β The hilbert space to sample
rule β A MetropolisRule to generate random transitions from a given state as well as uniform random states.
n_replicas (
int
|None
) β The number of different temperatures Ξ² for the sampling, must be even. (default : 32).betas (
str
|Array
|None
) β (Optional) Distribution or list of values of the temperatures Ξ². For the distribution, possibility between βlinearβ for a linear distribution and βlogβ for a logarithmic one. For the explicit list of values, the length must be even and the value Ξ²=1 must obligatory be an element of betas, all other temperatures must be in (0,1]. (default : βlinβ, i.e. linear distribution between (0,1]).n_chains β The number of Markov Chain to be run in parallel on a single process.
sweep_size β The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine).
n_chains β The number of batches of the states to sample (default = 8)
machine_pow β The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype β The dtype of the states sampled (default = np.float32).
- Attributes
- is_exact#
Returns True if the sampler is exact.
The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them.
- n_batches#
The batch size of the configuration $sigma$ used by this sampler on this jax process.
If you are not using MPI, this is equal to n_chains * n_replicas, but if you are using MPI this is equal to n_chains_per_rank * n_replicas.
- n_chains: int#
Total number of independent chains across all MPI ranks and/or devices.
- n_chains_per_rank#
The total number of independent chains per MPI rank (or jax device if you set NETKET_EXPERIMENTAL_SHARDING=1).
If you are not distributing the calculation among different MPI ranks or jax devices, this is equal to
n_chains
.In general this is equal to
from netket.jax import sharding sampler.n_chains // sharding.device_count()
- n_sweeps#
- sorted_betas#
The sorted values of the temperatures for each _physical_ markov chain. The first value is Ξ² = 1 and is the _physical_ temperature.
-
n_replicas:
int
# The number of replicas evolving with different temperatures for every _physical_ markov chain.
The total number of chains evolved is
n_chains * n_replicas
.
- rule: MetropolisRule#
The Metropolis transition rule.
- sweep_size: int#
Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
- chunk_size: int | None#
Chunk size for evaluating wave functions.
- reset_chains: bool#
If True, resets the chain state when reset is called on every new sampling.
- hilbert: AbstractHilbert#
The Hilbert space to sample.
- machine_pow: float#
The power to which the machine should be exponentiated to generate the pdf.
- dtype: DType#
The dtype of the states sampled.
- Methods
- init_state(machine, parameters, seed=None)[source]#
Creates the structure holding the state of the sampler.
If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.
If running across several MPI processes, all sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating `n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds.
The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods.
- Parameters:
machine (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jnp.ndarray
.parameters (
Any
) β The PyTree of parameters of the model.seed (
Union
[int
,Any
,None
]) β An optional seed or jax PRNGKey. If not specified, a random seed will be used.
- Return type:
- Returns:
The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.
- log_pdf(model)[source]#
Returns a closure with the log-pdf function encoded by this sampler.
- Parameters:
model (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jnp.ndarray
.- Return type:
- Returns:
The log-probability density function.
Note
The result is returned as a HashablePartial so that the closure does not trigger recompilation.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- reset(machine, parameters, state=None)[source]#
Resets the state of the sampler. To be used every time the parameters are changed.
- Parameters:
machine (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jnp.ndarray
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, it will be constructed by callingsampler.init_state(machine, parameters)
with a random seed.
- Return type:
- Returns:
A valid sampler state.
- sample(machine, parameters, *, state=None, chain_length=1)[source]#
Samples chain_length batches of samples along the chains.
- Parameters:
machine (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jnp.ndarray
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, then initialize and reset it.chain_length (
int
) β The length of the chains (default = 1).
- Returns:
The generated batches of samples. state: The new state of the sampler.
- Return type:
Ο
- sample_next(machine, parameters, state=None)[source]#
Samples the next state in the Markov chain.
- Parameters:
machine (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jnp.ndarray
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, then initialize and reset it.
- Returns:
The new state of the sampler. Ο: The next batch of samples.
- Return type:
state
Note
The return order is inverted wrt sample because when called inside of a scan function the first returned argument should be the state.
- samples(machine, parameters, *, state=None, chain_length=1)[source]#
Returns a generator sampling chain_length batches of samples along the chains.
- Parameters:
machine (
Callable
|Module
) β A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, Ο) -> jnp.ndarray
.parameters (
Any
) β The PyTree of parameters of the model.state (
SamplerState
|None
) β The current state of the sampler. If not specified, then initialize and reset it.chain_length (
int
) β The length of the chains (default = 1).
- Return type: