netket.stats.OnlineStats#
- class netket.stats.OnlineStats[source]#
Bases:
PytreeStreaming accumulator for MCMC statistics across multiple batches.
Accumulates mean, variance, tau_corr, R_hat, and error_of_mean incrementally using the parallel Welford algorithm. Supports optional exponential decay (EMA) to down-weight old data.
When
max_lag > 0(default 64), the autocovariance function at lags0..max_lagis also tracked online via a per-chain sample buffer. This enables a Geyer IPS+IMS estimate of the integrated autocorrelation time (tau_corr_acf) that works even with a single chain.All per-chain arrays are JAX arrays (
pytree_node=True) so that the object is a valid JAX pytree. The scalar configuration fields (decay,max_lag) and the Python-int counters (_n_samples_total,_buf_len) are static (pytree_node=False).Use
online_statistics()as the functional API, orfrom_data()to construct directly from a first batch.Example:
estimator = OnlineStats.from_data(first_batch) for batch in remaining_batches: estimator = estimator.update(batch) stats = estimator.get_stats()
- Inheritance

- __init__(n_chains, dtype, *, decay=None, max_lag=64)[source]#
Initialize empty online statistics buffers.
- Attributes
- R_hat#
The split-RΜ convergence diagnostic.
Compares intra-chain and inter-chain variance. Returns
NaNwhen fewer than 2 chains have been observed.
- acf#
Normalized autocorrelation function, shape
(max_lag+1,).Averaged over chains. Returns
Noneif no ACF data has been accumulated yet (max_lag == 0,decay != 1.0, or no updates).Follows OnlineStats.jl AutoCov: per-lag autocovariance is recovered as
\[C(k) = E[x_t x_{t-k}] - E_{\rm lag}[x_{t-k}] \cdot E_{\rm cur}[x_t]\]using the actual pair-subset means for each lag (not the global mean).
- chain_means#
Per-chain mean estimates, shape
(n_chains,).Useful for computing cross-chain covariance matrices when combining multiple
OnlineStatsaccumulators (e.g. inOnlineStatsBatch).
- mean#
The current mean estimate (dtype matches the input data).
- n_chains#
Number of Markov chains.
- n_samples#
Total number of raw samples accumulated (never decayed).
- tau_corr#
Integrated autocorrelation time (ACF-based if available, else batch).
- tau_corr_acf#
Integrated autocorrelation time from the online ACF via Geyer IPS+IMS.
Uses Geyerβs initial positive sequence (IPS): builds paired sums P[t] = rho[2t] + rho[2t+1] and discards all pairs from the first non-positive one onward. Then enforces the initial monotone sequence (IMS) condition by taking a cumulative minimum. Finally returns tau = 2 * sum(P) - 1.
This is more robust than the Sokal window when max_lag is finite: the Sokal condition m >= c*tau requires m ~ 5*tau, which exceeds max_lag for any tau > max_lag/5, leading to inflated estimates. Geyer IPS truncates at the natural decay of the ACF regardless of tau.
Returns
NaNwhenmax_lag == 0, whendecay != 1.0, or before enough data has accumulated.
- tau_corr_batch#
Integrated autocorrelation time estimated from between-chain (batch) variance.
For M independent Markov chains each of effective length n, the variance of the chain means B and the pooled within-chain variance W are related by:
\[\frac{B}{W} \approx \frac{1 + 2\tau}{n}\]Rearranging gives the estimator:
\[\tau = \frac{1}{2}\left(\frac{n \, B}{W} - 1\right)\]where:
\(n\) = effective samples per chain (
n_samples_total / n_chains)\(B\) = between-chain variance =
Var(chain_means)\(W\) = within-chain (pooled) variance =
global_variance
Requires at least 2 chains and
decay == 1.0. ReturnsNaNotherwise.
- variance#
The current variance estimate (float64).
- Methods
- classmethod from_data(data, *, decay=None, max_lag=64)[source]#
Construct an
OnlineStatsfrom an initial batch of samples.Equivalent to calling
online_statistics()withold_estimator=None.- Parameters:
- Return type:
- Returns:
A new
OnlineStatsinitialized from data.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- to_compound()[source]#
Call
get_stats()and return its compound representation.
- to_dict()[source]#
Call
get_stats()and return its dictionary representation.
- update(data)[source]#
Incorporate a new batch of samples and return an updated
OnlineStats.- Parameters:
data β Array of shape
(n_samples,)or(n_chains, n_samples_per_chain). Must have the same number of chains as the existing accumulator.- Return type:
- Returns:
A new
OnlineStatswith updated state.