netket.nn.DenseSymm#

netket.nn.DenseSymm(symmetries, point_group=None, mode='auto', shape=None, **kwargs)[source]#

Implements a projection onto a symmetry group. The output will be equivariant with respect to the symmetry operations in the group and can be averaged to produce an invariant model.

This layer maps an input of shape (…, in_features, n_sites) to an output of shape (…, features, num_symm).

Note: The output shape has changed to seperate the feature and symmetry dimensions. The previous shape was [num_samples, num_symm*features] and the new shape is [num_samples, features, num_symm]

Parameters
  • symmetries – A specification of the symmetry group. Can be given by a netket.graph.Graph, a netket.utils.group.PermutationGroup, or an array of shape (n_symm, n_sites). A netket.utils.HashableArray may also be passed. specifying the permutations corresponding to symmetry transformations of the lattice.

  • point_group – The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten.

  • mode – string β€œfft, matrix, auto” specifying whether to use a fast Fourier transform, matrix multiplication, or to choose a sensible default based on the symmetry group.

  • shape – A tuple specifying the dimensions of the translation group.

  • features – The number of output features. The full output shape is [n_batch,features,n_symm].

  • use_bias – A bool specifying whether to add a bias to the output (default: True).

  • mask – An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape.

  • dtype – The datatype of the weights. Defaults to a 64bit float.

  • precision – Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details.

  • kernel_init – Optional kernel initialization function. Defaults to variance scaling.

  • bias_init – Optional bias initialization function. Defaults to zero initialization.