netket.jax.logsumexp_cplx

Contents

netket.jax.logsumexp_cplx#

netket.jax.logsumexp_cplx(a, b=None, **kwargs)[source]#

Compute the log of the sum of exponentials of input elements, always returning a complex number.

Equivalent to, but more numerically stable than, np.log(np.sum(b*np.exp(a))). If the optional argument b is omitted, np.log(np.sum(np.exp(a))) is returned.

Wraps jax.scipy.special.logsumexp but uses return_sign=True if both a and b are real numbers in order to support b<0 instead of returning nan.

See the JAX function for details of the calling sequence; return_sign is not supported.

Return type:

Array

Parameters: