netket.stats.LocalEstimatorsBatch#

class netket.stats.LocalEstimatorsBatch[source]#

Bases: Pytree

Per-sample K-channel estimators for nonlinear observables.

data has shape (n_chains, chain_len, K). combinator: (K,) -> scalar | array collapses the K channel means into a result of any shape using the delta method (first-order error propagation).

This class is returned by local_estimators() for operators such as VarianceObservable that require more than one local estimator channel to form the final quantity.

to_stats() uses jax.eval_shape() to inspect the combinator output shape and returns:

  • a Stats for scalar combinators;

  • a StatsBatch for array-valued combinators, with .mean and .error_of_mean having the same shape as combinator(X).

to_online_stats() returns an OnlineStatsBatch.

Examples:

# Variance: 2 channels, scalar combinator
le = LocalEstimatorsBatch(
    data=jnp.stack([H_loc, H_loc**2], axis=-1),  # (n_chains, chain_len, 2)
    combinator=lambda mu: mu[1] - mu[0]**2,
)
stats = le.to_stats()    # Stats (scalar combinator → Stats)
acc   = le.accumulate()  # OnlineStatsBatch for iterative estimation

# Susceptibility matrix: p+p² channels, array combinator
le = LocalEstimatorsBatch(
    data=channels,        # (n_chains, chain_len, p + p²)
    combinator=chi_matrix,  # (K,) -> (p, p)
)
sb = le.to_stats()                    # StatsBatch, shape (p, p)
sb.mean, sb.error_of_mean             # both shape (p, p)
sb = le.accumulate().get_stats()      # online version, same shapes
Inheritance
Inheritance diagram of netket.stats.LocalEstimatorsBatch
Attributes
n_channels#

Number of channels K in the data (last axis of data).

data: Array#

Estimator channels with shape (n_chains, chain_len, K).

combinator: Callable#

JAX-traceable (K,) -> scalar | array map combining channel means.

Methods
accumulate(old=None, *, max_lag=64)[source]#

Fold this batch into an online accumulator.

Parameters:
  • old – existing OnlineStatsBatch returned by a previous call, or None to start a fresh accumulator.

  • max_lag (int) – maximum ACF lag (only used when creating a fresh accumulator on the first call).

Return type:

OnlineStatsBatch

Returns:

Updated OnlineStatsBatch.

replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)

to_online_stats(*, max_lag=64)[source]#

Create an OnlineStatsBatch initialised with this batch.

Subsequent batches are folded in via acc = acc.update(new_le.data). Prefer accumulate() when writing the accumulation loop, as it handles both first and subsequent batches uniformly.

Return type:

OnlineStatsBatch

Parameters:

max_lag (int)

to_stats()[source]#

One-shot delta-method statistics for this batch.

Computes per-chain means, forms their sample covariance matrix, and applies the delta method via jax.eval_shape()-based dispatch:

  • Returns a Stats for scalar combinators.

  • Returns (mean, error_of_mean) as JAX arrays for array-valued combinators, both with the same shape as combinator(X).