Source code for netket.hilbert.homogeneous

# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, List, Callable
from functools import lru_cache

from numbers import Real

import numpy as np

from .discrete_hilbert import DiscreteHilbert
from .hilbert_index import HilbertIndex

# This function has exponential runtime in self.size, so we cache it in order to
# only compute it once.
# TODO: distribute over MPI... chose better chunk size
def compute_constrained_to_bare_conversion_table(self, *, chunk_size: int = 100000):
    Computes the conversion table that converts the 'constrained' indices
    of an hilbert space to bare indices, so that routines generating
    only values in an unconstrained space can be used.

    This function operates on blocks of `chunk_size` states at a time in order
    to lower the memory cost. The default chunk size has been chosen by instinct
    and is likely wrong.
    n_chunks = int(np.ceil(self._hilbert_index.n_states / chunk_size))
    bare_number_chunks = []
    for i in range(n_chunks):
        id_start = chunk_size * i
        id_end = np.minimum(chunk_size * (i + 1), self._hilbert_index.n_states)
        ids = np.arange(id_start, id_end)

        states = self._hilbert_index.numbers_to_states(ids)
        is_constrained = self._constraint_fn(states)
        (chunk_bare_number,) = np.nonzero(is_constrained)
        bare_number_chunks.append(chunk_bare_number + id_start)

    return np.concatenate(bare_number_chunks)

class HomogeneousHilbert(DiscreteHilbert):
    r"""The Abstract base class for homogeneous hilbert spaces.

    This class should only be subclassed and should not be instantiated directly.

[docs] def __init__( self, local_states: Optional[List[Real]], N: int = 1, constraint_fn: Optional[Callable] = None, ): r""" Constructs a new ``HomogeneousHilbert`` given a list of eigenvalues of the states and a number of sites, or modes, within this hilbert space. This method should only be called from the subclasses `__init__` method. Args: local_states (list or None): Eigenvalues of the states. If the allowed states are an infinite number, None should be passed as an argument. N: Number of modes in this hilbert space (default 1). constraint_fn: A function specifying constraints on the quantum numbers. Given a batch of quantum numbers it should return a vector of bools specifying whether those states are valid or not. """ assert isinstance(N, int) self._is_finite = local_states is not None if self._is_finite: self._local_states = np.asarray(local_states) assert self._local_states.ndim == 1 self._local_size = self._local_states.shape[0] self._local_states = self._local_states.tolist() self._local_states_frozen = frozenset(self._local_states) else: self._local_states = None self._local_states_frozen = None self._local_size = np.iinfo(np.intp).max self._constraint_fn = constraint_fn self.__hilbert_index = None self.__bare_numbers = None shape = tuple(self._local_size for _ in range(N)) super().__init__(shape=shape)
@property def size(self) -> int: r"""The total number number of degrees of freedom.""" return len(self.shape) @property def local_size(self) -> int: r"""Size of the local degrees of freedom that make the total hilbert space.""" return self._local_size
[docs] def size_at_index(self, i: int) -> int: return self.local_size
@property def local_states(self) -> Optional[List[float]]: r"""A list of discrete local quantum numbers. If the local states are infinitely many, None is returned.""" return self._local_states
[docs] def states_at_index(self, i: int): return self.local_states
@property def n_states(self) -> int: r"""The total dimension of the many-body Hilbert space. Throws an exception iff the space is not indexable.""" if not self.constrained: return self._hilbert_index.n_states else: return self._bare_numbers.shape[0] @property def is_finite(self) -> bool: r"""Whether the local hilbert space is finite.""" return self._is_finite @property def constrained(self) -> bool: r"""Returns True if the hilbert space is constrained.""" return self._constraint_fn is not None def _numbers_to_states(self, numbers: np.ndarray, out: np.ndarray) -> np.ndarray: if self.constrained: numbers = self._bare_numbers[numbers] return self._hilbert_index.numbers_to_states(numbers, out) def _states_to_numbers(self, states: np.ndarray, out: np.ndarray): self._hilbert_index.states_to_numbers(states, out) if self.constrained: out[:] = np.searchsorted(self._bare_numbers, out) if np.max(out) >= self.n_states: raise RuntimeError( "The required state does not satisfy " "the given constraints." ) return out @property def _hilbert_index(self) -> HilbertIndex: """ Returns the `HilbertIndex` object, which is a numba jitclass used to convert integers to states and vice-versa. """ if self.__hilbert_index is None: if not self.is_indexable: raise RuntimeError("The hilbert space is too large to be indexed.") self.__hilbert_index = HilbertIndex( np.asarray(self.local_states, dtype=np.float64), self.size ) return self.__hilbert_index @property def _bare_numbers(self) -> np.ndarray: """ Returns the conversion table between indices in the constrained space and the corresponding unconstrained space. """ if not self.constrained: return None if self.__bare_numbers is None: self.__bare_numbers = compute_constrained_to_bare_conversion_table(self) return self.__bare_numbers def __repr__(self): constr = ", constrained={}".format(self.constrained) if self.constrained else "" clsname = type(self).__name__ return f"{clsname}(local_size={self._local_size}, N={self.size}{constr})" @property def _attrs(self): return ( self.size, self.local_size, self._local_states_frozen, self.constrained, self._constraint_fn, )