netket.jax.tree_axpy#

netket.jax.tree_axpy(a, x, y)[source]#

compute a * x + y

Parameters
  • a (Any) – scalar

  • x (Any) – pytrees with the same treedef

  • y (Any) – pytrees with the same treedef

Return type

Any

Returns

The sum of the respective leaves of the two pytrees x and y where the leaves of x are first scaled with a.