# Copyright 2021-2022 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union, List, Optional
import numbers
from textwrap import dedent
import numpy as np
import numba
from netket.hilbert import AbstractHilbert
from netket.utils.types import DType, Array
from ._discrete_operator import DiscreteOperator
from ._lazy import Transpose
from ._local_operator_helpers import (
_dtype,
canonicalize_input,
_multiply_operators,
cast_operator_matrix_dtype,
)
from ._local_operator_compile_helpers import pack_internals
def is_hermitian(a: np.ndarray, rtol=1e-05, atol=1e-08) -> bool:
return np.allclose(a, a.T.conj(), rtol=rtol, atol=atol)
def _is_sorted(a):
for i in range(len(a) - 1):
if a[i + 1] < a[i]:
return False
return True
class LocalOperator(DiscreteOperator):
"""A custom local operator. This is a sum of an arbitrary number of operators
acting locally on a limited set of k quantum numbers (i.e. k-local,
in the quantum information sense).
"""
[docs] def __init__(
self,
hilbert: AbstractHilbert,
operators: Union[List[Array], Array] = [],
acting_on: Union[List[int], List[List[int]]] = [],
constant: numbers.Number = 0,
dtype: Optional[DType] = None,
):
r"""
Constructs a new ``LocalOperator`` given a hilbert space and (if
specified) a constant level shift.
Args:
hilbert (netket.AbstractHilbert): Hilbert space the operator acts on.
operators (list(numpy.array) or numpy.array): A list of operators, in matrix form.
acting_on (list(numpy.array) or numpy.array): A list of sites, which the corresponding operators act on.
constant (float): Level shift for operator. Default is 0.0.
Examples:
Constructs a ``LocalOperator`` without any operators.
>>> from netket.hilbert import CustomHilbert
>>> from netket.operator import LocalOperator
>>> hi = CustomHilbert(local_states=[-1, 1])**20
>>> empty_hat = LocalOperator(hi)
>>> print(len(empty_hat.acting_on))
0
"""
super().__init__(hilbert)
self.mel_cutoff = 1.0e-6
self._initialized = None
if not all(
[_is_sorted(hilbert.states_at_index(i)) for i in range(hilbert.size)]
):
raise ValueError(
dedent(
"""LocalOperator needs an hilbert space with sorted state values at
every site.
"""
)
)
# Canonicalize input. From now on input is guaranteed to be in canonical order
operators, acting_on, dtype = canonicalize_input(
self.hilbert, operators, acting_on, constant, dtype=dtype
)
self._dtype = dtype
self._constant = np.array(constant, dtype=dtype)
self._operators_dict = {}
for (op, aon) in zip(operators, acting_on):
self._add_operator(aon, op)
def _add_operator(self, acting_on: Tuple, operator: Array):
"""
Adds an operator acting on a subset of sites.
Does not modify in-place the operators themselves which are treated as
immutables.
"""
assert isinstance(acting_on, tuple)
# acting_on_key = tuple(acting_on)
if acting_on in self._operators_dict:
operator = self._operators_dict[acting_on] + operator
self._operators_dict[acting_on] = operator
@property
def operators(self) -> List[np.ndarray]:
"""List of the matrices of the operators encoded in this Local Operator.
Returns a copy.
"""
return list(self._operators_dict.values())
@property
def _operators(self) -> List[np.ndarray]:
return self.operators
@property
def acting_on(self) -> List[List[int]]:
"""List containing the list of the sites on which every operator acts.
Every operator `self.operators[i]` acts on the sites `self.acting_on[i]`
"""
return list(self._operators_dict.keys())
@property
def n_operators(self) -> int:
return len(self._operators_dict)
@property
def dtype(self) -> DType:
return self._dtype
@property
def size(self) -> int:
return self._size
@property
# A way to cache the property depending on modifications of self._operators is described here:
# https://stackoverflow.com/questions/48262273/python-bookkeeping-dependencies-in-cached-attributes-that-might-change
def is_hermitian(self) -> bool:
"""Returns true if this operator is hermitian."""
# TODO: (VolodyaCO) I guess that if we have an operator with diagonal elements equal to 1j*C+Y, some complex constant, and
# self._constant=-1j*C, then the actual diagonal would be Y. How do we check hermiticity taking into account the diagonal
# elements as well as the self._constant? For the moment I just check hermiticity of the added constant, which must be real.
return all(map(is_hermitian, self.operators)) and np.isreal(self._constant)
@property
def mel_cutoff(self) -> float:
r"""float: The cutoff for matrix elements.
Only matrix elements such that abs(O(i,i))>mel_cutoff
are considered"""
return self._mel_cutoff
@mel_cutoff.setter
def mel_cutoff(self, mel_cutoff):
self._mel_cutoff = mel_cutoff
assert self.mel_cutoff >= 0
@property
def constant(self) -> numbers.Number:
return self._constant
[docs] def copy(self, *, dtype: Optional[DType] = None):
"""Returns a copy of the operator, while optionally changing the dtype
of the operator.
Args:
dtype: optional dtype
"""
if dtype is None:
dtype = self.dtype
if not np.can_cast(self.dtype, dtype, casting="same_kind"):
raise ValueError(f"Cannot cast {self.dtype} to {dtype}")
new = LocalOperator(self.hilbert, constant=self.constant, dtype=dtype)
new.mel_cutoff = self.mel_cutoff
if dtype == self.dtype:
new._operators_dict = self._operators_dict.copy()
else:
new._operators_dict = {
aon: cast_operator_matrix_dtype(op, dtype)
for aon, op in self._operators_dict.items()
}
return new
[docs] def transpose(self, *, concrete=False):
r"""LocalOperator: Returns the transpose of this operator."""
if concrete:
new = self.copy()
for aon in new._operators_dict.keys():
new._operators_dict[aon] = new._operators_dict[aon].transpose()
return new
else:
return Transpose(self)
[docs] def conjugate(self, *, concrete=False):
r"""LocalOperator: Returns the complex conjugate of this operator."""
new = self.copy()
for aon in new._operators_dict.keys():
new._operators_dict[aon] = new._operators_dict[aon].copy().conjugate()
return new
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self + (-other)
def __rsub__(self, other):
return other + (-self)
def __isub__(self, other):
return self.__iadd__(-other)
def __neg__(self):
return -1 * self
def __add__(self, other: Union["LocalOperator", numbers.Number]):
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
op = op.__iadd__(other)
return op
def __iadd__(self, other):
if isinstance(other, LocalOperator):
if self.hilbert != other.hilbert:
return NotImplemented
if not np.can_cast(other.dtype, self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {other.dtype} "
f"to operator with dtype {self.dtype}"
)
assert other.mel_cutoff == self.mel_cutoff
self._constant += other.constant.item()
for (aon, op) in other._operators_dict.items():
self._add_operator(aon, op)
self._reset_caches()
return self
if isinstance(other, numbers.Number):
if not np.can_cast(type(other), self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {type(other)} "
f"to operator with dtype {self.dtype}"
)
self._reset_caches()
self._constant += other
return self
return NotImplemented
def __truediv__(self, other):
if not isinstance(other, numbers.Number):
raise TypeError("Only division by a scalar number is supported.")
if other == 0:
raise ValueError("Dividing by 0")
return self.__mul__(1.0 / other)
def __rmul__(self, other):
return self.__mul__(other)
def __mul__(self, other):
if isinstance(other, DiscreteOperator):
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op.__imatmul__(other)
elif isinstance(other, numbers.Number):
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op.__imul__(other)
return NotImplemented
def __imul__(self, other):
if isinstance(other, DiscreteOperator):
return self.__imatmul__(other)
elif isinstance(other, numbers.Number):
if not np.can_cast(type(other), self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {type(other)} "
f"to operator with dtype {self.dtype}"
)
self._constant *= other
if np.abs(other) <= self.mel_cutoff:
self._operators_dict = {}
else:
for key in self._operators_dict:
self._operators_dict[key] = other * self._operators_dict[key]
self._reset_caches()
return self
return NotImplemented
def __imatmul__(self, other):
if not isinstance(other, LocalOperator):
return NotImplemented
if not np.can_cast(other.dtype, self.dtype, casting="same_kind"):
raise ValueError(
f"Cannot add inplace operator with dtype {type(other)} to operator with dtype {self.dtype}"
)
return self._op_imatmul_(other)
def _op__matmul__(self, other: "LocalOperator") -> "LocalOperator":
if not isinstance(other, LocalOperator):
return NotImplemented
op = self.copy(dtype=np.promote_types(self.dtype, _dtype(other)))
return op._op_imatmul_(other)
def _op_imatmul_(self, other: "LocalOperator") -> "LocalOperator":
if not isinstance(other, LocalOperator):
return NotImplemented
# (α + ∑ᵢAᵢ)(β + ∑ᵢBᵢ) =
# = αβ + α ∑ᵢBᵢ + β ∑ᵢAᵢ + ∑ᵢⱼAᵢBⱼ
# = β(α + ∑ᵢAᵢ) + α ∑ᵢBᵢ + ∑ᵢⱼAᵢBⱼ
α = self.constant.item()
β = other.constant.item()
# copy A dict because it is modified inplace in __imul__(β) and add_operators
A_op_dict = self._operators_dict.copy()
B_op_dict = other._operators_dict
# αβ + β ∑ᵢAᵢ
self.__imul__(β)
# α ∑ᵢBᵢ
if np.abs(α) > self.mel_cutoff:
for aon, op in B_op_dict.items():
self._add_operator(aon, α * op)
# ∑ᵢⱼAᵢBⱼ
for supp_A_i, A_i in A_op_dict.items():
for supp_B_j, B_j in B_op_dict.items():
self._add_operator(
*_multiply_operators(
self.hilbert, supp_A_i, A_i, supp_B_j, B_j, dtype=self.dtype
)
)
self._reset_caches()
return self
def _reset_caches(self):
"""
Cleans the internal caches built on the operator.
"""
self._initialized = False
def _setup(self, force: bool = False):
"""Analyze the operator strings and precompute arrays for get_conn inference"""
if force or not self._initialized:
data = pack_internals(
self.hilbert,
self._operators_dict,
self.constant,
self.dtype,
self.mel_cutoff,
)
self._acting_on = data["acting_on"]
self._acting_size = data["acting_size"]
self._diag_mels = data["diag_mels"]
self._mels = data["mels"]
self._x_prime = data["x_prime"]
self._n_conns = data["n_conns"]
self._local_states = data["local_states"]
self._basis = data["basis"]
self._nonzero_diagonal = data["nonzero_diagonal"]
self._max_conn_size = data["max_conn_size"]
@property
def max_conn_size(self) -> int:
"""The maximum number of non zero ⟨x|O|x'⟩ for every x."""
self._setup()
return self._max_conn_size
[docs] def get_conn_flattened(self, x, sections, pad=False):
r"""Finds the connected elements of the Operator. Starting
from a given quantum number x, it finds all other quantum numbers x' such
that the matrix element :math:`O(x,x')` is different from zero. In general there
will be several different connected states x' satisfying this
condition, and they are denoted here :math:`x'(k)`, for :math:`k=0,1...N_{\mathrm{connected}}`.
This is a batched version, where x is a matrix of shape (batch_size,hilbert.size).
Args:
x (matrix): A matrix of shape (batch_size,hilbert.size) containing
the batch of quantum numbers x.
sections (array): An array of size (batch_size) useful to unflatten
the output of this function.
See numpy.split for the meaning of sections.
pad (bool): Whether to use zero-valued matrix elements in order to return all equal sections.
Returns:
matrix: The connected states x', flattened together in a single matrix.
array: An array containing the matrix elements :math:`O(x,x')` associated to each x'.
"""
self._setup()
return self._get_conn_flattened_kernel(
np.asarray(x),
sections,
self._local_states,
self._basis,
self._constant,
self._diag_mels,
self._n_conns,
self._mels,
self._x_prime,
self._acting_on,
self._acting_size,
self._nonzero_diagonal,
pad,
)
def _get_conn_flattened_closure(self):
self._setup()
_local_states = self._local_states
_basis = self._basis
_constant = self._constant
_diag_mels = self._diag_mels
_n_conns = self._n_conns
_mels = self._mels
_x_prime = self._x_prime
_acting_on = self._acting_on
_acting_size = self._acting_size
# workaround my painfully discovered Numba#6979 (cannot use numpy bools in closures)
_nonzero_diagonal = bool(self._nonzero_diagonal)
fun = self._get_conn_flattened_kernel
def gccf_fun(x, sections):
return fun(
x,
sections,
_local_states,
_basis,
_constant,
_diag_mels,
_n_conns,
_mels,
_x_prime,
_acting_on,
_acting_size,
_nonzero_diagonal,
)
return numba.jit(nopython=True)(gccf_fun)
@staticmethod
@numba.jit(nopython=True)
def _get_conn_flattened_kernel(
x,
sections,
local_states,
basis,
constant,
diag_mels,
n_conns,
all_mels,
all_x_prime,
acting_on,
acting_size,
nonzero_diagonal,
pad=False,
):
batch_size = x.shape[0]
n_sites = x.shape[1]
dtype = all_mels.dtype
# TODO remove this line when numba 0.53 is dropped 0.54 is minimum version
# workaround a bug in Numba arising when NUMBA_BOUNDSCHECK=1
constant = constant.item()
assert sections.shape[0] == batch_size
n_operators = n_conns.shape[0]
xs_n = np.empty((batch_size, n_operators), dtype=np.intp)
tot_conn = 0
max_conn = 0
for b in range(batch_size):
# diagonal element
conn_b = 1 if nonzero_diagonal else 0
# counting the off-diagonal elements
for i in range(n_operators):
acting_size_i = acting_size[i]
xs_n[b, i] = 0
x_b = x[b]
x_i = x_b[acting_on[i, :acting_size_i]]
for k in range(acting_size_i):
xs_n[b, i] += (
np.searchsorted(
local_states[i, acting_size_i - k - 1],
x_i[acting_size_i - k - 1],
)
* basis[i, k]
)
conn_b += n_conns[i, xs_n[b, i]]
tot_conn += conn_b
sections[b] = tot_conn
if pad:
max_conn = max(conn_b, max_conn)
if pad:
tot_conn = batch_size * max_conn
x_prime = np.empty((tot_conn, n_sites), dtype=x.dtype)
mels = np.empty(tot_conn, dtype=dtype)
c = 0
for b in range(batch_size):
c_diag = c
x_batch = x[b]
if nonzero_diagonal:
mels[c_diag] = constant
x_prime[c_diag] = np.copy(x_batch)
c += 1
for i in range(n_operators):
if nonzero_diagonal:
mels[c_diag] += diag_mels[i, xs_n[b, i]]
n_conn_i = n_conns[i, xs_n[b, i]]
if n_conn_i > 0:
sites = acting_on[i]
acting_size_i = acting_size[i]
for cc in range(n_conn_i):
mels[c + cc] = all_mels[i, xs_n[b, i], cc]
x_prime[c + cc] = np.copy(x_batch)
for k in range(acting_size_i):
x_prime[c + cc, sites[k]] = all_x_prime[
i, xs_n[b, i], cc, k
]
c += n_conn_i
if pad:
delta_conn = max_conn - (c - c_diag)
mels[c : c + delta_conn].fill(0)
x_prime[c : c + delta_conn, :] = np.copy(x_batch)
c += delta_conn
sections[b] = c
return x_prime, mels
[docs] def get_conn_filtered(self, x, sections, filters):
r"""Finds the connected elements of the Operator using only a subset of operators. Starting
from a given quantum number x, it finds all other quantum numbers x' such
that the matrix element :math:`O(x,x')` is different from zero. In general there
will be several different connected states x' satisfying this
condition, and they are denoted here :math:`x'(k)`, for :math:`k=0,1...N_{\mathrm{connected}}`.
This is a batched version, where x is a matrix of shape (batch_size,hilbert.size).
Args:
x (matrix): A matrix of shape (batch_size,hilbert.size) containing
the batch of quantum numbers x.
sections (array): An array of size (batch_size) useful to unflatten
the output of this function.
See numpy.split for the meaning of sections.
filters (array): Only operators op(filters[i]) are used to find the connected elements of
x[i].
Returns:
matrix: The connected states x', flattened together in a single matrix.
array: An array containing the matrix elements :math:`O(x,x')` associated to each x'.
"""
self._setup()
return self._get_conn_filtered_kernel(
x,
sections,
self._local_states,
self._basis,
self._constant,
self._diag_mels,
self._n_conns,
self._mels,
self._x_prime,
self._acting_on,
self._acting_size,
filters,
)
@staticmethod
@numba.jit(nopython=True)
def _get_conn_filtered_kernel(
x,
sections,
local_states,
basis,
constant,
diag_mels,
n_conns,
all_mels,
all_x_prime,
acting_on,
acting_size,
filters,
):
batch_size = x.shape[0]
n_sites = x.shape[1]
dtype = all_mels.dtype
assert filters.shape[0] == batch_size and sections.shape[0] == batch_size
# TODO remove this line when numba 0.53 is dropped 0.54 is minimum version
# workaround a bug in Numba arising when NUMBA_BOUNDSCHECK=1
constant = constant.item()
n_operators = n_conns.shape[0]
xs_n = np.empty((batch_size, n_operators), dtype=np.intp)
tot_conn = 0
for b in range(batch_size):
# diagonal element
tot_conn += 1
# counting the off-diagonal elements
i = filters[b]
assert i < n_operators and i >= 0
acting_size_i = acting_size[i]
xs_n[b, i] = 0
x_b = x[b]
x_i = x_b[acting_on[i, :acting_size_i]]
for k in range(acting_size_i):
xs_n[b, i] += (
np.searchsorted(
local_states[i, acting_size_i - k - 1],
x_i[acting_size_i - k - 1],
)
* basis[i, k]
)
tot_conn += n_conns[i, xs_n[b, i]]
sections[b] = tot_conn
x_prime = np.empty((tot_conn, n_sites))
mels = np.empty(tot_conn, dtype=dtype)
c = 0
for b in range(batch_size):
c_diag = c
mels[c_diag] = constant
x_batch = x[b]
x_prime[c_diag] = np.copy(x_batch)
c += 1
i = filters[b]
# Diagonal part
mels[c_diag] += diag_mels[i, xs_n[b, i]]
n_conn_i = n_conns[i, xs_n[b, i]]
if n_conn_i > 0:
sites = acting_on[i]
acting_size_i = acting_size[i]
for cc in range(n_conn_i):
mels[c + cc] = all_mels[i, xs_n[b, i], cc]
x_prime[c + cc] = np.copy(x_batch)
for k in range(acting_size_i):
x_prime[c + cc, sites[k]] = all_x_prime[i, xs_n[b, i], cc, k]
c += n_conn_i
return x_prime, mels
def __repr__(self):
ao = self.acting_on
acting_str = f"acting_on={ao}"
if len(acting_str) > 55:
acting_str = f"#acting_on={len(ao)} locations"
return f"{type(self).__name__}(dim={self.hilbert.size}, {acting_str}, constant={self.constant}, dtype={self.dtype})"