netket.errors.NumbaOperatorGetConnDuringTracingError

netket.errors.NumbaOperatorGetConnDuringTracingError#

exception netket.errors.NumbaOperatorGetConnDuringTracingError[source]#

Illegal attempt to use Numba-operators inside of a Jax function transformation.

This happens when calling get_conn_padded() or get_conn_flattened() inside of a function that is being transformed by jax with transformations such as jax.jit() or jax.grad(), and the operator is not compatible with Jax.

To avoid this error you can (i) convert your operator to a Jax compatible format if possible, or (ii) compute the connected elements outside of the jax function transformation and pass the results to a jax-transformed function.

(i) Converting an Operator to a Jax compatible format#

Some operators can be converted to a jax-compatible format by calling the method operator.to_jax_operator(). If this method is not available or raises an error, it means that the operator cannot be converted.

If the operator can be converted to a jax-compatible format, it will be possible to pass it as a standard argument to a jax-transformed function and it should not be declared as a static argument.

Jax compatible operators can be used like standard operators, for example by passing it to expect() function. However, the performance will differ from standard operators. In general, you might find that compile time will be much worse, while runtime might be faster or slower, depending on several factors.

The biggest advantage to Jax operators, however, is when experimenting with jax code, as you can succesfully use them in your own custom functions as in the example below:

import netket as nk
import jax
import jax.numpy as jnp

graph = nk.graph.Chain(10)
hilbert = nk.hilbert.Spin(1/2, graph.n_nodes)
ham = nk.operator.Ising(hilbert, graph, h=1.0)
ham_jax = ham.to_jax_operator()

ma = nk.models.RBM()
pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes)))

samples = hilbert.all_states()

@jax.jit
def compute_local_energies(pars, ham_jax, s):
    # this would raise the error
    sp, mels = ham.get_conn_padded(s)
    # this will work
    sp, mels = ham_jax.get_conn_padded(s)

    logpsi_sp = ma.apply(pars, sp)
    logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1)

    return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1)

elocs = compute_local_energies(pars, ham_jax, samples)
elocs_grad = jax.jacrev(compute_local_energies)(pars, ham_jax, samples)

Note

Note that this transformation might be a relatively expensive operation, so you should avoid executing this inside of an hot loop.

(ii) Precomputing connected elements outside of Jax transformations#

In most cases you won’t be able to convert the operator to a Jax-compatible format. In those cases, the workaround we usually employ is to precompute the connected elements before entering the Jax context, splitting our function into a non-jitted function and into a jitted kernel.

import netket as nk
import jax
import jax.numpy as jnp

graph = nk.graph.Chain(10)
hilbert = nk.hilbert.Spin(1/2, graph.n_nodes)
ham = nk.operator.Ising(hilbert, graph, h=1.0)

ma = nk.models.RBM()
pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes)))

samples = hilbert.all_states()

def compute_local_energies(pars, ham, s):
    sp, mels = ham.get_conn_padded(s)
    return _compute_local_energies_kernel(pars, s, sp, mels)

@jax.jit
def _compute_local_energies_kernel(pars, s, sp, mels):
    logpsi_sp = ma.apply(pars, sp)
    logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1)
    return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1)

elocs = compute_local_energies(pars, ham_jax, samples)
elocs_grad = jax.jacrev(compute_local_energies)(pars, ham_jax, samples)

Most DiscreteJaxOperator by executing

isinstance(operator, nk.operator.DiscreteJaxOperator)