Iterative Symmetrization: Coset Filters for Translation Groups#
Projecting a variational state onto a momentum sector is a common procedure that relies on the Bloch projector
Applying this projector to a state increases its cost by a factor \(|G|\), as the variational state will have to be re-evaluated \(|G|\) times every time we want to compute an amplitude, which may be expensive for large systems.
This tutorial discusses two situations where this cost can be avoided or attenuated:
Incremental simmetrization: start with a coarse subgroup \(H \leq G\) containing fewer terms, and gradually add the remaining symmetry as training progresses.
Equivariant networks: the architecture is already equivariant under a subgroup \(H\), and only a small number of symmetries, the coset \(G/H\) β of size \(|G|/|H|\), needs to be summed explicitly.
In this notebook we discuss translation symmetry, however the same ideas apply to any Abelian group.
This tutorial assumes that you are familiar with NetKetβs handling of symmetries and their representation. If you are not, look at NetKet symmetry tutorial for a general introduction to symmetry-adapted variational states.
import numpy as np
import netket as nk
np.random.seed(42)
L = 8
lattice = nk.graph.Chain(L, pbc=True)
hi = nk.hilbert.Spin(0.5, L)
β£NKβ© Tip: Debug multi-node HPC? `djaxrun -np 2 python Examples/Sharding/multi_process.py`
Translation groups and momentum projectors#
The translation group \(\mathbb{Z}_L\) of a 1D periodic chain contains cyclic shifts \(\{T^0, T^1, \ldots, T^{L-1}\}\), where \(T^n\) moves every spin \(n\) sites to the right. Its irreducible representations are labelled by the Bloch momentum \(k = 2\pi m / L\), \(m \in \{0, \ldots, L-1\}\).
NetKet represents the translation symmetry with the TranslationSimmetry object that can be built from the lattice as lattice.translation_group(). This group has dimension \(L\).
It is also possible to define the (smaller) strided subgroup, which uses only every \(s\)-th translation via
lattice.translation_group(strides=s), giving a group of size \(L/s\) with \(L/s\) distinct momenta.
A larger stride gives a smaller group (aka a coarser symmetry), but the projector will be cheaper.
NetKet encodes the operator representation of the translation symmetry in a TranslationRepresentation, built via nk.symmetry.canonical_representation:
T = nk.symmetry.canonical_representation(hi, lattice.translation_group())
The translation representation class exposes, among other things, the following methods:
.k_pointsβ array of valid Bloch momenta for this group..projector(k_point)β builds \(P_k = \frac{1}{|G|}\sum_{n} e^{-ikn}\, T^n\) as aSumOperatorwith \(|G|\) terms; cost \(\propto |G|\).
In the following, we build the representation for the full translation group T1 and two subgroups T2 and T4 and print some information about those objects.
T1 = nk.symmetry.canonical_representation(hi, lattice.translation_group()) # all 8
T2 = nk.symmetry.canonical_representation(hi, lattice.translation_group(strides=2)) # every 2nd
T4 = nk.symmetry.canonical_representation(hi, lattice.translation_group(strides=4)) # every 4th
print(f"T1 |G|={len(T1.group.elems)} k/Ο: {np.round(T1.k_points.ravel() / np.pi, 2)}")
print(f"T2 |G|={len(T2.group.elems)} k/Ο: {np.round(T2.k_points.ravel() / np.pi, 2)}")
print(f"T4 |G|={len(T4.group.elems)} k/Ο: {np.round(T4.k_points.ravel() / np.pi, 2)}")
T1 |G|=8 k/Ο: [ 0. 0.25 0.5 0.75 -1. -0.75 -0.5 -0.25]
T2 |G|=4 k/Ο: [ 0. 0.25 -0.5 -0.25]
T4 |G|=2 k/Ο: [ 0. -0.25]
The projectors are built by calling Representation.projector(k=value), and it is possible to verify that larger groups have more terms:
P1 = T1.projector(k=0.0)
print(f"P_T1(k=0): {len(P1.operators)} terms")
P2 = T2.projector(k=0.0)
print(f"P_T2(k=0): {len(P2.operators)} terms")
P4 = T4.projector(k=0.0)
print(f"P_T4(k=0): {len(P4.operators)} terms")
P_T1(k=0): 8 terms
P_T2(k=0): 4 terms
P_T4(k=0): 2 terms
The k-points for T2 and T4 are a subset of those for T1: a strided group can
only resolve momenta that are multiples of its reciprocal-lattice spacing.
For example, T4 (stride 4, 2 elements) has only \(k = 0\) and \(k = -\pi/4\).
A strided projector \(P_{T_s}(k=0)\) does not uniquely select the \(k=0\) sector: because \(T_s\) contains only every \(s\)-th translation, its projector only enforces \(e^{iks}=1\), which is satisfied by \(s\) different momenta. For example, \(P_{T_4}(k=0)\) projects onto a 4-fold degenerate subspace spanning \(k \in \{0, \pi/2, \pi, -\pi/2\}\). Applying coset refinement filters progressively resolves this degeneracy until the unique \(k=0\) sector is reached.
Coset filters#
Given a subgroup \(H \leq G\), the left cosets partition \(G\) into \(|G|/|H|\) disjoint classes \(G \;=\; \bigsqcup_{c \,\in\, G/H} c\,H\) . Picking one representative \(c\) per class defines a set \(C \subset G\) with \(|C| = |G|/|H|\).
The coset Fourier filter is
where only \(|C|\) generators \(T^c\) appear instead of the full \(|G|\) set of generators. This filter satisfies the identity
which means that the projector over the full group \(G\) can be factored into a projector on the subgroup \(H\) left-multiplied by a coset Fourier filter which contains fewer operators. This factorizes the \(|G|\)-term projector into two cheaper steps: \(P_H\) with \(|H|\) terms, then \(F_{G/H}\) with \(|G/H|\) terms. The total cost is unchanged, but it allows to iteratively add more and more fourier filters, or to βrefineβ a nqs model that is only partly symmetrized (as patched transformers).
In NetKet, the coset filter is obtained by first building the TranslationCosetFilter and then by building the refinement projector operator, as follows:
C = full_rep.coset_filter(sub_rep)
F = C.projector_refinement(k=k)
In the following, we consider the full group T1 with 8 translations, and the subgroup T2 with only 4 translations.
We build the coset and refinement projector for T1/T2:
print('T1 has #' + str(len(T1.group)), 'elements')
print('T2 has #' + str(len(T2.group)), 'elements')
C = T1.coset_filter(T2)
print(C)
print(f'C has {C.n_coset_reps} coset representatives')
T1 has #8 elements
T2 has #4 elements
TranslationCosetFilter(size=2 (8/4), full_group=8 (strides=[1]), sub_group=4 (strides=[2]))
C has #2 elements
And here we can show the equivalence of iterative refinement: The projector on the wave-vector k from T1 is equivalent to the one built from T2 and the coset:
# k = -pi/4 is a valid k-point for T1, T2, and T4
k = -np.pi / 4
P_full = T1.projector(k=k)
P_sub = T2.projector(k=k)
F = C.projector_refinement(k=k)
P_composed = F @ P_sub
print(f"\nAt k = -Ο/4:")
print(f" P_T1 terms : {len(P_full.operators)}")
print(f" P_T2 terms : {len(P_sub.operators)}")
print(f" Filter terms. : {len(F.operators)}")
print(f" Composed P_T2 terms: {len(P_composed.operators)}")
np.testing.assert_allclose(P_full.to_dense(), P_composed.to_dense(), atol=1e-15)
At k = -Ο/4:
P_T1 terms : 8
P_T2 terms : 4
Filter terms. : 2
Composed P_T2 terms: 8
The two routes give identical results. Applying \(P_{T_2}\) first (4 terms) then the coset filter (2 terms) is equivalent to the direct 8-term projector β but the coset filter step alone is the smallest possible operator for this refinement.
Iterative refinement: \(T_1 \supset T_2 \supset T_4\)#
The coset identity can be applied recursively to further decompose a projector into more coset refinement operators (you can think of them as projectors, but that picture is not precise). With the sequence of translation groups \(T_1 \supset T_2 \supset T_4\) (strides 1, 2, 4):
Level |
Translations |
Size |
Coset reps |
|---|---|---|---|
\(T_4\) |
\(\{T^0, T^4\}\) |
2 |
β (base projector) |
\(T_2/T_4\) |
\(\{T^0, T^2\}\) |
2 |
step 1 |
\(T_1/T_2\) |
\(\{T^0, T^1\}\) |
2 |
step 2 |
Each factor is a 2-term operator. Sequentially applying three 2-term operators is equivalent to the single 8-term \(P_{T_1}\).
A common recipe for NQS training is the following:
Apply only \(P_{T_4}\) (2 terms) β cheapest, weakest symmetry. Train.
Prepend \(F_{T_2/T_4}\) (2 more terms) β now the state has \(T_2\) symmetry. Train.
Prepend \(F_{T_1/T_2}\) (2 more terms) β full \(T_1\) symmetry. Fine-tune.
Each step is cheap because it builds on top of the already-symmetrized state.
C12 = T1.coset_filter(T2) # T1 / T2: reps {T^0, T^1}
C24 = T2.coset_filter(T4) # T2 / T4: reps {T^0, T^2}
for name, C in [("T1/T2", C12), ("T2/T4", C24)]:
print(f"{name}: {C.n_coset_reps} coset representatives")
# Apply the three-step chain at k=0 (Ξ-point, valid for all three groups)
psi = np.random.randn(hi.n_states) + 1j * np.random.randn(hi.n_states)
k = 0.0
step0 = T4.projector(k=k).to_dense() @ psi # project onto T4 sector (2 terms)
step1 = C24.projector_refinement(k=k).to_dense() @ step0 # refine to T2 (2 terms)
step2 = C12.projector_refinement(k=k).to_dense() @ step1 # refine to T1 (2 terms)
direct = T1.projector(k=k).to_dense() @ psi
print(f"\nIterative == direct: {np.allclose(step2, direct, atol=1e-12)}")
print(f"Direct P_T1: {len(T1.projector(k=k).operators)} terms")
Iterative == direct: True
Direct P_T1: 8 terms
The three-step training recipe in code β each apply_operator call wraps the
existing state with one more symmetry level:
hamiltonian = nk.operator.Heisenberg(hi, lattice)
sampler = nk.sampler.MetropolisLocal(hi)
vstate = nk.vqs.MCState(sampler, nk.models.RBM(), n_samples=512)
vstate = nk.vqs.apply_operator(T4.projector(k=0.0), vstate) # T4 symmetry (2 terms)
# ... train ...
vstate = nk.vqs.apply_operator(C24.projector_refinement(k=0.0), vstate) # T2 symmetry (4 terms)
# ... train ...
vstate = nk.vqs.apply_operator(C12.projector_refinement(k=0.0), vstate) # T1 symmetry (8 terms)
# ... fine-tune ...
Application: patch-based neural networks#
It is common for NN architectures to split the input ino patches patches and processes each patch with a shared encoder before aggregating across patches. See for example the patched transformer tutorial. For a chain of \(L\) sites with patch size \(s\), there are \(P = L/s\) patches.
Built-in symmetry: NQS based on a ViT are usually built to be equivariant under permutations of patch positions. The model already has \(T_2\)-symmetry (translations by multiples of \(s\), i.e., inter-patch translations). Enforcing full \(T_1\)-symmetry then only requires summing over the coset \(T_1/T_2\) β just \(s\) forward passes instead of \(L\).
The cell below shows a minimal patch NQS using
nk.nn.DenseSymm
to enforce equivariance over the \(P\) patch positions.
DenseSymm maps \((\ldots, d_\mathrm{in}, P) \to (\ldots, d_\mathrm{out}, P)\) equivariantly
under a permutation group acting on the \(P\) patch indices.
import flax.linen as nn
import jax.numpy as jnp
patch_size = 2
n_patches = L // patch_size # 4 patches for L=8
patch_lattice = nk.graph.Chain(n_patches, pbc=True)
class PatchNQS(nn.Module):
"""Patch NQS equivariant under inter-patch translations (T2, stride=patch_size)."""
features: int = 16
@nn.compact
def __call__(self, x):
# Shared patch embedding
x = x.reshape(*x.shape[:-1], n_patches, patch_size) # (..., 4, 2)
h = nn.Dense(self.features)(x)
h = jnp.tanh(h)
# Equivariant aggregation over patch positions via DenseSymm
h = jnp.swapaxes(h, -1, -2) # (..., features, 4)
h = nk.nn.DenseSymm(patch_lattice.translation_group(), features=1)(h) # (..., 1, 4)
return jnp.sum(h, axis=(-1, -2)) # (batch,)
# DenseSymm enforces equivariance under the 4 inter-patch translations (T2, stride 2).
# For full T1 symmetry, only the coset T1/T2 = {T^0, T^1} needs explicit summation.
C = T1.coset_filter(T2)
print(f"Built-in T2 symmetry : {len(T2.group.elems)} inter-patch translations")
print(f"Coset T1/T2 : {C.n_coset_reps} intra-patch passes for full T1")
print(f"Speed-up vs direct T1: Γ{len(T1.group.elems) // C.n_coset_reps}")
sampler = nk.sampler.MetropolisLocal(hi)
vstate = nk.vqs.MCState(sampler, PatchNQS(), n_samples=512)
print(f"\n{vstate}")
Built-in T2 symmetry : 4 inter-patch translations
Coset T1/T2 : 2 intra-patch passes for full T1
Speed-up vs direct T1: Γ4
MCState(hilbert = Spin(s=1/2, N=8, ordering=new), sampler = MetropolisSampler(rule = LocalRule(), n_chains = 64, sweep_size = 8, reset_chains = False, machine_power = 2, dtype = int8), n_samples = 512)
To make the variational state fully \(T_1\)-symmetric at a fixed momentum \(k\),
one wraps PatchNQS with a model that sums over the two intra-patch configurations
\(\{x, T^1 x\}\) with Bloch phase weights \(e^{-ikn}\).
This is exactly what a SymmetrizedNQS(group=C, characters=...) would do,
at cost 2 forward passes β compared to 8 for a naive full-group sum.
The coset structure thus provides a principled way to decompose the symmetrization cost across the model architecture (through equivariance, cheap) and the explicit sum-symmetrization (coset filter, expensive). Typically youβd use a model that is equivariance for an extensively-sized group, and youβd manually symmetrize over a coset of small constant size.
The typical example is the patched tranformer: it is equivariant for translations among patches, which has an extensive size, and we need to manually symmetrize only over a size-2 coset including \(T_0\) and \(T_1\).
A practical recipe for ViT-style models is as follows:
Simple Training
PatchNQSfreely β the architecture already enforces T2 symmetry.Apply the coset filter: wrap the trained model so that every forward pass sums the
T1/T2coset representatives with Bloch phases, projecting onto k=0.Fine-tune the wrapped model with a smaller learning rate. This will cost twice as much
The parameter transplant in step 2 preserves the pre-trained weights.
import jax
hamiltonian = nk.operator.Heisenberg(hi, lattice)
sampler = nk.sampler.MetropolisLocal(hi)
# ββ Phase 1: pre-train without explicit T1 projection βββββββββββββββββββββββββ
vstate = nk.vqs.MCState(sampler, PatchNQS(), n_samples=512)
gs1 = nk.VMC(hamiltonian, nk.optimizer.Adam(0.01), variational_state=vstate)
log1 = gs1.run(n_iter=300, out="run_phase1")
print("Phase 1 energy:", vstate.expect(hamiltonian))
# ββ Phase 2: wrap with T1/T2 coset filter and fine-tune βββββββββββββββββββββββ
# C = T1.coset_filter(T2) was built in the previous section
k = 0.0
vstate2 = nk.vqs.apply_operator(C, vstate) # apply coset filter to pre-trained state
gs2 = nk.VMC(hamiltonian, nk.optimizer.Adam(0.001), variational_state=vstate2)
log2 = gs2.run(n_iter=100, out="run_phase2")
print("Phase 2 energy:", vstate2.expect(hamiltonian))