netket.jax.apply_chunked
netket.jax.apply_chunked#
- netket.jax.apply_chunked(f, in_axes=0, *, chunk_size)[source]#
Takes an implicitly vmapped function over the axis 0 and uses scan to do the computations in smaller chunks over the 0-th axis of all input arguments.
For this to work, the function f should be vectorized along the in_axes of the arguments. This means that the function f should respect the following condition:
assert f(x) == jnp.concatenate([f(x_i) for x_i in x], axis=0)`
which is automatically satisfied if f is obtained by vmapping a function, such as:
f = jax.vmap(f_orig)