netket.jax.tree_dot#

netket.jax.tree_dot(a, b)[source]#

compute the dot product of two pytrees

Parameters
  • a (Any) – pytrees with the same treedef

  • b (Any) – pytrees with the same treedef

Return type

Any

Returns

A scalar equal the dot product of of the flattened arrays of a and b.