netket.errors.RealQGTComplexDomainError

netket.errors.RealQGTComplexDomainError#

exception netket.errors.RealQGTComplexDomainError[source]#

This error is raised when you apply the Quantum Geometric Tensor of a non-holomorphic function to a complex-valued vector, because the operation is not well-defined.

As is explained in the documentation of the geomtric tensors, the QGT implementation for non-holomorphic functions corresponds to the real- part of the QGT, not the full QGT.

This is because in most applications of variational monte carlo, such as the Time-Dependent Variational Principle, Ground-state search or supervised-learning, you only require knowledge of the real part, and can safely discard the imaginary part that would incur in an increased computational cost.

While the product of the real part of QGT by a complex vector is well-defined, to prevent the common mistake of assuming that the QGT is complex we explicitly raise this error, forcing users to manually multiply the QGT by the real and imaginary part of the vector, as we would have to do inside of this class anyway.

If this is really the mathematical operation you want to perform, then you can do it manually, but very often we have found that when you apply the real part of the QGT to a complex vector you might have your math wrong. In such cases, sometimes what you actually wanted to do was

>>> import netket as nk; import jax
>>>
>>> vstate = nk.vqs.FullSumState(nk.hilbert.Spin(0.5, 5),                                         nk.models.RBM(param_dtype=complex))
>>> _, vec = vstate.expect_and_grad(nk.operator.spin.sigmax(vstate.hilbert, 1))
>>> G = nk.optimizer.qgt.QGTOnTheFly(vstate, holomorphic=False)
>>>
>>> vec_real = jax.tree.map(lambda x: x.real, vec)
>>> sol = G@vec_real

Or, if you used the QGT in a linear solver, try using:

>>> import netket as nk; import jax
>>>
>>> vstate = nk.vqs.FullSumState(nk.hilbert.Spin(0.5, 5),                                         nk.models.RBM(param_dtype=complex))
>>> _, vec = vstate.expect_and_grad(nk.operator.spin.sigmax(vstate.hilbert, 1))
>>>
>>> G = nk.optimizer.qgt.QGTOnTheFly(vstate, holomorphic=False)
>>> vec_real = jax.tree.map(lambda x: x.real, vec)
>>>
>>> linear_solver = jax.scipy.sparse.linalg.cg
>>> solution, info = G.solve(linear_solver, vec_real)