netket.vqs.MCState#
- class netket.vqs.MCState[source]#
Bases:
VariationalStateVariational State for a Variational Neural Quantum State.
The state is sampled according to the provided sampler.
- Inheritance

- __init__(sampler, model=None, *, n_samples=None, n_samples_per_rank=None, n_discard_per_chain=None, chunk_size=None, variables=None, init_fun=None, apply_fun=None, seed=None, sampler_seed=None, mutable=False, training_kwargs={})[source]#
Constructs the MCState.
- Parameters:
sampler (
Sampler) – The samplermodel – (Optional) The neural quantum state ansatz, encoded into a model. This should be a
flax.linen.Moduleinstance, or any other supported neural network framework. If not provided, you must specify init_fun and apply_fun.n_samples (
int|None) – the total number of samples across chains and processes when sampling (default=1000).n_samples_per_rank (
int|None) – the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None).n_discard_per_chain (
int|None) – number of discarded samples at the beginning of each monte-carlo chain (default=5, except for ‘direct’ samplers where it is 0).seed (
Union[int,Any,None]) – rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.sampler_seed (
Union[int,Any,None]) – rng seed used to initialise the sampler. Defaults to a random one.mutable (
Union[bool,str,Collection[str],DenyList]) – Name or list of names of mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See alsoflax.linen.Module.apply()documentation (default=False)init_fun (
Callable[[Any,Sequence[int],Union[None,str,type[Any],dtype,_SupportsDType]],Array] |None) – Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method.variables (
Any|None) – Optional initial value for the variables (parameters and model state) of the model.apply_fun (
Callable|None) – Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to model.apply(variables, σ). specify only if your network has a non-standard apply method.training_kwargs (
dict) – a dict containing the optional keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training.chunk_size (
int|None) – (Defaults to None) If specified, calculations are split into chunks where the neural network is evaluated at most onchunk_sizesamples at once. This does not change the mathematical results, but will trade a higher computational cost for lower memory cost.
- Attributes
- chain_length#
Length of the markov chain used for sampling configurations.
If running under JAX sharding, the total samples will be n_devices * chain_length * n_batches.
- chunk_size#
Suggested maximum size of the chunks used in forward and backward evaluations of the Neural Network model.
If your inputs are smaller than the chunk size this setting is ignored.
This can be used to lower the memory required to run a computation with a very high number of samples or on a very large lattice. Notice that inputs and outputs must still fit in memory, but the intermediate computations will now require less memory.
This option comes at an increased computational cost. While this cost should be negligible for large-enough chunk sizes, don’t use it unless you are memory bound!
This option is an hint: only some operations support chunking. If you perform an operation that is not implemented with chunking support, it will fall back to no chunking. To check if this happened, set the environment variable NETKET_DEBUG=1.
- hilbert#
The descriptor of the Hilbert space on which this variational state is defined.
- model#
Returns the model definition of this variational state.
When using model frameworks that encode the parameters directly into the model, such as equinox, bound
flax.linenmodules, orflax.nnx, this will return the model including the parameters.If you want access to the raw model without the parameters that is used internally by netket, use
MCState._modelinstead.
- model_state#
The optional PyTree with the mutable state of the model, which is not optimized.
- n_discard_per_chain#
Number of discarded samples at the beginning of the markov chain.
- n_parameters#
The total number of parameters in the model.
- n_samples#
The total number of samples generated at every sampling step.
- n_samples_per_rank#
The number of samples generated on every JAX device at every sampling step.
- parameters#
The pytree of the parameters of the model.
- sampler#
The Monte Carlo sampler used by this Monte Carlo variational state.
- samples#
Returns the set of cached samples.
The samples returned are guaranteed valid for the current state of the variational state. If no cached parameters are available, then they are sampled first and then cached.
To obtain a new set of samples either use
reset()orsample().
- variables#
The PyTree containing the parameters and state of the model, used when evaluating it.
- sampler_state: SamplerState#
The current state of the sampler.
- mutable: bool | str | Collection[str] | DenyList#
Specifies which collections in the model_state should be treated as mutable. Largely unused.
- Methods
- check_mc_convergence(op, *, min_chain_length=50, max_chain_length=500, plot=False)[source]#
Diagnose whether the Markov-chain sweep size is long enough to produce decorrelated samples for the expectation value of
op.Algorithm overview
The function operates on a temporary copy of
state_whose internal sweep size is reset to 1, exposing every elementary MC step as an individual sample row. Batches of local estimators are fed one at a time intoonline_statistics(), which maintains running estimates of the mean, variance, and the autocorrelation function (ACF) via the Geyer initial positive sequence (IPS) estimator.The loop continues until those conditions are met:
The ACF window is not saturated — i.e. the IPS found a non-positive consecutive pair, confirming that
max_lagwas large enough to capture the full tail of the ACF.The integrated autocorrelation time
τis reliable: the number of effective samples per chain is at least 50 (i.e.n_per_chain / τ ≥ 50).
Adaptive coarsening when the ACF window saturates
If the ACF window saturates (every consecutive pair
(ρ[2t]+ρ[2t+1])is positive up tomax_lag), the currentτestimate is merely a lower bound — the chains are too short or the sweep too fine. The algorithm then:doubles the internal sweep size (
sweep_size *= 2), thinning the Markov chain to make long-range correlations visible within the window;calls
thin_acf_by_2()to re-index the ACF accumulator for the coarser cadence; andcalls
expand_max_lag()to restore the lag window width so the next iterations can probe further.
This doubling is repeated as needed until the window is no longer saturated, or until
max_chain_lengthsamples/chain are exhausted.Final diagnosis
After convergence the correlation time is re-expressed in terms of raw MC steps (
τ_mc = τ_acf × final_sweep_size) and of the user’s original sweep units (τ_sweeps = τ_mc / orig_sweep_size). A sweep size is considered adequate whenτ_sweeps < 1, meaning that consecutive samples produced by the user’s MCState are effectively independent. The recommended minimum sweep size is2 τ_mcraw steps.- Parameters:
op (
AbstractOperator) – The operator whose local estimators are used to probe correlations.min_chain_length (
int) – Minimum number of samples per chain to accumulate before the convergence check is applied.plot (
bool) – Controls diagnostic figure output.False(default) skips the figure entirely.Trueshows it interactively when a display is available, or saves it tomc_convergence.pngwhen running non-interactively (e.g. a SLURM batch job — detected viasys.stdout.isatty()and the matplotlib backend). AstrorPathalways saves to that path.max_chain_length (
int) – Hard upper limit on samples per chain. The loop is terminated unconditionally once this many samples have been drawn.
- Returns:
A tuple
(stats, hist_data)wherestatsis the finalOnlineStatisticsaccumulator andhist_datais aHistoryDictrecording the evolution of key diagnostics (mean, error, R̂, τ) as the number of samples is increased.
See also
thermalise_mcmc()— advance chains to stationarity before measuring.
- expect(O)[source]#
Estimates the quantum expectation value for a given operator \(O\) or generic observable. In the case of a pure state \(\psi\) and an operator, this is \(\langle O\rangle= \langle \Psi|O|\Psi\rangle/\langle\Psi|\Psi\rangle\) otherwise for a mixed state \(\rho\), this is \(\langle O\rangle= \textrm{Tr}[\rho \hat{O}]/\textrm{Tr}[\rho]\).
- Parameters:
O (
AbstractOperator) – the operator or observable for which to compute the expectation value.- Return type:
- Returns:
An estimation of the quantum expectation value \(\langle O\rangle\).
- expect_and_forces(O, *, mutable=None)[source]#
Estimates the quantum expectation value and the corresponding force vector for a given operator O.
The force vector \(F_j\) is defined as the covariance of log-derivative of the trial wave function and the local estimators of the operator. For complex holomorphic states, this is equivalent to the expectation gradient \(\frac{\partial\langle O\rangle}{\partial(\theta_j)^\star} = F_j\). For real-parameter states, the gradient is given by \(\frac{\partial\partial_j\langle O\rangle}{\partial\partial_j\theta_j} = 2 \textrm{Re}[F_j]\).
- Parameters:
O (
AbstractOperator) – The operator O for which expectation value and force are computed.mutable (
Union[bool,str,Collection[str],DenyList,None]) – Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).
- Return type:
- Returns:
An estimate of the quantum expectation value <O>. An estimate of the force vector \(F_j = \textrm{Cov}[\partial_j\log\psi, O_{\textrm{loc}}]\).
- expect_and_grad(O, *, mutable=None, **kwargs)[source]#
Estimates the quantum expectation value and its gradient for a given operator \(O\).
- Parameters:
O (
AbstractOperator) – The operator \(O\) for which expectation value and gradient are computed.mutable (
Union[bool,str,Collection[str],DenyList,None]) –Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).
use_covariance – whether to use the covariance formula, usually reserved for hermitian operators, \(\textrm{Cov}[\partial\log\psi, O_{\textrm{loc}}\rangle]\)
- Return type:
- Returns:
An estimate of the quantum expectation value <O>. An estimate of the gradient of the quantum expectation value <O>.
- expect_to_precision(op, *, atol=None, rtol=None, max_iter=10000, max_lag=64, verbose=True)[source]#
Sample until the standard error of \(\langle O \rangle\) meets the requested tolerance.
Warning
Experimental functionality. This method is subject to change without notice in future NetKet releases. If you find it useful (or not!), please let us know with a 👍 / 👎 on GitHub or on Slack.
Iteratively draws new batches of samples and updates an online statistics accumulator until the estimated standard error of the mean satisfies the requested absolute and/or relative tolerance, or until
max_iteriterations are exhausted. A progress bar is shown by default.At least one of
atolorrtolmust be provided. Sampling stops whenerror_of_mean ≤ atol + rtol * |mean|, with a missing tolerance treated as 0 (NumPy convention): with onlyatolthe criterion is absolute, with onlyrtolit is relative, and with both,atolacts as an absolute floor that keeps the relative criterion well-behaved when the mean is close to zero.Unlike
expect(), this method modifies the state’s sampler in place (new samples are drawn onselfdirectly), so the sampler state is advanced as a side effect.- Parameters:
op (
AbstractOperator) – The operator \(O\) whose expectation value \(\langle O \rangle\) is estimated.atol (
float|None) – Desired absolute standard error of the mean.rtol (
float|None) – Desired relative standard error of the mean.max_iter (
int) – Maximum number of sampling iterations before stopping unconditionally.max_lag (
int) – Maximum lag used by the online autocorrelation estimator.verbose (
bool) – IfTrue(default), display atqdmprogress bar showing the current error and tolerances.
- Returns:
The final
OnlineStatisticsaccumulator. Call.get_stats()on it to obtain a standardStatsobject with mean, variance, and error of the mean.- Raises:
ValueError – If neither
atolnorrtolis provided, or if the sampler is not aMetropolisSampler.
See also
netket._src.vqs.expect_to_precision.expect_to_precision()
- grad(Ô, *, use_covariance=None, mutable=None)[source]#
Estimates the gradient of the quantum expectation value of a given operator O.
- Parameters:
op (
netket.operator.AbstractOperator) – the operator O.is_hermitian – optional override for whether to use or not the hermitian logic. By default it’s automatically detected.
use_covariance (bool | None)
mutable (bool | str | Collection[str] | DenyList | None)
- Returns:
An estimation of the average gradient of the quantum expectation value <O>.
- Return type:
array
- init(seed=None, dtype=None)[source]#
Initialises the variational parameters of the variational state.
- init_parameters(init_fun=None, *, seed=None)[source]#
Re-initializes all the parameters with the provided initialization function, defaulting to the normal distribution of standard deviation 0.01.
Warning
The init function will not change the dtype of the parameters, which is determined by the model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION
- Parameters:
init_fun (
Callable[[Any,Sequence[int],Union[None,str,type[Any],dtype,_SupportsDType]],Array] |None) – a jax initializer such asjax.nn.initializers.normal(). Must be a Callable taking 3 inputs, the jax PRNG key, the shape and the dtype, and outputting an array with the valid dtype and shape. If left unspecified, defaults tojax.nn.initializers.normal(stddev=0.01)seed (
Any|None) – Optional seed to be used. The seed is synced across all JAX processes. If unspecified, uses a random seed.
- local_estimators(op, *, chunk_size=None)[source]#
Compute the local estimators for the operator
op(also known as local energies whenopis the Hamiltonian) at the current configuration samplesself.samples.For standard operators the local estimator is the ratio
\[O_\mathrm{loc}(s) = \frac{\langle s | \hat{O} | \psi \rangle}{\langle s | \psi \rangle}\]Return type and shape
The return type depends on the operator:
LocalEstimators— for operators that produce one scalar estimate per sample.result.datahas shape(n_chains, chain_length).LocalEstimatorsBatch— for nonlinear observables (e.g.VarianceObservable) that require K > 1 channels to form the final quantity.result.datahas shape(n_chains, chain_length, K).
In both cases the underlying array is at
.data; never treat the returned object as a plain JAX array.Typical usage
One-shot statistics:
le = vstate.local_estimators(H) stats = le.to_stats() # Stats with mean, error_of_mean, …
Online accumulation over multiple sampling steps:
acc = None for _ in range(n_steps): vstate.sample(n_discard_per_chain=0) le = vstate.local_estimators(H) acc = le.accumulate(acc) # OnlineStats or OnlineStatsBatch print(acc.get_stats())
Nonlinear observables return multiple channels plus a combining rule. For example, a variance observable returns a
LocalEstimatorsBatchwhose last axis stores[O_loc, (O^2)_loc]for each sample:import jax.numpy as jnp import netket as nk var_op = nk.observable.VarianceObservable(H) le = vstate.local_estimators(var_op) O_loc = le.data[..., 0] O2_loc = le.data[..., 1] channel_means = jnp.mean(le.data, axis=(0, 1)) variance_mean = le.combinator(channel_means)
Here
variance_meanmatchesle.to_stats().mean. Prefernetket.stats.LocalEstimatorsBatch.to_stats()if you also want the error bar, since it applies delta-method error propagation.Warning
Samples differ between JAX processes, so the local estimators will take different values on each process. Use
expect()(orto_stats()for multi-channel estimators) to obtain process-independent averages.- Parameters:
op (
AbstractOperator) – The operator or observable.chunk_size (
int|None) – Maximum forward-pass chunk size. (Default:self.chunk_size)
- log_value(σ)[source]#
Evaluate the variational state for a batch of states and returns the logarithm of the amplitude of the quantum state.
For pure states, this is \(\log(\langle\sigma|\psi\rangle)\), whereas for mixed states this is \(\log(\langle\sigma_r|\rho|\sigma_c\rangle)\), where \(\psi\) and \(\rho\) are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as
σr=σ[::,0:N]andσc=σ[::,N:].Given a batch of inputs
(Nb, N), returns a batch of outputs(Nb,).- Return type:
Array- Parameters:
σ (Array)
- quantum_geometric_tensor(qgt_T=None)[source]#
Computes an estimate of the quantum geometric tensor G_ij. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix.
- Parameters:
qgt_T (
Callable[[VariationalState],LinearOperator] |None) – the optional type of the quantum geometric tensor. By default it’s automatically selected.- Returns:
A linear operator representing the quantum geometric tensor.
- Return type:
nk.optimizer.LinearOperator
- reset()[source]#
Resets the sampled states. This method is called automatically every time that the parameters/state is updated.
- sample(*, chain_length=None, n_samples=None, n_discard_per_chain=None)[source]#
Sample a certain number of configurations.
If one among chain_length or n_samples is defined, that number of samples are generated. Otherwise the value set internally is used.
- thermalise(op, *, min_chain_length=10, max_chain_length=100, rhat_tol=1.05, decay=0.9, patience=1, verbose=True, raise_on_failure=False)[source]#
Advance the Markov chains until they are thermalized (R̂ converged).
Unlike
check_mc_convergence(), this function mutatesstatein-place: on return,state.sampler_statereflects the position of the chains after thermalisation. The sampler’ssweep_sizeis not changed.The function monitors the Gelman-Rubin R̂ diagnostic computed from a sliding EMA window of recent batches (controlled by
decay). Thermalisation is declared when R̂ <rhat_tolforpatienceconsecutive iterations and at leastmin_chain_lengthsamples/chain have been drawn.Warning
Experimental functionality. This method is subject to change without notice in future NetKet releases.
Note
R̂ is unreliable when the total samples/chain accumulated so far is small (roughly < 50). With very short chains the between-chain variance is noisy and R̂ tends to be overestimated, so the function may not declare convergence until
min_chain_lengthis satisfied even when the chains are already well-mixed. Ifmax_chain_lengthis set too low (e.g. below 50 ×chain_length) you may receive a failure warning even though the sampler is actually thermalized. In that case either increasemax_chain_lengthor reducemin_chain_length.- Parameters:
op (
AbstractOperator) – The operator whose local estimators are used to monitor convergence. An operator is required because computing R̂ needs per-chain scalar values, and local estimators are the only quantity that provides this. The operator does not need to be the one you ultimately care about — prefer a cheap observable (e.g. a single-site magnetisation \(\hat{\sigma}^z_0\), or the total magnetisation) over the full Hamiltonian, which may have many terms and be slow to evaluate. The convergence criterion is the same regardless of which operator you choose.min_chain_length (
int) – Minimum samples/chain before the convergence check is applied (default:10).max_chain_length (
int) – Hard upper limit on samples/chain (default:100). If R̂ has not converged by this limit aUserWarningis emitted (or aRuntimeErroris raised whenraise_on_failure=True). Make sure this is comfortably abovemin_chain_length; otherwise a false failure warning may be triggered before R̂ has had enough data to be meaningful.rhat_tol (
float) – R̂ threshold below which chains are considered mixed (default:1.05).decay (
float) – EMA decay factor for the sliding-window R̂ (default:0.9, effective window ≈ 10 batches). Lower values react faster to recent mixing but are noisier.patience (
int) – Number of consecutive iterations with R̂ <rhat_tolrequired before declaring convergence (default:1).verbose (
bool) – IfTrue, display atqdmprogress bar (default:True).raise_on_failure (
bool) – IfTrue, raiseRuntimeErroron failure instead of emitting aUserWarning(default:False).
- Returns:
A tuple
(stats, hist_data)wherestatsis the finalOnlineStatsaccumulator andhist_datais aHistoryDictrecording the evolution of mean, variance, and R̂ across iterations.
See also
check_mc_convergence()— diagnose autocorrelation without mutating state.