# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from functools import partial
import jax
from jax import numpy as jnp
from flax import struct
from netket.utils.types import Scalar, PyTree
from netket.utils import mpi, timing
import netket.jax as nkjax
from netket.nn import split_array_mpi
from ..linear_operator import LinearOperator, Uninitialized
from .common import check_valid_vector_type
from .qgt_jacobian_common import (
sanitize_diag_shift,
to_shift_offset,
rescale,
)
[docs]
@timing.timed
def QGTJacobianDense(
vstate=None,
*,
mode: Optional[str] = None,
holomorphic: Optional[bool] = None,
diag_shift=None,
diag_scale=None,
rescale_shift=None,
chunk_size=None,
**kwargs,
) -> "QGTJacobianDenseT":
"""
Semi-lazy representation of an S Matrix where the Jacobian O_k is precomputed
and stored as a dense matrix.
The matrix of gradients O is computed on initialisation, but not S,
which can be computed by calling :code:`to_dense`.
The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in
the field `sr`.
Numerical estimates of the QGT are usually ill-conditioned and require
regularisation. The standard approach is to add a positive constant to the diagonal;
alternatively, Becca and Sorella (2017) propose scaling this offset with the
diagonal entry itself. NetKet allows using both in tandem:
.. math::
S_{ii} \\mapsto S_{ii} + \\epsilon_1 S_{ii} + \\epsilon_2;
:math:`\\epsilon_{1,2}` are specified using `diag_scale` and `diag_shift`,
respectively.
Args:
vstate: The variational state
mode: "real", "complex" or "holomorphic": specifies the implementation
used to compute the jacobian. "real" discards the imaginary part
of the output of the model. "complex" splits the real and imaginary
part of the parameters and output. It works also for non holomorphic
models. holomorphic works for any function assuming it's holomorphic
or real valued.
holomorphic: a flag to indicate that the function is holomorphic.
diag_scale: Fractional shift :math:`\\epsilon_1` added to diagonal entries (see above).
diag_shift: Constant shift :math:`\\epsilon_2` added to diagonal entries (see above).
chunk_size: If supplied, overrides the chunk size of the variational state
(useful for models where the backward pass requires more
memory than the forward pass).
"""
if mode is not None and holomorphic is not None:
raise ValueError("Cannot specify both `mode` and `holomorphic`.")
if rescale_shift is not None and diag_scale is not None:
raise ValueError("Cannot specify both `rescale_shift` and `diag_scale`.")
if vstate is None:
return partial(
QGTJacobianDense,
mode=mode,
holomorphic=holomorphic,
chunk_size=chunk_size,
diag_shift=diag_shift,
diag_scale=diag_scale,
rescale_shift=rescale_shift,
**kwargs,
)
diag_shift, diag_scale = sanitize_diag_shift(diag_shift, diag_scale, rescale_shift)
# TODO: Find a better way to handle this case
from netket.vqs import FullSumState
if isinstance(vstate, FullSumState):
samples = split_array_mpi(vstate._all_states)
pdf = split_array_mpi(vstate.probability_distribution())
else:
samples = vstate.samples
if samples.ndim >= 3:
# use jit so that we can do it on global shared array
samples = jax.jit(jax.lax.collapse, static_argnums=(1, 2))(samples, 0, 2)
pdf = None
if mode is None:
mode = nkjax.jacobian_default_mode(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
samples,
holomorphic=holomorphic,
)
if chunk_size is None and hasattr(vstate, "chunk_size"):
chunk_size = vstate.chunk_size
shift, offset = to_shift_offset(diag_shift, diag_scale)
jacobians = nkjax.jacobian(
vstate._apply_fun,
vstate.parameters,
samples,
vstate.model_state,
mode=mode,
pdf=pdf,
chunk_size=chunk_size,
dense=True,
center=True,
_sqrt_rescale=True,
)
if offset is not None:
ndims = 1 if mode != "complex" else 2
jacobians, scale = rescale(jacobians, offset, ndims=ndims)
else:
scale = None
pars_struct = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), vstate.parameters
)
return QGTJacobianDenseT(
O=jacobians,
scale=scale,
mode=mode,
_params_structure=pars_struct,
diag_shift=shift,
**kwargs,
)
@struct.dataclass
class QGTJacobianDenseT(LinearOperator):
"""
Semi-lazy representation of an S Matrix behaving like a linear operator.
The matrix of gradients O is computed on initialisation, but not S,
which can be computed by calling :code:`to_dense`.
The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in
the field `sr`.
"""
O: jnp.ndarray = Uninitialized
"""Gradients O_ij = ∂log ψ(σ_i)/∂p_j of the neural network
for all samples σ_i at given values of the parameters p_j
Average <O_j> subtracted for each parameter
Divided through with sqrt(#samples) to normalise S matrix
If scale is not None, columns normalised to unit norm
"""
scale: Optional[jnp.ndarray] = None
"""If not None, contains 2-norm of each column of the gradient matrix,
i.e., the sqrt of the diagonal elements of the S matrix
"""
mode: str = struct.field(pytree_node=False, default=Uninitialized)
"""Differentiation mode:
- "real": for real-valued R->R and C->R Ansätze, splits the complex inputs
into real and imaginary part.
- "complex": for complex-valued R->C and C->C Ansätze, splits the complex
inputs and outputs into real and imaginary part
- "holomorphic": for any Ansätze. Does not split complex values.
- "auto": autoselect real or complex.
"""
_in_solve: bool = struct.field(pytree_node=False, default=False)
"""Internal flag used to signal that we are inside the _solve method and matmul should
not take apart into real and complex parts the other vector"""
_params_structure: PyTree = struct.field(pytree_node=False, default=Uninitialized)
@jax.jit
def __matmul__(self, vec: Union[PyTree, jnp.ndarray]) -> Union[PyTree, jnp.ndarray]:
if not hasattr(vec, "ndim") and not self._in_solve:
check_valid_vector_type(self._params_structure, vec)
vec, reassemble = convert_tree_to_dense_format(
vec, self.mode, disable=self._in_solve
)
if self.scale is not None:
vec = vec * self.scale
result = mat_vec(vec, self.O, self.diag_shift)
if self.scale is not None:
result = result * self.scale
return reassemble(result)
@jax.jit
def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None) -> PyTree:
if not hasattr(y, "ndim"):
check_valid_vector_type(self._params_structure, y)
y, reassemble = convert_tree_to_dense_format(y, self.mode)
if x0 is not None:
x0, _ = convert_tree_to_dense_format(x0, self.mode)
if self.scale is not None:
x0 = x0 * self.scale
if self.scale is not None:
y = y / self.scale
# to pass the object LinearOperator itself down
# but avoid rescaling, we pass down an object with
# scale = None
unscaled_self = self.replace(scale=None, _in_solve=True)
out, info = solve_fun(unscaled_self, y, x0=x0)
if self.scale is not None:
out = out / self.scale
return reassemble(out), info
@jax.jit
def to_dense(self) -> jnp.ndarray:
"""
Convert the lazy matrix representation to a dense matrix representation.
Returns:
A dense matrix representation of this S matrix.
"""
if self.scale is None:
O = self.O
diag = jnp.eye(self.O.shape[-1])
else:
O = self.O * self.scale[jnp.newaxis, :]
diag = jnp.diag(self.scale**2)
# concatenate samples with real/Imaginary dimension
O = O.reshape(-1, O.shape[-1])
return mpi.mpi_sum_jax(O.conj().T @ O)[0] + self.diag_shift * diag
def __repr__(self):
return (
f"QGTJacobianDense(diag_shift={self.diag_shift}, "
f"scale={self.scale}, mode={self.mode})"
)
#################################################
##### QGT internal Logic #####
#################################################
def mat_vec(v: PyTree, O: PyTree, diag_shift: Scalar) -> PyTree:
w = O @ v
res = jnp.tensordot(w.conj(), O, axes=w.ndim).conj()
return mpi.mpi_sum_jax(res)[0] + diag_shift * v
def convert_tree_to_dense_format(vec, mode, *, disable=False):
"""
Converts an arbitrary PyTree/vector which might be real/complex
to the dense-(maybe-real)-vector used for QGTJacobian.
The format is dictated by the sequence of operations chosen by
`nk.jax.jacobian(..., dense=True)`. As `nk.jax.jacobian` first
converts the pytree of parameters to real and then concatenates
real and imaginary terms with a tree_ravel, we must do the same
in here.
"""
unravel = lambda x: x
reassemble = lambda x: x
if not disable:
if mode != "holomorphic":
vec, reassemble = nkjax.tree_to_real(vec)
if not hasattr(vec, "ndim"):
vec, unravel = nkjax.tree_ravel(vec)
return vec, lambda x: reassemble(unravel(x))