netket.jax.tree_cast

Contents

netket.jax.tree_cast#

netket.jax.tree_cast(x, target)[source]#

cast x the types of target

Parameters:
  • x (Any) – a pytree with arrays as leaves

  • target (Any) – a pytree with the same treedef as x where only the dtypes of the leaves are accessed

Return type:

Any

Returns:

A pytree where each leaf of x is cast to the dtype of the corresponding leaf in target. The imaginary part of complex leaves which are cast to real is discarded.