class netket.driver.AbstractVariationalDriver[source]#

Bases: ABC

Abstract base class for NetKet Variational Monte Carlo drivers

This class must be inherited from in order to create an optimization driver that immediately works with NetKet loggers and callback mechanism.


How to implement a new driver

For a concrete example, look at the file netket/driver/

If you want to inherit the nice interface of netket.driver.AbstractVariationalDriver, you should subclass it, and define the following methods:

  • The __init__() method should be called with the machine, optimizer and optionally the name of the loss minimised. If this driver is minimising a loss function and you want it’s name to show up automatically in the progress bar/output files you should pass the optional keyword argument.

  • _forward_and_backward(), that should compute the loss function and the gradient, returning the latter. If the driver is minimizing or maximising some loss function, this quantity should be assigned to the field self._loss_stats in order to monitor it.

  • _estimate_stats() should return the expectation value over the variational state of a single observable.

  • reset(), should reset the driver (usually the sampler). The basic implementation will call reset(), but you are responsible for resetting extra fields in the driver itself.

Inheritance diagram of netket.driver.AbstractVariationalDriver
__init__(variational_state, optimizer, minimized_quantity_name='loss')[source]#

Initializes a variational optimization driver.

  • variational_state (VariationalState) – The variational state to be optimized

  • optimizer (Any) – an optax optimizer. If you do not want to use an optimizer, just pass a sgd optimizer with learning rate -1.

  • minimized_quantity_name (str) – the name of the loss function in the logged data set.


The optimizer used to update the parameters at every iteration.


Returns the machine that is optimized by this driver.


Returns a monotonic integer labelling all the steps performed by this driver. This can be used, for example, to identify the line in a log file.


Performs steps optimization steps.


steps (int) – (Default=1) number of steps.


Return MCMC statistics for the expectation value of observables in the current state of the driver.


observables – A pytree of operators for which statistics should be computed.


A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.

iter(n_steps, step=1)[source]#

Returns a generator which advances the VMC optimization, yielding after every step_size steps.

  • n_steps (int) – The total number of steps to perform (this is equivalent to the length of the iterator)

  • step (int) – The number of internal steps the simulation is advanced between yielding from the iterator


int – The current step.


Resets the driver.

Subclasses should make sure to call super().reset() to ensure that the step count is set to 0.

run(n_iter, out=(), obs=None, step_size=1, show_progress=True, save_params_every=50, write_every=50, callback=<function AbstractVariationalDriver.<lambda>>, timeit=False)[source]#

Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.

It is possible to control more specifically what quantities are logged, when to stop the optimisation, or to execute arbitrary code at every step by specifying one or more callbacks, which are passed as a list of functions to the keyword argument callback.

Callbacks are functions that follow this signature:

def callback(step, log_data, driver) -> bool:
    return True/False

If a callback returns True, the optimisation continues, otherwise it is stopped. The log_data is a dictionary that can be modified in-place to change what is logged at every step. For example, this can be used to log additional quantities such as the acceptance rate of a sampler.

Loggers are specified as an iterable passed to the keyword argument out. If only a string is specified, this will create by default a nk.logging.JsonLog. To know about the output format check its documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.

  • n_iter (int) – the total number of iterations to be performed during this run.

  • out (Optional[Iterable[AbstractLog]]) – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.

  • obs (Optional[dict[str, AbstractObservable]]) – An iterable containing all observables that should be computed

  • step_size (int) – Every how many steps should observables be logged to disk (default=1)

  • callback (Callable[[int, dict, AbstractVariationalDriver], bool]) – Callable or list of callable callback functions to stop training given a condition

  • show_progress (bool) – If true displays a progress bar (default=True)

  • save_params_every (int) – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)

  • write_every (int) – Every how many steps the json data should be flushed to disk (ignored if logger is provided)

  • timeit (bool) – If True, provide timing information.


Updates the parameters of the machine using the optimizer in this driver


dp – the pytree containing the updates to the parameters