netket.nn.DenseEquivariant#

netket.nn.DenseEquivariant(symmetries, features=None, mode='auto', shape=None, point_group=None, in_features=None, mask=None, **kwargs)[source]#

A group convolution operation that is equivariant over a symmetry group.

Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm] and returns a feature map of poses of shape [num_samples, features, num_symm]

G-convolutions are described in Cohen et. Al and applied to quantum many-body problems in Roth et. Al

The G-convolution generalizes the convolution to non-commuting groups:

$C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h$

Group elements that differ by the same symmetry operation (i.e. $$g = xh$$ and $$g' = xh'$$) are connected by the same filter.

This layer maps an input of shape (..., in_features, n_sites) to an output of shape (..., features, num_symm).

Parameters:
• symmetries β A specification of the symmetry group. Can be given by a nk.graph.Graph, an nk.utils.PermutationGroup, a list of irreducible representations or a product table.

• point_group β The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.

• mode β string βfft, irreps, matrix, autoβ specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix.

• shape β A tuple specifying the dimensions of the translation group.

• features () β The number of output features. The full output shape is [n_batch,features,n_symm].

• use_bias β A bool specifying whether to add a bias to the output (default: True).

• mask β Optional array of shape (n_symm,) where (n_symm,) = len(graph.automorphisms()) used to restrict the convolutional kernel. Only parameters with mask :math:βne 0β are used. For best performance a boolean mask should be used.

• param_dtype β The datatype of the weights. Defaults to a 64bit float.

• precision β Optional argument specifying numerical precision of the computation. see jax.lax.Precision for details.

• kernel_init β Optional kernel initialization function. Defaults to variance scaling.

• bias_init β Optional bias initialization function. Defaults to zero initialization.