netket.hilbert.constraint.DiscreteHilbertConstraint#
- class netket.hilbert.constraint.DiscreteHilbertConstraint[source]#
Bases:
Pytree
Protocol to define an Abstract Constraint for a discete Hilbert space.
To define a customized constraint, you must subclass this class and at least implement the
__call__
method. The__call__
method should take as input a matrix encoding a batch of configurations, and return a vector of booleans specifying whether they are valid configurations or not.The
__call__
method must bejax.jit
-able. If you cannot make it jax-jittable, you can implement it in numba/python and wrap it into ajax.pure_callback()
to make it compatible with jax.The callback should be hashable and comparable with itself, which means it must implement
__hash__
and__eq__
. By default, the__hash__
method is implemented by the id of the object, which is unique for each object, which will work but might lead to more recompilations in jax. If you can, you should implement a custom__hash__
Example
The following example shows a class that implements a simple constraint checking that the total sum of the elements in the configuration is equal to a given value. The example shows how to implement the
__call__
method and the__hash__
and__eq__
methods.import netket as nk from netket.utils import struct import jax; import jax.numpy as jnp class SumConstraint(nk.hilbert.constraint.DiscreteHilbertConstraint): # A simple constraint checking that the total sum of the elements # in the configuration is equal to a given value. # The value must be set as a pytree_node=False field, meaning # that it is a constant and changes to this value represent different # constraints. total_sum : float = struct.field(pytree_node=False) def __init__(self, total_sum): self.total_sum = total_sum def __call__(self, x): # Makes it jax-compatible return jnp.sum(x, axis=-1) == self.total_sum def __hash__(self): return hash(("SumConstraint", self.total_sum)) def __eq__(self, other): if isinstance(other, SumConstraint): return self.total_sum == other.total_sum return False
Example
The following example shows how to implement the same function as above, but using a pure python function and a
jax.pure_callback()
to make it compatible with jax.import jax import jax.numpy as jnp import numpy as np import netket as nk from netket.utils import struct class SumConstraintPy(nk.hilbert.constraint.DiscreteHilbertConstraint): # A simple constraint checking that the total sum of the elements # in the configuration is equal to a given value. total_sum : float = struct.field(pytree_node=False) def __init__(self, total_sum): self.total_sum = total_sum def __call__(self, x): return jax.pure_callback(self._call_py, (jax.ShapeDtypeStruct(x.shape[:-1], bool)), x, vmap_method="expand_dims") def _call_py(self, x): # Not Jax compatible return np.sum(x, axis=-1) == self.total_sum def __hash__(self): return hash(("SumConstraintPy", self.total_sum)) def __eq__(self, other): if isinstance(other, SumConstraintPy): return self.total_sum == other.total_sum return False
- Inheritance
- __init__()#
- Methods
- abstract __call__(x)[source]#
This function should take as input a matrix encoding a batch of configurations, and return a vector of booleans specifying whether they are valid configurations of the Hilbert space or not.