netket.jax#

This module contains some internal utilities to work with Jax. This part of the API is not public, and can change without notice.

Utility functions#

HashablePartial

A class behaving like functools.partial, but that retains it's hash if it's created with a lexically equivalent (the same) function and with the same partially applied arguments and keywords.

PRNGKey

Initialises a PRNGKey using an optional starting seed.

PRNGSeq

A sequence of PRNG keys generated based on an initial key.

mpi_split

Split a key across MPI nodes in the communicator.

Tree Linear Algebra#

tree_dot

compute the dot product of two pytrees

tree_axpy

compute a * x + y

tree_cast

cast x the types of target

tree_conj

Conjugate all complex leaves.

tree_size

Returns the sum of the size of all leaves in the tree.

tree_leaf_iscomplex

Returns true if at least one leaf in the tree has complex dtype.

tree_ishomogeneous

Returns true if all leaves have real dtype or all leaves have complex dtype.

tree_ravel

Ravel (i.e. flatten) a pytree of arrays down to a 1D array.

tree_to_real

Replace all complex leaves of a pytree with a RealImagTuple of 2 real leaves.

Dtype tools#

dtype_complex

Return the complex dtype corresponding to the type passed in.

is_complex_dtype

Returns True if typ is a complex dtype.

maybe_promote_to_complex

Maybe promotes the first argument to it's complex counterpart given by dtype_complex(typ) if any of the arguments is complex

Complex-aware AD#

expect

Computes the expectation value over a log-pdf.

vjp

rtype:

Union[tuple[Any, Callable], tuple[Any, Callable, Any]]

jacobian

Computes the jacobian of a NN model with respect to its parameters.

jacobian_default_mode

Returns the default mode for {func}`netket.jax.jacobian` given a certain wave-function ansatz.

Chunked operations#

chunk

Split an array (or a pytree of arrays) into chunks along the first axis

unchunk

Merge the first two axes of an array (or a pytree of arrays) :type x_chunked: :param x_chunked: an array (or pytree of arrays) of at least 2 dimensions

apply_chunked

Takes an implicitly vmapped function over the axis 0 and uses scan to do the computations in smaller chunks over the 0-th axis of all input arguments.

vmap_chunked

Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.

vjp_chunked

calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments

Math#

logsumexp_cplx

Compute the log of the sum of exponentials of input elements, always returning a complex number.

logdet_cmplx

Log-determinant, with automatic upconversion to a complex output dtype in order to encode the sign.