netket.errors.JaxOperatorSetupDuringTracingError#
- exception netket.errors.JaxOperatorSetupDuringTracingError[source]#
Illegal attempt to use a Jax-operator Numba-operators constructed inside of a Jax function transformation with non-constant data.
This happens when building a
DiscreteJaxOperator
inside of a function that is being transformed by jax with transformations such asjax.jit()
orjax.grad()
, and the operator is not compatible with Jax.Notice that
DiscreteJaxOperator
can be used inside of jax function transformations, but NetKet is currently limited in that you cannot build them inside of Jax transformations.To avoid this error you should build your operators outside of the jax context.
(i) Building a Jax operator outside of a jax context#
Build the operator outside of a jax context.
import netket as nk import jax import jax.numpy as jnp N = 2 ham = nk.operator.PauliStringsJax(['XI', 'IX'], jnp.array([0.3, 0.4])) samples = ham.hilbert.all_states() @jax.jit def compute_values(ham, s): return ham.get_conn_padded(s) compute_values(ham, samples)
Note
This limitation is not systematic, and it could be lifted in the future by some interested coder. If at that moment Jax will support dynamic shape, this feature could be implemented at no additional runtime cost. If Jax won’t support yet dynamic shapes then it should be implemented as a secondary path (instead of this error) that is only taken if the operator is constructed inside of a jax context. This path will lead to a slightly less optimized, higher computational cost operators.
If you are really interested in contributing to NetKet and find yourself in need of building operators in a jax context (for example because you’re doing optimal control), get in touch with us by opening an issue.
Note
Most operators lazily initialise the fields used to compute the connected elements only when needed. To check whether an operator was initialized you can probe the boolean flag
operator._initialized
. Ifoperator._initialized
is True, you can safely callget_conn_padded()
and similar methods. If it is False, then the setup procedure will be handled by an internal method usually calledoperator._setup()
. If you see this error, it means that this method internally uses dynamically determined shapes, and it is what should be converted to be jax-friendly.