Source code for netket.experimental.operator._fermion_operator_2nd_numba

# Copyright 2022 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

import numpy as np
import numba

from netket.utils.types import DType
from netket.errors import concrete_or_error, NumbaOperatorGetConnDuringTracingError

from ._fermion_operator_2nd_utils import _is_diag_term, OperatorDict
from ._fermion_operator_2nd_base import FermionOperator2ndBase

if TYPE_CHECKING:
    from ._fermion_operator_2nd_jax import FermionOperator2ndJax


[docs] class FermionOperator2nd(FermionOperator2ndBase): r""" A fermionic operator in :math:`2^{nd}` quantization, using Numba for indexing. .. warning:: This class is not a Pytree, so it cannot be used inside of jax-transformed functions like `jax.grad` or `jax.jit`. The standard usage is to index into the operator from outside the jax function transformation and pass the results to the jax-transformed functions. To use this operator inside of a jax function transformation, convert it to a jax operator (class:`netket.experimental.operator.FermionOperator2ndJax`) by using :meth:`netket.experimental.operator.FermionOperator2nd.to_jax_operator()`. When using native (experimental) sharding, or when working with GPUs, we reccomend using the Jax implementations of the operators for potentially better performance. """ def _setup(self, force: bool = False): """Analyze the operator strings and precompute arrays for get_conn inference""" if force or not self._initialized: # following lists will be used to compute matrix elements # they are filled in _add_term out = _pack_internals(self._operators, self._dtype) ( self._orb_idxs, self._daggers, self._numba_weights, self._diag_idxs, self._off_diag_idxs, self._term_split_idxs, ) = out self._max_conn_size = 0 if len(self._diag_idxs) > 0: self._max_conn_size += 1 # the following could be reduced further self._max_conn_size += len(self._off_diag_idxs) self._initialized = True
[docs] def to_jax_operator(self) -> "FermionOperator2ndJax": # noqa: F821 """ Returns the jax version of this operator, which is an instance of :class:`netket.experimental.operator.FermionOperator2ndJax`. """ from ._fermion_operator_2nd_jax import FermionOperator2ndJax new_op = FermionOperator2ndJax( self.hilbert, cutoff=self._cutoff, dtype=self.dtype ) new_op._operators = self._operators.copy() return new_op
def _get_conn_flattened_closure(self): self._setup() _max_conn_size = self.max_conn_size _orb_idxs = self._orb_idxs _daggers = self._daggers _weights = self._numba_weights _diag_idxs = self._diag_idxs _off_diag_idxs = self._off_diag_idxs _term_split_idxs = self._term_split_idxs _cutoff = self._cutoff fun = self._flattened_kernel def gccf_fun(x, sections): return fun( x, sections, _max_conn_size, _orb_idxs, _daggers, _weights, _diag_idxs, _off_diag_idxs, _term_split_idxs, _cutoff, ) return numba.jit(nopython=True)(gccf_fun)
[docs] def get_conn_flattened(self, x, sections, pad=False): r"""Finds the connected elements of the Operator. Starting from a given quantum number x, it finds all other quantum numbers x' such that the matrix element :math:`O(x,x')` is different from zero. In general there will be several different connected states x' satisfying this condition, and they are denoted here :math:`x'(k)`, for :math:`k=0,1...N_{\mathrm{connected}}`. This is a batched version, where x is a matrix of shape (batch_size,hilbert.size). Args: x: A matrix of shape (batch_size,hilbert.size) containing the batch of quantum numbers x. sections: An array of size (batch_size) useful to unflatten the output of this function. See numpy.split for the meaning of sections. Returns: matrix: The connected states x', flattened together in a single matrix. array: An array containing the matrix elements :math:`O(x,x')` associated to each x'. """ self._setup() x = concrete_or_error( np.asarray, x, NumbaOperatorGetConnDuringTracingError, self, ) assert ( x.shape[-1] == self.hilbert.size ), "size of hilbert space does not match size of x" return self._flattened_kernel( x, sections, self.max_conn_size, self._orb_idxs, self._daggers, self._numba_weights, self._diag_idxs, self._off_diag_idxs, self._term_split_idxs, self._cutoff, pad, )
@staticmethod @numba.jit(nopython=True) def _flattened_kernel( # pragma: no cover x, sections, max_conn, orb_idxs, daggers, weights, diag_idxs, off_diag_idxs, term_split_idxs, cutoff, pad=False, ): x_prime = np.empty((x.shape[0] * max_conn, x.shape[1]), dtype=x.dtype) mels = np.zeros((x.shape[0] * max_conn), dtype=weights.dtype) # do not split at the last one (gives empty array) term_split_idxs = term_split_idxs[:-1] orb_idxs_list = np.split(orb_idxs, term_split_idxs) daggers_list = np.split(daggers, term_split_idxs) # loop over the batch dimension n_c = 0 for b in range(x.shape[0]): xb = x[b, :] # we can already fill up with default values if pad: x_prime[b * max_conn : (b + 1) * max_conn, :] = np.copy(xb) non_zero_diag = False # first do the diagonal terms, they all generate just 1 term for term_idx in diag_idxs: mel = weights[term_idx] xt = np.copy(xb) has_xp = True for orb_idx, dagger in zip( orb_idxs_list[term_idx], daggers_list[term_idx] ): _, mel, op_has_xp = _apply_operator( xt, orb_idx, dagger, mel, cutoff ) if not op_has_xp: has_xp = False continue if has_xp: x_prime[n_c, :] = np.copy(xb) # should be untouched mels[n_c] += mel non_zero_diag = non_zero_diag or has_xp # end of the diagonal terms if non_zero_diag: n_c += 1 # now do the off-diagonal terms for term_idx in off_diag_idxs: mel = weights[term_idx] xt = np.copy(xb) has_xp = True for orb_idx, dagger in zip( orb_idxs_list[term_idx], daggers_list[term_idx] ): xt, mel, op_has_xp = _apply_operator( xt, orb_idx, dagger, mel, cutoff ) if not op_has_xp: # detect zeros has_xp = False continue if has_xp: x_prime[n_c, :] = np.copy(xt) # should be different mels[n_c] += mel n_c += 1 # end of this sample if pad: n_c = (b + 1) * max_conn sections[b] = n_c if pad: return x_prime, mels else: return x_prime[:n_c], mels[:n_c]
def _pack_internals(operators: OperatorDict, dtype: DType): """ Create the internal structures to compute the matrix elements Processes and adds a single term such that we can compute its matrix elements, in tuple format ((1,1), (2,0)) """ # properties of single-fermion operators, e.g. "0^" orb_idxs = [] daggers = [] # properties of multi-body operators, e.g. "0^ 1" weights = [] # herm_term = [] diag_idxs = [] off_diag_idxs = [] # below connect the second type to the first type (used to split single-fermion lists) term_split_idxs = [] term_counter = 0 single_op_counter = 0 for term, weight in operators.items(): if len(term) > 0 and not all(len(t) == 2 for t in term): # pragma: no cover raise ValueError(f"terms must contain (i, dag) pairs, but received {term}") # fill some info about the term weights.append(weight) is_diag = _is_diag_term(term) if is_diag: diag_idxs.append(term_counter) else: off_diag_idxs.append(term_counter) # single-fermion operators for orb_idx, dagger in term: # orb_idxs: holds the hilbert index of the orbital orb_idxs.append(orb_idx) # daggers: stores whether operator is creator or annihilator daggers.append(not bool(dagger)) single_op_counter += 1 term_split_idxs.append(single_op_counter) term_counter += 1 orb_idxs = np.array(orb_idxs, dtype=np.intp) daggers = np.array(daggers, dtype=bool) weights = np.array(weights, dtype=dtype) diag_idxs = np.array(diag_idxs, dtype=np.intp) off_diag_idxs = np.array(off_diag_idxs, dtype=np.intp) term_split_idxs = np.array(term_split_idxs, dtype=np.intp) return ( orb_idxs, daggers, weights, diag_idxs, off_diag_idxs, term_split_idxs, ) @numba.jit(nopython=True) def _isclose(a, b, cutoff): # pragma: no cover return np.abs(a - b) < cutoff @numba.jit(nopython=True) def _is_empty(site): # pragma: no cover return _isclose(site, 0, 1e-10) @numba.jit(nopython=True) def _flip(site): # pragma: no cover return 1 - site @numba.jit(nopython=True) def _apply_operator(xt, orb_idx, dagger, mel, cutoff): # pragma: no cover has_xp = True empty_site = _is_empty(xt[orb_idx]) if dagger: if not empty_site: has_xp = False else: mel *= (-1) ** np.sum(xt[:orb_idx]) # jordan wigner sign xt[orb_idx] = _flip(xt[orb_idx]) else: if empty_site: has_xp = False else: mel *= (-1) ** np.sum(xt[:orb_idx]) # jordan wigner sign xt[orb_idx] = _flip(xt[orb_idx]) if _isclose(mel, 0, cutoff): has_xp = False return xt, mel, has_xp