netket.nn.freeze_parameters#
- netket.nn.freeze_parameters(model, variables, is_frozen)[source]#
Freeze a subset of model parameters, identified by a filter function.
This function supports {class}`flax.linen.Module`, {class}`flax.nnx.Module`, and simple functions in the following way:
NNX modules: matching
nnx.Paramvariables are converted toFrozen, which NetKet puts in the model_state dictionary instead of the parameters.Flax Linen modules: frozen parameters land in the
"frozen_params"collection.Plain apply functions (any callable that is not a Module): frozen parameters live in
"frozen_params".
- Parameters:
model (
Any) – A module, or a plain(variables, x) -> yapply function.variables (
dict) – Variables dict matching model (with at least a"params"key for Linen and functional models; ignored for NNX, whose parameters are stored inside the module).is_frozen (
Callable[[tuple[str,...],Any],bool]) – Callable filter(path, leaf) -> bool.
- Return type:
- Returns:
(new_model, new_variables)— new_model has the same calling convention as model (Module → Module, apply_fun → apply_fun).new_variablesisNonefor NNX modules (parameters live inside the module).
For variational states, prefer calling
netket.vqs.freeze_parameters(), which operates on the variational state itself.Example (NNX):
import jax.numpy as jnp import netket as nk from flax import nnx class RBM(nnx.Module): def __init__(self): self.dense = nnx.Linear(4, 8, rngs=nnx.Rngs(0)) def __call__(self, x): return jnp.sum(jnp.log(jnp.cosh(self.dense(x.astype(float))))) new_model, new_variables = nk.nn.freeze_parameters( RBM(), {}, lambda path, _: "kernel" in path ) # new_variables is None for NNX — params travel inside new_model vstate = nk.vqs.MCState(sampler, new_model, variables=new_variables)