netket.nn.DenseEquivariant#
- netket.nn.DenseEquivariant(symmetries, features=None, mode='auto', shape=None, point_group=None, in_features=None, mask=None, **kwargs)[source]#
A group convolution operation that is equivariant over a symmetry group.
Acts on a feature map of symmetry poses of shape
[num_samples, in_features, num_symm]
and returns a feature map of poses of shape[num_samples, features, num_symm]
G-convolutions are described in Cohen et. Al and applied to quantum many-body problems in Roth et. Al
The G-convolution generalizes the convolution to non-commuting groups:
\[C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h\]Group elements that differ by the same symmetry operation (i.e. \(g = xh\) and \(g' = xh'\)) are connected by the same filter.
This layer maps an input of shape
(..., in_features, n_sites)
to an output of shape(..., features, num_symm)
.- Parameters:
symmetries β A specification of the symmetry group. Can be given by a nk.graph.Graph, an nk.utils.PermutationGroup, a list of irreducible representations or a product table.
point_group β The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.
mode β string βfft, irreps, matrix, autoβ specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix.
shape β A tuple specifying the dimensions of the translation group.
features (
Optional
[int
]) β The number of output features. The full output shape is [n_batch,features,n_symm].use_bias β A bool specifying whether to add a bias to the output (default: True).
mask β Optional array of shape
(n_symm,)
where(n_symm,) = len(graph.automorphisms())
used to restrict the convolutional kernel. Only parameters with mask :math:βne 0β are used. For best performance a boolean mask should be used.param_dtype β The datatype of the weights. Defaults to a 64bit float.
precision β Optional argument specifying numerical precision of the computation. see
jax.lax.Precision
for details.kernel_init β Optional kernel initialization function. Defaults to variance scaling.
bias_init β Optional bias initialization function. Defaults to zero initialization.