netket.nn.FastMaskedDense1D#

class netket.nn.FastMaskedDense1D[source]#

Bases: Module

1D linear transformation module with mask for fast autoregressive NN.

See netket.models.FastARNNSequential for a brief explanation of fast autoregressive sampling.

TODO: FastMaskedDense1D does not support JIT yet, because it involves slicing the cached inputs and the weights with a dynamic shape.

Attributes
precision: Any = None#

numerical precision of the computation, see jax.lax.Precision for details.

use_bias: bool = True#

True).

Type:

whether to add a bias to the output (default

size: int#

number of sites.

features: int#

output feature density, should be the last dimension.

exclusive: bool#

True if an output element does not depend on the input element at the same index.

kernel_init: Callable[[Any, Sequence[int], Any], Union[ndarray, Array]]#

initializer for the weight matrix.

bias_init: Callable[[Any, Sequence[int], Any], Union[ndarray, Array]]#

initializer for the bias.

Methods
__call__(inputs)[source]#

Applies the masked linear transformation to all input sites.

Parameters:

inputs (Union[ndarray, Array]) – input data with dimensions (batch, size, features).

Return type:

Union[ndarray, Array]

Returns:

The transformed data.

update_site(inputs, index)[source]#

Adds an input site into the cache, and applies the masked linear transformation to the cache.

Parameters:
  • inputs (Union[ndarray, Array]) – an input site to be added into the cache with dimensions (batch, features).

  • index (int) – the index of the output site. The index of the input site should be index - self.exclusive.

Return type:

Union[ndarray, Array]

Returns:

The output site with dimensions (batch, features).