Compute the norm of a pytree, intended as a 1D vector of values.
Equivalent to jnp.linalg.norm(jax.flatten_util.ravel_pytree(a)[0], ord).
- Parameters:
a (Any) – A pytree, interpreted as a vector
ord (int) – Specify the vector L norm to be computed. Defaults to L=2.
- Return type:
Any
- Returns:
A scalar.