netket.jax.tree_dot

Contents

netket.jax.tree_dot#

netket.jax.tree_dot(a, b)[source]#

compute the dot product of two pytrees

Parameters:
  • a (Any) – pytrees with the same treedef

  • b (Any) – pytrees with the same treedef

Return type:

Any

Returns:

A scalar equal the dot product of of the flattened arrays of a and b.