Source code for netket.models.mlp

# Copyright 2022 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Callable, Union

import jax
import jax.numpy as jnp

from flax import linen as nn
from jax.nn.initializers import lecun_normal, zeros

from netket.utils.types import NNInitFunc, DType
from netket import nn as nknn

default_kernel_init = lecun_normal()
default_bias_init = zeros


[docs] class MLP(nn.Module): r"""A Multi-Layer Perceptron with hidden layers. This model uses the MLP block with output dimension 1, which is squeezed. This combines multiple dense layers and activations functions into a single object. It separates the output layer from the hidden layers, since it typically has a different form. One can specify the specific activation functions per layer. The size of the hidden dimensions can be provided as a number, or as a factor relative to the input size (similar as for RBM). The default model is a single linear layer without activations. Forms a common building block for models such as `PauliNet (continuous) <https://www.nature.com/articles/s41557-020-0544-y>`_ """ hidden_dims: Optional[Union[int, tuple[int, ...]]] = None """The size of the hidden layers, excluding the output layer.""" hidden_dims_alpha: Optional[Union[int, tuple[int, ...]]] = None """The size of the hidden layers provided as number of times the input size. One must choose to either specify this or the hidden_dims keyword argument""" param_dtype: DType = jnp.float64 """The dtype of the weights.""" hidden_activations: Optional[Union[Callable, tuple[Callable, ...]]] = nknn.gelu """The nonlinear activation function after each hidden layer. Can be provided as a single activation, where the same activation will be used for every layer.""" output_activation: Optional[Callable] = None """The nonlinear activation at the output layer. If None is provided, the output layer will be essentially linear.""" use_hidden_bias: bool = True """if True uses a bias in the hidden layer.""" use_output_bias: bool = False """if True adds a bias to the output layer.""" precision: Optional[jax.lax.Precision] = None """Numerical precision of the computation see :class:`jax.lax.Precision` for details.""" kernel_init: NNInitFunc = default_kernel_init """Initializer for the Dense layer matrix.""" bias_init: NNInitFunc = default_bias_init """Initializer for the biases."""
[docs] @nn.compact def __call__(self, input): x = nknn.blocks.MLP( output_dim=1, # a netket model has a single output hidden_dims=self.hidden_dims, hidden_dims_alpha=self.hidden_dims_alpha, param_dtype=self.param_dtype, hidden_activations=self.hidden_activations, output_activation=self.output_activation, use_hidden_bias=self.use_hidden_bias, use_output_bias=self.use_output_bias, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, )(input) x = x.squeeze(-1) return x