netket.errors.JaxOperatorSetupDuringTracingError

netket.errors.JaxOperatorSetupDuringTracingError#

exception netket.errors.JaxOperatorSetupDuringTracingError[source]#

Illegal attempt to use a Jax-operator Numba-operators constructed inside of a Jax function transformation with non-constant data.

This happens when building a DiscreteJaxOperator inside of a function that is being transformed by jax with transformations such as jax.jit() or jax.grad(), and the operator is not compatible with Jax.

Notice that DiscreteJaxOperator can be used inside of jax function transformations, but NetKet is currently limited in that you cannot build them inside of Jax transformations.

To avoid this error you should build your operators outside of the jax context.

(i) Building a Jax operator outside of a jax context#

Build the operator outside of a jax context.

import netket as nk
import jax
import jax.numpy as jnp

N = 2

ham = nk.operator.PauliStringsJax(['XI', 'IX'], jnp.array([0.3, 0.4]))

samples = ham.hilbert.all_states()

@jax.jit
def compute_values(ham, s):
    return ham.get_conn_padded(s)

compute_values(ham, samples)

Note

This limitation is not systematic, and it could be lifted in the future by some interested coder. If at that moment Jax will support dynamic shape, this feature could be implemented at no additional runtime cost. If Jax won’t support yet dynamic shapes then it should be implemented as a secondary path (instead of this error) that is only taken if the operator is constructed inside of a jax context. This path will lead to a slightly less optimized, higher computational cost operators.

If you are really interested in contributing to NetKet and find yourself in need of building operators in a jax context (for example because you’re doing optimal control), get in touch with us by opening an issue.

Note

Most operators lazily initialise the fields used to compute the connected elements only when needed. To check whether an operator was initialized you can probe the boolean flag operator._initialized. If operator._initialized is True, you can safely call get_conn_padded() and similar methods. If it is False, then the setup procedure will be handled by an internal method usually called operator._setup(). If you see this error, it means that this method internally uses dynamically determined shapes, and it is what should be converted to be jax-friendly.