netket.nn.blocks.SymmExpSum#

class netket.nn.blocks.SymmExpSum[source]#

Bases: Module

A flax module symmetrizing the log-wavefunction \(\log\psi_\theta(\sigma)\) encoded into another flax module (flax.linen.Module) by summing over all possible symmetries \(g\) in a certain discrete permutation group \(G\).

\[\log\psi_\theta(\sigma) = \frac{1}{|G|}\log\sum_{g\in G} \chi_g\exp[\log\psi_\theta(T_{g}\sigma)]\]

For the ground-state, it is usually found that \(\chi_g=1 \forall g\in G\).

To construct this network, one has to specify the module, the symmetry group and (optionally)the id of the character to consider.

The module’s .__call__ will be called. The symm_group attribute

Examples

Constructs a netket.nn.blocks.SymmExpSum for a bare netket.models.RBM, summing over all translations of a 2D Square lattice

>>> import netket as nk
>>> graph = nk.graph.Square(3)
>>> print("number of translational symmetries: ", len(graph.translation_group()))
number of translational symmetries:  9
>>> # Construct the bare unsymmetrized machine
>>> machine_no_symm = nk.models.RBM(alpha=2)
>>> # Symmetrize the RBM over all translations
>>> ma = nk.nn.blocks.SymmExpSum(module = machine_no_symm, symm_group=graph.translation_group())

If you have a Convolutional NN that is already invariant under translations, you might want to only symmetrize over the point-group (mirror symmetry and rotations).

>>> import netket as nk
>>> graph = nk.graph.Square(3)
>>> print("number of point-group symmetries: ", len(graph.point_group()))
number of point-group symmetries:  8
>>> # Construct the bare unsymmetrized machine
>>> machine_no_symm = nk.models.RBM(alpha=2)
>>> # Symmetrize the RBM over all translations
>>> ma = nk.nn.blocks.SymmExpSum(module = machine_no_symm, symm_group=graph.point_group())
Attributes
character_id: Optional[int] = None#

The # identifying the target character in the character table of the symmetry group. By default the characters are taken to be all 1, giving the homogeneous state.

The characters are accessed as:

symm_group.character_table()[character_id]
module: Module#

The neural network architecture encoding the log-wavefunction to symmetrize in the .__call__ function.

symm_group: PermutationGroup#

The symmetry group to use. It should be a valid netket.utils.group.PermutationGroup object.

Can be extracted from a netket.graph.Lattice object by calling point_group() or translation_group().

Alternatively, if you have a netket.graph.Graph object you can build it from automorphisms().

graph = nk.graph.Square(3)
symm_group = graph.point_group()
Methods
__call__(x)[source]#

Accepts a single input or arbitrary batch of inputs.

The last dimension of x must match the shape of the permutation group.

Parameters:

x (ndarray | Array)