Bases: Module
1D convolution module with mask for fast autoregressive NN.
The fast autoregressive sampling is described in Ramachandran et. {it al}.
To generate one sample using an autoregressive network, we need to evaluate the network N times, where N is
the number of input sites. But we only change one input site each time, so we can cache unchanged intermediate results
and avoid repeated computation.
- Attributes
-
feature_group_count:
int
= 1
1).
- Type:
if specified, divides the input features into groups (default
-
precision:
Any
= None
numerical precision of the computation, see jax.lax.Precision
for details.
-
use_bias:
bool
= True
True).
- Type:
whether to add a bias to the output (default
-
features:
int
number of convolution filters.
-
kernel_size:
int
length of the convolutional kernel.
-
kernel_dilation:
int
dilation factor of the convolution kernel.
-
exclusive:
bool
True if an output element does not depend on the input element at the same index.
-
kernel_init:
Callable
[[Any
, Sequence
[int
], Any
], Union
[ndarray
, Array
]]
initializer for the convolutional kernel.
-
bias_init:
Callable
[[Any
, Sequence
[int
], Any
], Union
[ndarray
, Array
]]
initializer for the bias.
- Methods
-
__call__(inputs)[source]
Applies the masked convolution to all input sites.
- Parameters:
inputs (Union
[ndarray
, Array
]) – input data with dimensions (batch, size, features).
- Return type:
Union
[ndarray
, Array
]
- Returns:
The convolved data.
-
update_site(inputs, index)[source]
Adds an input site into the cache, and applies the masked convolution to the cache.
- Parameters:
inputs (Union
[ndarray
, Array
]) – an input site to be added into the cache with dimensions (batch, features).
index (int
) – the index of the output site. The index of the input site should be index - self.exclusive.
- Return type:
Union
[ndarray
, Array
]
- Returns:
The next output site with dimensions (batch, features).