Source code for netket.nn.blocks.mlp

# Copyright 2022 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union, Optional

import jax
import jax.numpy as jnp

from flax import linen as nn
from jax.nn.initializers import lecun_normal, zeros

from netket.utils.types import NNInitFunc, DType
from netket import nn as nknn

default_kernel_init = lecun_normal()
default_bias_init = zeros

[docs]class MLP(nn.Module): r"""A Multi-Layer Perceptron with hidden layers. This combines multiple dense layers and activations functions into a single object. It separates the output layer from the hidden layers, since it typically has a different form. One can specify the specific activation functions per layer. The size of the hidden dimensions can be provided as a number, or as a factor relative to the input size (similar as for RBM). The default model is a single linear layer without activations. Forms a common building block for models such as `PauliNet (continuous) <>`_ """ output_dim: int = 1 """The output dimension""" hidden_dims: Optional[Union[int, tuple[int, ...]]] = None """The size of the hidden layers, excluding the output layer.""" hidden_dims_alpha: Optional[Union[int, tuple[int, ...]]] = None """The size of the hidden layers provided as number of times the input size. One must choose to either specify this or the hidden_dims keyword argument""" param_dtype: DType = jnp.float64 """The dtype of the weights.""" hidden_activations: Optional[Union[Callable, tuple[Callable, ...]]] = nknn.gelu """The nonlinear activation function after each hidden layer. Can be provided as a single activation, where the same activation will be used for every layer.""" output_activation: Optional[Callable] = None """The nonlinear activation at the output layer. If None is provided, the output layer will be essentially linear.""" use_hidden_bias: bool = True """If True uses a bias in the hidden layer.""" use_output_bias: bool = True """If True adds a bias to the output layer.""" precision: Optional[jax.lax.Precision] = None """Numerical precision of the computation see :class:`jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """Initializer for the Dense layer matrix.""" bias_init: NNInitFunc = default_bias_init """Initializer for the biases."""
[docs] @nn.compact def __call__(self, input): if self.hidden_dims is None: if self.hidden_dims_alpha is not None: hidden_dims = [ int(nh * input.shape[-1]) for nh in self.hidden_dims_alpha ] else: hidden_dims = [] else: if self.hidden_dims_alpha is not None: raise ValueError( "Cannot specify both hidden_dims and alpha_hidden_dims, " "choose one way to provide the hidden dimensions" ) hidden_dims = self.hidden_dims if self.hidden_activations is None: hidden_activations = [None] * len(hidden_dims) elif hasattr(self.hidden_activations, "__len__"): hidden_activations = self.hidden_activations else: hidden_activations = [self.hidden_activations] * len(hidden_dims) if len(hidden_activations) != len(hidden_dims): raise ValueError( "number of hidden activations must be the same " "as the length of the hidden dimensions list" ) x = input # hidden layers for nh, act_h in zip(hidden_dims, hidden_activations): x = nn.Dense( features=nh, param_dtype=self.param_dtype, precision=self.precision, use_bias=self.use_hidden_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) if act_h: x = act_h(x) # output layer x = nn.Dense( features=self.output_dim, param_dtype=self.param_dtype, precision=self.precision, use_bias=self.use_output_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) if self.output_activation: x = self.output_activation(x) return x