netket.jax.tree_ravel

Contents

netket.jax.tree_ravel#

netket.jax.tree_ravel(pytree)[source]#

Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

Parameters:

pytree (Any) – a pytree to ravel

Return type:

tuple[Array, Callable]

Returns:

A pair where the first element is a 1D array representing the flattened and concatenated leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input pytree.