Source code for netket.models.gaussian

import flax.linen as nn
import jax.numpy as jnp

from flax.linen.dtypes import promote_dtype
from flax.linen.initializers import normal

from netket.utils.types import DType, Array, NNInitFunc


[docs] class Gaussian(nn.Module): r""" Multivariate Gaussian function with mean 0 and parametrised covariance matrix :math:`\Sigma_{ij}`. The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`. The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as non-positive definite matrix A. """ param_dtype: DType = jnp.float64 """The dtype of the weights.""" kernel_init: NNInitFunc = normal(stddev=1.0) """Initializer for the weights."""
[docs] @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] kernel = self.param("kernel", self.kernel_init, (nv, nv), self.param_dtype) kernel = jnp.dot(kernel.T, kernel) kernel, x_in = promote_dtype(kernel, x_in, dtype=None) y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y