netket.stats.LocalEstimatorsBatch#
- class netket.stats.LocalEstimatorsBatch[source]#
Bases:
PytreePer-sample K-channel estimators for nonlinear observables.
datahas shape(n_chains, chain_len, K).combinator: (K,) -> scalar | arraycollapses the K channel means into a result of any shape using the delta method (first-order error propagation).This class is returned by
local_estimators()for operators such asVarianceObservablethat require more than one local estimator channel to form the final quantity.to_stats()usesjax.eval_shape()to inspect the combinator output shape and returns:a
Statsfor scalar combinators;a
StatsBatchfor array-valued combinators, with.meanand.error_of_meanhaving the same shape ascombinator(X).
to_online_stats()returns anOnlineStatsBatch.Examples:
# Variance: 2 channels, scalar combinator le = LocalEstimatorsBatch( data=jnp.stack([H_loc, H_loc**2], axis=-1), # (n_chains, chain_len, 2) combinator=lambda mu: mu[1] - mu[0]**2, ) stats = le.to_stats() # Stats (scalar combinator → Stats) acc = le.accumulate() # OnlineStatsBatch for iterative estimation # Susceptibility matrix: p+p² channels, array combinator le = LocalEstimatorsBatch( data=channels, # (n_chains, chain_len, p + p²) combinator=chi_matrix, # (K,) -> (p, p) ) sb = le.to_stats() # StatsBatch, shape (p, p) sb.mean, sb.error_of_mean # both shape (p, p) sb = le.accumulate().get_stats() # online version, same shapes
- Inheritance

- Attributes
- n_channels#
Number of channels K in the data (last axis of
data).
- data: Array#
Estimator channels with shape
(n_chains, chain_len, K).
- Methods
- accumulate(old=None, *, max_lag=64)[source]#
Fold this batch into an online accumulator.
- Parameters:
old – existing
OnlineStatsBatchreturned by a previous call, orNoneto start a fresh accumulator.max_lag (
int) – maximum ACF lag (only used when creating a fresh accumulator on the first call).
- Return type:
- Returns:
Updated
OnlineStatsBatch.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- to_online_stats(*, max_lag=64)[source]#
Create an
OnlineStatsBatchinitialised with this batch.Subsequent batches are folded in via
acc = acc.update(new_le.data). Preferaccumulate()when writing the accumulation loop, as it handles both first and subsequent batches uniformly.- Return type:
- Parameters:
max_lag (int)
- to_stats()[source]#
One-shot delta-method statistics for this batch.
Computes per-chain means, forms their sample covariance matrix, and applies the delta method via
jax.eval_shape()-based dispatch:Returns a
Statsfor scalar combinators.Returns
(mean, error_of_mean)as JAX arrays for array-valued combinators, both with the same shape ascombinator(X).