netket.jax#

This module contains some internal utilities to work with Jax. This part of the API is not public, and can change without notice.

Utility functions#

HashablePartial

A class behaving like functools.partial, but that retains it's hash if it's created with a lexically equivalent (the same) function and with the same partially applied arguments and keywords.

PRNGKey

Initialises a PRNGKey using an optional starting seed.

PRNGSeq

A sequence of PRNG keys genrated based on an initial key.

mpi_split

Split a key across MPI nodes in the communicator.

Tree Linear Algebra#

tree_dot

compute the dot product of two pytrees

tree_axpy

compute a * x + y

tree_cast

cast x the types of target

tree_conj

Conjugate all complex leaves.

tree_size

Returns the sum of the size of all leaves in the tree.

tree_leaf_iscomplex

Returns true if at least one leaf in the tree has complex dtype.

tree_ishomogeneous

Returns true if all leaves have real dtype or all leaves have complex dtype.

tree_ravel

Ravel (i.e.

tree_to_real

Replace all complex leaves of a pytree with a tuple of 2 real leaves.

Dtype tools#

dtype_complex

Return the complex dtype corresponding to the type passed in.

is_complex

Returns True if x has a complex dtype

is_complex_dtype

Returns True if typ is a complex dtype

maybe_promote_to_complex

Maybe promotes the first argument to it's complex counterpart given by dtype_complex(typ) if any of the arguments is complex

Complex-aware AD#

expect

Computes the expectation value over a log-pdf.

vjp

rtype

Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]]

Chunked operations#

chunk

Split an array (or a pytree of arrays) into chunks along the first axis

unchunk

Merge the first two axes of an array (or a pytree of arrays) :param x_chunked: an array (or pytree of arrays) of at least 2 dimensions

vjp_chunked

calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments

vmap_chunked

Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.