class netket.experimental.driver.VMC_SRt#

Bases: VMC

Energy minimization using Variational Monte Carlo (VMC) and the kernel formulation of Stochastic Reconfiguration (SR). This approach lead to exactly the same parameter updates of the standard SR with a diagonal shift regularization. For this reason, it is equivalent to the standard nk.driver.VMC with the preconditioner nk.optimizer.SR(solver=netket.optimizer.solver.solvers.solve)). In the kernel SR framework, the updates of the parameters can be written as:

\[\delta \theta = \tau X(X^TX + \lambda \mathbb{I}_{2M})^{-1} f,\]

where \(X \in R^{P \times 2M}\) is the concatenation of the real and imaginary part of the centered Jacobian, with P the number of parameters and M the number of samples. The vector f is the concatenation of the real and imaginary part of the centered local energy. Note that, to compute the updates, it is sufficient to invert an \(M\times M\) matrix instead of a \(P\times P\) one. As a consequence, this formulation is useful in the typical deep learning regime where \(P \gg M\).

See R.Rende, L.L.Viteritti, L.Bardone, F.Becca and S.Goldt for a detailed description of the derivation. A similar result can be obtained by minimizing the Fubini-Study distance with a specific constrain, see A.Chen and M.Heyl for details.

Inheritance diagram of netket.experimental.driver.VMC_SRt
__init__(hamiltonian, optimizer, *, diag_shift, linear_solver_fn=<function <lambda>>, jacobian_mode=None, variational_state=None)[source]#

Initializes the driver class.

  • hamiltonian (AbstractOperator) – The Hamiltonian of the system.

  • optimizer – Determines how optimization steps are performed given the bare energy gradient.

  • diag_shift (Union[Any, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – The diagonal shift of the stochastic reconfiguration matrix. Typical values are 1e-4 Γ· 1e-3. Can also be an optax schedule.

  • hamiltonian – The Hamiltonian of the system.

  • linear_solver_fn (Callable[[Array, Array], Array]) – Callable to solve the linear problem associated to the updates of the parameters

  • jacobian_mode (Optional[str]) – The mode used to compute the jacobian of the variational state. Can be β€˜real’ or β€˜complex’ (defaults to the dtype of the output of the model).

  • variational_state (MCState) – The netket.vqs.MCState to be optimised. Other variational states are not supported.


Return MCMC statistics for the expectation value of observables in the current state of the driver.


The mode used to compute the jacobian of the variational state. Can be β€˜real’ or β€˜complex’.

Real mode truncates imaginary part of the wavefunction, while complex does not. This internally uses netket.jax.jacobian(). See that function for a more complete documentation.


The optimizer used to update the parameters at every iteration.


The preconditioner used to modify the gradient.

This is a function with the following signature

precondtioner(vstate: VariationalState,
              grad: PyTree,
              step: Optional[Scalar] = None)

Where the first argument is a variational state, the second argument is the PyTree of the gradient to precondition and the last optional argument is the step, used to change some parameters along the optimisation.

Often, this is taken to be SR(). If it is set to None, then the identity is used.


Returns the machine that is optimized by this driver.


Returns a monotonic integer labelling all the steps performed by this driver. This can be used, for example, to identify the line in a log file.


Performs steps optimization steps.


steps (int) – (Default=1) number of steps.


Return MCMC statistics for the expectation value of observables in the current state of the driver.


observables – A pytree of operators for which statistics should be computed.


A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.


Returns an info string used to print information to screen about this driver.

iter(n_steps, step=1)#

Returns a generator which advances the VMC optimization, yielding after every step_size steps.

  • n_steps (int) – The total number of steps to perform (this is equivalent to the length of the iterator)

  • step (int) – The number of internal steps the simulation is advanced between yielding from the iterator


int – The current step.


Resets the driver.

Subclasses should make sure to call super().reset() to ensure that the step count is set to 0.

run(n_iter, out=None, obs=None, show_progress=True, save_params_every=50, write_every=50, step_size=1, callback=<function AbstractVariationalDriver.<lambda>>)#

Executes the Monte Carlo Variational optimization, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger. If no logger is specified, creates a json file at out, overwriting files with the same prefix.

By default uses nk.logging.JsonLog. To know about the output format check it’s documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.

  • n_iter – the total number of iterations

  • out – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.

  • obs – An iterable containing all observables that should be computed

  • save_params_every – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)

  • write_every – Every how many steps the json data should be flushed to disk (ignored if logger is provided)

  • step_size – Every how many steps should observables be logged to disk (default=1)

  • show_progress – If true displays a progress bar (default=True)

  • callback – Callable or list of callable callback functions to stop training given a condition


Updates the parameters of the machine using the optimizer in this driver


dp – the pytree containing the updates to the parameters