netket.jax.tree_norm

Contents

netket.jax.tree_norm#

netket.jax.tree_norm(a, ord=2)[source]#

Compute the norm of a pytree, intended as a 1D vector of values.

Equivalent to jnp.linalg.norm(nk.jax.tree_ravel(a)[0], ord).

Parameters:
  • a (Any) – A pytree, interpreted as a vector

  • ord (int) – Specify the vector L norm to be computed. Defaults to L=2.

Return type:

Any

Returns:

A scalar.