netket.nn.blocks.DeepSetMLP#

class netket.nn.blocks.DeepSetMLP[source]#

Bases: flax.linen.module.Module

Implements the DeepSets architecture, which is permutation invariant.

\[f(x_1,...,x_N) = \rho\left(\sum_i \phi(x_i)\right)\]

that is suitable for the simulation of bosonic.

The input shape must have an axis that is reshaped to (…, N, D), where we pool over N.

Inheritance
Inheritance diagram of netket.nn.blocks.DeepSetMLP
Attributes
features_phi: Optional[Union[int, Tuple[int, ...]]] = None#

Number of features in each layer for phi network. When features_phi is None, no phi network is created.

features_rho: Optional[Union[int, Tuple[int, ...]]] = None#

Number of features in each layer for rho network. Should include final dimension of the network. When features_rho is None, no rho network is created.

name: Optional[str] = None#
output_activation: Optional[Callable] = None#

The nonlinear activation function at the output layer.

parent: Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]] = None#
precision: Optional[jax._src.lax.lax.Precision] = None#

numerical precision of the computation see `jax.lax.Precision`for details.

scope: Optional[Scope] = None#
use_bias: bool = True#

if True uses a bias in all layers.

variables#

Returns the variables in this module.

Return type

Mapping[str, Mapping[str, Any]]

Methods
__call__(x)[source]#

The input shape must have an axis that is reshaped to (…, N, D), where we pool over N.

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

model = Transformer()
encoded = model.apply({'params': params}, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

encoded = model.apply({'params': params}, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For example, the previous can be written as:

encoded = model.apply({'params': params}, x, method='encode')

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

def other_fn(instance, ...):
  instance.some_module_attr(...)
  ...

model.apply({'params': params}, x, method=other_fn)
Parameters
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments passed to the specified apply method.

  • rngs (Optional[Dict[str, Union[Array, PRNGKeyArray]]]) – a dict of PRNGKeys to initialize the PRNG sequences. The β€œparams” PRNG sequence is used to initialize parameters.

  • method (Union[Callable[..., Any], str, None]) – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the β€œintermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the specified apply method.

Return type

Union[Any, Tuple[Any, Union[FrozenDict[str, Mapping[str, Any]], Dict[str, Any]]]]

Returns

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.

bias_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
Return type

Any

Parameters
bind(variables, *args, rngs=None, mutable=False)#

Creates an interactive Module instance by binding variables and RNGs.

bind provides an β€œinteractive” instance of a Module directly without transforming a function with apply. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage users to use apply() instead.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = nn.Dense(3)
    self.decoder = nn.Dense(5)

  def __call__(self, x):
    return self.decoder(self.encoder(x))

x = jnp.ones((16, 9))
ae = AutoEncoder()
variables = ae.init(jax.random.PRNGKey(0), x)
model = ae.bind(variables)
z = model.encoder(x)
x_reconstructed = model.decoder(z)
Parameters
  • variables (Mapping[str, Mapping[str, Any]]) – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments (not used).

  • rngs (Optional[Dict[str, Union[Array, PRNGKeyArray]]]) – a dict of PRNGKeys to initialize the PRNG sequences.

  • mutable (Union[bool, str, Collection[str], DenyList]) –

    Can be bool, str, or list. Specifies which collections should be treated as mutable:

    bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • self (flax.linen.module.M) –

Return type

TypeVar(M, bound= Module)

Returns

A copy of this instance with bound variables and RNGs.

clone(*, parent=None, _deep_clone=False, **updates)#

Creates a clone of this Module, with optionally updated arguments.

Parameters
  • parent (Union[Scope, Module, None]) – The parent of the clone. The clone will have no parent if no explicit parent is specified.

  • _deep_clone (Union[bool, WeakValueDictionary]) – A boolean or a weak value dictionary to control deep cloning of submodules. If True, submodules will be cloned recursively. If a weak value dictionary is passed, it will be used to cache cloned submodules. This flag is used by init/apply/bind to avoid scope leakage.

  • **updates – Attribute updates.

  • self (flax.linen.module.M) –

Return type

TypeVar(M, bound= Module)

Returns

A clone of the this Module with the updated attributes and parent.

get_variable(col, name, default=None)#

Retrieves the value of a Variable.

Parameters
  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • default (Optional[TypeVar(T)]) – the default value to return if the variable does not exist in this scope.

Return type

TypeVar(T)

Returns

The value of the input variable, of the default value if the variable doesn’t exist in this scope.

has_rng(name)#

Returns true if a PRNGSequence with name name exists.

Return type

bool

Parameters

name (str) –

has_variable(col, name)#

Checks if a variable of given collection and name exists in this Module.

See flax.core.variables for more explanation on variables and collections.

Parameters
  • col (str) – The variable collection name.

  • name (str) – The name of the variable.

Return type

bool

Returns

True if the variable exists.

hidden_activation(approximate=True)#

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters
  • x (Any) – input array

  • approximate (bool) – whether to use the approximate or exact formulation.

Return type

Any

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Example:

>>> import flax.linen as nn
>>> import jax.numpy as jnp
>>> import jax
...
>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...     return nn.Dense(1)(x)
...
>>> module = Foo()
>>> key = jax.random.PRNGKey(0)
>>> variables = module.init(key, jnp.empty((1, 7)), train=False)

If you pass a single PRNGKey, Flax will use it to feed the 'params' RNG stream. If you want to use a different RNG stream or need to use multiple streams, you must pass a dictionary mapping each RNG stream name to its corresponding PRNGKey to init.

Example:

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(16)(x)
...     x = nn.BatchNorm(use_running_average=not train)(x)
...     x = nn.relu(x)
...
...     # Add gaussian noise
...     noise_key = self.make_rng('noise')
...     x = x + jax.random.normal(noise_key, x.shape)
...
...     return nn.Dense(1)(x)
...
>>> module = Foo()
>>> rngs = {'params': jax.random.PRNGKey(0), 'noise': jax.random.PRNGKey(1)}
>>> variables = module.init(rngs, jnp.empty((1, 7)), train=False)

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

>>> module = nn.Dense(1)
>>> init_jit = jax.jit(module.init)
>>> variables = init_jit(jax.random.PRNGKey(0), jnp.empty((1, 7)))

init is a light wrapper over apply, so other apply arguments like method, mutable, and capture_intermediates are also available.

Parameters
  • rngs (Union[Array, PRNGKeyArray, Dict[str, Union[Array, PRNGKeyArray]]]) – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method (Union[Callable[..., Any], str, None]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€œintermediates” are mutable.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the β€œintermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Return type

Union[FrozenDict[str, Mapping[str, Any]], Dict[str, Any]]

Returns

The initialized variable dict.

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns output and modified variables.

Parameters
  • rngs (Union[Array, PRNGKeyArray, Dict[str, Union[Array, PRNGKeyArray]]]) – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method (Union[Callable[..., Any], str, None]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be’ provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€œintermediates” are mutable.

  • capture_intermediates (Union[bool, Callable[[Module, str], bool]]) – If True, captures intermediate return values of all Modules inside the β€œintermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Return type

Tuple[Any, Union[FrozenDict[str, Mapping[str, Any]], Dict[str, Any]]]

Returns

(output, vars)`, where vars are is a dict of the modified collections.

is_initializing()#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

Return type

bool

is_mutable_collection(col)#

Returns true if the collection col is mutable.

Return type

bool

Parameters

col (str) –

kernel_init(shape, dtype=<class 'jax.numpy.float64'>)#
Return type

Any

Parameters
lazy_init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), **kwargs)#

Initializes a module without computing on an actual input.

lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a jax.ShapeDtypeStruct which specifies the shape and dtype of the input but no concrete data.

Example:

model = nn.Dense(features=256)
variables = model.lazy_init(rng, jax.ShapeDtypeStruct((1, 128), jnp.float32))

The args and kwargs args passed to lazy_init can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise lazy_init cannot infer which variables should be initialized.

Parameters
  • rngs (Union[Array, PRNGKeyArray, Dict[str, Union[Array, PRNGKeyArray]]]) – The rngs for the variable collections.

  • *args – arguments passed to the init function.

  • method (Optional[Callable[..., Any]]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€œintermediates” are mutable.

  • **kwargs – Keyword arguments passed to the init function.

Return type

FrozenDict[str, Mapping[str, Any]]

Returns

The initialized variable dict.

make_rng(name)#

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility.

TODO: Link to Flax RNG design note.

Parameters

name (str) – The RNG sequence name.

Return type

Union[Array, PRNGKeyArray]

Returns

The newly generated RNG key.

param(name, init_fn, *init_args, unbox=True)#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named β€œparams”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args:

mean = self.param('mean', lecun_normal(), (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Parameters
  • name (str) – The parameter name.

  • init_fn (Callable[..., TypeVar(T)]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.

  • *init_args – The arguments to pass to init_fn.

  • unbox (bool) – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

Return type

TypeVar(T)

Returns

The value of the initialized parameter. Throws an error if the parameter exists already.

perturb(name, value, collection='perturbations')#

Add an zero-value variable (β€˜perturbation’) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation argument.

Note: this is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(3)(x)
        x = self.perturb('dense3', x)
        return nn.Dense(2)(x)

def loss(params, perturbations, inputs, targets):
  variables = {'params': params, 'perturbations': perturbations}
  preds = model.apply(variables, inputs)
  return jnp.square(preds - targets).mean()

x = jnp.ones((2, 9))
y = jnp.ones((2, 2))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
intm_grads = jax.grad(loss, argnums=1)(variables['params'], variables['perturbations'], x, y)
print(intm_grads['dense3']) # ==> [[-1.456924   -0.44332537  0.02422847]
                            #      [-1.456924   -0.44332537  0.02422847]]

If perturbations are not passed to apply, perturb behaves like a no-op so you can easily disable the behavior when not needed:

model.apply({'params': params, 'perturbations': perturbations}, x) # works as expected
model.apply({'params': params}, x) # behaves like a no-op
Return type

TypeVar(T)

Parameters
  • name (str) –

  • value (flax.linen.module.T) –

  • collection (str) –

pooling(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)#

Sum of array elements over a given axis.

LAX-backend implementation of numpy.sum().

Original docstring below.

Parameters
  • a (array_like) – Elements to sum.

  • axis (None or int or tuple of ints, optional) – Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.

  • dtype (dtype, optional) – The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

    If the default value is passed, then keepdims will not be passed through to the sum method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.

  • initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.

  • where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.

  • promote_integers (bool, default True) – If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input. promote_integers is ignored if dtype is specified.

  • out (None) –

Returns

sum_along_axis – An array with the same shape as a, with the specified axis removed. If a is a 0-d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.

Return type

ndarray

put_variable(col, name, value)#

Updates the value of the given variable if it is mutable, or an error otherwise.

Parameters
  • col (str) – the variable collection.

  • name (str) – the name of the variable.

  • value (Any) – the new value of the variable.

setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    h = nn.Dense(4)(x)
    self.sow('intermediates', 'h', h)
    return nn.Dense(2)(h)

x = jnp.ones((16, 9))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
y, state = model.apply(variables, x, mutable=['intermediates'])
print(state['intermediates'])  # {'h': (...,)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

class Foo2(nn.Module):
  @nn.compact
  def __call__(self, x):
    init_fn = lambda: 0
    reduce_fn = lambda a, b: a + b
    self.sow('intermediates', 'h', x,
             init_fn=init_fn, reduce_fn=reduce_fn)
    self.sow('intermediates', 'h', x * 2,
             init_fn=init_fn, reduce_fn=reduce_fn)
    return x

model = Foo2()
variables = model.init(jax.random.PRNGKey(0), x)
y, state = model.apply(variables, jnp.ones((1, 1)), mutable=['intermediates'])
print(state['intermediates'])  # ==> {'h': [[3.]]}
Parameters
  • col (str) – The name of the variable collection.

  • name (str) – The name of the variable.

  • value (TypeVar(T)) – The value of the variable.

  • reduce_fn (Callable[[TypeVar(K), TypeVar(T)], TypeVar(K)]) – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn (Callable[[], TypeVar(K)]) – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Return type

bool

Returns

True if the value has been stored successfully, False otherwise.

tabulate(rngs, *args, depth=None, show_repeated=False, mutable=True, console_kwargs=None, **kwargs)#

Creates a summary of the Module represented as a table.

This method has the same signature and internally calls Module.init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {β€˜width’: 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))

print(Foo().tabulate(jax.random.PRNGKey(0), x))

This gives the following output:

                                Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ params               ┃
┑━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
β”‚         β”‚ Foo    β”‚ float32[16,9] β”‚ float32[16,2] β”‚                      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_0 β”‚ Dense  β”‚ float32[16,9] β”‚ float32[16,4] β”‚ bias: float32[4]     β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ kernel: float32[9,4] β”‚
β”‚         β”‚        β”‚               β”‚               β”‚                      β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ 40 (160 B)           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Dense_1 β”‚ Dense  β”‚ float32[16,4] β”‚ float32[16,2] β”‚ bias: float32[2]     β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ kernel: float32[4,2] β”‚
β”‚         β”‚        β”‚               β”‚               β”‚                      β”‚
β”‚         β”‚        β”‚               β”‚               β”‚ 10 (40 B)            β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚         β”‚        β”‚               β”‚         Total β”‚ 50 (200 B)           β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

                      Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Parameters
  • rngs (Union[Array, PRNGKeyArray, Dict[str, Union[Array, PRNGKeyArray]]]) – The rngs for the variable collections as passed to Module.init.

  • *args – The arguments to the forward computation.

  • depth (Optional[int]) – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • show_repeated (bool) – If True, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is False.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except β€˜intermediates’ are mutable.

  • console_kwargs (Optional[Mapping[str, Any]]) – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are {β€˜force_terminal’: True, β€˜force_jupyter’: False}.

  • **kwargs – keyword arguments to pass to the forward computation.

Return type

str

Returns

A string summarizing the Module.

unbind()#

Returns an unbound copy of a Module and its variables.

unbind helps create a stateless version of a bound Module.

An example of a common use case: to extract a sub-Module defined inside setup() and its corresponding variables: 1) temporarily bind the parent Module; and then 2) unbind the desired sub-Module. (Recall that setup() is only called when the Module is bound.):

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x):
    return self.decoder(self.encoder(x))

module = AutoEncoder()
variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
...
# Extract the Encoder sub-Module and its variables
encoder, encoder_vars = module.bind(variables).encoder.unbind()
Return type

Tuple[TypeVar(M, bound= Module), Mapping[str, Mapping[str, Any]]]

Returns

A tuple with an unbound copy of this Module and its variables.

Parameters

self (flax.linen.module.M) –

variable(col, name, init_fn=None, *init_args, unbox=True)#

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the β€œparams” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

key = self.make_rng('stats')
mean = self.variable('stats', 'mean', lecun_normal(), key, (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Parameters
  • col (str) – The variable collection name.

  • name (str) – The variable name.

  • init_fn (Optional[Callable[..., Any]]) – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.

  • *init_args – The arguments to pass to init_fn.

  • unbox (bool) – If True, AxisMetadata instances are replaced by their unboxed value, see flax.nn.meta.unbox (default: True).

Return type

Variable

Returns

A flax.core.variables.Variable that can be read or set via β€œ.value” attribute. Throws an error if the variable exists already.