netket.sampler.MetropolisSampler
netket.sampler.MetropolisSampler#
- class netket.sampler.MetropolisSampler#
Bases:
netket.sampler.Sampler
Metropolis-Hastings sampler for a Hilbert space according to a specific transition rule.
The transition rule is used to generate a proposed state \(s^\prime\), starting from the current state \(s\). The move is accepted with probability
\[A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right) ,\]where the probability being sampled from is \(P(s)=|M(s)|^p\). Here \(M(s)\) is a user-provided function (the machine), \(p\) is also user-provided with default value \(p=2\), and \(L(s,s^\prime)\) is a suitable correcting factor computed by the transition kernel.
The dtype of the sampled states can be chosen.
- Inheritance
- __init__(*args, __precompute_cached_properties=False, __skip_preprocess=False, **kwargs)#
Constructs a Metropolis Sampler.
- Parameters
hilbert – The Hilbert space to sample.
rule – A MetropolisRule to generate random transitions from a given state as well as uniform random states.
n_chains – The total number of independent Markov chains across all MPI ranks. Either specify this or n_chains_per_rank. If MPI is disabled, the two are equivalent; if MPI is enabled and n_chains is specified, then every MPI rank will run n_chains/mpi.n_nodes chains. In general, we recommend specifying n_chains_per_rank as it is more portable.
n_chains_per_rank – Number of independent chains on every MPI rank (default = 16).
n_sweeps – Number of sweeps for each step along the chain. This is equivalent to subsampling the Markov chain. (Defaults to the number of sites in the Hilbert space.)
reset_chains – If True, resets the chain state when reset is called on every new sampling (default = False).
machine_pow – The power to which the machine should be exponentiated to generate the pdf (default = 2).
dtype – The dtype of the states sampled (default = np.float64).
- Attributes
- is_exact#
Returns True if the sampler is exact.
The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them.
- Return type
- n_batches#
The batch size of the configuration $sigma$ used by this sampler.
In general, it is equivalent to
n_chains_per_rank
.- Return type
- n_chains#
The total number of independent chains across all MPI ranks.
If you are not using MPI, this is equal to
n_chains_per_rank
.- Return type
- rule: netket.sampler.rules.MetropolisRule = None#
- hilbert: netket.hilbert.AbstractHilbert#
- Methods
- init_state(machine, parameters, seed=None)#
Creates the structure holding the state of the sampler.
If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.
If running across several MPI processes, all sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating `n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds.
The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods.
- Parameters
machine (
Union
[Callable
,Module
]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, σ) -> jnp.ndarray
.parameters (
Any
) – The PyTree of parameters of the model.seed (
Union
[int
,Any
,None
]) – An optional seed or jax PRNGKey. If not specified, a random seed will be used.
- Return type
- Returns
The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.
- log_pdf(model)#
Returns a closure with the log-pdf function encoded by this sampler.
- Parameters
model (
Union
[Callable
,Module
]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, σ) -> jnp.ndarray
.- Return type
- Returns
The log-probability density function.
Note
The result is returned as a HashablePartial so that the closure does not trigger recompilation.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- reset(machine, parameters, state=None)#
Resets the state of the sampler. To be used every time the parameters are changed.
- Parameters
machine (
Union
[Callable
,Module
]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, σ) -> jnp.ndarray
.parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – The current state of the sampler. If not specified, it will be constructed by callingsampler.init_state(machine, parameters)
with a random seed.
- Return type
- Returns
A valid sampler state.
- sample(machine, parameters, *, state=None, chain_length=1)#
Samples chain_length batches of samples along the chains.
- Parameters
machine (
Union
[Callable
,Module
]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, σ) -> jnp.ndarray
.parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – The current state of the sampler. If not specified, then initialize and reset it.chain_length (
int
) – The length of the chains (default = 1).
- Returns
The generated batches of samples. state: The new state of the sampler.
- Return type
σ
- sample_next(machine, parameters, state=None)[source]#
Samples the next state in the Markov chain.
- Parameters
machine (
Union
[Callable
,Module
]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, σ) -> jnp.ndarray
.parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – The current state of the sampler. If not specified, then initialize and reset it.
- Returns
The new state of the sampler. σ: The next batch of samples.
- Return type
state
Note
The return order is inverted wrt sample because when called inside of a scan function the first returned argument should be the state.
- samples(machine, parameters, *, state=None, chain_length=1)#
Returns a generator sampling chain_length batches of samples along the chains.
- Parameters
machine (
Union
[Callable
,Module
]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signaturef(parameters, σ) -> jnp.ndarray
.parameters (
Any
) – The PyTree of parameters of the model.state (
Optional
[SamplerState
]) – The current state of the sampler. If not specified, then initialize and reset it.chain_length (
int
) – The length of the chains (default = 1).
- Return type