Source code for netket.nn.blocks.deepset

from typing import Callable, Union, Optional

import jax
from jax import numpy as jnp
from flax import linen as nn
from netket.utils.types import NNInitFunc, DType
from jax.nn.initializers import (

from .mlp import MLP

def _process_features(features) -> tuple[Optional[tuple[int, ...]], Optional[int]]:
    Convert some inputs to a consistent format of features.
    Returns hidden dimensions and output dimensions of the MLP separately.
    if features is None:
        feat, out = None, None
    elif isinstance(features, int):
        feat, out = None, features
    elif len(features) == 0:
        feat, out = None, None
    elif len(features) == 1:
        feat, out = None, features[0]
        feat, out = tuple(features[:-1]), features[-1]
    return feat, out

[docs] class DeepSetMLP(nn.Module): r"""Implements the DeepSets architecture, which is permutation invariant and is suitable for the encoding of bosonic systems. .. math:: f(x_1,...,x_N) = \rho\left(\sum_i \phi(x_i)\right) The input shape must have an axis that is reshaped to `(..., N, D)`, where we pool over N. """ features_phi: Optional[Union[int, tuple[int, ...]]] = None """ Number of features in each layer for phi network. When features_phi is None, no phi network is created. """ features_rho: Optional[Union[int, tuple[int, ...]]] = None """ Number of features in each layer for rho network. Should include final dimension of the network. When features_rho is None, no rho network is created. """ param_dtype: DType = jnp.float64 """The dtype of the weights.""" hidden_activation: Optional[Callable] = jax.nn.gelu """The nonlinear activation function between hidden layers.""" output_activation: Optional[Callable] = None """The nonlinear activation function at the output layer.""" pooling: Callable = jnp.sum """The pooling operation to be used after the phi-transformation""" use_bias: bool = True """if True uses a bias in all layers.""" kernel_init: NNInitFunc = lecun_normal() """Initializer for the Dense layer matrix""" bias_init: NNInitFunc = zeros """Initializer for the hidden bias""" precision: Optional[jax.lax.Precision] = None """numerical precision of the computation see :class:`jax.lax.Precision` for details.""" def setup(self): def _create_mlp(features, output_activation, name): hidden_dims, out_dim = _process_features(features) if out_dim is None: return None else: return MLP( output_dim=out_dim, hidden_dims=hidden_dims, param_dtype=self.param_dtype, hidden_activations=self.hidden_activation, output_activation=output_activation, use_hidden_bias=self.use_bias, use_output_bias=self.use_bias, kernel_init=self.kernel_init, name=name, ) self.phi = _create_mlp(self.features_phi, self.hidden_activation, "ds_phi") self.rho = _create_mlp(self.features_rho, self.output_activation, "ds_rho") if self.pooling is None: raise ValueError("Must specifyc pooling function for a DeepSet")
[docs] @nn.compact def __call__(self, x): """The input shape must have an axis that is reshaped to (..., N, D), where we pool over N.""" if x.ndim < 2: raise ValueError( f"input of deepset should have shape (..., N, D), but got {x.shape}" ) if self.phi: x = self.phi(x) x = self.pooling(x, axis=-2) if self.rho: x = self.rho(x) return x