netket.driver.Infidelity_SR#

class netket.driver.Infidelity_SR[source]#

Bases: AbstractOptimizationDriver

Infidelity minimization with respect to a target state \(|\Phi\rangle\) (with possibly an operator \(U\) such that \(|\Phi\rangle \equiv U|\Phi\rangle\)) using Variational Monte Carlo (VMC) and Stochastic Reconfiguration/Natural Gradient Descent. The optimization is analogous to the one of netket.driver.VMC_SR for ground state. The infidelity \(I\) among the variational state \(|\Psi\rangle\) and the target state \(|\Phi\rangle\) corresponds to:

\[I = 1 - \frac{|\langle\Psi|\Phi\rangle|^2 }{ \langle\Psi|\Psi\rangle \langle\Phi|\Phi\rangle } = 1 - \frac{\langle\Psi|\hat{I}_{op}|\Psi\rangle }{ \langle\Psi|\Psi\rangle },\]

where:

\[\hat{I}_{op} = \frac{|\Phi\rangle\langle\Phi|}{\langle\Phi|\Phi\rangle}\]

is the projector onto the target state \(|\Phi\rangle\) which corresponds to an effective Hamiltonian. In this case, the effective local energy is \(H^{loc}(x) = \frac{\Phi(x)}{\Psi(x)} \mathbb{E}_{y \sim |\Phi(y)|^2}\left[\frac{\Psi(y)}{\Phi(y)}\right]\).

For details see Sinibaldi et al. and Gravina et al..

Inheritance
Inheritance diagram of netket.driver.Infidelity_SR
__init__(target_state, optimizer, *, operator=None, diag_shift, proj_reg=None, momentum=None, linear_solver=<function cholesky>, linear_solver_fn=Deprecated, variational_state=None, chunk_size_bwd=None, mode=None, use_ntk=None, on_the_fly=None)[source]#

Initialize the driver with the given arguments.

Warning

The optimizer should be an instance of optax.sgd. Other optimizers, while they might work, will not make mathematical sense in the context of the SR/NGD optimization.

Parameters:
  • target_state (VariationalState) – The target state \(|\Phi\rangle\) that must be matched.

  • optimizer (Any) – The optimizer to use for the parameter updates. To perform proper SR/NGD optimization this should be an instance of optax.sgd, but can be any other optimizer if you are brave.

  • operator (AbstractOperator | None) – The operator \(U\).

  • variational_state (VariationalState | None) – The variational state to optimize.

  • diag_shift (Union[Any, Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]]]) – The diagonal regularization parameter \(\lambda\) for the QGT/NTK.

  • proj_reg (Union[Any, Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]], None]) – The regularization parameter for the projection of the updates. (This usually is not very important and can be left to None)

  • momentum (Union[Any, Callable[[Union[Array, ndarray, bool, number, bool, int, float, complex]], Union[Array, ndarray, bool, number, bool, int, float, complex]], None]) – (SPRING, disabled by default, read above for details) a number between [0,1] that specifies the damping factor of the previous updates and works somewhat similarly to the beta parameter of ADAM. The maximum amplification of the step size in SPRING is \(A(\mu)=1/\sqrt{1-μ^2}\) Thus the amplification is at most a factor of \(A(0.9)=2.3\) or \(A(0.99)=7.1\). Values around momentum = 0.8 empirically work well. (Defaults to None)

  • linear_solver (Callable[[Union[ndarray, Array], Union[ndarray, Array]], Union[ndarray, Array]]) – The linear solver function to use for the NGD solver.

  • mode (JacobianMode | None) – The mode used to compute the jacobian of the variational state. Can be ‘real’ or ‘complex’. Real can be used for real-valued wavefunctions with a sign, to truncate the arbitrary phase of the wavefunction. This leads to lower computational cost.

  • on_the_fly (bool | None) – Whether to compute the QGT or NTK using lazy evaluation methods. This usually requires less memory. (Defaults to None, which will automatically chose the potentially best method).

  • chunk_size_bwd (int | None) – The number of rows of the NTK or of the Jacobian evaluated in a single sweep.

  • use_ntk (bool | None) – Wheter to compute the updates using the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT), aka switching between SR and minSR. (Defaults to None, which will automatically choose the best method)

  • linear_solver_fn (Callable[[ndarray | Array, ndarray | Array], ndarray | Array] | DeprecatedArg)

Attributes
chunk_size_bwd#

Chunk size for backward-mode differentiation. This reduces memory pressure at a potential cost of higher computation time.

If computing the jacobian, the jacobian is computed in blocks of chunk_size_bwd rows. If computing the NTK lazily, this is the number of rows of NTK evaluated in a single sweep. The chunk size does not affect the result, up to numerical precision.

mode#

The mode used to compute the jacobian of the variational state. Can be ‘real’, ‘complex’, or ‘onthefly’.

  • ‘real’ mode truncates imaginary part of the wavefunction, useful for real-valued wf with a sign.

  • ‘complex’ is the general implementation that always works.

  • onthefly uses a lazy implementation of the neural tangent kernel and does not compute the jacobian.

This internally uses netket.jax.jacobian(). See that function for a more complete documentation.

on_the_fly#

Whether

optimizer#

The optimizer used to update the parameters at every iteration.

state#

Returns the machine that is optimized by this driver.

step_count#

Returns a monotonic integer labelling all the steps performed by this driver. This can be used, for example, to identify the line in a log file.

use_ntk#

Whether to use the Neural Tangent Kernel (NTK) instead of the Quantum Geometric Tensor (QGT) to compute the update.

target_state: VariationalState#

|Phi angle.

Type:

The target variational state

Type:

math

operator: AbstractOperator#

Operator \(U\).

cv_coeff: float#

Optional control variate coefficient for variance reduction in Monte Carlo estimation (see Sinibaldi et al. <https://quantum-journal.org/papers/q-2023-10-10-1131/>). If None, no control variate is used. Default to the optimal value -0.5.

diag_shift: Any | Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex]#

The diagonal shift \(\lambda\) in the curvature matrix.

This can be a scalar or a schedule. If it is a schedule, it should be a function that takes the current step as input and returns the value of the shift.

proj_reg: Any | Callable[[Array | ndarray | bool | number | bool | int | float | complex], Array | ndarray | bool | number | bool | int | float | complex]#
momentum: bool#

Flag specifying whether to use momentum in the optimisation.

If True, the optimizer will use momentum to accumulate previous updates following the SPRING optimizer from G.Goldshlager, N.Abrahamsen and L.Lin to accumulate previous updates for better approximation of the exact SR with no significant performance penalty.

info: Any | None#

PyTree to pass on information from the solver,e.g, the quadratic model.

Methods
compute_loss_and_update()[source]#

Performs a step of the optimization driver, returning the PyTree of the gradients that will be optimized.

Concrete drivers must override this method.

Note

When implementing this function on a subclass, you must return the gradient which must match the pytree structure of the parameters of the variational state.

The gradient will then be passed on to the optimizer in order to update the parameters.

Moreover, if you are minimising a loss function you must set the field self._loss_stats with the current value of the loss function.

This will be logged to any logger during optimisation.

Returns:

the update for the weights.

estimate(observables, fullsum=False)[source]#

Return MCMC statistics for the expectation value of observables in the current state of the driver.

Parameters:
  • observables – A pytree of operators for which statistics should be computed.

  • fullsum (bool)

Returns:

A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.

replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)

reset()[source]#

Deprecated since version 3.22: Use reset_step() to reset the sampler state at the beginning of a step. Note that the old reset() also reset step_count to 0; this behaviour is no longer supported.

reset_step(hard=False)[source]#

Resets the state of the driver at the beginning of a new step.

This method is called at the beginning of every step in the optimization.

Parameters:
  • hard (bool) – If True, the reset is a hard reset, resulting in a complete resampling even if resample_fraction

  • None. (is not)

run(n_iter, out=(), obs=None, step_size=1, show_progress=True, save_params_every=50, write_every=50, callback=None, timeit=False, _graceful_keyboard_interrupt=True)[source]#

Runs this variational driver, updating the weights of the network stored in this driver for n_iter steps and dumping values of the observables obs in the output logger.

It is possible to control more specifically what quantities are logged, when to stop the optimisation, or to execute arbitrary code at every step by specifying one or more callbacks, which are passed as a list of functions to the keyword argument callback.

Callbacks are functions that follow this signature:

def callback(step, log_data, driver) -> bool:
    ...
    return True/False

If a callback returns True, the optimisation continues, otherwise it is stopped. The log_data is a dictionary that can be modified in-place to change what is logged at every step. For example, this can be used to log additional quantities such as the acceptance rate of a sampler.

Alternatively, AbstractCallback subclasses can be used to hook into more stages of the loop. To stop the optimisation early from any callback hook, raise StopRun: the driver will catch it, finalise all callbacks via their on_run_end method, and return normally without propagating the exception.

Loggers are specified as an iterable passed to the keyword argument out. If only a string is specified, this will create by default a nk.logging.JsonLog. To know about the output format check its documentation. The logger object is also returned at the end of this function so that you can inspect the results without reading the json output.

Parameters:
  • n_iter (int) – the total number of iterations to be performed during this run.

  • out (Iterable[AbstractLog] | None) – A logger object, or an iterable of loggers, to be used to store simulation log and data. If this argument is a string, it will be used as output prefix for the standard JSON logger.

  • obs (dict[str, AbstractObservable] | None) – An iterable containing all observables that should be computed

  • step_size (int) – Every how many steps should observables be logged to disk (default=1)

  • callback (Union[Callable[[int, dict, AbstractDriver], bool], AbstractCallback, None]) – Callable or list of callable callback functions to stop training given a condition

  • show_progress (bool) – If true displays a progress bar (default=True)

  • save_params_every (int) – Every how many steps the parameters of the network should be serialized to disk (ignored if logger is provided)

  • write_every (int) – Every how many steps the json data should be flushed to disk (ignored if logger is provided)

  • timeit (bool) – If True, provide timing information.

  • _graceful_keyboard_interrupt (bool) – (Internal flag, defaults to True) If True, the driver will gracefully handle a KeyboardInterrupt, usually arising from doing ctrl-C, returning the current state of the simulation. If False, the KeyboardInterrupt will be raised as usual. This only has an effect when running in interactive mode.

update_parameters(dp)[source]#

Updates the parameters of the machine using the optimizer in this driver.

Parameters:

dp – the pytree containing the updates to the parameters.