netket.nn.apply_operator.ApplyOperatorModuleLinen#

class netket.nn.apply_operator.ApplyOperatorModuleLinen[source]#

Bases: Module

A Flax Linen module that wraps another module and applies an operator transformation.

This module wraps a base neural network module and applies an operator in front of it, computing \(\log(O\lvert\psi\rangle)\) where \(O\) is the operator and \(\lvert\psi\rangle\) is represented by the base module.

The operator is stored in flattened form. The static structure (treedef) is stored as a module attribute, while the dynamic data (leaves) is stored in the ‘operator’ variable collection, which is not trainable. This separation allows the operator’s arrays to be updated without triggering recompilation while keeping the structure static.

Parameters:
  • base_module (Module) – The Flax module to wrap

  • operator_treedef (Any) – The pytree structure of the operator (obtained from jax.tree.flatten)

Example:

import netket as nk
hilbert = nk.hilbert.Spin(0.5, 4)
base_module = nk.models.RBM(alpha=1)
operator = nk.operator.spin.sigmax(hilbert, 0)

# Flatten the operator to separate static structure from dynamic data
leaves, treedef = jax.tree.flatten(operator)
transformed = ApplyOperatorModuleLinen(base_module=base_module, operator_treedef=treedef)

# Initialize: first init the base module to get its params
base_params = base_module.init(jax.random.key(1), hilbert.all_states())
# Then add only the operator leaves to the variables dict
variables = {**base_params, 'operator': {'leaves': leaves}}

# Apply the transformed module
logpsi = transformed.apply(variables, x)

# The operator can be updated without recompilation
# Only update the leaves (treedef is fixed in the module)
new_leaves, _ = jax.tree.flatten(new_operator)
variables['operator']['leaves'] = new_leaves
logpsi = transformed.apply(variables, x)
Attributes
operator#

Reconstruct the operator from its flattened representation.

base_module: Module#

The base module to wrap

operator_treedef: Any#

The static pytree structure of the operator

Methods
__call__(x, *args, **kwargs)[source]#

Call self as a function.

classmethod from_module_and_variables(bare_module, operator, bare_variables)[source]#

Create a TransformedModule from a bare module, operator, and variables.

Parameters:
  • bare_module (Module) – The bare Flax module to wrap

  • operator – The operator to apply (will be flattened)

  • bare_variables (dict) – The variables dictionary from the bare module

Returns:

  • wrapped_module is the new TransformedModule instance

  • wrapped_variables is the properly structured variables dict

Return type:

A tuple of (wrapped_module, wrapped_variables) where