class netket.models.DeepSetRelDistance[source]#

Bases: flax.linen.module.Module

Implements an equivariant version of the DeepSets architecture given by (

\[f(x_1,...,x_N) = \rho\left(\sum_i \phi(x_i)\right)\]

that is suitable for the simulation of periodic systems. Additionally one can add a cusp condition by specifying the asymptotic exponent. For helium the Ansatz reads (

\[\psi(x_1,...,x_N) = \rho\left(\sum_i \phi(d_{\sin}(x_i,x_j))\right) \cdot \exp\left[-\frac{1}{2}\left(b/d_{\sin}(x_i,x_j)\right)^5\right]\]
cusp_exponent: Optional[int] = None#

exponent of Katos cusp condition

use_bias: bool = True#

if True uses a bias in all layers.


Returns the variables in this module.

Return type

Mapping[str, Mapping[str, Any]]

hilbert: netket.hilbert.ContinuousHilbert#

The hilbert space defining the periodic box where this ansatz is defined.

layers_phi: int#

Number of layers in phi network.

layers_rho: int#

Number of layers in rho network.

features_phi: Union[Tuple, int]#

Number of features in each layer for phi network.

features_rho: Union[Tuple, int]#

Number of features in each layer for rho network. If specified as a list, the last layer must have 1 feature.


Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

  • x (Any) – input array

  • approximate (bool) – whether to use the approximate or exact formulation.

Return type


bias_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[0., 0., 0.],
             [0., 0., 0.]], dtype=float32)

dtype (Any) –

distance(x, sdim, L)[source]#

Returns true if a PRNGSequence with name name exists.

Return type



name (str) –

kernel_init(shape, dtype=<class 'jax.numpy.float64'>)#
params_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
DeviceArray([[1., 1.],
             [1., 1.],
             [1., 1.]], dtype=float32)

dtype (Any) –

pooling(axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)#

Sum of array elements over a given axis.

LAX-backend implementation of numpy.sum().

Original docstring below.

  • a (array_like) – Elements to sum.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • dtype (dtype, optional) – The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the sum method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.


sum_along_axis – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.

Return type


put_variable(col, name, value)#

Sets the value of a Variable.

  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • value (Any) – the new value of the variable.


tabulate(rngs, *args, method=None, mutable=True, depth=None, exclude_methods=(), **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature as init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.


import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))

print(Foo().tabulate(jax.random.PRNGKey(0), x))

This gives the following output:

                   Foo Summary
┃ path    ┃ outputs       ┃ params               ┃
β”‚ Inputs  β”‚ float32[16,9] β”‚                      β”‚
β”‚ Dense_0 β”‚ float32[16,4] β”‚ bias: float32[4]     β”‚
β”‚         β”‚               β”‚ kernel: float32[9,4] β”‚
β”‚         β”‚               β”‚                      β”‚
β”‚         β”‚               β”‚ 40 (160 B)           β”‚
β”‚ Dense_1 β”‚ float32[16,2] β”‚ bias: float32[2]     β”‚
β”‚         β”‚               β”‚ kernel: float32[4,2] β”‚
β”‚         β”‚               β”‚                      β”‚
β”‚         β”‚               β”‚ 10 (40 B)            β”‚
β”‚ Foo     β”‚ float32[16,2] β”‚                      β”‚
β”‚         β”‚         Total β”‚ 50 (200 B)           β”‚

          Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

  • rngs (Union[Any, Dict[str, Any]]) – The rngs for the variable collections.

  • *args – The arguments to the forward computation.

  • method (Optional[Callable[..., Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€˜intermediates’ are mutable.

  • depth (Optional[int]) – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • exclude_methods (Sequence[str]) – A sequence of strings that specifies which methods should be ignored. In case a module calls a helper method from its main method, use this argument to exclude the helper method from the summary to avoid ambiguity.

  • **kwargs – keyword arguments to pass to the forward computation.

Return type



A string summarizing the Module.