Source code for netket.jax._jacobian.default_mode
# Copyright 2023 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
from functools import partial
import warnings
import jax
import netket.jax as nkjax
from netket.utils import struct
from netket.utils.types import PyTree, Array
from netket.errors import (
HolomorphicUndeclaredWarning,
IllegalHolomorphicDeclarationForRealParametersError,
)
@struct.dataclass
class JacobianMode:
"""
Jax-compatible string type, used to return static information from a jax-jitted
function.
"""
name: str = struct.field(pytree_node=False)
def __str__(self):
return self.name
def __repr__(self):
return f"JacobianMode({self.name})"
def __hash__(self):
return hash(self.name)
def __eq__(self, o):
if isinstance(o, JacobianMode):
o = o.name
return self.name == o
RealMode = JacobianMode("real")
ComplexMode = JacobianMode("complex")
HolomorphicMode = JacobianMode("holomorphic")
[docs]
@partial(jax.jit, static_argnames=("apply_fun", "holomorphic", "warn"))
def jacobian_default_mode(
apply_fun: Callable[[PyTree, Array], Array],
pars: PyTree,
model_state: Optional[PyTree],
samples: Array,
*,
holomorphic: Optional[bool] = None,
warn: bool = True,
) -> JacobianMode:
"""
Returns the default `mode` for {func}`netket.jax.jacobian` given a certain
wave-function ansatz.
This function uses an abstract evaluation of the ansatz to determine if
the ansatz has real or complex output, and uses that to determine the
default mode to be used to compute the Jacobian.
In particular:
- for functions with a real output, it will return `RealMode`.
- for functions with a complex output, it will return:
- If `holomorphic==False` or it not been specified, it will return
`ComplexMode`, which will force the calculation of both the jacobian
and adjoint jacobian. See the documentation of{func}`nk.jax.jacobian`
for more details.
- If `holomorphic==True`, it will compute only the complex-valued
jacobian, and assumes the adjoint-jacobian to be zero.
This function will also raise an error if `holomorphic` is not specified
but the output is complex.
Args:
apply_fun: A callable taking as input a pytree of parameters and the samples,
and returning the output.
pars: The Pytree of parameters.
model_state: The optional `model_state`, according to the flax model definition.
samples: An array of samples.
holomorphic: A boolean specifying whether `apply_fun` is
holomorphic or not (`None` by default).
warn: A boolean specifying whether to raise a warning
when holomorphic is not specified. For internal use
only.
"""
nkjax.tree_ishomogeneous(pars)
nkjax.tree_leaf_iscomplex(pars)
leaf_isreal = nkjax.tree_leaf_isreal(pars)
if holomorphic is True:
if leaf_isreal:
# all real or mixed real/complex parameters. It's not holomorphic
raise IllegalHolomorphicDeclarationForRealParametersError()
else:
## all complex parameters
mode = HolomorphicMode
else:
complex_output = jax.numpy.iscomplexobj(
jax.eval_shape(
apply_fun,
{"params": pars, **model_state},
samples.reshape(-1, samples.shape[-1]),
)
)
if complex_output:
if not leaf_isreal:
if holomorphic is None and warn:
warnings.warn(
HolomorphicUndeclaredWarning(),
UserWarning,
stacklevel=2,
)
mode = ComplexMode
else:
mode = ComplexMode
else:
mode = RealMode
return mode