Source code for netket.operator._sumoperators

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Optional, Union
from collections.abc import Hashable, Iterable

from netket.utils.numbers import is_scalar
from netket.utils.types import DType, PyTree, Array

from netket.jax import canonicalize_dtypes
from netket.operator import ContinuousOperator
from netket.utils import struct, HashableArray

import jax.numpy as jnp


@struct.dataclass
class SumOperatorPyTree:
    """Internal class used to pass data from the operator to the jax kernel.

    This is used such that we can pass a PyTree containing some static data.
    We could avoid this if the operator itself was a pytree, but as this is not
    the case we need to pass as a separte object all fields that are used in
    the kernel.

    We could forego this, but then the kernel could not be marked as
    @staticmethod and we would recompile every time we construct a new operator,
    even if it is identical
    """

    ops: tuple[ContinuousOperator, ...] = struct.field(pytree_node=False)
    coeffs: Array
    op_data: tuple[PyTree, ...]


def _flatten_sumoperators(operators: Iterable[ContinuousOperator], coefficients: Array):
    """Flatten sumoperators inside of operators."""
    new_operators = []
    new_coeffs = []
    for op, c in zip(operators, coefficients):
        if isinstance(op, SumOperator):
            new_operators.extend(op.operators)
            new_coeffs.extend(c * op.coefficients)
        else:
            new_operators.append(op)
            new_coeffs.append(c)
    return new_operators, new_coeffs


[docs] class SumOperator(ContinuousOperator): r"""This class implements the action of the _expect_kernel()-method of ContinuousOperator for a sum of ContinuousOperator objects. """
[docs] def __init__( self, *operators: tuple[ContinuousOperator, ...], coefficients: Union[float, Iterable[float]] = 1.0, dtype: Optional[DType] = None, ): r""" Returns the action of a sum of local operators. Args: operators: A list of ContinuousOperator objects coefficients: A coefficient for each ContinuousOperator object dtype: Data type of the coefficients """ hi_spaces = [op.hilbert for op in operators] if not all(hi == hi_spaces[0] for hi in hi_spaces): raise NotImplementedError( "Cannot add operators on different hilbert spaces" ) if is_scalar(coefficients): coefficients = [coefficients for _ in operators] if len(operators) != len(coefficients): raise AssertionError("Each operator needs a coefficient") operators, coefficients = _flatten_sumoperators(operators, coefficients) dtype = canonicalize_dtypes(float, *operators, *coefficients, dtype=dtype) self._operators = tuple(operators) self._coefficients = jnp.asarray(coefficients, dtype=dtype) super().__init__(hi_spaces[0], self._coefficients.dtype) self._is_hermitian = all([op.is_hermitian for op in operators]) self.__attrs = None
@property def is_hermitian(self) -> bool: return self._is_hermitian @property def operators(self) -> tuple[ContinuousOperator, ...]: """The list of all operators in the terms of this sum. Every operator is summed with a corresponding coefficient """ return self._operators @property def coefficients(self) -> Array: return self._coefficients @staticmethod def _expect_kernel( logpsi: Callable, params: PyTree, x: Array, data: Optional[PyTree] ): result = [ data.coeffs[i] * op._expect_kernel(logpsi, params, x, op_data) for i, (op, op_data) in enumerate(zip(data.ops, data.op_data)) ] return sum(result) def _pack_arguments(self) -> SumOperatorPyTree: return SumOperatorPyTree( self.operators, self.coefficients, tuple(op._pack_arguments() for op in self.operators), ) @property def _attrs(self) -> tuple[Hashable, ...]: if self.__attrs is None: self.__attrs = ( self.hilbert, self.operators, HashableArray(self.coefficients), self.dtype, ) return self.__attrs def __repr__(self): return ( f"SumOperator(operators={self.operators}, coefficients={self.coefficients})" )