netket.experimental.nn.rnn.GRU1DCell#
- class netket.experimental.nn.rnn.GRU1DCell[source]#
Bases:
RNNCell
Gated recurrent unit cell.
Only supports one previous neighbor at each site.
- Attributes
-
kernel_init:
Callable
[[Any
,Sequence
[int
],Union
[None
,str
,type
[Any
],dtype
,_SupportsDType
]],Array
]# initializer for the weight matrix.
-
bias_init:
Callable
[[Any
,Sequence
[int
],Union
[None
,str
,type
[Any
],dtype
,_SupportsDType
]],Array
]# initializer for the bias.
- features: int#
output feature density, should be the last dimension.
-
kernel_init:
- Methods
- __call__(inputs, cell_mem, hidden)[source]#
Applies the RNN cell to a batch of input sites at a given index.
- Parameters:
inputs – input data with dimensions (batch, in_features).
cell_mem – cell memory from the previous site with dimensions (batch, features).
hidden – hidden memories from the previous neighbors with dimensions (batch, n_neighbors, features).
- Returns:
cell_mem
the updated cell memory with dimensions
(batch, self.features)
.
outputs
the updated hidden memory with dimensions
(batch, self.features)
, also serves as the output data at the current site for thenetket.experimental.nn.rnn.RNNLayer
layer.