# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Optional, TYPE_CHECKING
import jax
import numpy as np
from numba import jit
from netket.graph import AbstractGraph
from netket.hilbert import Spin
from netket.utils.types import DType
from netket.errors import concrete_or_error, NumbaOperatorGetConnDuringTracingError
from .base import IsingBase
if TYPE_CHECKING:
from .jax import IsingJax
[docs]
class Ising(IsingBase):
r"""
The Transverse-Field Ising Hamiltonian :math:`-h\sum_i \sigma_i^{(x)} +J\sum_{\langle i,j\rangle} \sigma_i^{(z)}\sigma_j^{(z)}`.
This implementation is considerably faster than the Ising hamiltonian constructed by summing :class:`~netket.operator.LocalOperator` s.
"""
@wraps(IsingBase.__init__)
def __init__(
self,
hilbert: Spin,
graph: AbstractGraph,
h: float,
J: float = 1.0,
dtype: Optional[DType] = None,
):
r"""
Constructs the Ising Operator from an hilbert space and a
graph specifying the connectivity.
Args:
hilbert: Hilbert space the operator acts on.
h: The strength of the transverse field.
J: The strength of the coupling. Default is 1.0.
dtype: The dtype of the matrix elements.
Examples:
Constructs an ``Ising`` operator for a 1D system.
>>> import netket as nk
>>> g = nk.graph.Hypercube(length=20, n_dim=1, pbc=True)
>>> hi = nk.hilbert.Spin(s=0.5, N=g.n_nodes)
>>> op = nk.operator.Ising(h=1.321, hilbert=hi, J=0.5, graph=g)
>>> print(op)
Ising(J=0.5, h=1.321; dim=20)
"""
if not isinstance(hilbert, Spin):
raise TypeError(
"""The Hilbert space used by Ising must be a `Spin-1/2` space.
This limitation could be lifted by 'fixing' the method
`_flattened_kernel` to work with arbitrary hilbert spaces, which
should be relatively straightforward to do, but we have not done so
yet.
In the meantime, you can just use `nk.operator.IsingJax` as a
workaround.
"""
)
if len(hilbert.local_states) != 2:
raise ValueError("Ising only supports Spin-1/2 hilbert spaces.")
h = np.array(h, dtype=dtype)
J = np.array(J, dtype=dtype)
if isinstance(graph, jax.Array):
graph = np.asarray(graph)
super().__init__(hilbert, graph=graph, h=h, J=J, dtype=dtype)
[docs]
def to_jax_operator(self) -> "IsingJax": # noqa: F821
"""
Returns the jax-compatible version of this operator, which is an
instance of :class:`netket.operator.IsingJax`.
"""
from .jax import IsingJax
return IsingJax(
self.hilbert, graph=self.edges, h=self.h, J=self.J, dtype=self.dtype
)
@staticmethod
@jit(nopython=True)
def _flattened_kernel(x, sections, edges, h, J): # pragma: no cover
n_sites = x.shape[1]
n_conn = n_sites + 1
x_prime = np.empty((x.shape[0] * n_conn, n_sites), dtype=x.dtype)
mels = np.empty(x.shape[0] * n_conn, dtype=h.dtype)
diag_ind = 0
for i in range(x.shape[0]):
mels[diag_ind] = 0.0
for k in range(edges.shape[0]):
mels[diag_ind] += J * x[i, edges[k, 0]] * x[i, edges[k, 1]]
odiag_ind = 1 + diag_ind
mels[odiag_ind : (odiag_ind + n_sites)].fill(-h)
x_prime[diag_ind : (diag_ind + n_conn)] = np.copy(x[i])
for j in range(n_sites):
x_prime[j + odiag_ind][j] *= -1.0
diag_ind += n_conn
sections[i] = diag_ind
return x_prime, mels
[docs]
def get_conn_flattened(self, x, sections, pad=False):
r"""Finds the connected elements of the Operator. Starting
from a given quantum number x, it finds all other quantum numbers x' such
that the matrix element :math:`O(x,x')` is different from zero. In general there
will be several different connected states x' satisfying this
condition, and they are denoted here :math:`x'(k)`, for
:math:`k=0,1...N_{\mathrm{connected}}`.
This is a batched version, where x is a matrix of shape (batch_size,hilbert.size).
Args:
x (matrix): A matrix of shape (batch_size,hilbert.size) containing
the batch of quantum numbers x.
sections (array): An array of size (batch_size) useful to unflatten
the output of this function.
See numpy.split for the meaning of sections.
pad (bool): no effect here
Returns:
matrix: The connected states x', flattened together in a single matrix.
array: An array containing the matrix elements :math:`O(x,x')` associated to each x'.
"""
x = concrete_or_error(
np.asarray,
x,
NumbaOperatorGetConnDuringTracingError,
self,
)
return self._flattened_kernel(x, sections, self.edges, self._h, self._J)
def _get_conn_flattened_closure(self):
_edges = self._edges
_h = self._h
_J = self._J
fun = self._flattened_kernel
def gccf_fun(x, sections): # pragma: no cover
return fun(x, sections, _edges, _h, _J)
return jit(nopython=True)(gccf_fun)