netket.nn.apply_operator.ApplyOperatorModuleNNX#
- class netket.nn.apply_operator.ApplyOperatorModuleNNX[source]#
A Flax NNX module that wraps another NNX module and applies an operator transformation.
This module wraps a base neural network module and applies an operator in front of it, computing \(\log(O\lvert\psi\rangle)\) where \(O\) is the operator and \(\lvert\psi\rangle\) is represented by the base module.
The operator is stored as a regular NNX Variable with collection=’operator’, which makes it non-trainable by default (since optimizers typically only update ‘params’ collection).
Unlike the Linen version, NNX modules are stateful and contain their parameters directly. This makes the implementation more straightforward - we just store the base module and operator as attributes.
- Parameters:
base_module (
Module) – The NNX module to wrapoperator – The operator to apply
Example:
import netket as nk from netket.models import RBM from flax import nnx # Create base NNX module (already initialized with parameters) base_module = RBM(N=10, alpha=2, rngs=nnx.Rngs(0)) operator = nk.operator.spin.sigmax(hilbert, 0) # Create transformed module transformed = ApplyOperatorModuleNNX(base_module, operator) # Use it directly (NNX style) logpsi = transformed(x) # Or use with MCState vstate = nk.vqs.MCState(sampler, transformed, n_samples=1000) # The operator can be updated transformed.operator = new_operator logpsi = transformed(x)