netket.jax.tree_axpy

Contents

netket.jax.tree_axpy#

netket.jax.tree_axpy(a, x, y)[source]#

compute a * x + y

Parameters:
  • a (Any) – scalar

  • x (Any) – pytrees with the same treedef

  • y (Any) – pytrees with the same treedef

Return type:

Any

Returns:

The sum of the respective leaves of the two pytrees x and y where the leaves of x are first scaled with a.