netket.errors.NumbaOperatorGetConnDuringTracingError#
- exception netket.errors.NumbaOperatorGetConnDuringTracingError[source]#
Illegal attempt to use Numba-operators inside of a Jax function transformation.
This happens when calling
get_conn_padded()
orget_conn_flattened()
inside of a function that is being transformed by jax with transformations such asjax.jit()
orjax.grad()
, and the operator is not compatible with Jax.To avoid this error you can (i) convert your operator to a Jax compatible format if possible, or (ii) compute the connected elements outside of the jax function transformation and pass the results to a jax-transformed function.
(i) Converting an Operator to a Jax compatible format#
Some operators can be converted to a jax-compatible format by calling the method operator.to_jax_operator(). If this method is not available or raises an error, it means that the operator cannot be converted.
If the operator can be converted to a jax-compatible format, it will be possible to pass it as a standard argument to a jax-transformed function and it should not be declared as a static argument.
Jax compatible operators can be used like standard operators, for example by passing it to
expect()
function. However, the performance will differ from standard operators. In general, you might find that compile time will be much worse, while runtime might be faster or slower, depending on several factors.The biggest advantage to Jax operators, however, is when experimenting with jax code, as you can succesfully use them in your own custom functions as in the example below:
import netket as nk import jax import jax.numpy as jnp graph = nk.graph.Chain(10) hilbert = nk.hilbert.Spin(1/2, graph.n_nodes) ham = nk.operator.Ising(hilbert, graph, h=1.0) ham_jax = ham.to_jax_operator() ma = nk.models.RBM() pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes))) samples = hilbert.all_states() @jax.jit def compute_local_energies(pars, ham_jax, s): # this would raise the error sp, mels = ham.get_conn_padded(s) # this will work sp, mels = ham_jax.get_conn_padded(s) logpsi_sp = ma.apply(pars, sp) logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1) return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1) elocs = compute_local_energies(pars, ham_jax, samples) elocs_grad = jax.jacrev(compute_local_energies)(pars, ham_jax, samples)
Note
Note that this transformation might be a relatively expensive operation, so you should avoid executing this inside of an hot loop.
(ii) Precomputing connected elements outside of Jax transformations#
In most cases you won’t be able to convert the operator to a Jax-compatible format. In those cases, the workaround we usually employ is to precompute the connected elements before entering the Jax context, splitting our function into a non-jitted function and into a jitted kernel.
import netket as nk import jax import jax.numpy as jnp graph = nk.graph.Chain(10) hilbert = nk.hilbert.Spin(1/2, graph.n_nodes) ham = nk.operator.Ising(hilbert, graph, h=1.0) ma = nk.models.RBM() pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes))) samples = hilbert.all_states() def compute_local_energies(pars, ham, s): sp, mels = ham.get_conn_padded(s) return _compute_local_energies_kernel(pars, s, sp, mels) @jax.jit def _compute_local_energies_kernel(pars, s, sp, mels): logpsi_sp = ma.apply(pars, sp) logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1) return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1) elocs = compute_local_energies(pars, ham_jax, samples) elocs_grad = jax.jacrev(compute_local_energies)(pars, ham_jax, samples)
Most
DiscreteJaxOperator
by executingisinstance(operator, nk.operator.DiscreteJaxOperator)