netket.jax.chunk#
- netket.jax.chunk(x, chunk_size=None)[source]#
Split an array (or a pytree of arrays) into chunks along the first axis
- Parameters:
x – an array (or pytree of arrays)
chunk_size – an integer or None (default) The first axis in x must be a multiple of chunk_size
- Returns: a pair (x_chunked, unchunk_fn) where
x_chunked is x reshaped to (-1, chunk_size)+x.shape[1:] if chunk_size is None then it defaults to x.shape[0], i.e. just one chunk
unchunk_fn is a function which restores x given x_chunked