Source code for netket.jax._vjp

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Any, Union
from functools import partial

import jax

from jax import numpy as jnp
from jax.tree_util import Partial, tree_map

from netket.utils import HashablePartial

from ._utils_tree import tree_leaf_iscomplex, eval_shape


# _grad_CC, _RR and _RC are the chunked gradient functions for machines going
# from R -> C, R->R and R->C. Ditto for vjp
# Thee reason why R->C is more complicated is that it splits the calculation
# into the real and complex part in order to be more efficient.


def _cmplx(re, im, conj=False):
    """
    Safely convert real and imaginary part to a complex number, considering
    `float0` dtypes which cannot be summed upon.

    Those types appear when computing the `vjp` of functions with integer
    inputs.
    """
    # detect tangent-0 dtypes
    is_re_0 = jax.dtypes.issubdtype(re.dtype, jax.dtypes.float0)
    is_im_0 = jax.dtypes.issubdtype(re.dtype, jax.dtypes.float0)
    if is_re_0 or is_im_0:
        return re
    else:
        if conj:
            return re - 1j * im
        else:
            return re + 1j * im


def vjp_fun_cc(out_dtype, conjugate, _vjp_fun, ȳ):
    ȳ = jnp.asarray(ȳ, dtype=out_dtype)

    dȳ = _vjp_fun(ȳ)

    if conjugate:
        dȳ = tree_map(jnp.conjugate, dȳ)

    return dȳ


def vjp_cc(
    fun: Callable, *primals, has_aux: bool = False, conjugate: bool = False
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
    if has_aux:
        out, _vjp_fun, aux = jax.vjp(fun, *primals, has_aux=True)
    else:
        out, _vjp_fun = jax.vjp(fun, *primals, has_aux=False)

    vjp_fun = Partial(HashablePartial(vjp_fun_cc, out.dtype, conjugate), _vjp_fun)

    if has_aux:
        return out, vjp_fun, aux
    else:
        return out, vjp_fun


def vjp_fun_rr(primals_out_dtype, conjugate, _vjp_fun, ȳ):
    """
    function computing the vjp product for a R->R function.
    """
    if not jnp.iscomplexobj(ȳ):
        out = _vjp_fun(jnp.asarray(ȳ, dtype=primals_out_dtype))
    else:
        out_r = _vjp_fun(jnp.asarray(ȳ.real, dtype=primals_out_dtype))
        out_i = _vjp_fun(jnp.asarray(ȳ.imag, dtype=primals_out_dtype))
        out = tree_map(partial(_cmplx, conj=conjugate), out_r, out_i)

    return out


def vjp_rr(
    fun: Callable, *primals, has_aux: bool = False, conjugate: bool = False
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
    if has_aux:
        primals_out, _vjp_fun, aux = jax.vjp(fun, *primals, has_aux=True)
    else:
        primals_out, _vjp_fun = jax.vjp(fun, *primals, has_aux=False)

    vjp_fun = Partial(
        HashablePartial(vjp_fun_rr, primals_out.dtype, conjugate), _vjp_fun
    )

    if has_aux:
        return primals_out, vjp_fun, aux
    else:
        return primals_out, vjp_fun


def vjp_fun_rc(vals_r_dtype, vals_j_dtype, conjugate, vjp_r_fun, vjp_j_fun, ȳ):
    """
    function computing the vjp product for a R->C function.
    """
    ȳ_r = ȳ.real
    ȳ_j = ȳ.imag

    # val = vals_r + vals_j
    vr_jr = vjp_r_fun(jnp.asarray(ȳ_r, dtype=vals_r_dtype))
    vj_jr = vjp_r_fun(jnp.asarray(ȳ_j, dtype=vals_r_dtype))
    vr_jj = vjp_j_fun(jnp.asarray(ȳ_r, dtype=vals_j_dtype))
    vj_jj = vjp_j_fun(jnp.asarray(ȳ_j, dtype=vals_j_dtype))

    r = tree_map(_cmplx, vr_jr, vj_jr)
    i = tree_map(_cmplx, vr_jj, vj_jj)
    out = tree_map(_cmplx, r, i)

    if conjugate:
        out = tree_map(jnp.conjugate, out)

    return out


def vjp_rc(
    fun: Callable, *primals, has_aux: bool = False, conjugate: bool = False
) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]:
    if has_aux:

        def real_fun(*primals):
            val, aux = fun(*primals)
            return val.real, aux

        def imag_fun(*primals):
            val, aux = fun(*primals)
            return val.imag, aux

        vals_r, vjp_r_fun, aux = jax.vjp(real_fun, *primals, has_aux=True)
        vals_j, vjp_j_fun, _ = jax.vjp(imag_fun, *primals, has_aux=True)

    else:
        real_fun = lambda *primals: fun(*primals).real
        imag_fun = lambda *primals: fun(*primals).imag

        vals_r, vjp_r_fun = jax.vjp(real_fun, *primals, has_aux=False)
        vals_j, vjp_j_fun = jax.vjp(imag_fun, *primals, has_aux=False)

    primals_out = vals_r + 1j * vals_j

    vjp_fun = Partial(
        HashablePartial(vjp_fun_rc, vals_r.dtype, vals_j.dtype, conjugate),
        vjp_r_fun,
        vjp_j_fun,
    )

    if has_aux:
        return primals_out, vjp_fun, aux
    else:
        return primals_out, vjp_fun


# This function dispatches to the right
[docs] def vjp( fun: Callable, *primals, has_aux: bool = False, conjugate: bool = False ) -> Union[tuple[Any, Callable], tuple[Any, Callable, Any]]: # output dtype out_shape = eval_shape(fun, *primals, has_aux=has_aux) if tree_leaf_iscomplex(primals): if jnp.iscomplexobj(out_shape): # C -> C return vjp_cc(fun, *primals, has_aux=has_aux, conjugate=conjugate) else: # C -> R return vjp_cc(fun, *primals, has_aux=has_aux, conjugate=conjugate) else: if jnp.iscomplexobj(out_shape): # R -> C return vjp_rc(fun, *primals, has_aux=has_aux, conjugate=conjugate) else: # R -> R return vjp_rr(fun, *primals, has_aux=has_aux, conjugate=conjugate)