import netket as nk
import jax.numpy as jnp
import numpy as np
Defining custom local estimators (new)#
When you call expect(), NetKet internally computes a per-sample local estimator \(O_{\text{loc}}(x)\), the quantity whose Monte Carlo average converges to \(\langle O \rangle\).
If an operator expectation value can be written as
then computing the mean and estimating the variance and error bar is straightforward: you pass the matrix of local estimators to statistics().
Other observables are however more complex, and their expectation value depends nonlinearly on several random variables. A standard example is the variance,
so computing the observable and estimating its variance and error bar is not as straightforward.
This page discusses how NetKet estimates errors and variances in this setting, and how you can define new observables or operators that benefit from the same functionality.
In particular, it focuses on the local_estimators() interface, which is also what makes expect_to_precision() work for custom observables.
What local_estimators() returns#
The local_estimators() interface returns one of two pytrees.
For observables whose expectation is obtained by averaging a single local estimator, the returned samples are
and the return type is LocalEstimators.
For observables built from several channels, it returns all the local quantities needed by the final nonlinear combination. For the variance example above this means
In that case the return type is LocalEstimatorsBatch.
LocalEstimatorsfor scalar observables.datahas shape(n_chains, chain_length).LocalEstimatorsBatchfor nonlinear observables that are built fromKscalar expectations.datahas shape(n_chains, chain_length, K), andcombinatormaps the channel means to the final scalar observable.
Both containers expose to_stats() for one-shot statistics and accumulate() for online accumulation; LocalEstimatorsBatch provides methods with the same names.
# Set up a small system
hi = nk.hilbert.Spin(s=1 / 2, N=8)
g = nk.graph.Chain(8)
H = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
sa = nk.sampler.MetropolisLocal(hi, n_chains=32)
vs = nk.vqs.MCState(sa, nk.models.RBM(alpha=1), n_samples=512)
le = vs.local_estimators(H)
print(type(le).__name__)
print(le.data.shape)
print(le.to_stats())
print(vs.expect(H))
LocalEstimators is a JAX pytree whose dynamic leaf is data.
LocalEstimatorsBatch is also a pytree; its combinator is stored as static metadata in the treedef.
You can pass both containers through jax.jit() or jax.vmap() as long as the combinator itself is JAX-traceable.
Nonlinear observables and the delta method#
Consider the variance of the Hamiltonian, \(\mathrm{Var}(H) = \langle H^2 \rangle - \langle H \rangle^2\). If we estimate \(\mu_0 = \langle H \rangle\) and \(\mu_1 = \langle H^2 \rangle\), the observable is the nonlinear function \(f(\mu) = \mu_1 - \mu_0^2\). The correct error bar comes from the delta method applied to the covariance of the channel means.
NetKet represents this with LocalEstimatorsBatch.
For one-shot summaries, call to_stats(); for online accumulation, use online_statistics_batch().
from netket.stats import LocalEstimatorsBatch
var_op = nk.observable.VarianceObservable(H, use_Oloc_squared=True)
le_var = vs.local_estimators(var_op)
print(type(le_var).__name__)
print(le_var.data.shape)
print(le_var.to_stats())
# Manual equivalent: two channels, <H> and <H^2>, combined as mu[1] - mu[0]**2.
H_loc = vs.local_estimators(H).data
manual_var = LocalEstimatorsBatch(
data=jnp.stack([H_loc, H_loc**2], axis=-1),
combinator=lambda mu: mu[1] - mu[0] ** 2,
)
print(manual_var.to_stats())
In practice you usually use built-in observables such as VarianceObservable or InfidelityOperator, but the same pattern is available for custom operators.
Online accumulation with accumulate#
When you need to pool samples across multiple Monte Carlo draws, call accumulate() on successive containers.
Scalar batches return OnlineStats; multi-channel batches return OnlineStatsBatch.
Both expose get_stats(); OnlineStatsBatch provides the same method.
acc_scalar = None
acc_batch = None
for _ in range(5):
vs.sample(n_discard_per_chain=0)
acc_scalar = vs.local_estimators(H).accumulate(acc_scalar, max_lag=64)
acc_batch = vs.local_estimators(var_op).accumulate(acc_batch, max_lag=64)
print(acc_scalar.get_stats())
print(acc_batch.get_stats())
OnlineStatsBatch is also a JAX pytree: its internal tuple of per-channel OnlineStats objects is dynamic, while the combinator is static.
That makes it suitable for checkpointing or for use as a carry in jax.lax.scan().
Example: V-score custom operator#
Any custom AbstractOperator that registers a local_estimators() dispatch gets expect() and expect_to_precision() for free.
For MCState, define separate chunk_size=None and chunk_size=int overloads so the custom implementation stays unambiguous with the generic scalar-operator path.
Here is a minimal example.
Suppose we want to estimate the relative variance (V-score),
We reuse the Hamiltonian local estimator and return a LocalEstimatorsBatch with two channels.
from netket.operator import AbstractOperator
from netket.stats import LocalEstimatorsBatch
from netket.vqs import MCState
from netket.vqs.mc import local_estimators
class VScore(AbstractOperator):
"""V-score = Var(H) / <H>^2 for a given Hamiltonian."""
def __init__(self, hamiltonian):
super().__init__(hamiltonian.hilbert)
self.hamiltonian = hamiltonian
@property
def dtype(self):
return self.hamiltonian.dtype
def _vscore_local_estimators(vstate: MCState, op: VScore, chunk_size):
le = vstate.local_estimators(op.hamiltonian, chunk_size=chunk_size)
H_loc = le.data
data = jnp.stack([H_loc, H_loc**2], axis=-1)
return LocalEstimatorsBatch(
data=data,
combinator=lambda mu: (mu[1] - mu[0] ** 2) / mu[0] ** 2,
)
@local_estimators.dispatch
def _(vstate: MCState, op: VScore, chunk_size: None) -> LocalEstimatorsBatch:
return _vscore_local_estimators(vstate, op, chunk_size)
@local_estimators.dispatch
def _(vstate: MCState, op: VScore, chunk_size: int) -> LocalEstimatorsBatch:
return _vscore_local_estimators(vstate, op, chunk_size)
vscore = VScore(H)
print(vs.expect(vscore))
Because VScore now has a local_estimators() dispatch returning a LocalEstimatorsBatch, expect_to_precision() can accumulate both channels online and apply the same combinator at each step.
result = vs.expect_to_precision(vscore, atol=0.5, max_iter=50, verbose=False)
print(result.get_stats())
Example: using LocalEstimatorsBatch without MCState#
All the same building blocks also work on plain arrays. This is useful for post-processing saved data or for writing unit tests.
from netket.stats import (
LocalEstimators,
LocalEstimatorsBatch,
online_statistics,
online_statistics_batch,
)
rng = np.random.default_rng(0)
n_chains, chain_len = 16, 200
# Scalar: mean of a noisy signal
data_1d = jnp.array(rng.standard_normal((n_chains, chain_len)))
le_scalar = LocalEstimators(data=data_1d)
print(le_scalar.to_stats())
# Vector: variance-like functional with correct delta-method error bars
data_2d = jnp.stack([data_1d, data_1d**2], axis=-1)
le_batch = LocalEstimatorsBatch(
data=data_2d,
combinator=lambda mu: mu[1] - mu[0] ** 2,
)
print(le_batch.to_stats())
# Online accumulation with the explicit batch helper
batches = jnp.array(rng.standard_normal((10, n_chains, chain_len, 2)))
acc = None
for batch in batches:
acc = online_statistics_batch(
batch,
combinator=lambda mu: mu[1] - mu[0] ** 2,
old_estimator=acc,
)
print(acc.get_stats())
# The generic online_statistics() entry point also accepts LocalEstimatorsBatch
acc2 = None
for batch in batches:
le_batch = LocalEstimatorsBatch(
batch,
combinator=lambda mu: mu[1] - mu[0] ** 2,
)
acc2 = online_statistics(le_batch, old_estimator=acc2)
print(acc2.get_stats())