netket.jax.tree_axpy# netket.jax.tree_axpy(a, x, y)[source]# compute a * x + y Parameters: a (Any) – scalar x (Any) – pytrees with the same treedef y (Any) – pytrees with the same treedef Return type: Any Returns: The sum of the respective leaves of the two pytrees x and y where the leaves of x are first scaled with a.