Bases: Module
1D linear transformation module with mask for fast autoregressive NN.
See netket.models.FastARNNSequential for a brief explanation of fast autoregressive sampling.
TODO: FastMaskedDense1D does not support JIT yet, because it involves slicing the cached inputs
and the weights with a dynamic shape.
- Attributes
-
precision: Any = None
numerical precision of the computation, see jax.lax.Precision for details.
-
use_bias: bool = True
True).
- Type:
whether to add a bias to the output (default
-
size: int
number of sites.
-
features: int
output feature density, should be the last dimension.
-
exclusive: bool
True if an output element does not depend on the input element at the same index.
-
kernel_init: Callable[[Any, Sequence[int], None | str | type[Any] | dtype | _SupportsDType], Array]
initializer for the weight matrix.
-
bias_init: Callable[[Any, Sequence[int], None | str | type[Any] | dtype | _SupportsDType], Array]
initializer for the bias.
- Methods
-
__call__(inputs)[source]
Applies the masked linear transformation to all input sites.
- Parameters:
inputs (Union[ndarray, Array]) – input data with dimensions (batch, size, features).
- Return type:
Union[ndarray, Array]
- Returns:
The transformed data.
-
update_site(inputs, index)[source]
Adds an input site into the cache, and applies the masked linear transformation to the cache.
- Parameters:
inputs (Union[ndarray, Array]) – an input site to be added into the cache with dimensions (batch, features).
index (int) – the index of the output site. The index of the input site should be index - self.exclusive.
- Return type:
Union[ndarray, Array]
- Returns:
The output site with dimensions (batch, features).