Source code for netket.optimizer.qgt.default

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from functools import partial

import jax

import netket.jax as nkjax

from .qgt_jacobian_dense import QGTJacobianDense
from .qgt_jacobian_pytree import QGTJacobianPyTree
from .qgt_onthefly import QGTOnTheFly

from .. import solver as nk_solver_module

solvers = []

for solver in dir(nk_solver_module):
    # only add solvers, not random
    # useless things
    if solver[:2] == "__":
        continue
    else:
        solvers.append(getattr(nk_solver_module, solver))


def _is_dense_solver(solver: Any) -> bool:
    """
    Returns true if the solver is one of our known dense solvers
    """
    if isinstance(solver, partial):
        solver = solver.func

    if solver in solvers:
        return True

    return False


def default_qgt_matrix(variational_state, solver=False, **kwargs):
    """
    Determines default metric tensor depending on variational_state and solver
    """
    from netket.vqs import FullSumState

    if isinstance(variational_state, FullSumState):
        return partial(QGTJacobianPyTree, **kwargs)

    n_param_leaves = len(jax.tree_util.tree_leaves(variational_state.parameters))
    n_params = variational_state.n_parameters

    # those require dense matrix that is known to be faster for this qgt
    if _is_dense_solver(solver):
        return partial(QGTJacobianDense, **kwargs)

    # TODO: Remove this once all QGT support diag_scale.
    has_diag_rescale = kwargs.get("diag_scale") is not None

    # arbitrary heuristic: if the network's parameters has many leaves
    # (an rbm has 3) then JacobianDense might be faster
    # the numbers chosen below are rather arbitrary and should be tuned.
    if (n_param_leaves > 6 and n_params > 800) or has_diag_rescale:
        if nkjax.tree_ishomogeneous(variational_state.parameters):
            return partial(QGTJacobianDense, **kwargs)
        else:
            return partial(QGTJacobianPyTree, **kwargs)
    else:
        return partial(QGTOnTheFly, **kwargs)


[docs] class QGTAuto: """ Automatically select the 'best' Quantum Geometric Tensor computing format according to some rather untested heuristic. Args: variational_state: The variational State kwargs: are passed on to the QGT constructor. """ _last_vstate = None """Cached last variational state to skip logic to decide what type of QGT to chose. """ _last_matrix = None """ Cached last QGT. Used when vstate == _last_vstate """ _kwargs = {} """ Kwargs passed at construction. Used when constructing a QGT. """ def __init__(self, solver=None, **kwargs): self._solver = solver self._kwargs = kwargs def __call__(self, variational_state, *args, **kwargs): if self._last_vstate != variational_state: self._last_vstate = variational_state self._last_matrix = default_qgt_matrix( variational_state, solver=self._solver, **self._kwargs, **kwargs ) return self._last_matrix(variational_state, *args, **kwargs) def __repr__(self): return "QGTAuto()"