netket.errors.JaxOperatorNotConvertibleToNumba

netket.errors.JaxOperatorNotConvertibleToNumba#

exception netket.errors.JaxOperatorNotConvertibleToNumba[source]#

Illegal attempt to convert to the Numba format a Jax operator that had been flattened and unflattened.

This probably happened because you passed a Jax operator to a jax function transformation or jitted function and then tried to re-convert it to the numba format like in the example below:

import netket as nk

hi = nk.hilbert.Spin(0.5, 2)

op = nk.operator.spin.sigmax(hi, 0)
op = op.to_jax_operator()

@jax.jit
def test(op):
    op.to_numba_operator()

test(op)

Unfortunately, once an operator is flattened with {ref}`jax.tree_util.tree_flatten`, which happens at all jax-function transformation boundaries, it usually cannot be converted back to the original numba form.

This happens for performance reasons, and we might reconsider. If it is a problem for you, do open an issue.