netket.jax.tree_ravel#
- netket.jax.tree_ravel(pytree)[source]#
Ravel (i.e. flatten) a pytree of arrays down to a 1D array.
- Parameters:
pytree (
Any
) – a pytree to ravel- Return type:
- Returns:
A pair where the first element is a 1D array representing the flattened and concatenated leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of of the same structure as the input
pytree
.