netket.stats.OnlineStatsBatch#

class netket.stats.OnlineStatsBatch[source]#

Bases: Pytree

Batched accumulator for K OnlineStats estimators and a combining function.

Wraps K separate OnlineStats instances and applies the delta method at get_stats() time to compute statistics for a smooth functional of the K marginal means.

The combinator f: (K,) -> scalar | array must be JAX-traceable so that jax.jacfwd can compute its Jacobian. get_stats() uses jax.eval_shape() to inspect the combinator output shape and returns:

  • a Stats object when f returns a scalar;

  • a StatsBatch object when f returns an array of any shape (access results via .mean and .error_of_mean).

Use online_statistics_batch() as the functional API, or from_data() to construct directly from a first batch:

acc = None
for batch in batches:                       # batch: (n_chains, chain_len, K)
    acc = online_statistics_batch(batch, combinator, acc)
stats = acc.get_stats()                     # Stats for scalar combinator

For array-valued combinators (e.g. susceptibility matrix):

def chi_matrix(X):                          # combinator: (K,) -> (p, p)
    return X[p:].reshape(p, p) - X[:p, None] * X[None, :p]

acc = online_statistics_batch(data, chi_matrix, acc)
sb = acc.get_stats()                        # StatsBatch, shape (p, p)
sb.mean, sb.error_of_mean                   # both shape (p, p)

This class is used internally by expect_to_precision() whenever the operator has a LocalEstimatorsBatch dispatch.

Inheritance
Inheritance diagram of netket.stats.OnlineStatsBatch
__init__(estimators, combinator)[source]#
Parameters:
  • estimators (tuple) – per-channel OnlineStats accumulators, one per channel K.

  • combinator (Callable) – JAX-traceable (K,) -> scalar | array map combining the K channel means into the final observable.

Attributes
error_of_mean#

Standard error of the mean, computed via the delta method.

mean#

Current estimate of combinator(X) as a JAX array.

n_chains#

Number of chains.

n_samples#

Total samples accumulated.

estimators: tuple[OnlineStats, ...]#

Per-channel OnlineStats accumulators.

combinator: Callable#

JAX-traceable (K,) -> scalar | array map combining channel means.

Methods
classmethod from_data(data, combinator, *, decay=None, max_lag=64)[source]#

Construct an OnlineStatsBatch from an initial batch.

Equivalent to calling online_statistics_batch() with old_estimator=None.

Parameters:
  • data – Array of shape (n_chains, n_samples_per_chain, K).

  • combinator (Callable) – f: (K,) -> scalar | array, must be JAX-traceable.

  • decay (float | None) – EMA decay factor applied per update call to each per-channel accumulator (default None = no decay).

  • max_lag (int) – Maximum lag for the per-channel online ACF estimator.

Return type:

OnlineStatsBatch

Returns:

A new OnlineStatsBatch initialized from data.

get_stats()[source]#

Delta-method statistics for the combined functional.

Uses jax.eval_shape() to inspect the combinator output shape and returns:

  • a Stats for scalar combinators;

  • a StatsBatch for array-valued combinators, with .mean and .error_of_mean having the same shape as combinator(X).

replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)

to_compound()[source]#

Call get_stats() and return its compound representation (scalar combinators only).

to_dict()[source]#

Call get_stats() and return its dictionary representation (scalar combinators only).

update(data)[source]#

Incorporate a new batch.

Parameters:

data – shape (n_chains, chain_len, K)

Return type:

OnlineStatsBatch