Source code for netket.models.gaussian
import flax.linen as nn
import jax.numpy as jnp
from flax.linen.dtypes import promote_dtype
from flax.linen.initializers import normal
from netket.utils.types import DType, Array, NNInitFunc
[docs]
class Gaussian(nn.Module):
r"""
Multivariate Gaussian function with mean 0 and parametrised covariance matrix
:math:`\Sigma_{ij}`.
The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`.
The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as
non-positive definite matrix A.
"""
param_dtype: DType = jnp.float64
"""The dtype of the weights."""
kernel_init: NNInitFunc = normal(stddev=1.0)
"""Initializer for the weights."""
[docs]
@nn.compact
def __call__(self, x_in: Array):
nv = x_in.shape[-1]
kernel = self.param("kernel", self.kernel_init, (nv, nv), self.param_dtype)
kernel = jnp.dot(kernel.T, kernel)
kernel, x_in = promote_dtype(kernel, x_in, dtype=None)
y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in)
return y