netket.sampler.ARDirectSampler#

class netket.sampler.ARDirectSampler[source]#

Bases: Sampler

Direct sampler for autoregressive neural networks.

This sampler only works with Flax models. This flax model must expose a specific method, model.conditional, which given a batch of samples and an index i∈[0,self.hilbert.size] must return the vector of partial probabilities at index i for the various (partial) samples provided.

In short, if your model can be sampled according to a probability $ p(x) = p_1(x_1)p_2(x_2|x_1)dots p_N(x_N|x_{N-1}dots x_1) $ then model.conditional(x, i) should return $p_i(x)$.

NetKet implements some autoregressive networks that can be used together with this sampler.

Inheritance
Inheritance diagram of netket.sampler.ARDirectSampler
__init__(hilbert, machine_pow=None, dtype=<class 'float'>)[source]#

Construct an autoregressive direct sampler.

Parameters:
  • hilbert (DiscreteHilbert) – The Hilbert space to sample.

  • dtype (Any) – The dtype of the states sampled (default = np.float64).

  • machine_pow (None) –

Note

ARDirectSampler.machine_pow has no effect. Please set the model’s machine_pow instead.

Attributes
is_exact#

Returns True because the sampler is exact.

The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them.

n_batches#

The batch size of the configuration $sigma$ used by this sampler on this jax process.

This is used to determine the shape of the batches generated in a single process. This is needed because when using MPI, every process must create a batch of chains of n_chains_per_rank, while when using the experimental sharding mode we must declare the full shape on every jax process, therefore this returns n_chains.

Usage of this flag is required to support both MPI and sharding.

Samplers may override this to have a larger batch size, for example to propagate multiple replicas (in the case of parallel tempering).

n_chains#

The total number of independent chains.

This is at least equal to the total number of MPI ranks/jax devices that are used to distribute the calculation.

n_chains_per_rank#

The total number of independent chains per MPI rank (or jax device if you set NETKET_EXPERIMENTAL_SHARDING=1).

If you are not distributing the calculation among different MPI ranks or jax devices, this is equal to n_chains.

In general this is equal to

from netket.jax import sharding
sampler.n_chains // sharding.device_count()
hilbert: AbstractHilbert#

The Hilbert space to sample.

machine_pow: int#

The power to which the machine should be exponentiated to generate the pdf.

dtype: DType#

The dtype of the states sampled.

Methods
init_state(machine, parameters, seed=None)#

Creates the structure holding the state of the sampler.

If you want reproducible samples, you should specify seed, otherwise the state will be initialised randomly.

If running across several MPI processes, all sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating `n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds.

The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • seed (Union[int, Any, None]) – An optional seed or jax PRNGKey. If not specified, a random seed will be used.

Return type:

SamplerState

Returns:

The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use.

log_pdf(model)#

Returns a closure with the log-pdf function encoded by this sampler.

Parameters:

model (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

Return type:

Callable

Returns:

The log-probability density function.

Note

The result is returned as a HashablePartial so that the closure does not trigger recompilation.

replace(**kwargs)#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P) –

  • kwargs (Any) –

reset(machine, parameters, state=None)#

Resets the state of the sampler. To be used every time the parameters are changed.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If not specified, it will be constructed by calling sampler.init_state(machine, parameters) with a random seed.

Return type:

SamplerState

Returns:

A valid sampler state.

sample(machine, parameters, *, state=None, chain_length=1)#

Samples chain_length batches of samples along the chains.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If not specified, then initialize and reset it.

  • chain_length (int) – The length of the chains (default = 1).

Returns:

The generated batches of samples. state: The new state of the sampler.

Return type:

Οƒ

samples(machine, parameters, *, state=None, chain_length=1)#

Returns a generator sampling chain_length batches of samples along the chains.

Parameters:
  • machine (Union[Callable, Module]) – A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature f(parameters, Οƒ) -> jnp.ndarray.

  • parameters (Any) – The PyTree of parameters of the model.

  • state (Optional[SamplerState]) – The current state of the sampler. If not specified, then initialize and reset it.

  • chain_length (int) – The length of the chains (default = 1).

Return type:

Iterator[Array]