netket.jax.vmap_chunked
netket.jax.vmap_chunked#
- netket.jax.vmap_chunked(f, in_axes=0, *, chunk_size)[source]#
Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.
This function is essentially equivalent to:
nk.jax.apply_chunked(jax.vmap(f, in_axes), in_axes, chunk_size)
Some limitations to in_axes apply.
- Parameters
- Return type
- Returns
A vectorised and chunked function