# Copyright 2021-2023 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import jax.numpy as jnp
from numba import jit
from scipy.sparse import csr_matrix as _csr_matrix
from scipy.sparse import issparse
from netket.hilbert import DiscreteHilbert
from netket.operator import AbstractOperator
from netket.utils.optional_deps import import_optional_dependency
from netket.jax.sharding import replicate_sharding_decorator_for_get_conn_padded
[docs]
class DiscreteOperator(AbstractOperator):
r"""This class is the base class for operators defined on a
discrete Hilbert space. Users interested in implementing new
quantum Operators for discrete Hilbert spaces should derive
their own class from this class
"""
def __init__(self, hilbert: DiscreteHilbert):
if not isinstance(hilbert, DiscreteHilbert):
raise ValueError(
"A Discrete Operator can only act upon a discrete Hilbert space."
)
super().__init__(hilbert)
@property
def max_conn_size(self) -> int:
"""The maximum number of non zero ⟨x|O|x'⟩ for every x."""
raise NotImplementedError
[docs]
@replicate_sharding_decorator_for_get_conn_padded
def get_conn_padded(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
r"""Finds the connected elements of the Operator.
Starting from a batch of quantum numbers :math:`x={x_1, ... x_n}` of
size :math:`B \times M` where :math:`B` size of the batch and :math:`M`
size of the hilbert space, finds all states :math:`y_i^1, ..., y_i^K`
connected to every :math:`x_i`.
Returns a matrix of size :math:`B \times K_{max} \times M` where
:math:`K_{max}` is the maximum number of connections for every
:math:`y_i`.
Args:
x : A N-tensor of shape :math:`(...,hilbert.size)` containing
the batch/batches of quantum numbers :math:`x`.
Returns:
**(x_primes, mels)**: The connected states x', in a N+1-tensor and an
N-tensor containing the matrix elements :math:`O(x,x')`
associated to each x' for every batch.
"""
n_visible = x.shape[-1]
n_samples = x.size // n_visible
sections = np.empty(n_samples, dtype=np.int32)
x_primes, mels = self.get_conn_flattened(
x.reshape(-1, x.shape[-1]), sections, pad=True
)
n_primes = sections[0]
x_primes_r = x_primes.reshape(*x.shape[:-1], n_primes, n_visible)
mels_r = mels.reshape(*x.shape[:-1], n_primes)
return x_primes_r, mels_r
[docs]
def get_conn_flattened(
self, x: np.ndarray, sections: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
r"""Finds the connected elements of the Operator.
Starting from a given quantum number :math:`x`, it finds all
other quantum numbers :math:`x'` such that the matrix element
:math:`O(x,x')` is different from zero. In general there will be
several different connected states :math:`x'` satisfying this
condition, and they are denoted here :math:`x'(k)`, for
:math:`k=0,1...N_{\mathrm{connected}}`.
This is a batched version, where x is a matrix of shape
:code:`(batch_size,hilbert.size)`.
Args:
x: A matrix of shape `(batch_size, hilbert.size)`
containing the batch of quantum numbers x.
sections: An array of sections for the flattened x'.
See numpy.split for the meaning of sections.
Returns:
(matrix, array): The connected states x', flattened together in
a single matrix.
An array containing the matrix elements :math:`O(x,x')`
associated to each x'.
"""
raise NotImplementedError(
f"""
The method get_conn_flattened has not been implemented for the object of
type {type(self)}.
This may happen if you defined a custom class inheriting from DiscreteOperator
and you have not implemented this method. In that case, you should define
`get_conn_flattened(self, x: array, sections: array)` according to the
docstring provided on the documentation.
Otherwise, please open an issue on netket's github repository.
"""
)
[docs]
def get_conn(self, x: np.ndarray):
r"""Finds the connected elements of the Operator. Starting
from a given quantum number x, it finds all other quantum numbers x' such
that the matrix element :math:`O(x,x')` is different from zero. In general there
will be several different connected states x' satisfying this
condition, and they are denoted here :math:`x'(k)`, for :math:`k=0,1...N_{\mathrm{connected}}`.
Args:
x: An array of shape `(hilbert.size, )` containing the quantum numbers x.
Returns:
matrix: The connected states x' of shape (N_connected,hilbert.size)
array: An array containing the matrix elements :math:`O(x,x')` associated to each x'.
Raise:
ValueError: If the given quantum number is not compatible with the hilbert space.
"""
if x.ndim != 1:
raise ValueError(
"get_conn does not support batches. Please use get_conn_flattened instead."
)
if x.shape[0] != self.hilbert.size:
raise ValueError(
"The given quantum numbers do not match the hilbert space because"
f"it has shape {x.shape} of which[0] but expected {self.hilbert.size}."
)
return self.get_conn_flattened(
x.reshape((1, -1)),
np.ones(1),
)
[docs]
def n_conn(self, x, out=None) -> np.ndarray:
r"""Return the number of states connected to x.
Args:
x (matrix): A matrix of shape (batch_size,hilbert.size) containing
the batch of quantum numbers x.
out (array): If None an output array is allocated.
Returns:
array: The number of connected states x' for each x[i].
"""
if out is None:
out = np.empty(x.shape[0], dtype=np.int32)
self.get_conn_flattened(x, out)
out = self._n_conn_from_sections(out)
return out
@staticmethod
@jit(nopython=True)
def _n_conn_from_sections(out):
low = 0
for i in range(out.shape[0]):
old_out = out[i]
out[i] = out[i] - low
low = old_out
return out
[docs]
def to_sparse(self) -> _csr_matrix:
r"""Returns the sparse matrix representation of the operator. Note that,
in general, the size of the matrix is exponential in the number of quantum
numbers, and this operation should thus only be performed for
low-dimensional Hilbert spaces or sufficiently sparse operators.
This method requires an indexable Hilbert space.
Returns:
The sparse matrix representation of the operator.
"""
concrete_op = self.collect()
hilb = self.hilbert
x = hilb.all_states()
sections = np.empty(x.shape[0], dtype=np.int32)
x_prime, mels = concrete_op.get_conn_flattened(x, sections)
numbers = hilb.states_to_numbers(x_prime)
sections1 = np.empty(sections.size + 1, dtype=np.int32)
sections1[1:] = sections
sections1[0] = 0
## eliminate duplicates from numbers
# rows_indices = compute_row_indices(hilb.states_to_numbers(x), sections1)
return _csr_matrix(
(mels, numbers, sections1),
shape=(self.hilbert.n_states, self.hilbert.n_states),
)
# return _csr_matrix(
# (mels, (rows_indices, numbers)),
# shape=(self.hilbert.n_states, self.hilbert.n_states),
# )
[docs]
def to_dense(self) -> np.ndarray:
r"""Returns the dense matrix representation of the operator. Note that,
in general, the size of the matrix is exponential in the number of quantum
numbers, and this operation should thus only be performed for
low-dimensional Hilbert spaces or sufficiently sparse operators.
This method requires an indexable Hilbert space.
Returns:
The dense matrix representation of the operator as a Numpy array.
"""
return self.to_sparse().todense().A
[docs]
def to_qobj(self): # -> "qutip.Qobj"
r"""Convert the operator to a qutip's Qobj.
Returns:
A :class:`qutip.Qobj` object.
"""
qutip = import_optional_dependency("qutip", descr="to_qobj")
return qutip.Qobj(
self.to_sparse(), dims=[list(self.hilbert.shape), list(self.hilbert.shape)]
)
[docs]
def __call__(self, v: np.ndarray) -> np.ndarray:
return self.apply(v)
[docs]
def apply(self, v: np.ndarray) -> np.ndarray:
op = self.to_linear_operator()
return op @ v
def __matmul__(self, other):
if (
isinstance(other, np.ndarray)
or isinstance(other, jnp.ndarray)
or issparse(other)
):
return self.apply(other)
elif isinstance(other, AbstractOperator):
return self._op__matmul__(other)
else:
return NotImplemented
def _op__matmul__(self, other):
"Implementation on subclasses of __matmul__"
return NotImplemented
def __rmatmul__(self, other):
if (
isinstance(other, np.ndarray)
or isinstance(other, jnp.ndarray)
or issparse(other)
):
# return self.apply(other)
return NotImplemented
elif isinstance(other, AbstractOperator):
return self._op__rmatmul__(other)
else:
return NotImplemented
def _op__rmatmul__(self, other):
"Implementation on subclasses of __matmul__"
return NotImplemented
[docs]
def to_linear_operator(self):
return self.to_sparse()
def _get_conn_flattened_closure(self):
raise NotImplementedError(
"""
_get_conn_flattened_closure not implemented for this operator type.
You were probably trying to use an operator with a sampler.
Please report this bug.
numba4jax won't work.
"""
)