Source code for netket.nn.blocks.symmetry_sum
# Copyright 2023 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import jax
import numpy as np
from flax import linen as nn
from netket.jax import logsumexp_cplx
from netket.utils.group import PermutationGroup
from netket.utils.types import Array
[docs]class SymmExpSum(nn.Module):
r"""
A flax module symmetrizing the log-wavefunction :math:`\log\psi_\theta(\sigma)`
encoded into another flax module (:class:`flax.linen.Module`) by summing over
all possible symmetries :math:`g` in a certain discrete permutation
group :math:`G`.
.. math::
\log\psi_\theta(\sigma) = \frac{1}{|G|}\log\sum_{g\in G}
\chi_g\exp[\log\psi_\theta(T_{g}\sigma)]
For the ground-state, it is usually found that :math:`\chi_g=1 \forall g\in G`.
To construct this network, one has to specify the module, the symmetry group
and (optionally)the id of the character to consider.
The module's :code:`.__call__` will be called.
The :code:`symm_group` attribute
Examples:
Constructs a :ref:`netket.nn.blocks.SymmExpSum` for a bare
:ref:`netket.models.RBM`, summing over all translations of a
2D Square lattice
>>> import netket as nk
>>> graph = nk.graph.Square(3)
>>> print("number of translational symmetries: ", len(graph.translation_group()))
number of translational symmetries: 9
>>> # Construct the bare unsymmetrized machine
>>> machine_no_symm = nk.models.RBM(alpha=2)
>>> # Symmetrize the RBM over all translations
>>> ma = nk.nn.blocks.SymmExpSum(module = machine_no_symm, symm_group=graph.translation_group())
If you have a Convolutional NN that is already invariant under translations, you might
want to only symmetrize over the point-group (mirror symmetry and rotations).
>>> import netket as nk
>>> graph = nk.graph.Square(3)
>>> print("number of point-group symmetries: ", len(graph.point_group()))
number of point-group symmetries: 8
>>> # Construct the bare unsymmetrized machine
>>> machine_no_symm = nk.models.RBM(alpha=2)
>>> # Symmetrize the RBM over all translations
>>> ma = nk.nn.blocks.SymmExpSum(module = machine_no_symm, symm_group=graph.point_group())
"""
module: nn.Module
"""The neural network architecture encoding the log-wavefunction
to symmetrize in the :code:`.__call__` function."""
symm_group: PermutationGroup
"""The symmetry group to use. It should be a valid
:ref:`netket.utils.group.PermutationGroup` object.
Can be extracted from a :ref:`netket.graph.Lattice` object by calling
:meth:`~netket.graph.Lattice.point_group` or
:meth:`~netket.graph.Lattice.translation_group`.
Alternatively, if you have a :class:`netket.graph.Graph` object you
can build it from :meth:`~netket.graph.Lattice.automorphisms`.
.. code::
graph = nk.graph.Square(3)
symm_group = graph.point_group()
"""
character_id: Optional[int] = None
"""The # identifying the target character in the character table of
the symmetry group. By default the characters are taken to be all
`1`, giving the homogeneous state.
The characters are accessed as:
.. code::
symm_group.character_table()[character_id]
"""
[docs] @nn.compact
def __call__(self, x: Array):
"""
Accepts a single input or arbitrary batch of inputs.
The last dimension of x must match the shape of the permutation
group.
"""
# apply the group and obtain a x_symm of shape (N_symm, ...)
x_symm = self.symm_group @ x
# reshape it to (-1, N_sites)
x_symm_shape = x_symm.shape
x_symm = x_symm.reshape(-1, x.shape[-1])
# Compute the log-wavefunction obtaining (-1,) and reshape to (N_symm, ...)
psi_symm = self.module(x_symm).reshape(*x_symm_shape[:-1])
# Extract the characters. Those are compile-time constant (a numpy array).
if self.character_id is None:
characters = np.ones(len(np.asarray(self.symm_group)))
else:
characters = self.symm_group.character_table()[self.character_id]
characters = characters.reshape((-1,) + tuple(1 for _ in range(x.ndim - 1)))
# If those are all positive, then use standard logsumexp that returns a
# real-valued, positive logsumexp
logsumexp_fun = (
jax.scipy.special.logsumexp if np.all(characters >= 0) else logsumexp_cplx
)
# log (sum_i ( c_i/Nsymm* exp(psi[sigma_i])))
psi = logsumexp_fun(psi_symm, axis=0, b=characters / len(self.symm_group))
return psi