netket.stats.OnlineStatsBatch#
- class netket.stats.OnlineStatsBatch[source]#
Bases:
PytreeBatched accumulator for K OnlineStats estimators and a combining function.
Wraps K separate
OnlineStatsinstances and applies the delta method atget_stats()time to compute statistics for a smooth functional of the K marginal means.The combinator
f: (K,) -> scalar | arraymust be JAX-traceable so thatjax.jacfwdcan compute its Jacobian.get_stats()usesjax.eval_shape()to inspect the combinator output shape and returns:a
Statsobject whenfreturns a scalar;a
StatsBatchobject whenfreturns an array of any shape (access results via.meanand.error_of_mean).
Use
online_statistics_batch()as the functional API, orfrom_data()to construct directly from a first batch:acc = None for batch in batches: # batch: (n_chains, chain_len, K) acc = online_statistics_batch(batch, combinator, acc) stats = acc.get_stats() # Stats for scalar combinator
For array-valued combinators (e.g. susceptibility matrix):
def chi_matrix(X): # combinator: (K,) -> (p, p) return X[p:].reshape(p, p) - X[:p, None] * X[None, :p] acc = online_statistics_batch(data, chi_matrix, acc) sb = acc.get_stats() # StatsBatch, shape (p, p) sb.mean, sb.error_of_mean # both shape (p, p)
This class is used internally by
expect_to_precision()whenever the operator has aLocalEstimatorsBatchdispatch.- Inheritance

- __init__(estimators, combinator)[source]#
- Parameters:
estimators (
tuple) – per-channelOnlineStatsaccumulators, one per channel K.combinator (
Callable) – JAX-traceable(K,) -> scalar | arraymap combining the K channel means into the final observable.
- Attributes
- error_of_mean#
Standard error of the mean, computed via the delta method.
- mean#
Current estimate of
combinator(X)as a JAX array.
- n_chains#
Number of chains.
- n_samples#
Total samples accumulated.
- estimators: tuple[OnlineStats, ...]#
Per-channel
OnlineStatsaccumulators.
- Methods
- classmethod from_data(data, combinator, *, decay=None, max_lag=64)[source]#
Construct an
OnlineStatsBatchfrom an initial batch.Equivalent to calling
online_statistics_batch()withold_estimator=None.- Parameters:
data – Array of shape
(n_chains, n_samples_per_chain, K).combinator (
Callable) –f: (K,) -> scalar | array, must be JAX-traceable.decay (
float|None) – EMA decay factor applied per update call to each per-channel accumulator (default None = no decay).max_lag (
int) – Maximum lag for the per-channel online ACF estimator.
- Return type:
- Returns:
A new
OnlineStatsBatchinitialized from data.
- get_stats()[source]#
Delta-method statistics for the combined functional.
Uses
jax.eval_shape()to inspect the combinator output shape and returns:a
Statsfor scalar combinators;a
StatsBatchfor array-valued combinators, with.meanand.error_of_meanhaving the same shape ascombinator(X).
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- to_compound()[source]#
Call
get_stats()and return its compound representation (scalar combinators only).
- to_dict()[source]#
Call
get_stats()and return its dictionary representation (scalar combinators only).