netket.optimizer.SR#

class netket.optimizer.SR#

Bases: netket.optimizer.preconditioner.AbstractLinearPreconditioner

Stochastic Reconfiguration or Natural Gradient preconditioner for the gradient.

Constructs the structure holding the parameters for using the Stochastic Reconfiguration/Natural gradient method.

This preconditioner changes the gradient \(\nabla_i E\) such that the preconditioned gradient \(\Delta_j\) solves the system of equations

\[(S_{i,j} + \delta_{i,j}(\epsilon_1 S_{i,i} + \epsilon_2)) \Delta_{j} = \nabla_i E\]

Where \(S\) is the Quantum Geometric Tensor (or Fisher Information Matrix), preconditioned according to the diagonal scale \(\epsilon_1\) (diag_scale) and the diagonal shift \(epsilon_2\) (diag_shift). The default regularisation takes \(\epsilon_1=0\) and \(\epsilon_2=0.01\).

Depending on the arguments, an implementation is chosen. For details on all possible kwargs check the specific SR implementations in the documentation.

You can also construct one of those structures directly.

Parameters
  • qgt (Optional[Callable]) – The Quantum Geometric Tensor type to use.

  • solver (Callable) – The method used to solve the linear system. Must be a jax- jittable function taking as input a pytree and outputting a tuple of the solution and extra data.

  • diag_shift (Union[Any, Callable[[Union[Array, float, int]], Union[Array, float, int]]]) – (Default 0.01) Diagonal shift added to the S matrix. Can be a Scalar value, an optax schedule or a Callable function.

  • diag_scale (Union[Any, Callable[[Union[Array, float, int]], Union[Array, float, int]], None]) –

    (Default 0) Scale of the shift proportional to the diagonal of the S matrix added added to it. Can be a Scalar value, an optax schedule or a Callable function.

  • solver_restart (bool) – If False uses the last solution of the linear system as a starting point for the solution of the next (default=False).

  • holomorphic – boolean indicating if the ansatz is boolean or not. May speed up computations for models with complex-valued parameters.

Inheritance
Inheritance diagram of netket.optimizer.SR
__init__(qgt=None, solver=<function cg>, *, diag_shift=0.01, diag_scale=None, solver_restart=False, **kwargs)[source]#

Constructs the structure holding the parameters for using the Stochastic Reconfiguration/Natural gradient method.

Depending on the arguments, an implementation is chosen. For details on all possible kwargs check the specific SR implementations in the documentation.

You can also construct one of those structures directly.

Parameters
  • qgt (Optional[Callable]) – The Quantum Geometric Tensor type to use.

  • solver (Callable) – The method used to solve the linear system. Must be a jax- jittable function taking as input a pytree and outputting a tuple of the solution and extra data.

  • diag_shift (Union[Any, Callable[[Union[Array, float, int]], Union[Array, float, int]]]) –

    (Default 0.01) Diagonal shift added to the S matrix. Can be a Scalar value, an optax schedule or a Callable function.

  • diag_scale (Union[Any, Callable[[Union[Array, float, int]], Union[Array, float, int]], None]) –

    (Default 0) Scale of the shift proportional to the diagonal of the S matrix added added to it. Can be a Scalar value, an optax schedule or a Callable function.

  • solver_restart (bool) – If False uses the last solution of the linear system as a starting point for the solution of the next (default=False).

  • holomorphic – boolean indicating if the ansatz is boolean or not. May speed up computations for models with complex-valued parameters.

Attributes
diag_scale: Optional[Union[Any, Callable[[Union[jax.Array, float, int]], Union[jax.Array, float, int]]]] = None#
diag_shift: Union[Any, Callable[[Union[jax.Array, float, int]], Union[jax.Array, float, int]]] = 0.01#
info: Any = None#

Additional information returned by the solver when solving the last linear system.

qgt_constructor: Callable = None#
qgt_kwargs: dict = None#
solver_restart: bool = False#

If False uses the last solution of the linear system as a starting point for the solution of the next.

x0: Optional[PyTree] = None#

Solution of the last linear system solved.

solver: SolverT#

Function used to solve the linear system.

Methods
__call__(vstate, gradient, step=None)#

Call self as a function.

Return type

Any

Parameters
lhs_constructor(vstate, step=None)[source]#

This method does things

Parameters