netket.jax.tree_ax

Contents

netket.jax.tree_ax#

netket.jax.tree_ax(a, x)[source]#

Compute a * x , where a is a scalar or pytree.

Parameters:
  • a (Any) – scalar

  • x (Any) – pytree

Return type:

Any

Returns:

The pytree x scaled by a