netket.jax.vmap_chunked#
- netket.jax.vmap_chunked(f, in_axes=0, *, chunk_size, axis_0_is_sharded=None, pvary_argnums=None)[source]#
Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.
This function is essentially equivalent to:
nk.jax.apply_chunked(jax.vmap(f, in_axes), in_axes, chunk_size)
Some limitations to in_axes apply.
Note
If netket_sharding is enabled, this function assumes that chunked in_axes are sharded by default. This can be overridden by specifying axis_0_is_sharded=False.
- Parameters:
f (
Callable) – The function to be vectorised.in_axes – The axes that should be scanned along. Only supports 0 or None
chunk_size (
int|None) – The maximum size of the chunks to be used. If it is None, chunking is disabledaxis_0_is_sharded (
bool) – specifies if axis 0 of the arrays scanned is sharded among multiple devices, The function is then computed in chunks of size chunk_size on every device. Defaults True if config.netket_sharding, oterhwise defaults to False.pvary_argnums (
tuple[int,...] |None) – Explicit tuple of argument indices that should receive a pvary annotation inside the sharded chunking path. If None, all non-chunked arguments are marked as pvary, matching the historical behaviour. This does not change how those arguments are physically sharded: it only tells shard_map to treat them as varying rather than invariant. This is mainly useful for differentiated inputs like parameters, where treating them as invariant can trigger unwanted implicit psums.
- Return type:
- Returns:
A vectorised and chunked function