netket.jax.tree_ravel#

netket.jax.tree_ravel(pytree)[source]#

Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

Parameters

pytree (Any) – a pytree to ravel

Return type

Tuple[ndarray, Callable]

Returns

A pair where the first element is a 1D array representing the flattened and concatenated leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input pytree.