Symmetries, Representations and Projectors for arbitrary NQS#

Author: Louis Sharma (College de France), September 2025

In this tutorial, we will discuss how to use symmetries of a quantum Hamiltonian to enhance the performance of a VMC calculation. More specifically, you will learn:

  • How to symmetrize a MCState with respect to lattice symmetries.

  • How to implement custom symmetry representations.

We will consider the Heisenberg antiferromagnet on a 2d \(L \times L\) square lattice:

\[ \hat H = J \sum_{⟨ij⟩} \hat{\vec S}_i \cdot \hat{\vec S}_j \]

where \(J>0\) is the antiferromagnetic exchange coupling, \(⟨ij⟩\) refers to pairs of first neighbor sites on the lattice and \(\vec{\hat S}_i = \frac 12 ( \hat \sigma_i^x, \hat \sigma_i^y, \hat \sigma_i^z)\) is the spin operator for site \(i\).

Through this tutorial we will consider translation symmetries and the spin-flip symmetry, and show how to combine them.

Definition#

Below we setup the lattice, hamiltonian and get a reference ground state energy with ED.

import netket as nk
import numpy as np
import jax
seed = jax.random.PRNGKey(1234) # For reproducibility
square_lattice = nk.graph.Square(length=4, pbc=True) # 4x4 square lattice with periodic boundary conditions
hilbert = nk.hilbert.Spin(s=0.5, N=square_lattice.n_nodes) #16 spin 1/2 particles 
H = nk.operator.Heisenberg(hilbert=hilbert, graph=square_lattice, J = 1.0, sign_rule=False)

square_lattice.draw();


# we can perform ED because it's a small sysem
from scipy.sparse.linalg import eigsh

H_sp = H.to_sparse()
evals, evecs = eigsh(H_sp, k=1, which='SA')  # 'SA' means smallest algebraic eigenvalue
print("Ground state energy (exact diagonalization): ", evals[0])
../_images/1be28f8ecf01c03e19a0cf52f1bed1ac68761e2a180b938b3f78fe290ced58c5.png

Inspecting symmetries#

The Hamiltonian commutes with the set of operators that correspond to a representation of the space group of the lattice. From representation theory, we know that we can use the irreducible representations (irreps) of the group to block diagonalize \(\hat H\) and restrict the search for the ground state to a particular irrep.

The first step in doing this is to select the relevant group. Here, we will consider the translation group of the lattice.

translation_group = square_lattice.translation_group()
for g in translation_group: 
    print(g)
Id()
Translation([0, 1])
Translation([0, 2])
Translation([0, 3])
Translation([1, 0])
Translation([1, 1])
Translation([1, 2])
Translation([1, 3])
Translation([2, 0])
Translation([2, 1])
Translation([2, 2])
Translation([2, 3])
Translation([3, 0])
Translation([3, 1])
Translation([3, 2])
Translation([3, 3])

The elements of translation_group correspond to permutations of the lattice sites.

print("Permutation corresponding to a translation by R= [0,1]: ", 
      translation_group[1].permutation_array)
Permutation corresponding to a translation by R= [0,1]:  [ 1  2  3  0  5  6  7  4  9 10 11  8 13 14 15 12]

We can also view the characters which classify different irreducible representations.

print("Second row of the character table:", translation_group.character_table()[1])
Second row of the character table: [ 1.+0.j  1.+0.j  1.+0.j  1.+0.j  0.+1.j  0.+1.j  0.+1.j  0.+1.j -1.+0.j
 -1.+0.j -1.+0.j -1.+0.j  0.-1.j  0.-1.j  0.-1.j  0.-1.j]

As it turns out, the characters of the translation group can all be written in the form:

\[\chi_{\vec k}(\vec R) = e^{i \vec k \cdot \vec R}\]

where \(\chi_{\vec k}(\vec R)\) is the character corresponding to the translation by a lattice vector \(\vec R\) and \(\vec k\) is a vector in the first Brillouin zone.

In principle, the true ground state may be found at any value of \(\vec k\). However, due to symmetry, we can restrict our search for the ground state to the irreducible Brillouin zone (red triangle on the plot below).

import matplotlib.pyplot as plt
kx = np.linspace(-np.pi, np.pi, 4+1, endpoint=True)
ky = np.linspace(-np.pi, np.pi, 4+1, endpoint=True)

Kx, Ky = np.meshgrid(kx, ky)
plt.scatter(Kx.flatten(), Ky.flatten(), s=10, color='black')
plt.plot([0,np.pi], [0, np.pi], color='r', lw=0.5)
plt.plot([0,np.pi], [0, 0], color='r', lw=0.5)
plt.plot([np.pi, np.pi], [0, np.pi], color='r', lw=0.5)
plt.gca().set_aspect('equal', adjustable='box')
plt.xlabel(r'$k_x$')
plt.ylabel(r'$k_y$')
plt.title('Brillouin zone of the square lattice')
Text(0.5, 1.0, 'Brillouin zone of the square lattice')
../_images/c184f9aad8ccace2dd06c6a1a1bcf07fdac53c451969bcdbd8da56fbb1ba501c.png

Representations#

The next step is to construct netket.operator objects from the group elements that can act on the states of the Hilbert space. This is known as a representation. For lattice symmetries, NetKet has built in methods to construct representations using the nk.symmetry.canonical_representation() function, which takes a Hilbert space and a permutation group as arguments.

Note

This function is experimental and constructs the canonical representation of lattice symmetry groups as permutation groups acting on the lattice sites. For other types of symmetry groups (e.g., continuous symmetries), this representation is not defined.

translation_group_representation = nk.symmetry.canonical_representation(
    hilbert=hilbert,
    group=square_lattice.translation_group()
)

The result is a TranslationRepresentation, which extends the base Representation with momentum-aware features: the available k-points, human-readable irrep labels, and a projector(k=...) method to directly select a momentum sector.

For more details on the mathematical foundations of permutation operators, see the symmetry documentation.

In the custom symmetries section below, we will show how to construct a Representation manually for a group that is not a lattice symmetry.

Custom symmetries: spin flip#

In this part of the tutorial, we will see how to construct a representation of this group in NetKet on our spin Hilbert space.

Note

The spin-flip representation is already available as spin_flip_representation(), but here we show how to build it by hand to illustrate the general procedure for defining custom symmetry representations.

The Heisenberg Hamiltonian also commutes with the spin flip operator:

\[\hat \sigma^x = \bigotimes_i \hat \sigma_i^x\]

The set of operators \(\hat I, \hat \sigma^x\) form a representation of the group \(\mathbb Z_2\).

The LabeledRepresentation class needs two fundamental things to function:

  • A FiniteGroup object — the abstract group.

  • A dict mapping group elements to operators on the Hilbert space.

The group \(\mathbb Z_2\) is just a set with 2 elements \(\{e, g\}\) with one rule: \(g^2= e\). A simple example of a group that follows this blueprint is the symmetric group \(\mathcal S_2.\) This group can be implemented in NetKet using cyclic_group().

from netket.utils.group import Permutation
from netket.symmetry import LabeledRepresentation
group = nk.symmetry.group.cyclic_group(2)

cyclic_group() returns a PermutationGroup with two elements: the identity (E) and the swap (C2), which correspond to the two elements of \(\mathbb{Z}_2\).

From the group object we can extract the character table and other group-theoretic properties:

group.character_table()
array([[ 1.,  1.],
       [ 1., -1.]])

Next, we need to define the operators \(\hat I\) and \(\hat \sigma^x\) which furnish the representation of our group on the spin Hilbert space

spin_flip = nk.operator.PauliStrings(hilbert, "X"*square_lattice.n_nodes)

identity = nk.operator.spin.identity(hilbert=hilbert) #identity operator

To check that this all works as expected, recall that the states of the computational basis of our spin \(1/2\) system are encoded as jax.Array objects:

  • array[i] refers to the spin on the \(i\)th site.

  • array[i] = -1 for spin down or +1for spin up

Therefore, we can generate a random state of the basis and apply our spin flip operator to it using the get_conn_padded method. The resulting array should send all \(-1\) to \(1\) and \(1\) to \(-1\) in the original array.

# Generate a random state of the basis
state = hilbert.random_state(key=seed, size=1)
new_state, matrix_element = spin_flip.get_conn_padded(state)

print("Original state: ", state)
print("State after spin flip: ", new_state)
print("Sum of element-wise entries:", new_state + state)
Original state:  [[ 1  1  1  1 -1 -1 -1 -1  1  1  1 -1 -1 -1 -1 -1]]
State after spin flip:  [[[-1 -1 -1 -1  1  1  1  1 -1 -1 -1  1  1  1  1  1]]]
Sum of element-wise entries: [[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]]

Now we create the dictionary and pass it to instantiate a LabeledRepresentation object. Note that the keys of the dictionary must be identical to the elements of the group argument.

The labels are automatically derived from the character table: "+1" for the trivial (even-parity) irrep and "-1" for the alternating (odd-parity) irrep.

representation_dict = {group[0]: identity, group[1]: spin_flip}
spin_flip_representation = LabeledRepresentation(
    group=group,
    representation_dict=representation_dict,
)
print("Irrep labels:", spin_flip_representation.irrep_labels)
Irrep labels: ['+1', '-1']

Now we can optimize a new vstate which is symmetric with respect to this group.

# Reinitialize the variational state using the same seed
vstate = nk.vqs.MCState(sampler, model, n_samples=512, seed=seed)
# Project onto the even-parity sector
spin_flip_symmetric_vstate = spin_flip_representation.project(
    state=vstate,
    label="+1",
)

driver = VMC_SR(
    hamiltonian=H,
    optimizer=optimizer,
    linear_solver=solver,
    variational_state=spin_flip_symmetric_vstate,
    diag_shift=diag_shift,
)
driver.run(n_iter=300, out=nk.logging.RuntimeLog());
Automatic SR implementation choice:  NTK
online_statistics: chain_length=8, exponential moving average window: 50, decay=0.840
E_spin_flip = float(spin_flip_symmetric_vstate.expect(H).mean.real)
err_spin = 100 * relative_error(E_spin_flip, evals[0])

print(f"Energy with spin-flip symmetry (trivial irrep): {E_spin_flip:.4f}")
print(f"  Relative error: {err_spin:.4f} %")
Energy with spin-flip symmetry (trivial irrep): -40.6671
  Relative error: 9.4556 %

Combining representations#

When we have two commuting groups, \(G_1\) and \(G_2\) and two representations \(\hat U\) and \(\hat V\) on the same vector space \(\mathcal{H}\), we can define the following product representation \(\hat \Gamma\) from \(G_1 \times G_2 \to \mathcal{H}\) such that \(\hat \Gamma(g_1 g_2) = \hat U_{g_1} \hat V_{g_2}\) As it turns out, characters of the irreps of \(\hat \Gamma\), satisfy \(\chi_{\mu, \nu}(g_1 g_2) = \chi_\mu(g_1) \chi_\nu(g_2)\) where \(\chi_\mu(g_1)\) (resp. \(\chi_\nu(g_2)\)) are the characters of the irreps of \(\hat U\) (resp. \(\hat V\)) evaluated on element \(g_1\) (resp. \(g_2\)).

We can apply this to the translation group and the spin-flip group! Essentially, we can combine these two groups and classify the eigenstates by their momentum and their spin flip parity. To do this in NetKet, we first project the state onto an irrep of one group then do another projection onto the other group.

vstate = nk.vqs.MCState(sampler, model, n_samples=512, seed=seed)
# Projector onto the Γ-point sector (k=(0,0))
projector_T = translation_group_representation.projector(k=(0.0, 0.0))
# Projector onto the even-parity sector
projector_S = spin_flip_representation.projector(label="+1")

# Combine the two projectors
projector_ST = projector_S @ projector_T
# Apply the combined projector to the vstate
trans_and_spin_symmetric_vstate = nk.vqs.apply_operator(projector_ST, vstate)
driver = VMC_SR(
    hamiltonian=H,
    optimizer=optimizer,
    linear_solver=solver,
    variational_state=trans_and_spin_symmetric_vstate,
    diag_shift=diag_shift,
)
# 150 steps since it will take longer to run
driver.run(n_iter=150, out=nk.logging.RuntimeLog())
Automatic SR implementation choice:  NTK
online_statistics: chain_length=8, exponential moving average window: 50, decay=0.840
(RuntimeLog():
  keys = ['Energy', 'acceptance', 'wallclock', 'Energy_ema'],)
E_product = float(trans_and_spin_symmetric_vstate.expect(H).mean.real)

results = [
    ("without symmetrization", E_no_symm),
    ("with spin-flip symmetry (trivial irrep)", E_spin_flip),
    ("with translation symmetry (trivial irrep)", E_symm),
    ("with product symmetry (trivial irrep)", E_product),
]

for label, energy in results:
    err = 100 * relative_error(energy, evals[0])
    print(f"Energy {label:35s}: {energy:8.4f}   Relative error: {err:6.4f} %")
Energy without symmetrization             : -42.5803   Relative error: 5.1958 %
Energy with spin-flip symmetry (trivial irrep): -40.6671   Relative error: 9.4556 %
Energy with translation symmetry (trivial irrep): -44.7707   Relative error: 0.3189 %
Energy with product symmetry (trivial irrep): -44.8681   Relative error: 0.1020 %

Extension to fermionic systems#

The concepts we’ve discussed so far naturally extend to fermionic systems, but there are two subtleties to bear in mind:

  • On top of “spatial” degrees of freedom, fermionic Hilbert spaces may have additional degrees of freedom, like spin.

  • Fermionic states are antisymmetric with respect to particle exchange.

In NetKet fermionic Hilbert spaces are handled by the SpinOrbitalFermions class. Let’s define a spin \(1/2\) fermion Hilbert space on the square lattice.

fermion_hilbert = nk.hilbert.SpinOrbitalFermions(
    n_orbitals=square_lattice.n_nodes, 
    s = 1/2, 
    n_fermions_per_spin=(8,8) # Half-filling, zero spin
)
print("Size of basis states: ", fermion_hilbert.size)
Size of basis states:  32

So we need 32 occupation numbers to specify a state on the 16 site lattice.

This is because SpinOrbitalFermions’ factorization is such that the first n_orbitals entries correspond to the occupations of spin sector 0 and so on. Therefore, a permutation of physical sites must be extended to match hilbert.size.

Let’s do a quick example with the translation operator.

T1_perm = translation_group[1].permutation_array
T1_perm_extended = np.concatenate(
    [T1_perm, T1_perm + fermion_hilbert.n_orbitals]
) #extend the permutation to the second spin sector
print("Extended permutation corresponding to a translation by R= [0,1]: ", T1_perm_extended)
Extended permutation corresponding to a translation by R= [0,1]:  [ 1  2  3  0  5  6  7  4  9 10 11  8 13 14 15 12 17 18 19 16 21 22 23 20
 25 26 27 24 29 30 31 28]

As it turns out, the correct way to define the representation \(\hat U\) of a permutation group \(G\) on a fermionic Hilbert space is by the following: for all \(g\in G\) and basis states \(|n\rangle\),

\[\hat U_g |n\rangle = \xi_g(n) |n \circ g^{-1} \rangle \]

where \(\xi_g(n)\) is the sign resulting from applying the permutation to the state \(|n\rangle\). This sign is ultimately a consequence of the anti-commutation relations obeyed by the fermionic operators. More details can be found in the symmetry documentation.

The PermutationOperatorFermion class is a good tool to handle representations of permutation groups on fermionic spaces. This class works similarly to PermutationOperator but its get_conn_padded method also computes the sign.

from netket.operator.permutation import PermutationOperatorFermion

T1_fermion_perm = Permutation(permutation_array=T1_perm_extended, name="T[0,1]")
T1_fermion_op = PermutationOperatorFermion(fermion_hilbert, T1_fermion_perm)

#apply the operator to a batch of basis 
states = fermion_hilbert.random_state(key=seed, size=5)
new_states, matrix_elements = T1_fermion_op.get_conn_padded(states)
print("Matrix elements: ", matrix_elements)
Matrix elements:  [[ 1.]
 [-1.]
 [ 1.]
 [-1.]
 [-1.]]

As you can see, some of the matrix elements are \(-1\).

Fortunately, for lattice permutation symmetries, NetKet can automatically construct Representation objects on fermionic spaces, without needing to manually extend the permutations.

This is done by passing a SpinOrbitalFermions Hilbert space into nk.symmetry.canonical_representation(hilbert, group) where group can be obtained from methods like translation_group(), rotation_group(), point_group() or space_group().

fermi_translation_group_rep = nk.symmetry.canonical_representation(
    hilbert=fermion_hilbert,
    group=square_lattice.translation_group()
)
T01 = fermi_translation_group_rep.operators[1]

states_2, matrix_elements_2 = T01.get_conn_padded(states)

if np.all(matrix_elements == matrix_elements_2):
    print("The matrix elements are the same!")

if np.all(new_states == states_2):
    print("The new states are the same!")
The matrix elements are the same!
The new states are the same!