Source code for netket.jax._utils_tree

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
from typing import Callable


import jax
import netket.jax as nkjax
from jax import numpy as jnp
from jax.tree_util import (
    register_pytree_node,
    tree_flatten,
    tree_unflatten,
    tree_map,
    tree_leaves,
)

from netket.utils.types import PyTree, Scalar
from netket.utils.numbers import is_scalar


[docs] def tree_ravel(pytree: PyTree) -> tuple[jnp.ndarray, Callable]: """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. Args: pytree: a pytree to ravel Returns: A pair where the first element is a 1D array representing the flattened and concatenated leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input ``pytree``. """ leaves, treedef = tree_flatten(pytree) flat, unravel_list = nkjax.vjp(_ravel_list, *leaves) unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat)) return flat, unravel_pytree
def _ravel_list(*lst): return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) def eval_shape(fun, *args, has_aux=False, **kwargs): """ Returns the dtype of forward_fn(pars, v) """ if has_aux: out, _ = jax.eval_shape(fun, *args, **kwargs) else: out = jax.eval_shape(fun, *args, **kwargs) return out
[docs] def tree_size(tree: PyTree) -> int: """ Returns the sum of the size of all leaves in the tree. It's equivalent to the number of scalars in the pytree. """ return sum(tree_leaves(tree_map(lambda x: x.size, tree)))
[docs] def tree_leaf_iscomplex(pars: PyTree) -> bool: """ Returns true if at least one leaf in the tree has complex dtype. """ return any( jax.tree_util.tree_leaves(jax.tree_util.tree_map(jnp.iscomplexobj, pars)) )
def tree_leaf_isreal(pars: PyTree) -> bool: """ Returns true if at least one leaf in the tree has real dtype. """ return any(jax.tree_util.tree_leaves(jax.tree_util.tree_map(jnp.isrealobj, pars)))
[docs] def tree_ishomogeneous(pars: PyTree) -> bool: """ Returns true if all leaves have real dtype or all leaves have complex dtype. """ return not (tree_leaf_isreal(pars) and tree_leaf_iscomplex(pars))
[docs] @jax.jit def tree_conj(t: PyTree) -> PyTree: r""" Conjugate all complex leaves. The real leaves are left untouched. Args: t: pytree """ return jax.tree_util.tree_map( lambda x: jax.lax.conj(x) if jnp.iscomplexobj(x) else x, t )
[docs] @jax.jit def tree_dot(a: PyTree, b: PyTree) -> Scalar: r""" compute the dot product of two pytrees Args: a, b: pytrees with the same treedef Returns: A scalar equal the dot product of of the flattened arrays of a and b. """ return jax.tree_util.tree_reduce( jax.numpy.add, jax.tree_util.tree_map( jax.numpy.sum, jax.tree_util.tree_map(jax.numpy.multiply, a, b) ), )
[docs] @jax.jit def tree_cast(x: PyTree, target: PyTree) -> PyTree: r""" cast x the types of target Args: x: a pytree with arrays as leaves target: a pytree with the same treedef as x where only the dtypes of the leaves are accessed Returns: A pytree where each leaf of x is cast to the dtype of the corresponding leaf in target. The imaginary part of complex leaves which are cast to real is discarded. """ # astype alone would also work, however that raises ComplexWarning when casting complex to real # therefore the real is taken first where needed return jax.tree_util.tree_map( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), x, target, )
[docs] @jax.jit def tree_axpy(a: Scalar, x: PyTree, y: PyTree) -> PyTree: r""" compute a * x + y Args: a: scalar x, y: pytrees with the same treedef Returns: The sum of the respective leaves of the two pytrees x and y where the leaves of x are first scaled with a. """ if is_scalar(a): return jax.tree_util.tree_map(lambda x_, y_: a * x_ + y_, x, y) else: return jax.tree_util.tree_map(lambda a_, x_, y_: a_ * x_ + y_, a, x, y)
class RealImagTuple(tuple): """ A special kind of tuple which marks complex parameters which were split. Behaves like a regular tuple. """ @property def real(self): return self[0] @property def imag(self): return self[1] register_pytree_node( RealImagTuple, lambda xs: (xs, None), lambda _, xs: RealImagTuple(xs), ) def _tree_to_real(x): if tree_leaf_iscomplex(x): # TODO find a way to make it a nop? # return jax.vmap(lambda y: jnp.array((y.real, y.imag)))(x) r = jax.tree_util.tree_map(lambda x: x.real if jnp.iscomplexobj(x) else x, x) i = jax.tree_util.tree_map(lambda x: x.imag if jnp.iscomplexobj(x) else None, x) return RealImagTuple((r, i)) else: return x def _tree_to_real_inverse(x): if isinstance(x, RealImagTuple): # not using jax.lax.complex because it would convert scalars to arrays return jax.tree_util.tree_map( lambda re, im: re + 1j * im if im is not None else re, *x ) else: return x
[docs] def tree_to_real(pytree: PyTree) -> tuple[PyTree, Callable]: """Replace all complex leaves of a pytree with a RealImagTuple of 2 real leaves. Args: pytree: a pytree to convert to real Returns: A pair where the first element is the converted real pytree, and the second element is a callable for converting back a real pytree to a complex pytree of of the same structure as the input pytree. """ return _tree_to_real(pytree), _tree_to_real_inverse
def compose(*funcs): """ function composition compose(f,g,h)(x) is equivalent to f(g(h(x))) """ def _compose(f, g): return lambda *args, **kwargs: f(g(*args, **kwargs)) return reduce(_compose, funcs)