netket.models.DeepSetMLP#

class netket.models.DeepSetMLP[source]#

Bases: flax.linen.module.Module

Implements the DeepSets architecture, which is permutation invariant.

\[f(x_1,...,x_N) = \rho\left(\sum_i \phi(x_i)\right)\]

that is suitable for the simulation of bosonic.

The input shape must have an axis that is reshaped to (…, N, D), where we pool over N.

See DeepSetRelDistance for the bosonic wave function ansatz in https://arxiv.org/abs/1703.06114

Attributes
features_phi: Optional[Union[int, Tuple[int, ...]]] = None#

Number of features in each layer for phi network. When features_phi is None, no phi network is created.

features_rho: Optional[Union[int, Tuple[int, ...]]] = None#

Number of features in each layer for rho network. Should not include the final layer of dimension 1, which is included automatically. When features_rho is None, a single layer MLP with output 1 is created.

output_activation: Optional[Callable] = None#

The nonlinear activation function at the output layer.

precision: Optional[jax._src.lax.lax.Precision] = None#

numerical precision of the computation see `jax.lax.Precision`for details.

use_bias: bool = True#

if True uses a bias in all layers.

variables#

Returns the variables in this module.

Return type

Mapping[str, Mapping[str, Any]]

Methods
bias_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
DeviceArray([[0., 0., 0.],
             [0., 0., 0.]], dtype=float32)
Return type

Any

Parameters
has_rng(name)#

Returns true if a PRNGSequence with name name exists.

Return type

bool

Parameters

name (str) –

hidden_activation(approximate=True)#

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters
  • x (Any) – input array

  • approximate (bool) – whether to use the approximate or exact formulation.

Return type

Any

is_initializing()#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

Return type

bool

kernel_init(shape, dtype=<class 'jax.numpy.float64'>)#
Return type

Any

Parameters
perturb(name, value, collection='perturbations')#

Add an zero-value variable (β€˜perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument.

Note: this is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(3)(x)
        x = self.perturb('dense3', x)
        return nn.Dense(2)(x)

def loss(params, perturbations, inputs, targets):
  variables = {'params': params, 'perturbations': perturbations}
  preds = model.apply(variables, inputs)
  return jnp.square(preds - targets).mean()

x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924   -0.44332537  0.02422847]
                            #      [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

model.apply({'params': params, 'perturbations': perturbations}, x) # works as expected
model.apply({'params': params}, x) # behaves like a no-op
Return type

TypeVar(T)

Parameters
  • name (str) –

  • value (flax.linen.module.T) –

  • collection (str) –

pooling(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)#

Sum of array elements over a given axis.

LAX-backend implementation of numpy.sum().

Original docstring below.

Parameters
  • a (array_like) – Elements to sum.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • dtype (dtype, optional) – The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the sum method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.

  • promote_integers (bool, default True) – If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input. promote_integers is ignored if dtype is specified.

  • out (None) –

Returns

sum_along_axis – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.

Return type

ndarray

put_variable(col, name, value)#

Updates the value of the given variable if it is mutable, or an error otherwise.

Parameters
  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • value (Any) – the new value of the variable.

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=True, console_kwargs=None, **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {β€˜width’: 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))

print(Foo().tabulate(jax.random.PRNGKey(0), x))

This gives the following output:

                                Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ params               ┃
┑━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
β”‚         β”‚ Foo    β”‚ float32[16,9] β”‚ float32[16,2] β”‚                      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_0 β”‚ Dense  β”‚ float32[16,9] β”‚ float32[16,4] β”‚ bias: float32[4]     β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ kernel: float32[9,4] β”‚
β”‚         β”‚        β”‚               β”‚               β”‚                      β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ 40 (160 B)           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_1 β”‚ Dense  β”‚ float32[16,4] β”‚ float32[16,2] β”‚ bias: float32[2]     β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ kernel: float32[4,2] β”‚
β”‚         β”‚        β”‚               β”‚               β”‚                      β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ 10 (40 B)            β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         β”‚        β”‚               β”‚         Total β”‚ 50 (200 B)           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

                      Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Parameters
  • rngs (Union[Any, Dict[str, Any]]) – The rngs for the variable collections as passed to Module.init.

  • *args – The arguments to the forward computation.

  • depth (Optional[int]) – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • show_repeated (bool) – If True, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is False.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€˜intermediates’ are mutable.

  • console_kwargs (Optional[Mapping[str, Any]]) – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are {β€˜force_terminal’: True, β€˜force_jupyter’: False}.

  • **kwargs – keyword arguments to pass to the forward computation.

Return type

str

Returns

A string summarizing the Module.

unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x):
    return self.decoder(self.encoder(x))

module = AutoEncoder()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
...
# Extract the Encoder sub-Module and its variables
encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type

Tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Returns

A tuple with an unbound copy of this Module and its variables.

Parameters

self (flax.linen.module.M) –