netket.jax.logsumexp_cplx#
- netket.jax.logsumexp_cplx(a, b=None, **kwargs)[source]#
Compute the log of the sum of exponentials of input elements, always returning a complex number.
Equivalent to, but more numerically stable than, np.log(np.sum(b*np.exp(a))). If the optional argument b is omitted, np.log(np.sum(np.exp(a))) is returned.
Wraps jax.scipy.special.logsumexp but uses return_sign=True if both a and b are real numbers in order to support b<0 instead of returning nan.
See the JAX function for details of the calling sequence; return_sign is not supported.