netket.sampler.rules.TensorRule#

class netket.sampler.rules.TensorRule[source]#

Bases: MetropolisRule

A Metropolis sampling rule that can be used to combine different rules acting on different subspaces of the same tensor-hilbert space.

Inheritance
Inheritance diagram of netket.sampler.rules.TensorRule
__init__(hilbert, rules)[source]#

Construct the composition of rules.

It should be constructed by passing a TensorHilbert space as a first argument and a list of rules as a second argument. Each rule[i] will be used to generate a transition for the i-th subspace of the tensor hilbert space.

Parameters:
  • hilbert (TensorHilbert) – The tensor hilbert space on which the rule acts.

  • rules (tuple[MetropolisRule, ...]) – A list of rules, one for each subspace of the tensor hilbert space.

Return type:

TensorRule

Attributes
hilbert: TensorHilbert#

The hilbert space upon which this rule is defined. This must be a nk.hilbert.TensorHilbert with the same size as the expected input samples.

rules: tuple[MetropolisRule, ...]#

Tuple of rules to be used on every partition of the hilbert space.

Methods
init_state(sampler, machine, params, key)[source]#

Initialises the optional internal state of the Metropolis sampler transition rule.

The provided key is unique and does not need to be split.

It should return an immutable data structure.

Parameters:
  • sampler (MetropolisSampler) – The Metropolis sampler.

  • machine (Module) – A Flax module with the forward pass of the log-pdf.

  • params (Any) – The PyTree of parameters of the model.

  • key (Any) – A Jax PRNGKey.

Return type:

Optional[Any]

Returns:

An optional state.

random_state(sampler, machine, params, sampler_state, key)#

Generates a random state compatible with this rule.

By default this calls netket.hilbert.random.random_state().

Parameters:
  • sampler (MetropolisSampler) – The Metropolis sampler.

  • machine (Module) – A Flax module with the forward pass of the log-pdf.

  • params (Any) – The PyTree of parameters of the model.

  • sampler_state (SamplerState) – The current state of the sampler. Should not modify it.

  • key (Any) – The PRNGKey to use to generate the random state.

replace(**kwargs)#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)

reset(sampler, machine, params, sampler_state)[source]#

Resets the internal state of the Metropolis Sampler Transition Rule.

The default implementation returns the current rule_state without modifying it.

Parameters:
  • sampler (MetropolisSampler) – The Metropolis sampler.

  • machine (Module) – A Flax module with the forward pass of the log-pdf.

  • params (Any) – The PyTree of parameters of the model.

  • sampler_state (SamplerState) – The current state of the sampler. Should not modify it.

Return type:

Optional[Any]

Returns:

A reset state of the rule. This returns the same type of rule_state() and might be None.

transition(sampler, machine, parameters, state, key, σ)[source]#

Proposes a new configuration set of configurations $sigma’$ starting from the current chain configurations \(\sigma\).

The new configurations \(\sigma'\) should be a matrix with the same dimension as \(\sigma\).

This function should return a tuple. where the first element are the new configurations $sigma’$ and the second element is either None or an array of length σ.shape[0] containing an optional log-correction factor. The correction factor should be non-zero when the transition rule is non-symmetrical.

Parameters:
  • sampler – The Metropolis sampler.

  • machine – A Flax module with the forward pass of the log-pdf.

  • params – The PyTree of parameters of the model.

  • sampler_state – The current state of the sampler. Should not modify it.

  • key – A Jax PRNGKey to use to generate new random configurations.

  • σ – The current configurations stored in a 2D matrix.

Returns:

A tuple containing the new configurations \(\sigma'\) and the optional vector of log corrections to the transition probability.