netket.stats.OnlineStats#

class netket.stats.OnlineStats[source]#

Bases: Pytree

Streaming accumulator for MCMC statistics across multiple batches.

Accumulates mean, variance, tau_corr, R_hat, and error_of_mean incrementally using the parallel Welford algorithm. Supports optional exponential decay (EMA) to down-weight old data.

When max_lag > 0 (default 64), the autocovariance function at lags 0..max_lag is also tracked online via a per-chain sample buffer. This enables a Geyer IPS+IMS estimate of the integrated autocorrelation time (tau_corr_acf) that works even with a single chain.

All per-chain arrays are JAX arrays (pytree_node=True) so that the object is a valid JAX pytree. The scalar configuration fields (decay, max_lag) and the Python-int counters (_n_samples_total, _buf_len) are static (pytree_node=False).

Use online_statistics() as the functional API, or from_data() to construct directly from a first batch.

Example:

estimator = OnlineStats.from_data(first_batch)
for batch in remaining_batches:
    estimator = estimator.update(batch)
stats = estimator.get_stats()
Inheritance
Inheritance diagram of netket.stats.OnlineStats
__init__(n_chains, dtype, *, decay=None, max_lag=64)[source]#

Initialize empty online statistics buffers.

Parameters:
  • n_chains (int) – number of independent chains

  • dtype – dtype of incoming samples (used for means)

  • decay (float | None) – EMA decay factor applied per update call (default None ≑ 1.0 β†’ no decay)

  • max_lag (int) – maximum lag for online ACF estimator (set 0 to disable)

Attributes
R_hat#

The split-RΜ‚ convergence diagnostic.

Compares intra-chain and inter-chain variance. Returns NaN when fewer than 2 chains have been observed.

acf#

Normalized autocorrelation function, shape (max_lag+1,).

Averaged over chains. Returns None if no ACF data has been accumulated yet (max_lag == 0, decay != 1.0, or no updates).

Follows OnlineStats.jl AutoCov: per-lag autocovariance is recovered as

\[C(k) = E[x_t x_{t-k}] - E_{\rm lag}[x_{t-k}] \cdot E_{\rm cur}[x_t]\]

using the actual pair-subset means for each lag (not the global mean).

chain_means#

Per-chain mean estimates, shape (n_chains,).

Useful for computing cross-chain covariance matrices when combining multiple OnlineStats accumulators (e.g. in OnlineStatsBatch).

mean#

The current mean estimate (dtype matches the input data).

n_chains#

Number of Markov chains.

n_samples#

Total number of raw samples accumulated (never decayed).

tau_corr#

Integrated autocorrelation time (ACF-based if available, else batch).

tau_corr_acf#

Integrated autocorrelation time from the online ACF via Geyer IPS+IMS.

Uses Geyer’s initial positive sequence (IPS): builds paired sums P[t] = rho[2t] + rho[2t+1] and discards all pairs from the first non-positive one onward. Then enforces the initial monotone sequence (IMS) condition by taking a cumulative minimum. Finally returns tau = 2 * sum(P) - 1.

This is more robust than the Sokal window when max_lag is finite: the Sokal condition m >= c*tau requires m ~ 5*tau, which exceeds max_lag for any tau > max_lag/5, leading to inflated estimates. Geyer IPS truncates at the natural decay of the ACF regardless of tau.

Returns NaN when max_lag == 0, when decay != 1.0, or before enough data has accumulated.

tau_corr_batch#

Integrated autocorrelation time estimated from between-chain (batch) variance.

For M independent Markov chains each of effective length n, the variance of the chain means B and the pooled within-chain variance W are related by:

\[\frac{B}{W} \approx \frac{1 + 2\tau}{n}\]

Rearranging gives the estimator:

\[\tau = \frac{1}{2}\left(\frac{n \, B}{W} - 1\right)\]

where:

  • \(n\) = effective samples per chain (n_samples_total / n_chains)

  • \(B\) = between-chain variance = Var(chain_means)

  • \(W\) = within-chain (pooled) variance = global_variance

Requires at least 2 chains and decay == 1.0. Returns NaN otherwise.

variance#

The current variance estimate (float64).

max_lag: int#
Methods
classmethod from_data(data, *, decay=None, max_lag=64)[source]#

Construct an OnlineStats from an initial batch of samples.

Equivalent to calling online_statistics() with old_estimator=None.

Parameters:
  • data – Array of shape (n_samples,) or (n_chains, n_samples_per_chain).

  • decay (float | None) – EMA decay factor (default None = no decay).

  • max_lag (int) – Maximum lag for the online ACF estimator (default 64).

Return type:

OnlineStats

Returns:

A new OnlineStats initialized from data.

get_stats()[source]#

Convert accumulated state into a Stats object.

Return type:

Stats

replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)

to_compound()[source]#

Call get_stats() and return its compound representation.

to_dict()[source]#

Call get_stats() and return its dictionary representation.

update(data)[source]#

Incorporate a new batch of samples and return an updated OnlineStats.

Parameters:

data – Array of shape (n_samples,) or (n_chains, n_samples_per_chain). Must have the same number of chains as the existing accumulator.

Return type:

OnlineStats

Returns:

A new OnlineStats with updated state.