Defining Custom Preconditioners#

NetKet calls gradient preconditioner that class of techniques that transform the gradient of the cost function in order to improve convergence properties before passing it to an optimiser like netket.optimizer.Sgd() or netket.optimizer.Adam().

Examples of gradient preconditioners are the Stochastic Reconfiguration (SR) method (also known as Natural Gradient Descent in ML literature), the Linear Method or second order Hessian-based optimisation.

We call those methods gradient preconditioners because they take as input the gradient of the cost function (e.g., the energy gradient in VMC ground state optimisation) and output a transformed gradient.

In the current version, NetKet provides Stochastic Reconfiguration (netket.optimizer.SR()) as a built-in method. It is also possible to define your own method. If you implement the API as outlined in this document, you will be able to use your own preconditioner for use in NetKet optimisation driver without issues.

Keep in mind that writing your own optimisation loop only requires writing about 10 lines of code and you are not forced to use NetKet’s drivers! We believe our design to be fairly modular and flexible and thus should be able to accommodate a wide variety of use-cases, but there will be algorithms that are hard to express within the boundaries of our API. Do not turn away from NetKet in those cases, but take all the pieces that you need and write your own optimisation loop!

The preconditioner interface#

Preconditioners must be implemented as a Callable object or function with the following signature:

def preconditioner(vstate: VariationalState, gradient: PyTree) -> PyTree:

The Callable must accept two (positional) inputs, where the first is the variational state itself and the latter is the current gradient, stored as a PyTree. The output of the preconditioner must be the transformed gradient stored as a PyTree.

This general API will allow you to implement any preconditioner and use it together with NetKet’s variational optimization drivers. However, note that any performance optimisation (such as calling jax.jit() on the code) will be your responsibility.

The LinearPreconditioner interface#

Several preconditioners, including the Stochastic Reconfiguration, transform the gradient by solving a linear system of equation.

\[ S \bf{x} + \bf{F} \]

where \( S \) is a linear operator, \( F \) is the gradient and the solution \(\bf{x}\) is the preconditioned gradient.

NetKet implements a basic interface called netket.optimizer.LinearPreconditioner to make it easier to implement this kind of solvers. It is especially tuned for the cases where \( S \) is a linear operator.

To construct netket.optimizer.LinearPreconditioner you must supply two objects: the lhs_constructor, which is a function or closure with signature (VariationalState)->LinearOperator that accepts one argument, the variational state, and constructs the linear operator associated with it. The other object is a linear solver method, that must accept the linear operator and the gradient and compute the solution. The gradient is always provided as a PyTree.

To give a clear example: in the case of the Stochastic Reconfiguration (SR) method, if we call \( \bf{F} \) the gradient of the energy and \( S \) the Quantum Geometric Tensor (also known as SR matrix), we need to solve the system of equation \( S d\bf{w} = F \) to compute the resulting gradient. The \( S \) matrix in this case is the lhs or LinearOperator of the preconditioner, while the function is any linear solver such as cholesky, jax.numpy.linalg.solve() or iterative solvers such as jax.scipy.sparse.linalg.cg().

As there are different ways to compute the \( S \) matrix, all with their different computational performance characteristics, and there are different solvers, we believe that this design makes the code more modular and easier to reason about.

When defining a preconditioner object you have two options: you can implement the bare API, which gives you maximum freedom but makes you responsible for all optimisations, or you can implement the ~netket.optimizer.LinearOperator interface, which constraints you a bit but will take care of a few performance optimisations.

Bare interface#

The bare-minimum API a preconditioner lhs must implement:

- It must be a class

- There must be a function to build it from a variational state. This function will be called with the variational state as the first positional argument.  This function must not necessarily be a method of the class.

- This class must have a `solve(self, function, gradient, *, x0=None)` method taking as argument the gradient to be preconditioned and must not error if a keyword argument `x0` is passed to it. `x0` is the output of `solve` the last time it has been called, and might be ignored if not needed. `function` is the function computing the preconditioner.

You can subclass the abstract base class netket.optimizer.PreconditionerObject to be sure that you are implementing the correct interface, but you are not obliged to subclass it.

When you implement such an interface you are left with maximum flexibility, however you will be responsible for jax.jiting all computational intensive methods (most likely solve).

def MyObject(variational_state):
    stuff_a, stuff_b = compute_stuff(variational_state)
    return MyObjectT(stuff_a, stuff_b)


class MyObjectT:
    def __init__(self, a,b):
        # setup this object

    def solve(self, preconditioner_function, y, *, x0=None):
        # prepare
        ...
        # compute
        return solve_fun(self, y, x0=x0)

Be warned that if you want to jax.jit() compile the solve method, as it is usually computationally intensive, you must either specify how to flatten and unflatten to a PyTree your MyObjectT, or you should mark it as a flax.struct.dataclass, which is a frozen dataclass which does that automatically. Since you cannot write the __init__ method for a frozen dataclass, we usually define a constructor function as shown above.

You might be asking why each object needs a solve method and is passed to the preconditioner function instead of the over way around. The reason for this is to invert the control: preconditioner_functions must obey a certain API, but even if they do, different objects might need to perform some different initialization to compute the precondition in a more efficient way. This architecture allows every object to run arbitrary logic before executing the preconditioner of choice. Particular examples of this approach can be seen by looking at the implementation of netket.optimizer.qgt.QGTJacobianDense() and netket.optimizer.qgt.QGTJacobianPyTree().

LinearOperator interface#

You can also subclass LinearOperator. A LinearOperator must be a flax dataclass, which is an immutable object (therefore after construction you cannot modify its attributes).

LinearOperators have several convenience methods, and they will act as matrices: you can right-multiply them by a PyTree vector or a dense vector. You can obtain their dense representation, and it will automatically jit-compile all computationally intensive operations.

To implement the LinearOperator interface you should implement the following methods:


def MyLinearOperator(variational_state):
    stuff_a, stuff_b = compute_stuff(variational_state)
    return MyLinearOperatorT(stuff_a, stuff_b)

@flax.struct.dataclass
class MyLinearOperatorT(nk.optimizer.LinearOperator):

    @jax.jit
    def __matmul__(self, y):
        # prepare
        ...
        # compute
        return result

    @jax.jit
    def to_dense(self):
        ...
        return dense_matrix

    #optional
    @jax.jit
    def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None) -> PyTree:
        #...
        return solution, solution_info

The bare minimum thing to implement is __matmul__, specifying how to multiply the linear operator by a pytree. You can also define a custom _solve method if you have some computationally intensive setup code you wish to run before executing a solve function (that will call matmul repeatedly). The _solve takes as first input the solve function, which is passed as a closure so it does not need to be marked as static (even though it is). The x0 is an optional argument which must be accepted but can be ignored, and it is the last previous solution to the linear system. Optionally, one can also define the to_dense method.

The preconditioner function API#

The preconditioner function must have the following signature:

def preconditioner_function(object, gradient, *, x0=None):
    #...
    return preconditioned_gradient, x0

The object that will be passed is the selected preconditioner object, previously constructed. The gradient is the gradient of the loss function to precondition. x0 is an optional initial condition that might be ignored.

The gradient might be a PyTree version or a dense ravelling of the PyTree. The result of the function should be a preconditioned gradient with the same format. Additional keyword argument can be present, and will in general be set through a closure or functools.partial, because this function will be called with the signature above. If you have a peculiar preconditioner, you can assume that preconditioner_function will be called only from your preconditioner object, but in general it is good practice respecting the interface above so that different functions can work with different objects.