Source code for netket.experimental.driver.tdvp_schmitt

# Copyright 2020, 2021  The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Union, Optional

from functools import partial

import jax
import jax.numpy as jnp

from netket import stats
from netket.driver.vmc_common import info
from netket.operator import AbstractOperator
from netket.optimizer.qgt import QGTJacobianDense
from netket.optimizer.qgt.qgt_jacobian_dense import convert_tree_to_dense_format
from netket.vqs import VariationalState, VariationalMixedState, MCState
from netket.jax import tree_cast

from netket.experimental.dynamics import RKIntegratorConfig


from .tdvp_common import TDVPBaseDriver, odefun


class TDVPSchmitt(TDVPBaseDriver):
    r"""
    Variational time evolution based on the time-dependent variational principle which,
    when used with Monte Carlo sampling via :class:`netket.vqs.MCState`, is the time-dependent VMC
    (t-VMC) method.

    This driver, which only works with standard MCState variational states, uses the regularization
    procedure described in `M. Schmitt's PRL 125 <https://journals.aps.org/prl/pdf/10.1103/PhysRevLett.125.100503>`_ .

    With the force vector

    .. math::

        F_k=\langle \mathcal O_{\theta_k}^* E_{loc}^{\theta}\rangle_c

    and the quantum Fisher matrix

    .. math::

        S_{k,k'} = \langle \mathcal O_{\theta_k} (\mathcal O_{\theta_{k'}})^*\rangle_c

    and for real parameters :math:`\theta\in\mathbb R`, the TDVP equation reads

    .. math::

        q\big[S_{k,k'}\big]\theta_{k'} = -q\big[xF_k\big]

    Here, either :math:`q=\text{Re}` or :math:`q=\text{Im}` and :math:`x=1` for ground state
    search or :math:`x=i` (the imaginary unit) for real time dynamics.

    For ground state search a regularization controlled by a parameter :math:`\rho` can be included
    by increasing the diagonal entries and solving

    .. math::

        q\big[(1+\rho\delta_{k,k'})S_{k,k'}\big]\theta_{k'} = -q\big[F_k\big]

    The `TDVP` class solves the TDVP equation by computing a pseudo-inverse of :math:`S` via
    eigendecomposition yielding

    .. math::

        S = V\Sigma V^\dagger

    with a diagonal matrix :math:`\Sigma_{kk}=\sigma_k`
    Assuming that :math:`\sigma_1` is the smallest eigenvalue, the pseudo-inverse is constructed
    from the regularized inverted eigenvalues

    .. math::

        \tilde\sigma_k^{-1}=\frac{1}{\Big(1+\big(\frac{\epsilon_{SVD}}{\sigma_j/\sigma_1}\big)^6\Big)\Big(1+\big(\frac{\epsilon_{SNR}}{\text{SNR}(\rho_k)}\big)^6\Big)}

    with :math:`\text{SNR}(\rho_k)` the signal-to-noise ratio of
    :math:`\rho_k=V_{k,k'}^{\dagger}F_{k'}` (see
    `[arXiv:1912.08828] <https://arxiv.org/pdf/1912.08828.pdf>`_ for details).


    .. note::

        This TDVP Driver uses the time-integrators from the `nkx.dynamics` module, which are
        automatically executed under a `jax.jit` context.

        When running computations on GPU, this can lead to infinite hangs or extremely long
        compilation times. In those cases, you might try setting the configuration variable
        `nk.config.netket_experimental_disable_ode_jit = True` to mitigate those issues.

    """

[docs] def __init__( self, operator: AbstractOperator, variational_state: VariationalState, integrator: RKIntegratorConfig, *, t0: float = 0.0, propagation_type="real", holomorphic: Optional[bool] = None, diag_shift: float = 0.0, diag_scale: Optional[float] = None, error_norm: Union[str, Callable] = "qgt", rcond: float = 1e-14, rcond_smooth: float = 1e-8, snr_atol: float = 1, ): r""" Initializes the time evolution driver. Args: operator: The generator of the dynamics (Hamiltonian for pure states, Lindbladian for density operators). variational_state: The variational state. integrator: Configuration of the algorithm used for solving the ODE. t0: Initial time at the start of the time evolution. propagation_type: Determines the equation of motion: "real" for the real-time Schrödinger equation (SE), "imag" for the imaginary-time SE. error_norm: Norm function used to calculate the error with adaptive integrators. Can be either "euclidean" for the standard L2 vector norm :math:`w^\dagger w`, "maximum" for the maximum norm :math:`\max_i |w_i|` or "qgt", in which case the scalar product induced by the QGT :math:`S` is used to compute the norm :math:`\Vert w \Vert^2_S = w^\dagger S w` as suggested in PRL 125, 100503 (2020). Additionally, it possible to pass a custom function with signature :code:`norm(x: PyTree) -> float` which maps a PyTree of parameters :code:`x` to the corresponding norm. Note that norm is used in jax.jit-compiled code. holomorphic: a flag to indicate that the wavefunction is holomorphic. diag_shift: diagonal shift of the quantum geometric tensor (QGT) diag_scale: If not None rescales the diagonal shift of the QGT rcond : Cut-off ratio for small singular :math:`\sigma_k` values of the Quantum Geometric Tensor. For the purposes of rank determination, singular values are treated as zero if they are smaller than rcond times the largest singular value :code:`\sigma_{max}`. rcond_smooth : Smooth cut-off ratio for singular values of the Quantum Geometric Tensor. This regularization parameter used with a similar effect to `rcond` but with a softer curve. See :math:`\epsilon_{SVD}` in the formula above. snr_atol: Noise regularisation absolute tolerance, meaning that eigenvalues of the S matrix that have a signal to noise ratio above this quantity will be (soft) truncated. This is :math:`\epsilon_{SNR}` in the formulas above. """ self.propagation_type = propagation_type if isinstance(variational_state, VariationalMixedState): # assuming Lindblad Dynamics # TODO: support density-matrix imaginary time evolution if propagation_type == "real": self._loss_grad_factor = 1.0 else: raise ValueError( "only real-time Lindblad evolution is supported for " "mixed states" ) else: if propagation_type == "real": self._loss_grad_factor = -1.0j elif propagation_type == "imag": self._loss_grad_factor = -1.0 else: raise ValueError("propagation_type must be one of 'real', 'imag'") self.rcond = rcond self.rcond_smooth = rcond_smooth self.snr_atol = snr_atol self.diag_shift = diag_shift self.holomorphic = holomorphic self.diag_scale = diag_scale super().__init__( operator, variational_state, integrator, t0=t0, error_norm=error_norm )
[docs] def info(self, depth=0): lines = [ f"{name}: {info(obj, depth=depth + 1)}" for name, obj in [ ("generator ", self._generator_repr), ("integrator ", self._integrator), ("state ", self.state), ] ] return "\n{}".format(" " * 3 * (depth + 1)).join([str(self), *lines])
# Copyright notice: # The function `_impl` below includes lines copied from the jVMC repository # found at github.com/markusschmitt/vmc_jax and licensed according to # MIT License, Copyright (c) 2021 Markus Schmitt @partial(jax.jit, static_argnames=("n_samples")) def _impl(parameters, n_samples, E_loc, S, rhs_coeff, rcond, rcond_smooth, snr_atol): E = stats.statistics(E_loc) ΔE_loc = E_loc.reshape(-1, 1) - E.mean stack_jacobian = S.mode == "complex" O = S.O / jnp.sqrt(n_samples) # already divided by jnp.sqrt(n_s) if stack_jacobian: O = O.reshape(-1, 2, S.O.shape[-1]) O = O[:, 0, :] + 1j * O[:, 1, :] Sd = S.to_dense() ev, V = jnp.linalg.eigh(Sd) OEdata = O.conj() * ΔE_loc F = stats.sum(OEdata, axis=0) # Note: this implementation differs from Eq. 20 in Markus's paper, which I would # implement as `rho = mpi.mean(QEdata, axis=0)`. However, this is different from # changing the basis AFTER averaging over the samples, and leads to the wrong # normalisation of RHo. Q = jnp.tensordot(V.conj().T, O.T, axes=1).T QEdata = Q.conj() * ΔE_loc rho = V.conj().T @ F # Compute the SNR according to Eq. 21 snr = jnp.abs(rho) * jnp.sqrt(n_samples) / jnp.sqrt(stats.var(QEdata, axis=0)) # Discard eigenvalues below numerical precision ev_inv = jnp.where(jnp.abs(ev / ev[-1]) > rcond, 1.0 / ev, 0.0) # Set regularizer for singular value cutoff regularizer = 1.0 / (1.0 + (rcond_smooth / jnp.abs(ev / ev[-1])) ** 6) # Construct a soft cutoff based on the SNR regularizer2 = regularizer * (1.0 / (1.0 + (snr_atol / snr) ** 6)) # solve the linear system by hand eta_p = ev_inv * regularizer2 * rhs_coeff * rho # convert back to the parameter space update = V @ eta_p # remainder of the solution rmd = jnp.linalg.norm(Sd.dot(update) - rhs_coeff * F) / jnp.linalg.norm(F) y, reassemble = convert_tree_to_dense_format(parameters, S.mode) update_tree = reassemble(update if jnp.iscomplexobj(y) else update.real) # If parameters are real, then take only real part of the gradient (if it's complex) dw = tree_cast(update_tree, parameters) return E, dw, rmd, snr @odefun.dispatch def odefun_schmitt(state: MCState, self: TDVPSchmitt, t, w, *, stage=0): # noqa: F811 # pylint: disable=protected-access state.parameters = w state.reset() op_t = self.generator(t) E_loc = state.local_estimators(op_t) self._S = QGTJacobianDense( state, diag_shift=self.diag_shift, diag_scale=self.diag_scale, holomorphic=self.holomorphic, ) self._loss_stats, self._dw, self._rmd, self._snr = _impl( state.parameters, state.n_samples, E_loc, self._S, self._loss_grad_factor, self.rcond, self.rcond_smooth, self.snr_atol, ) if stage == 0: # TODO: This does not work with FSAL. self._last_qgt = self._S return self._dw @partial(jax.jit, static_argnums=(3, 4)) def _map_parameters(forces, parameters, loss_grad_factor, propagation_type, state_T): forces = jax.tree_map( lambda x, target: loss_grad_factor * x, forces, parameters, ) forces = tree_cast(forces, parameters) return forces