netket.jax.chunk

Contents

netket.jax.chunk#

netket.jax.chunk(x, chunk_size=None)[source]#

Split an array (or a pytree of arrays) into chunks along the first axis

Parameters:
  • x – an array (or pytree of arrays)

  • chunk_size – an integer or None (default) The first axis in x must be a multiple of chunk_size

Returns: a pair (x_chunked, unchunk_fn) where
  • x_chunked is x reshaped to (-1, chunk_size)+x.shape[1:] if chunk_size is None then it defaults to x.shape[0], i.e. just one chunk

  • unchunk_fn is a function which restores x given x_chunked