netket.jax.apply_chunked

Contents

netket.jax.apply_chunked#

netket.jax.apply_chunked(f, in_axes=0, *, chunk_size, axis_0_is_sharded=None, pvary_argnums=None)[source]#

Takes an implicitly vmapped function over the axis 0 and uses scan to do the computations in smaller chunks over the 0-th axis of all input arguments.

For this to work, the function f should be vectorized along the in_axes of the arguments. This means that the function f should respect the following condition:

assert f(x) == jnp.concatenate([f(x_i) for x_i in x], axis=0)

which is automatically satisfied if f is obtained by vmapping a function, such as:

f = jax.vmap(f_orig)

Note

If netket_experimental_sharding is enabled, this function assumes that chunked in_axes are sharded by default. This can be overridden by specifying axis_0_is_sharded=False.

Parameters:
  • f (Callable) – A function that satisfies the condition above

  • in_axes – The axes that should be scanned along. Only supports 0 or None

  • chunk_size (int | None) – The maximum size of the chunks to be used. If it is None, chunking is disabled

  • axis_0_is_sharded (bool) – specifies if axis 0 of the arrays scanned is sharded among multiple devices, The function is then computed in chunks of size chunk_size on every device. Defaults True if config.netket_experimental_sharding, oterhwise defaults to False.

  • pvary_argnums (tuple[int, ...] | None) – Explicit tuple of argument indices that should receive a pvary annotation inside the sharded chunking path. If None, all non-chunked arguments are marked as pvary, matching the historical behaviour. This does not change how those arguments are physically sharded: it only tells shard_map to treat them as varying rather than invariant. This is mainly useful for differentiated inputs like parameters, where treating them as invariant can trigger unwanted implicit psums.

Return type:

Callable