netket.sampler.Sampler#

class netket.sampler.Sampler#

Bases: ABC

Abstract base class for all samplers.

It contains the fields that all of them should possess, defining the common API. Note that fields marked with pytree_node=False are treated as static arguments when jitting.

Subclasses should be NetKet dataclasses and they should define the _init_state, _reset and _sample_chain methods which only accept positional arguments. See the respective method’s definition for its signature.

Notice that those methods are different from the API-entry point without the leading underscore in order to allow us to share some pre-processing code between samplers and simplify the definition of a new sampler.

Inheritance
Inheritance diagram of netket.sampler.Sampler
__init__(*args, __precompute_cached_properties=False, __skip_preprocess=False, **kwargs)#

Construct a Monte Carlo sampler.

Parameters:
  • hilbert – The Hilbert space to sample.

  • n_chains – The total number of independent chains across all MPI ranks. Either specify this or n_chains_per_rank.

  • n_chains_per_rank – Number of independent chains on every MPI rank (default = 1).

  • machine_pow – The power to which the machine should be exponentiated to generate the pdf (default = 2).

  • dtype – The dtype of the states sampled (default = np.float64).

Attributes
is_exact#

Returns True if the sampler is exact.

The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them.

machine_pow: int = 2#
n_batches#

The batch size of the configuration $sigma$ used by this sampler.

In general, it is equivalent to n_chains_per_rank.

n_chains#

The total number of independent chains across all MPI ranks.

If you are not using MPI, this is equal to n_chains_per_rank.

n_chains_per_rank: int = None#
hilbert: AbstractHilbert#
Methods
init_state(machine, parameters, seed=None)[source]#

Creates the structure holding the state of the sampler.

If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.

If running across several MPI processes, all sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating `n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds.

The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • seed (Union[int, Any, None]) – An optional seed or jax PRNGKey. If not specified, a random seed will be used.

Return type:

SamplerState

Returns:

The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.

log_pdf(model)[source]#

Returns a closure with the log-pdf function encoded by this sampler.

Parameters:

model (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

Return type:

Callable

Returns:

The log-probability density function.

Note

The result is returned as a HashablePartial so that the closure does not trigger recompilation.

replace(**updates)#

Returns a new object replacing the specified fields with new values.

reset(machine, parameters, state=None)[source]#

Resets the state of the sampler. To be used every time the parameters are changed.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If not specified, it will be constructed by calling sampler.init_state(machine, parameters) with a random seed.

Return type:

SamplerState

Returns:

A valid sampler state.

sample(machine, parameters, *, state=None, chain_length=1)[source]#

Samples chain_length batches of samples along the chains.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If not specified, then initialize and reset it.

  • chain_length (int) – The length of the chains (default = 1).

Returns:

The generated batches of samples. state: The new state of the sampler.

Return type:

Οƒ

samples(machine, parameters, *, state=None, chain_length=1)[source]#

Returns a generator sampling chain_length batches of samples along the chains.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If not specified, then initialize and reset it.

  • chain_length (int) – The length of the chains (default = 1).

Return type:

Iterator[Array]