Source code for netket.experimental.driver.tdvp

# Copyright 2020, 2021  The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union
from functools import partial

import jax

import netket as nk
from netket.driver.vmc_common import info
from netket.operator import AbstractOperator
from netket.optimizer import LinearOperator
from netket.optimizer.qgt import QGTAuto
from netket.vqs import (
    VariationalState,
    VariationalMixedState,
    MCState,
    FullSumState,
)
from netket.jax import tree_cast

from netket.experimental.dynamics import RKIntegratorConfig

from .tdvp_common import TDVPBaseDriver, odefun


class TDVP(TDVPBaseDriver):
    """
    Variational time evolution based on the time-dependent variational principle which,
    when used with Monte Carlo sampling via :class:`netket.vqs.MCState`, is the time-dependent VMC
    (t-VMC) method.

    .. note::
        This TDVP Driver uses the time-integrators from the `nkx.dynamics` module, which are
        automatically executed under a `jax.jit` context.

        When running computations on GPU, this can lead to infinite hangs or extremely long
        compilation times. In those cases, you might try setting the configuration variable
        `nk.config.netket_experimental_disable_ode_jit = True` to mitigate those issues.

    """

[docs] def __init__( self, operator: AbstractOperator, variational_state: VariationalState, integrator: RKIntegratorConfig, *, t0: float = 0.0, propagation_type="real", qgt: LinearOperator = None, linear_solver=nk.optimizer.solver.pinv_smooth, linear_solver_restart: bool = False, error_norm: Union[str, Callable] = "euclidean", ): r""" Initializes the time evolution driver. Args: operator: The generator of the dynamics (Hamiltonian for pure states, Lindbladian for density operators). variational_state: The variational state. integrator: Configuration of the algorithm used for solving the ODE. t0: Initial time at the start of the time evolution. propagation_type: Determines the equation of motion: "real" for the real-time Schrödinger equation (SE), "imag" for the imaginary-time SE. qgt: The QGT specification. linear_solver: The solver for solving the linear system determining the time evolution. This must be a jax-jittable function :code:`f(A,b) -> x` that accepts a Matrix-like, Linear Operator PyTree object :math:`A` and a vector-like PyTree :math:`b` and returns the PyTree :math:`x` solving the system :math:`Ax=b`. Defaults to :func:`nk.optimizer.solver.pinv_smooth` with the default svd threshold of 1e-10. To change the svd threshold you can use :func:`functools.partial` as follows: :code:`functools.partial(nk.optimizer.solver.pinv_smooth, rcond_smooth=1e-5)`. linear_solver_restart: If False (default), the last solution of the linear system is used as initial value in subsequent steps. error_norm: Norm function used to calculate the error with adaptive integrators. Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`, "maximum" for the maximum norm :math:`\max_i |w_i|` or "qgt", in which case the scalar product induced by the QGT :math:`S` is used to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested in PRL 125, 100503 (2020). Additionally, it possible to pass a custom function with signature :code:`norm(x: PyTree) -> float` which maps a PyTree of parameters :code:`x` to the corresponding norm. Note that norm is used in jax.jit-compiled code. """ if qgt is None: qgt = QGTAuto(solver=linear_solver) self.propagation_type = propagation_type if isinstance(variational_state, VariationalMixedState): # assuming Lindblad Dynamics # TODO: support density-matrix imaginary time evolution if propagation_type == "real": self._loss_grad_factor = 1.0 else: raise ValueError( "only real-time Lindblad evolution is supported for " "mixed states" ) else: if propagation_type == "real": self._loss_grad_factor = -1.0j elif propagation_type == "imag": self._loss_grad_factor = -1.0 else: raise ValueError("propagation_type must be one of 'real', 'imag'") self.qgt = qgt self.linear_solver = linear_solver self.linear_solver_restart = linear_solver_restart super().__init__( operator, variational_state, integrator, t0=t0, error_norm=error_norm )
[docs] def info(self, depth=0): lines = [ f"{name}: {info(obj, depth=depth + 1)}" for name, obj in [ ("generator ", self._generator_repr), ("integrator ", self._integrator), ("linear solver ", self.linear_solver), ("state ", self.state), ] ] return "\n{}".format(" " * 3 * (depth + 1)).join([str(self), *lines])
@odefun.dispatch def odefun_tdvp( # noqa: F811 state: Union[MCState, FullSumState], driver: TDVP, t, w, *, stage=0 ): # pylint: disable=protected-access state.parameters = w state.reset() op_t = driver.generator(t) driver._loss_stats, driver._loss_forces = state.expect_and_forces( op_t, ) driver._loss_grad = _map_parameters( driver._loss_forces, state.parameters, driver._loss_grad_factor, driver.propagation_type, type(state), ) qgt = driver.qgt(driver.state) if stage == 0: # TODO: This does not work with FSAL. driver._last_qgt = qgt initial_dw = None if driver.linear_solver_restart else driver._dw driver._dw, _ = qgt.solve(driver.linear_solver, driver._loss_grad, x0=initial_dw) # If parameters are real, then take only real part of the gradient (if it's complex) driver._dw = tree_cast(driver._dw, state.parameters) return driver._dw @partial(jax.jit, static_argnums=(3, 4)) def _map_parameters(forces, parameters, loss_grad_factor, propagation_type, state_T): forces = jax.tree_map( lambda x, target: loss_grad_factor * x, forces, parameters, ) forces = tree_cast(forces, parameters) return forces