netket.nn#

This sub-module wraps and re-exports flax.nn. Read more about the design goal of this module in their README

Linear Modules#

Dense

A linear transformation applied over the last dimension of the input.

DenseGeneral

A linear transformation with flexible axes.

DenseSymm

Implements a projection onto a symmetry group.

DenseEquivariant

A group convolution operation that is equivariant over a symmetry group.

Conv

Convolution Module wrapping lax.conv_general_dilated.

Embed

Embedding Module.

MaskedDense1D

1D linear transformation module with mask for autoregressive NN.

MaskedConv1D

1D convolution module with mask for autoregressive NN.

MaskedConv2D

2D convolution module with mask for autoregressive NN.

Activation functions#

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis, where, initial])

Log-Softmax function.

relu(x)

Rectified linear unit activation function.

sigmoid(x)

Sigmoid activation function.

soft_sign(x)

Soft-sign activation function.

softmax(x[, axis, where, initial])

Softmax function.

softplus(x)

Softplus activation function.

swish(x)

SiLU activation function.

log_cosh(x)

reim_relu(x)

reim_selu(x)