netket.models.AbstractARNN#

class netket.models.AbstractARNN[source]#

Bases: Module

Base class for autoregressive neural networks.

Subclasses must implement the method conditionals_log_psi, or override the methods __call__ and conditional if desired.

They can override conditional to implement the caching for fast autoregressive sampling. See netket.nn.FastARNNConv1D for example.

They must also implement the field machine_pow, which specifies the exponent to normalize the outputs of __call__.

Attributes
hilbert: HomogeneousHilbert#

the Hilbert space. Only homogeneous unconstrained Hilbert spaces are supported.

Methods
__call__(inputs)[source]#

Computes the log wave-functions for input configurations.

Parameters:

inputs (Union[ndarray, Array]) – configurations with dimensions (batch, Hilbert.size).

Return type:

Union[ndarray, Array]

Returns:

The log psi with dimension (batch,).

conditional(inputs, index)[source]#

Computes the conditional probabilities for one site to take each value.

It should only be called successively with indices 0, 1, 2, …, as in the autoregressive sampling procedure.

Parameters:
  • inputs (Union[ndarray, Array]) – configurations of partially sampled sites with dimensions (batch, Hilbert.size), where the sites that index depends on must be already sampled.

  • index (int) – index of the site being queried.

Return type:

Union[ndarray, Array]

Returns:

The probabilities with dimensions (batch, Hilbert.local_size).

conditionals(inputs)[source]#

Computes the conditional probabilities for each site to take each value.

Parameters:

inputs (Union[ndarray, Array]) – configurations with dimensions (batch, Hilbert.size).

Return type:

Union[ndarray, Array]

Returns:

The probabilities with dimensions (batch, Hilbert.size, Hilbert.local_size).

Examples

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> p = model.apply(variables, σ, method=model.conditionals)
>>> print(p[2, 3, :])
[0.3 0.7]
# For the 3rd spin of the 2nd sample in the batch,
# it takes probability 0.3 to be spin down (local state index 0),
# and probability 0.7 to be spin up (local state index 1).
abstract conditionals_log_psi(inputs)[source]#

Computes the log of the conditional wave-functions for each site to take each value.

Parameters:

inputs (Union[ndarray, Array]) – configurations with dimensions (batch, Hilbert.size).

Return type:

Union[ndarray, Array]

Returns:

The log psi with dimensions (batch, Hilbert.size, Hilbert.local_size).