netket.jax#
This module contains some internal utilities to work with Jax. This part of the API is not public, and can change without notice.
Utility functions#
A class behaving like functools.partial, but that retains it's hash if it's created with a lexically equivalent (the same) function and with the same partially applied arguments and keywords. |
|
Initialises a PRNGKey using an optional starting seed. |
|
A sequence of PRNG keys generated based on an initial key. |
|
Split a key across MPI nodes in the communicator. |
Tree Linear Algebra#
compute the dot product of two pytrees |
|
compute a * x + y |
|
cast x the types of target |
|
Conjugate all complex leaves. |
|
Returns the sum of the size of all leaves in the tree. |
|
Returns true if at least one leaf in the tree has complex dtype. |
|
Returns true if all leaves have real dtype or all leaves have complex dtype. |
|
Ravel (i.e. flatten) a pytree of arrays down to a 1D array. |
|
Replace all complex leaves of a pytree with a RealImagTuple of 2 real leaves. |
Dtype tools#
Return the complex dtype corresponding to the type passed in. |
|
Returns True if typ is a complex dtype. |
|
Maybe promotes the first argument to it's complex counterpart given by dtype_complex(typ) if any of the arguments is complex |
Complex-aware AD#
Computes the expectation value over a log-pdf. |
|
Computes the jacobian of a NN model with respect to its parameters. |
|
Returns the default mode for {func}`netket.jax.jacobian` given a certain wave-function ansatz. |
Chunked operations#
Split an array (or a pytree of arrays) into chunks along the first axis |
|
Merge the first two axes of an array (or a pytree of arrays) :type x_chunked: :param x_chunked: an array (or pytree of arrays) of at least 2 dimensions |
|
Takes an implicitly vmapped function over the axis 0 and uses scan to do the computations in smaller chunks over the 0-th axis of all input arguments. |
|
Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks. |
|
calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments |
Math#
Compute the log of the sum of exponentials of input elements, always returning a complex number. |
|
Log-determinant, with automatic upconversion to a complex output dtype in order to encode the sign. |