Source code for netket.operator._pauli_strings.numba

# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, TYPE_CHECKING
from functools import wraps

import numpy as np
from numba import jit

import jax.numpy as jnp

from netket.hilbert import AbstractHilbert, HomogeneousHilbert
from netket.errors import concrete_or_error, NumbaOperatorGetConnDuringTracingError
from netket.utils.types import DType

from .base import PauliStringsBase
from .jax import pack_internals

    from .jax import PauliStringsJax

def pack_internals_numba(
    hilbert: AbstractHilbert,
    operators: dict,
    dtype: DType,
    cutoff: float,  # unused
    acting = pack_internals(operators, weights)

    # the most Z we need to do anywhere
    _n_z_check_max = 0
    for v in acting.values():
        for _, b_z_check in v:
            _n_z_check_max = max(_n_z_check_max, len(b_z_check))

    n_operators = len(acting)
    # maximum number of strings which have the same sites to act on with X, but have different sites for Z
    _n_op_max = max(
        list(map(lambda x: len(x), list(acting.values()))), default=n_operators

    # Check if there are Y operators in the strings, and in that
    # case uppromote float to complex
    if not jnp.issubdtype(dtype, jnp.complexfloating):
        # this checks if there is an Y in one of the strings
        if np.any(np.char.find(operators, "Y") != -1):
            dtype = jnp.promote_types(jnp.complex64, dtype)

    # unpacking the dictionary into fixed-size arrays

    # the sites each X string is acting on, padded
    _sites = np.empty((n_operators, hilbert.size), dtype=np.intp)
    # number of sites each X string is acting on
    _ns = np.empty((n_operators), dtype=np.intp)
    # the number of operators for the same X string, with different sites we need to apply Z on
    _n_op = np.empty(n_operators, dtype=np.intp)
    # weights, padded
    _weights = np.empty((n_operators, _n_op_max), dtype=dtype)
    # the number of Z sites in each operators of the X strings, padded
    _nz_check = np.empty((n_operators, _n_op_max), dtype=np.intp)
    # sites to act on with Z for each operators of the X strings, padded
    _z_check = np.empty((n_operators, _n_op_max, _n_z_check_max), dtype=np.intp)

    for i, act in enumerate(acting.items()):
        sites = act[0]
        nsi = len(sites)
        _sites[i, :nsi] = sites
        _ns[i] = nsi
        values = act[1]
        _n_op[i] = len(values)
        for j in range(_n_op[i]):
            _weights[i, j] = values[j][0]
            _nz_check[i, j] = len(values[j][1])
            _z_check[i, j, : _nz_check[i, j]] = values[j][1]

    return {
        "sites": _sites,
        "ns": _ns,
        "n_op": _n_op,
        "weights_numba": _weights,
        "nz_check": _nz_check,
        "z_check": _z_check,
        "n_operators": n_operators,
        "mel_dtype": dtype,

class PauliStrings(PauliStringsBase):
    """A Hamiltonian consisting of the sum of products of Pauli operators."""

    def __init__(
        hilbert: AbstractHilbert,
        operators: Union[None, str, list[str]] = None,
        weights: Union[None, float, complex, list[Union[float, complex]]] = None,
        cutoff: float = 1.0e-10,
        dtype: DType = None,
        super().__init__(hilbert, operators, weights, cutoff=cutoff, dtype=dtype)

        # check that it is homogeneous, throw error if it's not
        if not isinstance(self.hilbert, HomogeneousHilbert):
            local_states = self.hilbert.states_at_index(0)
            if not all(
                np.allclose(local_states, self.hilbert.states_at_index(i))
                for i in range(self.hilbert.size)
                raise ValueError(
                    "Hilbert spaces with non homogeneous local_states are not "
                    "yet supported by PauliStrings."

        self._initialized = False

[docs] def to_jax_operator(self) -> "PauliStringsJax": # noqa: F821 """ Returns the jax-compatible version of this operator, which is an instance of :class:`netket.operator.PauliStringsJax`. """ from .jax import PauliStringsJax return PauliStringsJax( self.hilbert, self.operators, self.weights, dtype=self.dtype, # TODO: don't ignore the cutoff. # cutoff=self._cutoff, )
@property def max_conn_size(self) -> int: """The maximum number of non zero ⟨x|O|x'⟩ for every x.""" # 1 connection for every operator X, Y, Z... self._setup() return self._n_operators def _setup(self, force=False): """Analyze the operator strings and precompute arrays for get_conn inference""" if force or not self._initialized: data = pack_internals_numba( self.hilbert, self.operators, self.weights, self.dtype, self._cutoff ) self._sites = data["sites"] self._ns = data["ns"] self._n_op = data["n_op"] self._weights_numba = data["weights_numba"] self._nz_check = data["nz_check"] self._z_check = data["z_check"] self._n_operators = data["n_operators"] # caches for execution self._x_prime_max = np.empty((self._n_operators, self.hilbert.size)) self._mels_max = np.empty((self._n_operators), dtype=data["mel_dtype"]) self._local_states = np.array(self.hilbert.states_at_index(0)) self._initialized = True def _reset_caches(self): super()._reset_caches() self._initialized = False @staticmethod @jit(nopython=True) def _flattened_kernel( x, sections, x_prime, mels, sites, ns, n_op, weights, nz_check, z_check, cutoff, max_conn, local_states, pad=False, ): x_prime = np.empty((x.shape[0] * max_conn, x_prime.shape[1]), dtype=x.dtype) mels = np.zeros((x.shape[0] * max_conn), dtype=mels.dtype) state_1 = local_states[-1] n_c = 0 for b in range(x.shape[0]): xb = x[b] # initialize with the old state x_prime[b * max_conn : (b + 1) * max_conn, :] = np.copy(xb) for i in range(sites.shape[0]): # iterate over the X strings mel = 0.0 # iterate over the Z substrings for j in range(n_op[i]): # apply all the Z (check the qubits at all affected sites) if nz_check[i, j] > 0: to_check = z_check[i, j, : nz_check[i, j]] n_z = np.count_nonzero(xb[to_check] == state_1) else: n_z = 0 # multiply with -1 for every site we did Z which was state_1 mel += weights[i, j] * (-1.0) ** n_z if abs(mel) > cutoff: x_prime[n_c] = np.copy(xb) # now flip all the sites in the X string for site in sites[i, : ns[i]]: new_state_idx = int(x_prime[n_c, site] == local_states[0]) x_prime[n_c, site] = local_states[new_state_idx] mels[n_c] = mel n_c += 1 if pad: n_c = (b + 1) * max_conn sections[b] = n_c return x_prime[:n_c], mels[:n_c]
[docs] def get_conn_flattened(self, x, sections, pad=False): r"""Finds the connected elements of the Operator. Starting from a given quantum number x, it finds all other quantum numbers x' such that the matrix element :math:`O(x,x')` is different from zero. In general there will be several different connected states x' satisfying this condition, and they are denoted here :math:`x'(k)`, for :math:`k=0,1...N_{\mathrm{connected}}`. This is a batched version, where x is a matrix of shape (batch_size,hilbert.size). Args: x (matrix): A matrix of shape (batch_size,hilbert.size) containing the batch of quantum numbers x. sections (array): An array of size (batch_size) useful to unflatten the output of this function. See numpy.split for the meaning of sections. Returns: matrix: The connected states x', flattened together in a single matrix. array: An array containing the matrix elements :math:`O(x,x')` associated to each x'. """ self._setup() x = concrete_or_error( np.asarray, x, NumbaOperatorGetConnDuringTracingError, self, ) assert ( x.shape[-1] == self.hilbert.size ), "size of hilbert space does not match size of x" return self._flattened_kernel( x, sections, self._x_prime_max, self._mels_max, self._sites, self._ns, self._n_op, self._weights_numba, self._nz_check, self._z_check, self._cutoff, self._n_operators, self._local_states, pad, )
def _get_conn_flattened_closure(self): self._setup() _x_prime_max = self._x_prime_max _mels_max = self._mels_max _sites = self._sites _ns = self._ns _n_op = self._n_op _weights = self._weights_numba _nz_check = self._nz_check _z_check = self._z_check _cutoff = self._cutoff _n_operators = self._n_operators fun = self._flattened_kernel _local_states = self._local_states def gccf_fun(x, sections): return fun( x, sections, _x_prime_max, _mels_max, _sites, _ns, _n_op, _weights, _nz_check, _z_check, _cutoff, _n_operators, _local_states, ) return jit(nopython=True)(gccf_fun)