netket.jax.vmap_chunked#

netket.jax.vmap_chunked(f, in_axes=0, *, chunk_size)[source]#

Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.

This function is essentially equivalent to:

nk.jax.apply_chunked(jax.vmap(f, in_axes), in_axes, chunk_size)

Some limitations to in_axes apply.

Parameters
  • f (Callable) – The function to be vectorised.

  • in_axes – The axes that should be scanned along. Only supports 0 or None

  • chunk_size (Optional[int]) – The maximum size of the chunks to be used. If it is None, chunking is disabled

Return type

Callable

Returns

A vectorised and chunked function