netket.jax.unchunk

Contents

netket.jax.unchunk#

netket.jax.unchunk(x_chunked)[source]#

Merge the first two axes of an array (or a pytree of arrays) :type x_chunked: :param x_chunked: an array (or pytree of arrays) of at least 2 dimensions

Returns: a pair (x, chunk_fn)

where x is x_chunked reshaped to (-1,)+x.shape[2:] and chunk_fn is a function which restores x given x_chunked