Bases: Module
1D linear transformation module with mask for fast autoregressive NN.
See netket.models.FastARNNSequential
for a brief explanation of fast autoregressive sampling.
TODO: FastMaskedDense1D does not support JIT yet, because it involves slicing the cached inputs
and the weights with a dynamic shape.
- Attributes
-
precision:
Any
= None
numerical precision of the computation, see jax.lax.Precision
for details.
-
use_bias:
bool
= True
True).
- Type:
whether to add a bias to the output (default
-
size:
int
number of sites.
-
features:
int
output feature density, should be the last dimension.
-
exclusive:
bool
True if an output element does not depend on the input element at the same index.
-
kernel_init:
Callable
[[Any
, Sequence
[int
], Any
], Union
[ndarray
, Array
]]
initializer for the weight matrix.
-
bias_init:
Callable
[[Any
, Sequence
[int
], Any
], Union
[ndarray
, Array
]]
initializer for the bias.
- Methods
-
__call__(inputs)[source]
Applies the masked linear transformation to all input sites.
- Parameters:
inputs (Union
[ndarray
, Array
]) – input data with dimensions (batch, size, features).
- Return type:
Union
[ndarray
, Array
]
- Returns:
The transformed data.
-
update_site(inputs, index)[source]
Adds an input site into the cache, and applies the masked linear transformation to the cache.
- Parameters:
inputs (Union
[ndarray
, Array
]) – an input site to be added into the cache with dimensions (batch, features).
index (int
) – the index of the output site. The index of the input site should be index - self.exclusive.
- Return type:
Union
[ndarray
, Array
]
- Returns:
The output site with dimensions (batch, features).