netket.models.FastARNNConv2D#

class netket.models.FastARNNConv2D[source]#

Bases: FastARNNSequential

Fast autoregressive neural network with 2D convolution layers.

See netket.nn.FastMaskedConv1D for a brief explanation of fast autoregressive sampling.

Attributes
kernel_dilation: Tuple[int, int] = (1, 1)#

a sequence of 2 integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1).

machine_pow: int = 2#

exponent to normalize the outputs of __call__.

path#
precision: Any = None#

numerical precision of the computation, see jax.lax.Precision for details.

use_bias: bool = True#

True).

Type:

whether to add a bias to the output (default

variables#

Returns the variables in this module.

layers: int#

number of layers.

features: Union[Tuple[int, ...], int]#

output feature density in each layer. If a single number is given, all layers except the last one will have the same number of features.

kernel_size: Tuple[int, int]#

shape of the convolutional kernel (h, w). Typically, h = w // 2 + 1.

hilbert: HomogeneousHilbert#

the Hilbert space. Only homogeneous unconstrained Hilbert spaces are supported.

Methods
activation()#

selu applied separately to the real andimaginary parts of it’s input.

The docstring to the original function follows.

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Args:

x : input array

Returns:

An array.

See also:

elu()

Return type:

Any

bias_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
:rtype: :py:data:`~typing.Any`
Array([[0., 0., 0.],

[0., 0., 0.]], dtype=float32)

Parameters:
Return type:

Any

conditional(inputs, index)#

Computes the conditional probabilities for one site to take each value. See AbstractARNN.conditional.

Return type:

Union[ndarray, Array]

Parameters:
conditionals(inputs)#

Computes the conditional probabilities for each site to take each value.

Parameters:

inputs (Union[ndarray, Array]) – configurations with dimensions (batch, Hilbert.size).

Return type:

Union[ndarray, Array]

Returns:

The probabilities with dimensions (batch, Hilbert.size, Hilbert.local_size).

Examples

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> p = model.apply(variables, Οƒ, method=model.conditionals)
>>> print(p[2, 3, :])
[0.3 0.7]
# For the 3rd spin of the 2nd sample in the batch,
# it takes probability 0.3 to be spin down (local state index 0),
# and probability 0.7 to be spin up (local state index 1).
conditionals_log_psi(inputs)#

Computes the log of the conditional wave-functions for each site to take each value.

Parameters:

inputs (Union[ndarray, Array]) – configurations with dimensions (batch, Hilbert.size).

Return type:

Union[ndarray, Array]

Returns:

The log psi with dimensions (batch, Hilbert.size, Hilbert.local_size).

has_rng(name)#

Returns true if a PRNGSequence with name name exists.

Return type:

bool

Parameters:

name (str) –

is_initializing()#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

Return type:

bool

kernel_init(shape, dtype=<class 'jax.numpy.float64'>)#
Return type:

Any

Parameters:
lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)#

Initializes a module without computing on an actual input.

lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a jax.ShapeDtypeStruct which specifies the shape and dtype of the input but no concrete data.

Example:

model = nn.Dense(features=256)
variables = model.lazy_init(
    rng, jax.ShapeDtypeStruct((1, 128), jnp.float32))

The args and kwargs args passed to lazy_init can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise lazy_init cannot infer which variables should be initialized.

Parameters:
  • rngs (Union[Array, PRNGKeyArray, Dict[str, Union[Array, PRNGKeyArray]]]) – The rngs for the variable collections.

  • *args – arguments passed to the init function.

  • method (Optional[Callable[..., Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€œintermediates” are mutable.

  • **kwargs – Keyword arguments passed to the init function.

Return type:

FrozenDict[str, Mapping[str, Any]]

Returns:

The initialized variable dict.

perturb(name, value, collection='perturbations')#

Add an zero-value variable (β€˜perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument.

Note: this is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(3)(x)
        x = self.perturb('dense3', x)
        return nn.Dense(2)(x)

def loss(params, perturbations, inputs, targets):
  variables = {'params': params, 'perturbations': perturbations}
  preds = model.apply(variables, inputs)
  return jnp.square(preds - targets).mean()

x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(
    variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924   -0.44332537  0.02422847]
                            #      [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

model.apply(
    {'params': params, 'perturbations': perturbations},
    x) # works as expected
model.apply({'params': params}, x) # behaves like a no-op
Return type:

TypeVar(T)

Parameters:
  • name (str) –

  • value (T) –

  • collection (str) –

put_variable(col, name, value)#

Updates the value of the given variable if it is mutable, or an error otherwise.

Parameters:
  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • value (Any) – the new value of the variable.

reshape_inputs(inputs)[source]#

Reshapes the inputs from (batch_size, hilbert_size) to (batch_size, spatial_dims…) before sending them to the ARNN layers.

Return type:

Union[ndarray, Array]

Parameters:

inputs (Union[ndarray, Array]) –

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=True, console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {β€˜width’: 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))

print(Foo().tabulate(jax.random.PRNGKey(0), x))

This gives the following output:

                                Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ params              ┃
┑━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
β”‚         β”‚ Foo    β”‚ float32[16,9] β”‚ float32[16,2] β”‚                     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_0 β”‚ Dense  β”‚ float32[16,9] β”‚ float32[16,4] β”‚ bias: float32[4]    β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ kernel: float32[9,4]β”‚
β”‚         β”‚        β”‚               β”‚               β”‚                     β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ 40 (160 B)          β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_1 β”‚ Dense  β”‚ float32[16,4] β”‚ float32[16,2] β”‚ bias: float32[2]    β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ kernel: float32[4,2]β”‚
β”‚         β”‚        β”‚               β”‚               β”‚                     β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ 10 (40 B)           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         β”‚        β”‚               β”‚         Total β”‚ 50 (200 B)          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

                      Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Parameters:
  • rngs (Union[Array, PRNGKeyArray, Dict[str, Union[Array, PRNGKeyArray]]]) – The rngs for the variable collections as passed to Module.init.

  • *args – The arguments to the forward computation.

  • depth (Optional[int]) – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • show_repeated (bool) – If True, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is False.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€˜intermediates’ are mutable.

  • console_kwargs (Optional[Mapping[str, Any]]) – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are {β€˜force_terminal’: True, β€˜force_jupyter’: False}.

  • table_kwargs (Mapping[str, Any]) – An optional dictionary with additional keyword arguments that are passed to rich.table.Table constructor.

  • column_kwargs (Mapping[str, Any]) – An optional dictionary with additional keyword arguments that are passed to rich.table.Table.add_column when adding columns to the table.

  • **kwargs – keyword arguments to pass to the forward computation.

Return type:

str

Returns:

A string summarizing the Module.

unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x):
    return self.decoder(self.encoder(x))

module = AutoEncoder()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
...
# Extract the Encoder sub-Module and its variables
encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type:

Tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Returns:

A tuple with an unbound copy of this Module and its variables.

Parameters:

self (M) –