
class netket.hilbert.constraint.DiscreteHilbertConstraint[source]#

Bases: Pytree

Protocol to define an Abstract Constraint for a discete Hilbert space.

To define a customized constraint, you must subclass this class and at least implement the __call__ method. The __call__ method should take as input a matrix encoding a batch of configurations, and return a vector of booleans specifying whether they are valid configurations or not.

The __call__ method must be jax.jit-able. If you cannot make it jax-jittable, you can implement it in numba/python and wrap it into a jax.pure_callback() to make it compatible with jax.

The callback should be hashable and comparable with itself, which means it must implement __hash__ and __eq__. By default, the __hash__ method is implemented by the id of the object, which is unique for each object, which will work but might lead to more recompilations in jax. If you can, you should implement a custom __hash__


The following example shows a class that implements a simple constraint checking that the total sum of the elements in the configuration is equal to a given value. The example shows how to implement the __call__ method and the __hash__ and __eq__ methods.

import netket as nk
from netket.utils import struct

import jax; import jax.numpy as jnp

class SumConstraint(nk.hilbert.constraint.DiscreteHilbertConstraint):
    # A simple constraint checking that the total sum of the elements
    # in the configuration is equal to a given value.

    # The value must be set as a pytree_node=False field, meaning
    # that it is a constant and changes to this value represent different
    # constraints.
    total_sum : float = struct.field(pytree_node=False)

    def __init__(self, total_sum):
        self.total_sum = total_sum

    def __call__(self, x):
        # Makes it jax-compatible
        return jnp.sum(x, axis=-1) == self.total_sum

    def __hash__(self):
        return hash(("SumConstraint", self.total_sum))

    def __eq__(self, other):
        if isinstance(other, SumConstraint):
            return self.total_sum == other.total_sum
        return False


The following example shows how to implement the same function as above, but using a pure python function and a jax.pure_callback() to make it compatible with jax.

import jax
import jax.numpy as jnp
import numpy as np

import netket as nk
from netket.utils import struct

class SumConstraintPy(nk.hilbert.constraint.DiscreteHilbertConstraint):
    # A simple constraint checking that the total sum of the elements
    # in the configuration is equal to a given value.

    total_sum : float = struct.field(pytree_node=False)

    def __init__(self, total_sum):
        self.total_sum = total_sum

    def __call__(self, x):
        return jax.pure_callback(self._call_py,
                                (jax.ShapeDtypeStruct(x.shape[:-1], bool)),

    def _call_py(self, x):
        # Not Jax compatible
        return np.sum(x, axis=-1) == self.total_sum

    def __hash__(self):
        return hash(("SumConstraintPy", self.total_sum))

    def __eq__(self, other):
        if isinstance(other, SumConstraintPy):
            return self.total_sum == other.total_sum
        return False
Inheritance diagram of netket.hilbert.constraint.DiscreteHilbertConstraint
abstract __call__(x)[source]#

This function should take as input a matrix encoding a batch of configurations, and return a vector of booleans specifying whether they are valid configurations of the Hilbert space or not.


x (Array) – 2D matrix.

Return type:



Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

  • self (P)

  • kwargs (Any)