netket.errors.JaxOperatorGetConnInJitError#
- exception netket.errors.JaxOperatorGetConnInJitError[source]#
Illegal attempt to use
get_conn_flattened()or similar methods inside of a Jax function transformation with a Jax operator.This happens when calling
get_conn_flattened()orget_conn()inside of a function that is being transformed by jax with transformations such asjax.jit(),jax.grad(),jax.vmap(), or other jax function transformations.These methods require materializing the output to a concrete numpy array, which is not compatible with Jax’s abstract tracing mechanism used in function transformations.
To avoid this error, use
get_conn_padded()instead, which returns Jax arrays that are fully compatible with jax function transformations.Example
Instead of:
import netket as nk import jax import jax.numpy as jnp graph = nk.graph.Chain(10) hilbert = nk.hilbert.Spin(1/2, graph.n_nodes) ham = nk.operator.Ising(hilbert, graph, h=1.0) ham_jax = ham.to_jax_operator() ma = nk.models.RBM() pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes))) samples = hilbert.all_states() @jax.jit def compute_local_energies(pars, ham_jax, s): # This will raise JaxOperatorGetConnInJitError sp, mels = ham_jax.get_conn_flattened(s) logpsi_sp = ma.apply(pars, sp) logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1) return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1) elocs = compute_local_energies(pars, ham_jax, samples)
Use:
import netket as nk import jax import jax.numpy as jnp graph = nk.graph.Chain(10) hilbert = nk.hilbert.Spin(1/2, graph.n_nodes) ham = nk.operator.Ising(hilbert, graph, h=1.0) ham_jax = ham.to_jax_operator() ma = nk.models.RBM() pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes))) samples = hilbert.all_states() @jax.jit def compute_local_energies(pars, ham_jax, s): # Use get_conn_padded instead - this works in jax.jit sp, mels = ham_jax.get_conn_padded(s) logpsi_sp = ma.apply(pars, sp) logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1) return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1) elocs = compute_local_energies(pars, ham_jax, samples)
Note
get_conn_padded()returns arrays with a fixed padding dimension for the connected elements, which makes them compatible with jax transformations. The padding is filled with zeros in the matrix elements array.