netket.errors.JaxOperatorGetConnInJitError

netket.errors.JaxOperatorGetConnInJitError#

exception netket.errors.JaxOperatorGetConnInJitError[source]#

Illegal attempt to use get_conn_flattened() or similar methods inside of a Jax function transformation with a Jax operator.

This happens when calling get_conn_flattened() or get_conn() inside of a function that is being transformed by jax with transformations such as jax.jit(), jax.grad(), jax.vmap(), or other jax function transformations.

These methods require materializing the output to a concrete numpy array, which is not compatible with Jax’s abstract tracing mechanism used in function transformations.

To avoid this error, use get_conn_padded() instead, which returns Jax arrays that are fully compatible with jax function transformations.

Example

Instead of:

import netket as nk
import jax
import jax.numpy as jnp

graph = nk.graph.Chain(10)
hilbert = nk.hilbert.Spin(1/2, graph.n_nodes)
ham = nk.operator.Ising(hilbert, graph, h=1.0)
ham_jax = ham.to_jax_operator()

ma = nk.models.RBM()
pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes)))

samples = hilbert.all_states()

@jax.jit
def compute_local_energies(pars, ham_jax, s):
    # This will raise JaxOperatorGetConnInJitError
    sp, mels = ham_jax.get_conn_flattened(s)

    logpsi_sp = ma.apply(pars, sp)
    logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1)

    return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1)

elocs = compute_local_energies(pars, ham_jax, samples)

Use:

import netket as nk
import jax
import jax.numpy as jnp

graph = nk.graph.Chain(10)
hilbert = nk.hilbert.Spin(1/2, graph.n_nodes)
ham = nk.operator.Ising(hilbert, graph, h=1.0)
ham_jax = ham.to_jax_operator()

ma = nk.models.RBM()
pars = ma.init(jax.random.PRNGKey(1), jnp.ones((2,graph.n_nodes)))

samples = hilbert.all_states()

@jax.jit
def compute_local_energies(pars, ham_jax, s):
    # Use get_conn_padded instead - this works in jax.jit
    sp, mels = ham_jax.get_conn_padded(s)

    logpsi_sp = ma.apply(pars, sp)
    logpsi_s = jnp.expand_dims(ma.apply(pars, s), -1)

    return jnp.sum(mels * jnp.exp(logpsi_sp-logpsi_s), axis=-1)

elocs = compute_local_energies(pars, ham_jax, samples)

Note

get_conn_padded() returns arrays with a fixed padding dimension for the connected elements, which makes them compatible with jax transformations. The padding is filled with zeros in the matrix elements array.