netket.driver.VMC_SR#
- class netket.driver.VMC_SR[source]#
Bases:
AbstractOptimizationDriverEnergy minimization using Variational Monte Carlo (VMC) and Stochastic Reconfiguration/Natural Gradient Descent. This driver is mathematically equivalent to the standard
netket.driver.VMCwith the preconditionernetket.optimizer.SR(solver=netket.optimizer.solver.cholesky_with_fallback), but can easily switch between the standard and the kernel/minSR formulation of Natural Gradient Descent.The standard formulation computes the updates as:
\[\delta \theta = \tau (X^TX + \lambda \mathbb{I}_{N_P})^{-1} X^T E^{loc},\]where \(X \in R^{N_s \times N_p}\) is the Jacobian of the log-wavefunction, with \(N_p\) the number of parameters and \(N_s\) the number of samples. The vector \(E^{loc}\) is the centered local estimator for the local energies.
The kernel/minSR formulation computes the updates as:
\[\delta \theta = \tau X^T(XX^T + \lambda \mathbb{I}_{2N_s})^{-1} E^{loc},\]The regularization parameter \(\lambda\) is the diag_shift parameter of the driver, which can be a scalar or a schedule. The updates are then applied to the parameters using the optimizer which in general should be optax.sgd.
Matrix Inversion#
The matrix inversion of both methods is performed using a linear solver, which can be specified by the user. This must be a function, the
linear_solver_funargument, which has the following signature:linear_solver(A: Matrix, b: vector) -> tuple[jax.Array[vector], dict]
Where the vector is the solution and the dictionary may contain additional information about the solver or be None. The standard solver is based on the Cholesky decomposition
cholesky(), but any other solver from JAX, netket solvers or a custom-written one can be used.Natural Gradient Descent#
Stochastic Reconfiguration is equivalent to the Natural Gradient Descent method introduced by Amari 1998 in the context of neural network training, assuming that the natural metric of the space of wave-functions is the Fubini-Study metric. This was first studied by Stokes et Al 2019 and called quantum Natural Gradient Descent.
While stochastic reconfiguration has been heavily studied in the context of VMC, there is a vast literature in the Machine Learning community on the use of NGD, and tuning carefully the diag shift and the learning rate.
A very good introduction to the mathematics of Information Geometry and NGD is found in Bai et Al and further studied in Shrestha et Al 2022. From the Physicist point of view, a good discussion on the choice of the metric function (QGT vs Fisher Matrix) is found in Stokes et Al 2022 (section 4 in particoular). For a comprehensive review of the method, we suggest the review by Martens 2014.
Momentum / SPRING#
When momentum is used, this driver implements the SPRING optimizer in Goldshlager et Al. (2024) to accumulate previous updates for better approximation of the exact SR with no significant performance penalty.
momentum μ is a number between [0,1] that specifies the damping factor of the previous updates and works somewhat similarly to the beta parameter of ADAM. The difference is that rather than simply adding the damped previous update to the new update, SPRING uses the damped previous update to fill in the components of the SR direction that are not sampled by the current batch of walkers, resulting in a more accurate and less noisy estimate. Since SPRING only uses the previous update to fill in directions that are orthogonal to the current one, the maximum amplification of the step size in SPRING is \(A(\mu) = 1/\sqrt{1-μ^2}\) rather than \(1/(1-μ)\).
Thus the amplification is at most a factor of \(A(0.9)=2.3\) or \(A(0.99)=7.1\). ** Values that empirically work are around 0.8. **
Some progress has been made on theoretically analyzing this parameter, in particular Section 3 of Epperly et Al. demonstrates (albeit in a significantly simplified linear least-squares setting) that SPRING can be interpreted as iteratively estimating a regularized SR direction, with the amount of regularization proportional to the value of 1-momentum. Additional insights regarding the behavior of some SPRING-like algorithms, albeit still in the linear least-squares setting, are presented in Goldshlager et Al. (2025) .
Implementation details#
The kernel-trick/NTK based implementation can run with both a direct calculation of the jacobian (on_the_fly=False) or with a lazy evaluation of the NTK (on_the_fly=True). The latter is more computationally efficient for networks that reuse the parameters many times for every forward pass (convolutions, attention layers, but not dense layers…) and generally uses less memory.
However, the on the fly implementation relies on some JAX compiler behaviour, so it might at times have worse performance. We suggest you check on your specific model. For a more detailed explanation of the on-the-fly implementation of the NTK, we refer to Novak et Al 2022. The algorithm netket uses is the layer-wise jacobian contraction method (sec 3.2) of the manuscript.
The default choice is to use the
on_the_fly=Truemode.References
Stochastic Reconfiguration was originally introduced in the QMC field by Sorella. The method was later shown to be equivalent to the Natural Gradient Descent method introduced by Amari for the Fubini-Study metric.
The kernel trick which makes NGD/SR feasible in the large-parameter count limit was originally introduced to the field of NQS by Chen & Heyl under the name of minSR. Rende & Al proposed a simpler derivation in terms of the Kernel trick.
It’s interesting to note that those tricks were first mentioned by Ren & Goldfarb in the ML community.
When using Momentum you should cite G.Goldshlager et Al. (2024).
- Inheritance

- __init__(hamiltonian, optimizer, *, diag_shift, proj_reg=None, momentum=None, linear_solver=<function cholesky_with_fallback>, linear_solver_fn=Deprecated, variational_state, chunk_size_bwd=None, mode=None, use_ntk=None, on_the_fly=None)[source]#
Initialize the driver with the given arguments.
Warning
The optimizer should be an instance of optax.sgd. Other optimizers, while they might work, will not make mathematical sense in the context of the SR/NGD optimization.
- Parameters:
hamiltonian (
AbstractOperator) – The Hamiltonian of which the ground-state is to be found.optimizer (
Any) – The optimizer to use for the parameter updates. To perform proper SR/NGD optimization this should be an instance ofoptax.sgd(), but can be any other optimizer if you are brave.variational_state (
MCState) – The variational state to optimize.diag_shift (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]]]) – The diagonal regularization parameter \(\lambda\) for the QGT/NTK.proj_reg (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]],None]) – The regularization parameter for the projection of the updates. (This usually is not very important and can be left to None)momentum (
Union[Any,Callable[[Union[Array,ndarray,bool,number,bool,int,float,complex]],Union[Array,ndarray,bool,number,bool,int,float,complex]],None]) – (SPRING, disabled by default, read above for details) a number between [0,1] that specifies the damping factor of the previous updates and works somewhat similarly to the beta parameter of ADAM. The maximum amplification of the step size in SPRING is \(A(\mu)=1/\sqrt{1-μ^2}\) Thus the amplification is at most a factor of \(A(0.9)=2.3\) or \(A(0.99)=7.1\). Values aroundmomentum = 0.8empirically work well. (Defaults to None)linear_solver (
Callable[[Union[ndarray,Array],Union[ndarray,Array]],Union[ndarray,Array]]) –The linear solver function to use for the NGD solver. Defaults to
netket.optimizer.solver.cholesky_with_fallback(), which runs a Cholesky factorisation and automatically falls back topinv_smooth()if NaN or Inf values are detected in the result. Other available solvers:cholesky()is faster because it relies on LU decomposition instead of a full diagonalization, but it is more prone to numerical instabilities, especially with explicitly simmetrized networks or very singular QGT/NTKs. Often the issue is that your matrix is numerically not hermitian/positive semidefinite because of numerical errors, and this breaks the method. If you see NaNs in your weights during your optimization, this is likely the cause.pinv_smooth()(a smoothed variant ofpinv()) is considerably more stable and usually does not cause NaNs, but it is also considerably more expensive.cg()and other iterative solvers can sometimes be faster, but the quality of the solution is bad and they can lead to unpredictable step times because the number of CG iterations might vary with the condition number. While we used those a lot in the past, there is considerable evidence that they should be avoided (see Chen & Heyl Nature Physics).
In general, if you are not bottlenecked by the linear solver, it is a good idea to use the more reliable
pinv_smooth().mode (
JacobianMode|None) – The mode used to compute the jacobian of the variational state. Can be'real'or'complex'. Real can be used for real-valued wavefunctions with a sign, to truncate the arbitrary phase of the wavefunction. In complex mode, the QGT/NTK is concretized as a real-valued \(2N \times 2N\) where \(N\) is either the number of parameters or number of samples. If your wavefunctino is real (as is usually the case for fermionic hamiltonians) you should really set this to real.on_the_fly (
bool|None) – Whether to compute the QGT or NTK using lazy evaluation methods. This usually requires less memory. (Defaults to None, which will automatically chose the potentially best method).chunk_size_bwd (
int|None) – The number of rows of the NTK or of the Jacobian evaluated in a single sweep.use_ntk (
bool|None) – Wheter to compute the updates using the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT), aka switching between SR and minSR. (Defaults to None, which will automatically choose the best method)linear_solver_fn (Callable[[ndarray | Array, ndarray | Array], ndarray | Array] | DeprecatedArg)
- Attributes
- chunk_size_bwd#
Chunk size for backward-mode differentiation. This reduces memory pressure at a potential cost of higher computation time.
If computing the jacobian, the jacobian is computed in blocks of chunk_size_bwd rows. If computing the NTK lazily, this is the number of rows of NTK evaluated in a single sweep. The chunk size does not affect the result, up to numerical precision.
- mode#
The mode used to compute the jacobian of the variational state. Can be ‘real’, ‘complex’, or ‘onthefly’.
‘real’ mode truncates imaginary part of the wavefunction, useful for real-valued wf with a sign.
‘complex’ is the general implementation that always works.
onthefly uses a lazy implementation of the neural tangent kernel and does not compute the jacobian.
This internally uses
netket.jax.jacobian(). See that function for a more complete documentation.
- on_the_fly#
Whether to use a lazy implementation of th NTK or QGT which does not concretize the jacobian.
This usually requires less memory.
- optimizer#
The optimizer used to update the parameters at every iteration.
- state#
Returns the machine that is optimized by this driver.
- step_count#
Returns a monotonic integer labelling all the steps performed by this driver. This can be used, for example, to identify the line in a log file.
- use_ntk#
Whether to use the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT) to compute the update.
- diag_shift: Any | Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex]#
The diagonal shift \(\lambda\) in the curvature matrix.
This can be a scalar or a schedule. If it is a schedule, it should be a function that takes the current step as input and returns the value of the shift.
- proj_reg: Any | Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex]#
- momentum: bool#
Flag specifying whether to use momentum in the optimisation.
If True, the optimizer will use momentum to accumulate previous updates following the SPRING optimizer from G.Goldshlager, N.Abrahamsen and L.Lin to accumulate previous updates for better approximation of the exact SR with no significant performance penalty.
- Methods
- compute_loss_and_update()[source]#
Performs a step of the optimization driver, returning the PyTree of the gradients that will be optimized.
Concrete drivers must override this method.
Note
When implementing this function on a subclass, you must return the gradient which must match the pytree structure of the parameters of the variational state.
The gradient will then be passed on to the optimizer in order to update the parameters.
Moreover, if you are minimising a loss function you must set the field self._loss_stats with the current value of the loss function.
This will be logged to any logger during optimisation.
- Returns:
the update for the weights.
- estimate(observables, fullsum=False)[source]#
Return MCMC statistics for the expectation value of observables in the current state of the driver.
- Parameters:
observables – A pytree of operators for which statistics should be computed.
fullsum (bool)
- Returns:
A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- reset()[source]#
Deprecated since version 3.22: Use
reset_step()to reset the sampler state at the beginning of a step. Note that the oldreset()also resetstep_countto 0; this behaviour is no longer supported.
- reset_step(hard=False)[source]#
Resets the state of the driver at the beginning of a new step.
This method is called at the beginning of every step in the optimization.
- Parameters:
hard (
bool) – If True, the reset is a hard reset, resulting in a complete resampling even if resample_fractionNone. (is not)
- run(n_iter, out=(), obs=None, step_size=1, show_progress=True, save_params_every=50, write_every=50, callback=None, timeit=False, _graceful_keyboard_interrupt=True)[source]#
Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.
It is possible to control more specifically what quantities are logged, when to stop the optimisation, or to execute arbitrary code at every step by specifying one or more callbacks, which are passed as a list of functions to the keyword argument callback.
Callbacks are functions that follow this signature:
def callback(step, log_data, driver) -> bool: ... return True/False
If a callback returns True, the optimisation continues, otherwise it is stopped. The log_data is a dictionary that can be modified in-place to change what is logged at every step. For example, this can be used to log additional quantities such as the acceptance rate of a sampler.
Alternatively,
AbstractCallbacksubclasses can be used to hook into more stages of the loop. To stop the optimisation early from any callback hook, raiseStopRun: the driver will catch it, finalise all callbacks via theiron_run_endmethod, and return normally without propagating the exception.Loggers are specified as an iterable passed to the keyword argument out. If only a string is specified, this will create by default a
nk.logging.JsonLog. To know about the output format check its documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.- Parameters:
n_iter (
int) – the total number of iterations to be performed during this run.out (
Iterable[AbstractLog] |None) – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.obs (
dict[str,AbstractObservable] |None) – An iterable containing all observables that should be computedstep_size (
int) – Every how many steps should observables be logged to disk (default=1)callback (
Union[Callable[[int,dict,AbstractDriver],bool],AbstractCallback,None]) – Callable or list of callable callback functions to stop training given a conditionshow_progress (
bool) – If true displays a progress bar (default=True)save_params_every (
int) – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)write_every (
int) – Every how many steps the json data should be flushed to disk (ignored if logger is provided)timeit (
bool) – If True, provide timing information._graceful_keyboard_interrupt (
bool) – (Internal flag, defaults to True) If True, the driver will gracefully handle a KeyboardInterrupt, usually arising from doing ctrl-C, returning the current state of the simulation. If False, the KeyboardInterrupt will be raised as usual. This only has an effect when running in interactive mode.