netket.nn.DenseEquivariant

netket.nn.DenseEquivariant#

netket.nn.DenseEquivariant(symmetries, features=None, mode='auto', shape=None, point_group=None, in_features=None, mask=None, **kwargs)[source]#

A group convolution operation that is equivariant over a symmetry group.

Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm] and returns a feature map of poses of shape [num_samples, features, num_symm]

G-convolutions are described in Cohen et. Al and applied to quantum many-body problems in Roth et. Al

The G-convolution generalizes the convolution to non-commuting groups:

\[C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h\]

Group elements that differ by the same symmetry operation (i.e. \(g = xh\) and \(g' = xh'\)) are connected by the same filter.

This layer maps an input of shape (..., in_features, n_sites) to an output of shape (..., features, num_symm).

Parameters:
  • symmetries – A specification of the symmetry group. Can be given by a nk.graph.Graph, an nk.utils.PermutationGroup, a list of irreducible representations or a product table.

  • point_group – The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.

  • mode – string β€œfft, irreps, matrix, auto” specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix.

  • shape – A tuple specifying the dimensions of the translation group.

  • features (Optional[int]) – The number of output features. The full output shape is [n_batch,features,n_symm].

  • use_bias – A bool specifying whether to add a bias to the output (default: True).

  • mask – Optional array of shape (n_symm,) where (n_symm,) = len(graph.automorphisms()) used to restrict the convolutional kernel. Only parameters with mask :math:’ne 0’ are used. For best performance a boolean mask should be used.

  • param_dtype – The datatype of the weights. Defaults to a 64bit float.

  • precision – Optional argument specifying numerical precision of the computation. see jax.lax.Precision for details.

  • kernel_init – Optional kernel initialization function. Defaults to variance scaling.

  • bias_init – Optional bias initialization function. Defaults to zero initialization.