netket.models.MLP
netket.models.MLP#
- class netket.models.MLP[source]#
Bases:
flax.linen.module.Module
A Multi-Layer Perceptron with hidden layers.
This combines multiple dense layers and activations functions into a single object. It separates the output layer from the hidden layers, since it typically has a different form. One can specify the specific activation functions per layer. The size of the hidden dimensions can be provided as a number, or as a factor relative to the input size (similar as for RBM). The default model is a single linear layer without activations.
Forms a common building block for models such as PauliNet (continuous) https://www.nature.com/articles/s41557-020-0544-y
- Attributes
The size of the hidden layers, excluding the output layer.
The size of the hidden layers provided as number of times the input size. One must choose to either specify this or the hidden_dims keyword argument
- output_activation: Callable = None#
The nonlinear activation at the output layer. If None is provided, the output layer will be essentially linear.
- squeeze_output: bool = False#
Whether to remove output dimension 1 if it is present. This is typically useful if we want to use the MLP as an NQS directly, where we do not need the final dimension 1.
if True uses a bias in the hidden layer.
- Methods
- bias_init(shape, dtype=<class 'jax.numpy.float64'>)#
An initializer that returns a constant array full of zeros.
The
key
argument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- has_rng(name)#
Returns true if a PRNGSequence with name name exists.
Gaussian error linear unit activation function.
If
approximate=False
, computes the element-wise function:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]If
approximate=True
, uses the approximate formulation of GELU:\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]For more information, see Gaussian Error Linear Units (GELUs), section 2.
- is_initializing()#
Returns True if running under self.init(β¦) or nn.init(β¦)().
This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under
module.init
ornn.init
. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.- Return type
- kernel_init(shape, dtype=<class 'jax.numpy.float64'>)#
- put_variable(col, name, value)#
Sets the value of a Variable.
- Parameters
Returns:
- tabulate(rngs, *args, method=None, mutable=True, depth=None, exclude_methods=(), **kwargs)#
Creates a summary of the Module represented as a table.
This method has the same signature as init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.
Example:
import jax import jax.numpy as jnp import flax.linen as nn class Foo(nn.Module): @nn.compact def __call__(self, x): h = nn.Dense(4)(x) return nn.Dense(2)(h) x = jnp.ones((16, 9)) print(Foo().tabulate(jax.random.PRNGKey(0), x))
This gives the following output:
Foo Summary βββββββββββ³ββββββββββββββββ³βββββββββββββββββββββββ β path β outputs β params β β‘βββββββββββββββββββββββββββββββββββββββββββββββββ© β Inputs β float32[16,9] β β βββββββββββΌββββββββββββββββΌβββββββββββββββββββββββ€ β Dense_0 β float32[16,4] β bias: float32[4] β β β β kernel: float32[9,4] β β β β β β β β 40 (160 B) β βββββββββββΌββββββββββββββββΌβββββββββββββββββββββββ€ β Dense_1 β float32[16,2] β bias: float32[2] β β β β kernel: float32[4,2] β β β β β β β β 10 (40 B) β βββββββββββΌββββββββββββββββΌβββββββββββββββββββββββ€ β Foo β float32[16,2] β β βββββββββββΌββββββββββββββββΌβββββββββββββββββββββββ€ β β Total β 50 (200 B) β βββββββββββ΄ββββββββββββββββ΄βββββββββββββββββββββββ Total Parameters: 50 (200 B)
Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.
- Parameters
rngs (
Union
[Any
,Dict
[str
,Any
]]) β The rngs for the variable collections.*args β The arguments to the forward computation.
method (
Optional
[Callable
[...
,Any
]]) β An optional method. If provided, applies this method. If not provided, applies the__call__
method.mutable (
Union
[bool
,str
,Collection
[str
],DenyList
]) β Can be bool, str, or list. Specifies which collections should be treated as mutable:bool
: all/no collections are mutable.str
: The name of a single mutable collection.list
: A list of names of mutable collections. By default all collections except βintermediatesβ are mutable.depth (
Optional
[int
]) β controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.exclude_methods (
Sequence
[str
]) β A sequence of strings that specifies which methods should be ignored. In case a module calls a helper method from its main method, use this argument to exclude the helper method from the summary to avoid ambiguity.**kwargs β keyword arguments to pass to the forward computation.
- Return type
- Returns
A string summarizing the Module.