netket.models.AbstractARNN#

class netket.models.AbstractARNN[source]#

Bases: flax.linen.module.Module

Base class for autoregressive neural networks.

Subclasses must implement the methods __call__ and conditionals. They can also override _conditional to implement the caching for fast autoregressive sampling. See netket.nn.FastARNNConv1D for example.

They must also implement the field machine_pow, which specifies the exponent to normalize the outputs of __call__.

Attributes
variables#

Returns the variables in this module.

Return type

Mapping[str, Mapping[str, Any]]

hilbert: netket.hilbert.HomogeneousHilbert#

the Hilbert space. Only homogeneous unconstrained Hilbert spaces are supported.

Methods
abstract conditionals(inputs)[source]#

Computes the conditional probabilities for each site to take each value.

Parameters

inputs (Union[ndarray, DeviceArray, Tracer]) – configurations with dimensions (batch, Hilbert.size).

Return type

Union[ndarray, DeviceArray, Tracer]

Returns

The probabilities with dimensions (batch, Hilbert.size, Hilbert.local_size).

Examples

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> p = model.apply(variables, Οƒ, method=model.conditionals)
>>> print(p[2, 3, :])
[0.3 0.7]
# For the 3rd spin of the 2nd sample in the batch,
# it takes probability 0.3 to be spin down (local state index 0),
# and probability 0.7 to be spin up (local state index 1).
has_rng(name)#

Returns true if a PRNGSequence with name name exists.

Return type

bool

Parameters

name (str) –

is_initializing()#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

Return type

bool

put_variable(col, name, value)#

Sets the value of a Variable.

Parameters
  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • value (Any) – the new value of the variable.

Returns:

tabulate(rngs, *args, method=None, mutable=True, depth=None, exclude_methods=(), **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature as init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))

print(Foo().tabulate(jax.random.PRNGKey(0), x))

This gives the following output:

                   Foo Summary
┏━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ outputs       ┃ params               ┃
┑━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
β”‚ Inputs  β”‚ float32[16,9] β”‚                      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_0 β”‚ float32[16,4] β”‚ bias: float32[4]     β”‚
β”‚         β”‚               β”‚ kernel: float32[9,4] β”‚
β”‚         β”‚               β”‚                      β”‚
β”‚         β”‚               β”‚ 40 (160 B)           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_1 β”‚ float32[16,2] β”‚ bias: float32[2]     β”‚
β”‚         β”‚               β”‚ kernel: float32[4,2] β”‚
β”‚         β”‚               β”‚                      β”‚
β”‚         β”‚               β”‚ 10 (40 B)            β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Foo     β”‚ float32[16,2] β”‚                      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         β”‚         Total β”‚ 50 (200 B)           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

          Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Parameters
  • rngs (Union[Any, Dict[str, Any]]) – The rngs for the variable collections.

  • *args – The arguments to the forward computation.

  • method (Optional[Callable[..., Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€˜intermediates’ are mutable.

  • depth (Optional[int]) – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • exclude_methods (Sequence[str]) – A sequence of strings that specifies which methods should be ignored. In case a module calls a helper method from its main method, use this argument to exclude the helper method from the summary to avoid ambiguity.

  • **kwargs – keyword arguments to pass to the forward computation.

Return type

str

Returns

A string summarizing the Module.