netket.errors.UnoptimisedCustomConstraintRandomStateMethodWarning

netket.errors.UnoptimisedCustomConstraintRandomStateMethodWarning#

exception netket.errors.UnoptimisedCustomConstraintRandomStateMethodWarning[source]#

Warning thrown when calling random_state on a Hilbert space with a custom constraint.

This warning is thrown when the custom Hilbert space constraint does not have a custom random_state method implemented. This will default to a slow, possibly infinitely-looping method to generate random states.

The default fallback works by generating random states in the unconstrained Hilbert space until one is found that satisfies the constraint. This can be very slow, and even never terminate if the constraint is too restrictive.

Note

Generating random states is only required when initializing a Markov Chain Monte Carlo sampler, but is generally not needed during on. Therefore, even if this warning is thrown, it might not be a problem if you do not reset your chains very often.

In general, if your constraint is not overly (exponentially) restrictive, you may not need to worry about this warning.

Note

The fallback implementation is not optimised, especially on GPUs. It will generate 1 random state at a time until it finds one that satisfies the constraint, and will repeat this for every different chain.

We welcome contributions to improve this method to perform the loop in batches.

Note

You can silence this warning by setting the environment variable NETKET_RANDOM_STATE_FALLBACK_WARNING=0 or by setting nk.config.netket_random_state_fallback_warning = 0 in your code.

Example implementation of a custom random_state method:#

Implementations of netket.hilbert.random_state() are dispatched based on the Hilbert space and constraint type using Plum’s multiple dispatch (see the link <https://beartype.github.io/plum/intro.html>_

See an example here

@dispatch.dispatch
def random_state(hilb: HilbertType, constraint: ConstraintType, key, batches: int, *, dtype=None):
    # your custom implementation here
    # You should return a batch of `batches` random states, with the given dtype.
    # return jax.Array with shape (batches, hilb.size) and dtype dtype.

Note

We welcome contributions to improve the documentation here.