Source code for netket.optimizer.qgt.qgt_jacobian_pytree

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union
from functools import partial

import jax
from jax import numpy as jnp
from flax import struct

from netket.utils.types import PyTree, Array
from netket.utils import mpi, timing
import netket.jax as nkjax
from netket.nn import split_array_mpi

from ..linear_operator import LinearOperator, Uninitialized

from .common import check_valid_vector_type
from .qgt_jacobian_pytree_logic import mat_vec
from .qgt_jacobian_common import (
    sanitize_diag_shift,
    to_shift_offset,
    rescale,
)


[docs] @timing.timed def QGTJacobianPyTree( vstate=None, *, mode: Optional[str] = None, holomorphic: Optional[bool] = None, diag_shift=None, diag_scale=None, rescale_shift=None, chunk_size=None, **kwargs, ) -> "QGTJacobianPyTreeT": """ Semi-lazy representation of an S Matrix where the Jacobian O_k is precomputed and stored as a PyTree. The matrix of gradients O is computed on initialisation, but not S, which can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in the field `sr`. Numerical estimates of the QGT are usually ill-conditioned and require regularisation. The standard approach is to add a positive constant to the diagonal; alternatively, Becca and Sorella (2017) propose scaling this offset with the diagonal entry itself. NetKet allows using both in tandem: .. math:: S_{ii} \\mapsto S_{ii} + \\epsilon_1 S_{ii} + \\epsilon_2; :math:`\\epsilon_{1,2}` are specified using `diag_scale` and `diag_shift`, respectively. Args: vstate: The variational state mode: "real", "complex" or "holomorphic": specifies the implementation used to compute the jacobian. "real" discards the imaginary part of the output of the model. "complex" splits the real and imaginary part of the parameters and output. It works also for non holomorphic models. holomorphic works for any function assuming it's holomorphic or real valued. holomorphic: a flag to indicate that the function is holomorphic. diag_scale: Fractional shift :math:`\\epsilon_1` added to diagonal entries (see above). diag_shift: Constant shift :math:`\\epsilon_2` added to diagonal entries (see above). chunk_size: If supplied, overrides the chunk size of the variational state (useful for models where the backward pass requires more memory than the forward pass). """ if mode is not None and holomorphic is not None: raise ValueError("Cannot specify both `mode` and `holomorphic`.") if rescale_shift is not None and diag_scale is not None: raise ValueError("Cannot specify both `rescale_shift` and `diag_scale`.") if vstate is None: return partial( QGTJacobianPyTree, mode=mode, holomorphic=holomorphic, chunk_size=chunk_size, diag_shift=diag_shift, diag_scale=diag_scale, rescale_shift=rescale_shift, **kwargs, ) diag_shift, diag_scale = sanitize_diag_shift(diag_shift, diag_scale, rescale_shift) # TODO: Find a better way to handle this case from netket.vqs import FullSumState if isinstance(vstate, FullSumState): samples = split_array_mpi(vstate._all_states) pdf = split_array_mpi(vstate.probability_distribution()) else: samples = vstate.samples if samples.ndim >= 3: # use jit so that we can do it on global shared array samples = jax.jit(jax.lax.collapse, static_argnums=(1, 2))(samples, 0, 2) pdf = None # Choose sensible default mode if mode is None: mode = nkjax.jacobian_default_mode( vstate._apply_fun, vstate.parameters, vstate.model_state, samples, holomorphic=holomorphic, ) if chunk_size is None and hasattr(vstate, "chunk_size"): chunk_size = vstate.chunk_size shift, offset = to_shift_offset(diag_shift, diag_scale) jacobians = nkjax.jacobian( vstate._apply_fun, vstate.parameters, samples, vstate.model_state, mode=mode, pdf=pdf, chunk_size=chunk_size, dense=False, center=True, _sqrt_rescale=True, ) if offset is not None: ndims = 1 if mode != "complex" else 2 jacobians, scale = rescale(jacobians, offset, ndims=ndims) else: scale = None return QGTJacobianPyTreeT( O=jacobians, scale=scale, _params_structure=vstate.parameters, mode=mode, diag_shift=shift, **kwargs, )
@struct.dataclass class QGTJacobianPyTreeT(LinearOperator): """ Semi-lazy representation of an S Matrix behaving like a linear operator. The matrix of gradients O is computed on initialisation, but not S, which can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in the field `sr`. """ O: PyTree = Uninitialized """Centred gradients ΔO_ij = O_ij - <O_j> of the neural network, where O_ij = ∂log ψ(σ_i)/∂p_j, for all samples σ_i at given values of the parameters p_j Divided through with sqrt(#samples) to normalise S matrix If scale is not None, O_ij for is normalised to unit norm for each parameter j """ scale: Optional[PyTree] = None """If not None, contains 2-norm of each column of the gradient matrix, i.e., the sqrt of the diagonal elements of the S matrix """ mode: str = struct.field(pytree_node=False, default=Uninitialized) """Differentiation mode: - "real": for real-valued R->R and C->R Ansätze, splits the complex inputs into real and imaginary part. - "complex": for complex-valued R->C and C->C Ansätze, splits the complex inputs and outputs into real and imaginary part - "holomorphic": for any Ansätze. Does not split complex values. - "auto": autoselect real or complex. """ _params_structure: PyTree = struct.field(default=Uninitialized) """Parameters of the network. Its only purpose is to represent its own shape.""" _in_solve: bool = struct.field(pytree_node=False, default=False) """Internal flag used to signal that we are inside the _solve method and matmul should not take apart into real and complex parts the other vector""" def __matmul__(self, vec: Union[PyTree, Array]) -> Union[PyTree, Array]: return _matmul(self, vec) def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None) -> PyTree: """ Solve the linear system x=⟨S⟩⁻¹⟨y⟩ with the chosen iterative solver. Args: y: the vector y in the system above. x0: optional initial guess for the solution. Returns: x: the PyTree solving the system. info: optional additional information provided by the solver. Might be None if there are no additional information provided. """ return _solve(self, solve_fun, y, x0=x0) def to_dense(self) -> jnp.ndarray: """ Convert the lazy matrix representation to a dense matrix representation. Returns: A dense matrix representation of this S matrix. In R→R and R→C modes, real and imaginary parts of parameters get own rows/columns """ return _to_dense(self) def __repr__(self): return ( f"QGTJacobianPyTree(diag_shift={self.diag_shift}, " f"scale={self.scale}, mode={self.mode})" ) @jax.jit def _matmul( self: QGTJacobianPyTreeT, vec: Union[PyTree, Array] ) -> Union[PyTree, Array]: # Turn vector RHS into PyTree if hasattr(vec, "ndim"): _, unravel = nkjax.tree_ravel(self._params_structure) vec = unravel(vec) ravel = True else: ravel = False check_valid_vector_type(self._params_structure, vec) # Real-imaginary split RHS in R→R and R→C modes reassemble = None if self.mode != "holomorphic" and not self._in_solve: vec, reassemble = nkjax.tree_to_real(vec) if self.scale is not None: vec = jax.tree_util.tree_map(jnp.multiply, vec, self.scale) result = mat_vec(vec, self.O, self.diag_shift) if self.scale is not None: result = jax.tree_util.tree_map(jnp.multiply, result, self.scale) # Reassemble real-imaginary split as needed if reassemble is not None: result = reassemble(result) # Ravel PyTree back into vector as needed if ravel: result, _ = nkjax.tree_ravel(result) return result @jax.jit def _solve( self: QGTJacobianPyTreeT, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None ) -> PyTree: check_valid_vector_type(self._params_structure, y) # Real-imaginary split RHS in R→R and R→C modes if self.mode != "holomorphic": y, reassemble = nkjax.tree_to_real(y) if x0 is not None: x0, _ = nkjax.tree_to_real(x0) if self.scale is not None: y = jax.tree_util.tree_map(jnp.divide, y, self.scale) if x0 is not None: x0 = jax.tree_util.tree_map(jnp.multiply, x0, self.scale) # to pass the object LinearOperator itself down # but avoid rescaling, we pass down an object with # scale = None # mode=holomorphic to disable splitting the complex part unscaled_self = self.replace(scale=None, _in_solve=True) out, info = solve_fun(unscaled_self, y, x0=x0) if self.scale is not None: out = jax.tree_util.tree_map(jnp.divide, out, self.scale) # Reassemble real-imaginary split as needed if self.mode != "holomorphic": out = reassemble(out) return out, info @jax.jit def _to_dense(self: QGTJacobianPyTreeT) -> jnp.ndarray: O = self.O if self.mode == "complex": # I want to iterate across the samples and real/imaginary part O = jax.tree_util.tree_map(lambda x: x.reshape(-1, *x.shape[2:]), O) O = jax.vmap(lambda l: nkjax.tree_ravel(l)[0])(O) if self.scale is None: diag = jnp.eye(O.shape[1]) else: scale, _ = nkjax.tree_ravel(self.scale) O = O * scale[jnp.newaxis, :] diag = jnp.diag(scale**2) return mpi.mpi_sum_jax(O.T.conj() @ O)[0] + self.diag_shift * diag