netket.jax.unchunk#
- netket.jax.unchunk(x_chunked)[source]#
Merge the first two axes of an array (or a pytree of arrays) :type x_chunked: :param x_chunked: an array (or pytree of arrays) of at least 2 dimensions
- Returns: a pair (x, chunk_fn)
where x is x_chunked reshaped to (-1,)+x.shape[2:] and chunk_fn is a function which restores x given x_chunked