netket.errors.JaxOperatorNotConvertibleToNumba#
- exception netket.errors.JaxOperatorNotConvertibleToNumba[source]#
Illegal attempt to convert to the Numba format a Jax operator that had been flattened and unflattened.
This probably happened because you passed a Jax operator to a jax function transformation or jitted function and then tried to re-convert it to the numba format like in the example below:
import netket as nk hi = nk.hilbert.Spin(0.5, 2) op = nk.operator.spin.sigmax(hi, 0) op = op.to_jax_operator() @jax.jit def test(op): op.to_numba_operator() test(op)
Unfortunately, once an operator is flattened with {ref}`jax.tree_util.tree_flatten`, which happens at all jax-function transformation boundaries, it usually cannot be converted back to the original numba form.
This happens for performance reasons, and we might reconsider. If it is a problem for you, do open an issue.