netket.vqs.MCState#

class netket.vqs.MCState[source]#

Bases: VariationalState

Variational State for a Variational Neural Quantum State.

The state is sampled according to the provided sampler.

Inheritance
Inheritance diagram of netket.vqs.MCState
__init__(sampler, model=None, *, n_samples=None, n_samples_per_rank=None, n_discard_per_chain=None, chunk_size=None, variables=None, init_fun=None, apply_fun=None, seed=None, sampler_seed=None, mutable=False, training_kwargs={})[source]#

Constructs the MCState.

Parameters:
  • sampler (Sampler) – The sampler

  • model – (Optional) The neural quantum state ansatz, encoded into a model. This should be a flax.linen.Module instance, or any other supported neural network framework. If not provided, you must specify init_fun and apply_fun.

  • n_samples (int | None) – the total number of samples across chains and processes when sampling (default=1000).

  • n_samples_per_rank (int | None) – the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None).

  • n_discard_per_chain (int | None) – number of discarded samples at the beginning of each monte-carlo chain (default=5, except for ‘direct’ samplers where it is 0).

  • seed (Union[int, Any, None]) – rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.

  • sampler_seed (Union[int, Any, None]) – rng seed used to initialise the sampler. Defaults to a random one.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Name or list of names of mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.Module.apply() documentation (default=False)

  • init_fun (Callable[[Any, Sequence[int], Union[None, str, type[Any], dtype, _SupportsDType]], Array] | None) – Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method.

  • variables (Any | None) – Optional initial value for the variables (parameters and model state) of the model.

  • apply_fun (Callable | None) – Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to model.apply(variables, σ). specify only if your network has a non-standard apply method.

  • training_kwargs (dict) – a dict containing the optional keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training.

  • chunk_size (int | None) – (Defaults to None) If specified, calculations are split into chunks where the neural network is evaluated at most on chunk_size samples at once. This does not change the mathematical results, but will trade a higher computational cost for lower memory cost.

Attributes
chain_length#

Length of the markov chain used for sampling configurations.

If running under JAX sharding, the total samples will be n_devices * chain_length * n_batches.

chunk_size#

Suggested maximum size of the chunks used in forward and backward evaluations of the Neural Network model.

If your inputs are smaller than the chunk size this setting is ignored.

This can be used to lower the memory required to run a computation with a very high number of samples or on a very large lattice. Notice that inputs and outputs must still fit in memory, but the intermediate computations will now require less memory.

This option comes at an increased computational cost. While this cost should be negligible for large-enough chunk sizes, don’t use it unless you are memory bound!

This option is an hint: only some operations support chunking. If you perform an operation that is not implemented with chunking support, it will fall back to no chunking. To check if this happened, set the environment variable NETKET_DEBUG=1.

hilbert#

The descriptor of the Hilbert space on which this variational state is defined.

model#

Returns the model definition of this variational state.

When using model frameworks that encode the parameters directly into the model, such as equinox, bound flax.linen modules, or flax.nnx, this will return the model including the parameters.

If you want access to the raw model without the parameters that is used internally by netket, use MCState._model instead.

model_state#

The optional PyTree with the mutable state of the model, which is not optimized.

n_discard_per_chain#

Number of discarded samples at the beginning of the markov chain.

n_parameters#

The total number of parameters in the model.

n_samples#

The total number of samples generated at every sampling step.

n_samples_per_rank#

The number of samples generated on every JAX device at every sampling step.

parameters#

The pytree of the parameters of the model.

sampler#

The Monte Carlo sampler used by this Monte Carlo variational state.

samples#

Returns the set of cached samples.

The samples returned are guaranteed valid for the current state of the variational state. If no cached parameters are available, then they are sampled first and then cached.

To obtain a new set of samples either use reset() or sample().

variables#

The PyTree containing the parameters and state of the model, used when evaluating it.

sampler_state: SamplerState#

The current state of the sampler.

mutable: bool | str | Collection[str] | DenyList#

Specifies which collections in the model_state should be treated as mutable. Largely unused.

Methods
check_mc_convergence(op, *, min_chain_length=50, max_chain_length=500, plot=False)[source]#

Diagnose whether the Markov-chain sweep size is long enough to produce decorrelated samples for the expectation value of op.

Algorithm overview

The function operates on a temporary copy of state_ whose internal sweep size is reset to 1, exposing every elementary MC step as an individual sample row. Batches of local estimators are fed one at a time into online_statistics(), which maintains running estimates of the mean, variance, and the autocorrelation function (ACF) via the Geyer initial positive sequence (IPS) estimator.

The loop continues until those conditions are met:

  1. The ACF window is not saturated — i.e. the IPS found a non-positive consecutive pair, confirming that max_lag was large enough to capture the full tail of the ACF.

  2. The integrated autocorrelation time τ is reliable: the number of effective samples per chain is at least 50 (i.e. n_per_chain / τ 50).

Adaptive coarsening when the ACF window saturates

If the ACF window saturates (every consecutive pair (ρ[2t]+ρ[2t+1]) is positive up to max_lag), the current τ estimate is merely a lower bound — the chains are too short or the sweep too fine. The algorithm then:

  • doubles the internal sweep size (sweep_size *= 2), thinning the Markov chain to make long-range correlations visible within the window;

  • calls thin_acf_by_2() to re-index the ACF accumulator for the coarser cadence; and

  • calls expand_max_lag() to restore the lag window width so the next iterations can probe further.

This doubling is repeated as needed until the window is no longer saturated, or until max_chain_length samples/chain are exhausted.

Final diagnosis

After convergence the correlation time is re-expressed in terms of raw MC steps (τ_mc = τ_acf × final_sweep_size) and of the user’s original sweep units (τ_sweeps = τ_mc / orig_sweep_size). A sweep size is considered adequate when τ_sweeps < 1, meaning that consecutive samples produced by the user’s MCState are effectively independent. The recommended minimum sweep size is 2 τ_mc raw steps.

Parameters:
  • op (AbstractOperator) – The operator whose local estimators are used to probe correlations.

  • min_chain_length (int) – Minimum number of samples per chain to accumulate before the convergence check is applied.

  • plot (bool) – Controls diagnostic figure output. False (default) skips the figure entirely. True shows it interactively when a display is available, or saves it to mc_convergence.png when running non-interactively (e.g. a SLURM batch job — detected via sys.stdout.isatty() and the matplotlib backend). A str or Path always saves to that path.

  • max_chain_length (int) – Hard upper limit on samples per chain. The loop is terminated unconditionally once this many samples have been drawn.

Returns:

A tuple (stats, hist_data) where stats is the final OnlineStatistics accumulator and hist_data is a HistoryDict recording the evolution of key diagnostics (mean, error, R̂, τ) as the number of samples is increased.

See also

thermalise_mcmc() — advance chains to stationarity before measuring.

expect(O)[source]#

Estimates the quantum expectation value for a given operator \(O\) or generic observable. In the case of a pure state \(\psi\) and an operator, this is \(\langle O\rangle= \langle \Psi|O|\Psi\rangle/\langle\Psi|\Psi\rangle\) otherwise for a mixed state \(\rho\), this is \(\langle O\rangle= \textrm{Tr}[\rho \hat{O}]/\textrm{Tr}[\rho]\).

Parameters:

O (AbstractOperator) – the operator or observable for which to compute the expectation value.

Return type:

Stats

Returns:

An estimation of the quantum expectation value \(\langle O\rangle\).

expect_and_forces(O, *, mutable=None)[source]#

Estimates the quantum expectation value and the corresponding force vector for a given operator O.

The force vector \(F_j\) is defined as the covariance of log-derivative of the trial wave function and the local estimators of the operator. For complex holomorphic states, this is equivalent to the expectation gradient \(\frac{\partial\langle O\rangle}{\partial(\theta_j)^\star} = F_j\). For real-parameter states, the gradient is given by \(\frac{\partial\partial_j\langle O\rangle}{\partial\partial_j\theta_j} = 2 \textrm{Re}[F_j]\).

Parameters:
  • O (AbstractOperator) – The operator O for which expectation value and force are computed.

  • mutable (Union[bool, str, Collection[str], DenyList, None]) – Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).

Return type:

tuple[Stats, Any]

Returns:

An estimate of the quantum expectation value <O>. An estimate of the force vector \(F_j = \textrm{Cov}[\partial_j\log\psi, O_{\textrm{loc}}]\).

expect_and_grad(O, *, mutable=None, **kwargs)[source]#

Estimates the quantum expectation value and its gradient for a given operator \(O\).

Parameters:
  • O (AbstractOperator) – The operator \(O\) for which expectation value and gradient are computed.

  • mutable (Union[bool, str, Collection[str], DenyList, None]) –

    Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).

  • use_covariance – whether to use the covariance formula, usually reserved for hermitian operators, \(\textrm{Cov}[\partial\log\psi, O_{\textrm{loc}}\rangle]\)

Return type:

tuple[Stats, Any]

Returns:

An estimate of the quantum expectation value <O>. An estimate of the gradient of the quantum expectation value <O>.

expect_to_precision(op, *, atol=None, rtol=None, max_iter=10000, max_lag=64, verbose=True)[source]#

Sample until the standard error of \(\langle O \rangle\) meets the requested tolerance.

Warning

Experimental functionality. This method is subject to change without notice in future NetKet releases. If you find it useful (or not!), please let us know with a 👍 / 👎 on GitHub or on Slack.

Iteratively draws new batches of samples and updates an online statistics accumulator until the estimated standard error of the mean satisfies the requested absolute and/or relative tolerance, or until max_iter iterations are exhausted. A progress bar is shown by default.

At least one of atol or rtol must be provided. Sampling stops when error_of_mean atol + rtol * |mean|, with a missing tolerance treated as 0 (NumPy convention): with only atol the criterion is absolute, with only rtol it is relative, and with both, atol acts as an absolute floor that keeps the relative criterion well-behaved when the mean is close to zero.

Unlike expect(), this method modifies the state’s sampler in place (new samples are drawn on self directly), so the sampler state is advanced as a side effect.

Parameters:
  • op (AbstractOperator) – The operator \(O\) whose expectation value \(\langle O \rangle\) is estimated.

  • atol (float | None) – Desired absolute standard error of the mean.

  • rtol (float | None) – Desired relative standard error of the mean.

  • max_iter (int) – Maximum number of sampling iterations before stopping unconditionally.

  • max_lag (int) – Maximum lag used by the online autocorrelation estimator.

  • verbose (bool) – If True (default), display a tqdm progress bar showing the current error and tolerances.

Returns:

The final OnlineStatistics accumulator. Call .get_stats() on it to obtain a standard Stats object with mean, variance, and error of the mean.

Raises:

ValueError – If neither atol nor rtol is provided, or if the sampler is not a MetropolisSampler.

See also

netket._src.vqs.expect_to_precision.expect_to_precision()

grad(Ô, *, use_covariance=None, mutable=None)[source]#

Estimates the gradient of the quantum expectation value of a given operator O.

Parameters:
Returns:

An estimation of the average gradient of the quantum expectation value <O>.

Return type:

array

init(seed=None, dtype=None)[source]#

Initialises the variational parameters of the variational state.

init_parameters(init_fun=None, *, seed=None)[source]#

Re-initializes all the parameters with the provided initialization function, defaulting to the normal distribution of standard deviation 0.01.

Warning

The init function will not change the dtype of the parameters, which is determined by the model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION

Parameters:
  • init_fun (Callable[[Any, Sequence[int], Union[None, str, type[Any], dtype, _SupportsDType]], Array] | None) – a jax initializer such as jax.nn.initializers.normal(). Must be a Callable taking 3 inputs, the jax PRNG key, the shape and the dtype, and outputting an array with the valid dtype and shape. If left unspecified, defaults to jax.nn.initializers.normal(stddev=0.01)

  • seed (Any | None) – Optional seed to be used. The seed is synced across all JAX processes. If unspecified, uses a random seed.

local_estimators(op, *, chunk_size=None)[source]#

Compute the local estimators for the operator op (also known as local energies when op is the Hamiltonian) at the current configuration samples self.samples.

For standard operators the local estimator is the ratio

\[O_\mathrm{loc}(s) = \frac{\langle s | \hat{O} | \psi \rangle}{\langle s | \psi \rangle}\]

Return type and shape

The return type depends on the operator:

  • LocalEstimators — for operators that produce one scalar estimate per sample. result.data has shape (n_chains, chain_length).

  • LocalEstimatorsBatch — for nonlinear observables (e.g. VarianceObservable) that require K > 1 channels to form the final quantity. result.data has shape (n_chains, chain_length, K).

In both cases the underlying array is at .data; never treat the returned object as a plain JAX array.

Typical usage

One-shot statistics:

le    = vstate.local_estimators(H)
stats = le.to_stats()          # Stats with mean, error_of_mean, …

Online accumulation over multiple sampling steps:

acc = None
for _ in range(n_steps):
    vstate.sample(n_discard_per_chain=0)
    le  = vstate.local_estimators(H)
    acc = le.accumulate(acc)   # OnlineStats or OnlineStatsBatch
print(acc.get_stats())

Nonlinear observables return multiple channels plus a combining rule. For example, a variance observable returns a LocalEstimatorsBatch whose last axis stores [O_loc, (O^2)_loc] for each sample:

import jax.numpy as jnp
import netket as nk

var_op = nk.observable.VarianceObservable(H)
le = vstate.local_estimators(var_op)

O_loc = le.data[..., 0]
O2_loc = le.data[..., 1]

channel_means = jnp.mean(le.data, axis=(0, 1))
variance_mean = le.combinator(channel_means)

Here variance_mean matches le.to_stats().mean. Prefer netket.stats.LocalEstimatorsBatch.to_stats() if you also want the error bar, since it applies delta-method error propagation.

Warning

Samples differ between JAX processes, so the local estimators will take different values on each process. Use expect() (or to_stats() for multi-channel estimators) to obtain process-independent averages.

Parameters:
  • op (AbstractOperator) – The operator or observable.

  • chunk_size (int | None) – Maximum forward-pass chunk size. (Default: self.chunk_size)

log_value(σ)[source]#

Evaluate the variational state for a batch of states and returns the logarithm of the amplitude of the quantum state.

For pure states, this is \(\log(\langle\sigma|\psi\rangle)\), whereas for mixed states this is \(\log(\langle\sigma_r|\rho|\sigma_c\rangle)\), where \(\psi\) and \(\rho\) are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as σr=σ[::,0:N] and σc=σ[::,N:].

Given a batch of inputs (Nb, N), returns a batch of outputs (Nb,).

Return type:

Array

Parameters:

σ (Array)

quantum_geometric_tensor(qgt_T=None)[source]#

Computes an estimate of the quantum geometric tensor G_ij. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix.

Parameters:

qgt_T (Callable[[VariationalState], LinearOperator] | None) – the optional type of the quantum geometric tensor. By default it’s automatically selected.

Returns:

A linear operator representing the quantum geometric tensor.

Return type:

nk.optimizer.LinearOperator

reset()[source]#

Resets the sampled states. This method is called automatically every time that the parameters/state is updated.

sample(*, chain_length=None, n_samples=None, n_discard_per_chain=None)[source]#

Sample a certain number of configurations.

If one among chain_length or n_samples is defined, that number of samples are generated. Otherwise the value set internally is used.

Parameters:
  • chain_length (int | None) – The length of the markov chains.

  • n_samples (int | None) – The total number of samples across all JAX devices.

  • n_discard_per_chain (int | None) – Number of discarded samples at the beginning of the markov chain.

Return type:

Array

thermalise(op, *, min_chain_length=10, max_chain_length=100, rhat_tol=1.05, decay=0.9, patience=1, verbose=True, raise_on_failure=False)[source]#

Advance the Markov chains until they are thermalized (R̂ converged).

Unlike check_mc_convergence(), this function mutates state in-place: on return, state.sampler_state reflects the position of the chains after thermalisation. The sampler’s sweep_size is not changed.

The function monitors the Gelman-Rubin R̂ diagnostic computed from a sliding EMA window of recent batches (controlled by decay). Thermalisation is declared when R̂ < rhat_tol for patience consecutive iterations and at least min_chain_length samples/chain have been drawn.

Warning

Experimental functionality. This method is subject to change without notice in future NetKet releases.

Note

R̂ is unreliable when the total samples/chain accumulated so far is small (roughly < 50). With very short chains the between-chain variance is noisy and R̂ tends to be overestimated, so the function may not declare convergence until min_chain_length is satisfied even when the chains are already well-mixed. If max_chain_length is set too low (e.g. below 50 × chain_length) you may receive a failure warning even though the sampler is actually thermalized. In that case either increase max_chain_length or reduce min_chain_length.

Parameters:
  • op (AbstractOperator) – The operator whose local estimators are used to monitor convergence. An operator is required because computing R̂ needs per-chain scalar values, and local estimators are the only quantity that provides this. The operator does not need to be the one you ultimately care about — prefer a cheap observable (e.g. a single-site magnetisation \(\hat{\sigma}^z_0\), or the total magnetisation) over the full Hamiltonian, which may have many terms and be slow to evaluate. The convergence criterion is the same regardless of which operator you choose.

  • min_chain_length (int) – Minimum samples/chain before the convergence check is applied (default: 10).

  • max_chain_length (int) – Hard upper limit on samples/chain (default: 100). If R̂ has not converged by this limit a UserWarning is emitted (or a RuntimeError is raised when raise_on_failure=True). Make sure this is comfortably above min_chain_length; otherwise a false failure warning may be triggered before R̂ has had enough data to be meaningful.

  • rhat_tol (float) – R̂ threshold below which chains are considered mixed (default: 1.05).

  • decay (float) – EMA decay factor for the sliding-window R̂ (default: 0.9, effective window ≈ 10 batches). Lower values react faster to recent mixing but are noisier.

  • patience (int) – Number of consecutive iterations with R̂ < rhat_tol required before declaring convergence (default: 1).

  • verbose (bool) – If True, display a tqdm progress bar (default: True).

  • raise_on_failure (bool) – If True, raise RuntimeError on failure instead of emitting a UserWarning (default: False).

Returns:

A tuple (stats, hist_data) where stats is the final OnlineStats accumulator and hist_data is a HistoryDict recording the evolution of mean, variance, and R̂ across iterations.

See also

check_mc_convergence() — diagnose autocorrelation without mutating state.

to_array(normalize=True)[source]#

Returns the dense-vector representation of this state.

Parameters:

normalize (bool) – If True, the vector is normalized to have L2-norm 1.

Return type:

Array

Returns:

An exponentially large vector representing the state in the computational basis.

to_qobj()[source]#

Convert the variational state to a qutip’s ket Qobj.

Returns:

A qutip.Qobj object.