# Source code for netket.operator._local_operator

# Copyright 2021-2022 The NetKet Authors - All rights reserved.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from typing import Tuple, Union, List, Optional

import numbers

from textwrap import dedent

import numpy as np
import numba

from netket.hilbert import AbstractHilbert
from netket.utils.types import DType, Array

from ._discrete_operator import DiscreteOperator
from ._lazy import Transpose

from ._local_operator_helpers import (
_dtype,
canonicalize_input,
_multiply_operators,
cast_operator_matrix_dtype,
)
from ._local_operator_compile_helpers import pack_internals

def is_hermitian(a: np.ndarray, rtol=1e-05, atol=1e-08) -> bool:
return np.allclose(a, a.T.conj(), rtol=rtol, atol=atol)

def _is_sorted(a):
for i in range(len(a) - 1):
if a[i + 1] < a[i]:
return False
return True

class LocalOperator(DiscreteOperator):
"""A custom local operator. This is a sum of an arbitrary number of operators
acting locally on a limited set of k quantum numbers (i.e. k-local,
in the quantum information sense).
"""

[docs]    def __init__(
self,
hilbert: AbstractHilbert,
operators: Union[List[Array], Array] = [],
acting_on: Union[List[int], List[List[int]]] = [],
constant: numbers.Number = 0,
dtype: Optional[DType] = None,
):
r"""
Constructs a new LocalOperator given a hilbert space and (if
specified) a constant level shift.

Args:
hilbert (netket.AbstractHilbert): Hilbert space the operator acts on.
operators (list(numpy.array) or numpy.array): A list of operators, in matrix form.
acting_on (list(numpy.array) or numpy.array): A list of sites, which the corresponding operators act on.
constant (float): Level shift for operator. Default is 0.0.

Examples:
Constructs a LocalOperator without any operators.

>>> from netket.hilbert import CustomHilbert
>>> from netket.operator import LocalOperator
>>> hi = CustomHilbert(local_states=[-1, 1])**20
>>> empty_hat = LocalOperator(hi)
>>> print(len(empty_hat.acting_on))
0
"""
super().__init__(hilbert)
self.mel_cutoff = 1.0e-6
self._initialized = None

if not all(
[_is_sorted(hilbert.states_at_index(i)) for i in range(hilbert.size)]
):
raise ValueError(
dedent(
"""LocalOperator needs an hilbert space with sorted state values at
every site.
"""
)
)

# Canonicalize input. From now on input is guaranteed to be in canonical order
operators, acting_on, dtype = canonicalize_input(
self.hilbert, operators, acting_on, constant, dtype=dtype
)
self._dtype = dtype
self._constant = np.array(constant, dtype=dtype)

self._operators_dict = {}
for (op, aon) in zip(operators, acting_on):

def _add_operator(self, acting_on: Tuple, operator: Array):
"""
Adds an operator acting on a subset of sites.

Does not modify in-place the operators themselves which are treated as
immutables.
"""
assert isinstance(acting_on, tuple)
# acting_on_key = tuple(acting_on)
if acting_on in self._operators_dict:
operator = self._operators_dict[acting_on] + operator

self._operators_dict[acting_on] = operator

@property
def operators(self) -> List[np.ndarray]:
"""List of the matrices of the operators encoded in this Local Operator.
Returns a copy.
"""
return list(self._operators_dict.values())

@property
def _operators(self) -> List[np.ndarray]:
return self.operators

@property
def acting_on(self) -> List[List[int]]:
"""List containing the list of the sites on which every operator acts.

Every operator self.operators[i] acts on the sites self.acting_on[i]
"""
return list(self._operators_dict.keys())

@property
def n_operators(self) -> int:
return len(self._operators_dict)

@property
def dtype(self) -> DType:
return self._dtype

@property
def size(self) -> int:
return self._size

@property
# A way to cache the property depending on modifications of self._operators is described here:
# https://stackoverflow.com/questions/48262273/python-bookkeeping-dependencies-in-cached-attributes-that-might-change
def is_hermitian(self) -> bool:
"""Returns true if this operator is hermitian."""
# TODO: (VolodyaCO) I guess that if we have an operator with diagonal elements equal to 1j*C+Y, some complex constant, and
# self._constant=-1j*C, then the actual diagonal would be Y. How do we check hermiticity taking into account the diagonal
# elements as well as the self._constant? For the moment I just check hermiticity of the added constant, which must be real.
return all(map(is_hermitian, self.operators)) and np.isreal(self._constant)

@property
def mel_cutoff(self) -> float:
r"""float: The cutoff for matrix elements.
Only matrix elements such that abs(O(i,i))>mel_cutoff
are considered"""
return self._mel_cutoff

@mel_cutoff.setter
def mel_cutoff(self, mel_cutoff):
self._mel_cutoff = mel_cutoff
assert self.mel_cutoff >= 0

@property
def constant(self) -> numbers.Number:
return self._constant

[docs]    def copy(self, *, dtype: Optional[DType] = None):
"""Returns a copy of the operator, while optionally changing the dtype
of the operator.

Args:
dtype: optional dtype
"""

if dtype is None:
dtype = self.dtype

if not np.can_cast(self.dtype, dtype, casting="same_kind"):
raise ValueError(f"Cannot cast {self.dtype} to {dtype}")

new = LocalOperator(self.hilbert, constant=self.constant, dtype=dtype)
new.mel_cutoff = self.mel_cutoff

if dtype == self.dtype:
new._operators_dict = self._operators_dict.copy()
else:
new._operators_dict = {
aon: cast_operator_matrix_dtype(op, dtype)
for aon, op in self._operators_dict.items()
}

return new

[docs]    def transpose(self, *, concrete=False):
r"""LocalOperator: Returns the transpose of this operator."""
if concrete:
new = self.copy()
for aon in new._operators_dict.keys():
new._operators_dict[aon] = new._operators_dict[aon].transpose()
return new
else:
return Transpose(self)

[docs]    def conjugate(self, *, concrete=False):
r"""LocalOperator: Returns the complex conjugate of this operator."""
new = self.copy()
for aon in new._operators_dict.keys():
new._operators_dict[aon] = new._operators_dict[aon].copy().conjugate()
return new

def __sub__(self, other):
return self + (-other)

def __rsub__(self, other):
return other + (-self)

def __isub__(self, other):

def __neg__(self):
return -1 * self

op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op

if isinstance(other, LocalOperator):
if self.hilbert != other.hilbert:
return NotImplemented

if not np.can_cast(other.dtype, self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {other.dtype} "
f"to operator with dtype {self.dtype}"
)

assert other.mel_cutoff == self.mel_cutoff
self._constant += other.constant.item()
for (aon, op) in other._operators_dict.items():

self._reset_caches()
return self
if isinstance(other, numbers.Number):
if not np.can_cast(type(other), self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {type(other)} "
f"to operator with dtype {self.dtype}"
)

self._reset_caches()
self._constant += other
return self

return NotImplemented

def __truediv__(self, other):
if not isinstance(other, numbers.Number):
raise TypeError("Only division by a scalar number is supported.")

if other == 0:
raise ValueError("Dividing by 0")
return self.__mul__(1.0 / other)

def __rmul__(self, other):
return self.__mul__(other)

def __mul__(self, other):
if isinstance(other, DiscreteOperator):
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op.__imatmul__(other)
elif isinstance(other, numbers.Number):
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op.__imul__(other)
return NotImplemented

def __imul__(self, other):
if isinstance(other, DiscreteOperator):
return self.__imatmul__(other)
elif isinstance(other, numbers.Number):
if not np.can_cast(type(other), self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {type(other)} "
f"to operator with dtype {self.dtype}"
)

self._constant *= other
if np.abs(other) <= self.mel_cutoff:
self._operators_dict = {}
else:
for key in self._operators_dict:
self._operators_dict[key] = other * self._operators_dict[key]

self._reset_caches()
return self

return NotImplemented

def __imatmul__(self, other):
if not isinstance(other, LocalOperator):
return NotImplemented

if not np.can_cast(other.dtype, self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {type(other)} to operator with dtype {self.dtype}"
)

return self._op_imatmul_(other)

def _op__matmul__(self, other: "LocalOperator") -> "LocalOperator":
if not isinstance(other, LocalOperator):
return NotImplemented
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op._op_imatmul_(other)

def _op_imatmul_(self, other: "LocalOperator") -> "LocalOperator":
if not isinstance(other, LocalOperator):
return NotImplemented

# (Î± + âˆ‘áµ¢Aáµ¢)(Î² + âˆ‘áµ¢Báµ¢) =
# = Î±Î² + Î± âˆ‘áµ¢Báµ¢ + Î² âˆ‘áµ¢Aáµ¢ + âˆ‘áµ¢â±¼Aáµ¢Bâ±¼
# = Î²(Î± + âˆ‘áµ¢Aáµ¢) + Î± âˆ‘áµ¢Báµ¢ + âˆ‘áµ¢â±¼Aáµ¢Bâ±¼

Î± = self.constant.item()
Î² = other.constant.item()
# copy A dict because it is modified inplace in __imul__(Î²) and add_operators
A_op_dict = self._operators_dict.copy()
B_op_dict = other._operators_dict

# Î±Î² + Î² âˆ‘áµ¢Aáµ¢
self.__imul__(Î²)

# Î± âˆ‘áµ¢Báµ¢
if np.abs(Î±) > self.mel_cutoff:
for aon, op in B_op_dict.items():

# âˆ‘áµ¢â±¼Aáµ¢Bâ±¼
for supp_A_i, A_i in A_op_dict.items():
for supp_B_j, B_j in B_op_dict.items():
*_multiply_operators(
self.hilbert, supp_A_i, A_i, supp_B_j, B_j, dtype=self.dtype
)
)

self._reset_caches()
return self

def _reset_caches(self):
"""
Cleans the internal caches built on the operator.
"""
self._initialized = False

def _setup(self, force: bool = False):
"""Analyze the operator strings and precompute arrays for get_conn inference"""
if force or not self._initialized:
data = pack_internals(
self.hilbert,
self._operators_dict,
self.constant,
self.dtype,
self.mel_cutoff,
)

self._acting_on = data["acting_on"]
self._acting_size = data["acting_size"]
self._diag_mels = data["diag_mels"]
self._mels = data["mels"]
self._x_prime = data["x_prime"]
self._n_conns = data["n_conns"]
self._local_states = data["local_states"]
self._basis = data["basis"]
self._nonzero_diagonal = data["nonzero_diagonal"]
self._max_conn_size = data["max_conn_size"]

@property
def max_conn_size(self) -> int:
"""The maximum number of non zero âŸ¨x|O|x'âŸ© for every x."""
self._setup()
return self._max_conn_size

[docs]    def get_conn_flattened(self, x, sections, pad=False):
r"""Finds the connected elements of the Operator. Starting
from a given quantum number x, it finds all other quantum numbers x' such
that the matrix element :math:O(x,x') is different from zero. In general there
will be several different connected states x' satisfying this
condition, and they are denoted here :math:x'(k), for :math:k=0,1...N_{\mathrm{connected}}.

This is a batched version, where x is a matrix of shape (batch_size,hilbert.size).

Args:
x (matrix): A matrix of shape (batch_size,hilbert.size) containing
the batch of quantum numbers x.
sections (array): An array of size (batch_size) useful to unflatten
the output of this function.
See numpy.split for the meaning of sections.
pad (bool): Whether to use zero-valued matrix elements in order to return all equal sections.

Returns:
matrix: The connected states x', flattened together in a single matrix.
array: An array containing the matrix elements :math:O(x,x') associated to each x'.

"""
self._setup()

return self._get_conn_flattened_kernel(
np.asarray(x),
sections,
self._local_states,
self._basis,
self._constant,
self._diag_mels,
self._n_conns,
self._mels,
self._x_prime,
self._acting_on,
self._acting_size,
self._nonzero_diagonal,
)

def _get_conn_flattened_closure(self):
self._setup()
_local_states = self._local_states
_basis = self._basis
_constant = self._constant
_diag_mels = self._diag_mels
_n_conns = self._n_conns
_mels = self._mels
_x_prime = self._x_prime
_acting_on = self._acting_on
_acting_size = self._acting_size
# workaround my painfully discovered Numba#6979 (cannot use numpy bools in closures)
_nonzero_diagonal = bool(self._nonzero_diagonal)

fun = self._get_conn_flattened_kernel

def gccf_fun(x, sections):
return fun(
x,
sections,
_local_states,
_basis,
_constant,
_diag_mels,
_n_conns,
_mels,
_x_prime,
_acting_on,
_acting_size,
_nonzero_diagonal,
)

return numba.jit(nopython=True)(gccf_fun)

@staticmethod
@numba.jit(nopython=True)
def _get_conn_flattened_kernel(
x,
sections,
local_states,
basis,
constant,
diag_mels,
n_conns,
all_mels,
all_x_prime,
acting_on,
acting_size,
nonzero_diagonal,
):
batch_size = x.shape[0]
n_sites = x.shape[1]
dtype = all_mels.dtype

# TODO remove this line when numba 0.53 is dropped 0.54 is minimum version
# workaround a bug in Numba arising when NUMBA_BOUNDSCHECK=1
constant = constant.item()

assert sections.shape[0] == batch_size

n_operators = n_conns.shape[0]
xs_n = np.empty((batch_size, n_operators), dtype=np.intp)

tot_conn = 0
max_conn = 0

for b in range(batch_size):
# diagonal element
conn_b = 1 if nonzero_diagonal else 0

# counting the off-diagonal elements
for i in range(n_operators):
acting_size_i = acting_size[i]

xs_n[b, i] = 0
x_b = x[b]
x_i = x_b[acting_on[i, :acting_size_i]]
for k in range(acting_size_i):
xs_n[b, i] += (
np.searchsorted(
local_states[i, acting_size_i - k - 1],
x_i[acting_size_i - k - 1],
)
* basis[i, k]
)

conn_b += n_conns[i, xs_n[b, i]]

tot_conn += conn_b
sections[b] = tot_conn

max_conn = max(conn_b, max_conn)

tot_conn = batch_size * max_conn

x_prime = np.empty((tot_conn, n_sites), dtype=x.dtype)
mels = np.empty(tot_conn, dtype=dtype)

c = 0
for b in range(batch_size):
c_diag = c
x_batch = x[b]

if nonzero_diagonal:
mels[c_diag] = constant
x_prime[c_diag] = np.copy(x_batch)
c += 1

for i in range(n_operators):
if nonzero_diagonal:
mels[c_diag] += diag_mels[i, xs_n[b, i]]

n_conn_i = n_conns[i, xs_n[b, i]]

if n_conn_i > 0:
sites = acting_on[i]
acting_size_i = acting_size[i]

for cc in range(n_conn_i):
mels[c + cc] = all_mels[i, xs_n[b, i], cc]
x_prime[c + cc] = np.copy(x_batch)

for k in range(acting_size_i):
x_prime[c + cc, sites[k]] = all_x_prime[
i, xs_n[b, i], cc, k
]
c += n_conn_i

delta_conn = max_conn - (c - c_diag)
mels[c : c + delta_conn].fill(0)
x_prime[c : c + delta_conn, :] = np.copy(x_batch)
c += delta_conn
sections[b] = c

return x_prime, mels

[docs]    def get_conn_filtered(self, x, sections, filters):
r"""Finds the connected elements of the Operator using only a subset of operators. Starting
from a given quantum number x, it finds all other quantum numbers x' such
that the matrix element :math:O(x,x') is different from zero. In general there
will be several different connected states x' satisfying this
condition, and they are denoted here :math:x'(k), for :math:k=0,1...N_{\mathrm{connected}}.

This is a batched version, where x is a matrix of shape (batch_size,hilbert.size).

Args:
x (matrix): A matrix of shape (batch_size,hilbert.size) containing
the batch of quantum numbers x.
sections (array): An array of size (batch_size) useful to unflatten
the output of this function.
See numpy.split for the meaning of sections.
filters (array): Only operators op(filters[i]) are used to find the connected elements of
x[i].

Returns:
matrix: The connected states x', flattened together in a single matrix.
array: An array containing the matrix elements :math:O(x,x') associated to each x'.

"""
self._setup()
return self._get_conn_filtered_kernel(
x,
sections,
self._local_states,
self._basis,
self._constant,
self._diag_mels,
self._n_conns,
self._mels,
self._x_prime,
self._acting_on,
self._acting_size,
filters,
)

@staticmethod
@numba.jit(nopython=True)
def _get_conn_filtered_kernel(
x,
sections,
local_states,
basis,
constant,
diag_mels,
n_conns,
all_mels,
all_x_prime,
acting_on,
acting_size,
filters,
):

batch_size = x.shape[0]
n_sites = x.shape[1]
dtype = all_mels.dtype

assert filters.shape[0] == batch_size and sections.shape[0] == batch_size

# TODO remove this line when numba 0.53 is dropped 0.54 is minimum version
# workaround a bug in Numba arising when NUMBA_BOUNDSCHECK=1
constant = constant.item()

n_operators = n_conns.shape[0]
xs_n = np.empty((batch_size, n_operators), dtype=np.intp)

tot_conn = 0

for b in range(batch_size):
# diagonal element
tot_conn += 1

# counting the off-diagonal elements
i = filters[b]

assert i < n_operators and i >= 0
acting_size_i = acting_size[i]

xs_n[b, i] = 0
x_b = x[b]
x_i = x_b[acting_on[i, :acting_size_i]]
for k in range(acting_size_i):
xs_n[b, i] += (
np.searchsorted(
local_states[i, acting_size_i - k - 1],
x_i[acting_size_i - k - 1],
)
* basis[i, k]
)

tot_conn += n_conns[i, xs_n[b, i]]
sections[b] = tot_conn

x_prime = np.empty((tot_conn, n_sites))
mels = np.empty(tot_conn, dtype=dtype)

c = 0
for b in range(batch_size):
c_diag = c
mels[c_diag] = constant
x_batch = x[b]
x_prime[c_diag] = np.copy(x_batch)
c += 1

i = filters[b]
# Diagonal part
mels[c_diag] += diag_mels[i, xs_n[b, i]]
n_conn_i = n_conns[i, xs_n[b, i]]

if n_conn_i > 0:
sites = acting_on[i]
acting_size_i = acting_size[i]

for cc in range(n_conn_i):
mels[c + cc] = all_mels[i, xs_n[b, i], cc]
x_prime[c + cc] = np.copy(x_batch)

for k in range(acting_size_i):
x_prime[c + cc, sites[k]] = all_x_prime[i, xs_n[b, i], cc, k]
c += n_conn_i

return x_prime, mels

def __repr__(self):
ao = self.acting_on

acting_str = f"acting_on={ao}"
if len(acting_str) > 55:
acting_str = f"#acting_on={len(ao)} locations"
return f"{type(self).__name__}(dim={self.hilbert.size}, {acting_str}, constant={self.constant}, dtype={self.dtype})"