netket.jax.tree_size

Contents

netket.jax.tree_size#

netket.jax.tree_size(tree)[source]#

Returns the sum of the size of all leaves in the tree. It’s equivalent to the number of scalars in the pytree.

Return type:

int

Parameters:

tree (Any) –