netket.nn.freeze_parameters

netket.nn.freeze_parameters#

netket.nn.freeze_parameters(model, variables, is_frozen)[source]#

Freeze a subset of model parameters, identified by a filter function.

This function supports {class}`flax.linen.Module`, {class}`flax.nnx.Module`, and simple functions in the following way:

  • NNX modules: matching nnx.Param variables are converted to Frozen, which NetKet puts in the model_state dictionary instead of the parameters.

  • Flax Linen modules: frozen parameters land in the "frozen_params" collection.

  • Plain apply functions (any callable that is not a Module): frozen parameters live in "frozen_params".

Parameters:
  • model (Any) – A module, or a plain (variables, x) -> y apply function.

  • variables (dict) – Variables dict matching model (with at least a "params" key for Linen and functional models; ignored for NNX, whose parameters are stored inside the module).

  • is_frozen (Callable[[tuple[str, ...], Any], bool]) – Callable filter (path, leaf) -> bool.

Return type:

tuple[Any, dict]

Returns:

(new_model, new_variables) — new_model has the same calling convention as model (Module → Module, apply_fun → apply_fun). new_variables is None for NNX modules (parameters live inside the module).

For variational states, prefer calling netket.vqs.freeze_parameters(), which operates on the variational state itself.

Example (NNX):

import jax.numpy as jnp
import netket as nk
from flax import nnx

class RBM(nnx.Module):
    def __init__(self):
        self.dense = nnx.Linear(4, 8, rngs=nnx.Rngs(0))

    def __call__(self, x):
        return jnp.sum(jnp.log(jnp.cosh(self.dense(x.astype(float)))))

new_model, new_variables = nk.nn.freeze_parameters(
    RBM(), {}, lambda path, _: "kernel" in path
)
# new_variables is None for NNX — params travel inside new_model
vstate = nk.vqs.MCState(sampler, new_model, variables=new_variables)