netket.nn.apply_operator.ApplyOperatorModuleLinen#
- class netket.nn.apply_operator.ApplyOperatorModuleLinen[source]#
Bases:
ModuleA Flax Linen module that wraps another module and applies an operator transformation.
This module wraps a base neural network module and applies an operator in front of it, computing \(\log(O\lvert\psi\rangle)\) where \(O\) is the operator and \(\lvert\psi\rangle\) is represented by the base module.
The operator is stored in flattened form. The static structure (treedef) is stored as a module attribute, while the dynamic data (leaves) is stored in the ‘operator’ variable collection, which is not trainable. This separation allows the operator’s arrays to be updated without triggering recompilation while keeping the structure static.
- Parameters:
Example:
import netket as nk hilbert = nk.hilbert.Spin(0.5, 4) base_module = nk.models.RBM(alpha=1) operator = nk.operator.spin.sigmax(hilbert, 0) # Flatten the operator to separate static structure from dynamic data leaves, treedef = jax.tree.flatten(operator) transformed = ApplyOperatorModuleLinen(base_module=base_module, operator_treedef=treedef) # Initialize: first init the base module to get its params base_params = base_module.init(jax.random.key(1), hilbert.all_states()) # Then add only the operator leaves to the variables dict variables = {**base_params, 'operator': {'leaves': leaves}} # Apply the transformed module logpsi = transformed.apply(variables, x) # The operator can be updated without recompilation # Only update the leaves (treedef is fixed in the module) new_leaves, _ = jax.tree.flatten(new_operator) variables['operator']['leaves'] = new_leaves logpsi = transformed.apply(variables, x)
- Attributes
- operator#
Reconstruct the operator from its flattened representation.
- Methods
-
- classmethod from_module_and_variables(bare_module, operator, bare_variables)[source]#
Create a TransformedModule from a bare module, operator, and variables.
- Parameters:
- Returns:
wrapped_module is the new TransformedModule instance
wrapped_variables is the properly structured variables dict
- Return type:
A tuple of (wrapped_module, wrapped_variables) where