netket.jax.tree_size#

netket.jax.tree_size(tree)[source]#

Returns the sum of the size of all leaves in the tree. It’s equivalent to the number of scalars in the pytree.

Return type

int

Parameters

tree (Any) –