netket.nn.apply_operator.ApplyOperatorModuleNNX

netket.nn.apply_operator.ApplyOperatorModuleNNX#

class netket.nn.apply_operator.ApplyOperatorModuleNNX[source]#

A Flax NNX module that wraps another NNX module and applies an operator transformation.

This module wraps a base neural network module and applies an operator in front of it, computing \(\log(O\lvert\psi\rangle)\) where \(O\) is the operator and \(\lvert\psi\rangle\) is represented by the base module.

The operator is stored as a regular NNX Variable with collection=’operator’, which makes it non-trainable by default (since optimizers typically only update ‘params’ collection).

Unlike the Linen version, NNX modules are stateful and contain their parameters directly. This makes the implementation more straightforward - we just store the base module and operator as attributes.

Parameters:
  • base_module (Module) – The NNX module to wrap

  • operator – The operator to apply

Example:

import netket as nk
from netket.models import RBM
from flax import nnx

# Create base NNX module (already initialized with parameters)
base_module = RBM(N=10, alpha=2, rngs=nnx.Rngs(0))
operator = nk.operator.spin.sigmax(hilbert, 0)

# Create transformed module
transformed = ApplyOperatorModuleNNX(base_module, operator)

# Use it directly (NNX style)
logpsi = transformed(x)

# Or use with MCState
vstate = nk.vqs.MCState(sampler, transformed, n_samples=1000)

# The operator can be updated
transformed.operator = new_operator
logpsi = transformed(x)
__init__(base_module, operator)[source]#

Initialize the transformed module.

Parameters:
  • base_module (Module) – The base NNX module to wrap

  • operator – The operator to apply in front of the base module