Compute the norm of a pytree, intended as a 1D vector of values.
Equivalent to jnp.linalg.norm(nk.jax.tree_ravel(a)[0], ord)
.
- Parameters:
a (Any
) – A pytree, interpreted as a vector
ord (int
) – Specify the vector L norm to be computed. Defaults to L=2.
- Return type:
Any
- Returns:
A scalar.