netket.experimental.nn.rnn.GRU1DCell#

class netket.experimental.nn.rnn.GRU1DCell[source]#

Bases: RNNCell

Gated recurrent unit cell.

Only supports one previous neighbor at each site.

Attributes
kernel_init: Callable[[Any, Sequence[int], Any], Union[ndarray, Array]]#

initializer for the weight matrix.

bias_init: Callable[[Any, Sequence[int], Any], Union[ndarray, Array]]#

initializer for the bias.

features: int#

output feature density, should be the last dimension.

Methods
__call__(inputs, cell_mem, hidden)[source]#

Applies the RNN cell to a batch of input sites at a given index.

Parameters:
  • inputs – input data with dimensions (batch, in_features).

  • cell_mem – cell memory from the previous site with dimensions (batch, features).

  • hidden – hidden memories from the previous neighbors with dimensions (batch, n_neighbors, features).

Returns:

  • cell_mem

    the updated cell memory with dimensions (batch, self.features).

  • outputs

    the updated hidden memory with dimensions (batch, self.features), also serves as the output data at the current site for the netket.experimental.nn.rnn.RNNLayer layer.