# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
from functools import partial
import warnings
import jax
from jax import numpy as jnp
from flax import struct
import netket.jax as nkjax
from netket.utils import timing
from netket.utils.types import PyTree
from netket.errors import (
IllegalHolomorphicDeclarationForRealParametersError,
NonHolomorphicQGTOnTheFlyDenseRepresentationError,
HolomorphicUndeclaredWarning,
)
from netket.nn import split_array_mpi
from .common import check_valid_vector_type
from .qgt_onthefly_logic import mat_vec_factory, mat_vec_chunked_factory
from ..linear_operator import LinearOperator, Uninitialized
[docs]
@timing.timed
def QGTOnTheFly(
vstate=None, *, chunk_size=None, holomorphic: Optional[bool] = None, **kwargs
) -> "QGTOnTheFlyT":
"""
Lazy representation of an S Matrix computed by performing 2 jvp
and 1 vjp products, using the variational state's model, the
samples that have already been computed, and the vector.
The S matrix is not computed yet, but can be computed by calling
:code:`to_dense`.
The details on how the âšSâ©â»ÂčâšFâ© system is solved are contained in
the field `sr`.
Args:
vstate: The variational State.
chunk_size: If supplied, overrides the chunk size of the variational state
(useful for models where the backward pass requires more
memory than the forward pass).
"""
if vstate is None:
return partial(
QGTOnTheFly, chunk_size=chunk_size, holomorphic=holomorphic, **kwargs
)
if kwargs.pop("diag_scale", None) is not None:
raise NotImplementedError(
"\n`diag_scale` argument is not yet supported by QGTOnTheFly."
"Please use `QGTJacobianPyTree` or `QGTJacobianDense`.\n\n"
"You are also encouraged to nag the developers to support "
"this feature.\n\n"
)
# TODO: Find a better way to handle this case
from netket.vqs import FullSumState
if isinstance(vstate, FullSumState):
samples = split_array_mpi(vstate._all_states)
pdf = split_array_mpi(vstate.probability_distribution())
else:
samples = vstate.samples
if samples.ndim >= 3:
# use jit so that we can do it on global shared array
samples = jax.jit(jax.lax.collapse, static_argnums=(1, 2))(samples, 0, 2)
pdf = None
if chunk_size is None and hasattr(vstate, "chunk_size"):
chunk_size = vstate.chunk_size
n_samples = samples.shape[0]
if chunk_size is None or chunk_size >= n_samples:
mv_factory = mat_vec_factory
chunking = False
else:
mv_factory = partial(mat_vec_chunked_factory, chunk_size=chunk_size)
chunking = True
# check if holomorphic or not
if holomorphic:
if nkjax.tree_leaf_isreal(vstate.parameters):
raise IllegalHolomorphicDeclarationForRealParametersError()
else:
mode = "holomorphic"
else:
if not nkjax.tree_leaf_iscomplex(vstate.parameters):
mode = "real"
else:
if holomorphic is None:
warnings.warn(HolomorphicUndeclaredWarning(), UserWarning)
mode = "complex"
nkjax.jacobian_default_mode(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
samples,
holomorphic=holomorphic,
)
mat_vec = mv_factory(
forward_fn=vstate._apply_fun,
params=vstate.parameters,
model_state=vstate.model_state,
samples=samples,
pdf=pdf,
)
return QGTOnTheFlyT(
_mat_vec=mat_vec,
_params=vstate.parameters,
_chunking=chunking,
_mode=mode,
**kwargs,
)
@struct.dataclass
class QGTOnTheFlyT(LinearOperator):
"""
Lazy representation of an S Matrix computed by performing 2 jvp
and 1 vjp products, using the variational state's model, the
samples that have already been computed, and the vector.
The S matrix is not computed yet, but can be computed by calling
:code:`to_dense`.
The details on how the âšSâ©â»ÂčâšFâ© system is solved are contained in
the field `sr`.
"""
_mat_vec: Callable[[PyTree, float], PyTree] = Uninitialized
"""The S matrix-vector product as generated by mat_vec_factory.
It's a jax.Partial, so can be used as pytree_node."""
_params: PyTree = Uninitialized
"""The first input to apply_fun (parameters of the ansatz).
Only used as a shape placeholder."""
_chunking: bool = struct.field(pytree_node=False, default=False)
"""Whether the implementation with chunks is used which currently does not support vmapping over it"""
_mode: str = struct.field(pytree_node=False, default=None)
"""Differentiation mode:
- "real": for real-valued R->R and C->R AnsÀtze, splits the complex inputs
into real and imaginary part.
- "complex": for complex-valued R->C and C->C AnsÀtze, splits the complex
inputs and outputs into real and imaginary part
- "holomorphic": for any AnsÀtze. Does not split complex values.
- "auto": autoselect real or complex.
"""
def __matmul__(self, y):
return onthefly_mat_treevec(self, y)
def _solve(self, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs) -> PyTree:
return _solve(self, solve_fun, y, x0=x0)
def to_dense(self) -> jnp.ndarray:
"""
Convert the lazy matrix representation to a dense matrix representation.
Returns:
A dense matrix representation of this S matrix.
"""
# This condition will be true if the user specified `holomorphic=False` and
# if the parameters are complex. If the parameters are real and the user
# did not specify holomorphic we will have mode==real and if holomorphic is
# True mode==holomorphic.
#
# We must check this because the AD implementation will compute the wrong
# QGT in that case
if self._mode == "complex":
raise NonHolomorphicQGTOnTheFlyDenseRepresentationError()
return _to_dense(self)
def __repr__(self):
return f"QGTOnTheFly(diag_shift={self.diag_shift})"
@jax.jit
def onthefly_mat_treevec(
S: QGTOnTheFly, vec: Union[PyTree, jnp.ndarray]
) -> Union[PyTree, jnp.ndarray]:
"""
Perform the lazy mat-vec product, where vec is either a tree with the same structure as
params or a ravelled vector
"""
# if has a ndim it's an array and not a pytree
if hasattr(vec, "ndim"):
if not vec.ndim == 1:
raise ValueError("Unsupported mat-vec for chunks of vectors")
# If the input is a vector
if not nkjax.tree_size(S._params) == vec.size:
raise ValueError(
"""Size mismatch between number of parameters ({nkjax.tree_size(S.params)})
and vector size {vec.size}.
"""
)
_, unravel = nkjax.tree_ravel(S._params)
vec = unravel(vec)
ravel_result = True
else:
ravel_result = False
check_valid_vector_type(S._params, vec)
vec = nkjax.tree_cast(vec, S._params)
res = S._mat_vec(vec, S.diag_shift)
if ravel_result:
res, _ = nkjax.tree_ravel(res)
return res
@jax.jit
def _solve(
self: QGTOnTheFlyT, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs
) -> PyTree:
check_valid_vector_type(self._params, y)
y = nkjax.tree_cast(y, self._params)
# we could cache this...
if x0 is None:
x0 = jax.tree_util.tree_map(jnp.zeros_like, y)
out, info = solve_fun(self, y, x0=x0)
return out, info
@jax.jit
def _to_dense(self: QGTOnTheFlyT) -> jnp.ndarray:
"""
Convert the lazy matrix representation to a dense matrix representation
Returns:
A dense matrix representation of this S matrix.
"""
Npars = nkjax.tree_size(self._params)
I = jax.numpy.eye(Npars)
if self._chunking:
# the linear_call in mat_vec_chunked does currently not have a jax batching rule,
# so it cannot be vmapped but we can use scan
# which is better for reducing the memory consumption anyway
_, out = jax.lax.scan(lambda _, x: (None, self @ x), None, I)
else:
out = jax.vmap(lambda x: self @ x, in_axes=0)(I)
if jnp.iscomplexobj(out):
out = out.T
return out