netket.jax.tree_to_real

Contents

netket.jax.tree_to_real#

netket.jax.tree_to_real(pytree)[source]#

Replace all complex leaves of a pytree with a RealImagTuple of 2 real leaves.

Parameters:

pytree (Any) – a pytree to convert to real

Return type:

tuple[Any, Callable]

Returns:

A pair where the first element is the converted real pytree, and the second element is a callable for converting back a real pytree to a complex pytree of of the same structure as the input pytree.