Source code for netket.utils.static_range

# Copyright 2022 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from numbers import Number

import numpy as np

import jax
import jax.numpy as jnp

from netket.utils import struct
from netket.utils.types import DType
from netket.jax import canonicalize_dtypes


[docs] class StaticRange(struct.Pytree): """ An object representing a range similar to python's range, but that works with `jax.jit`. This range object can also be used to convert 'computational basis' configurations to integer indices ∈ [0,length]. This object is used inside of Hilbert spaces. This object can be converted to a numpy or jax array: .. code-block:: python >>> import netket as nk; import numpy as np >>> n_max = 10 >>> ran = nk.utils.StaticRange(start=0, step=1, length=n_max) >>> np.array(ran) array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) And it can be used to convert between integer values starting at 0 and the values in the range. .. code-block:: python >>> import netket as nk; import numpy as np >>> ran = nk.utils.StaticRange(start=-2, step=2, length=3) >>> np.array(ran) array([-2, 0, 2]) >>> len(ran) 3 >>> ran.states_to_numbers(0) array(1) >>> ran.numbers_to_states(0) -2 >>> ran.numbers_to_states(1) 0 >>> ran.numbers_to_states(2) 2 """ start: Number = struct.field(pytree_node=False) """The first value in the range.""" step: Number = struct.field(pytree_node=False) """The difference between two consecutive values in the range.""" length: int = struct.field(pytree_node=False) """The number of entries in the range.""" dtype: DType = struct.field(pytree_node=False) """The dtype of the range."""
[docs] def __init__(self, start: Number, step: Number, length: int, dtype: DType = None): """ Constructs a Static Range object. To construct it, one must specify the start value, the step and the length. It is also possible to specify a `dtype`. In case it's not specified, it's inferred from the input arguments. For example, the :class:`~netket.utils.StaticRange` of a Fock Hilbert space is constructed as .. code-block:: python >>> import netket as nk >>> n_max = 10 >>> nk.utils.StaticRange(start=0, step=1, length=n_max) StaticRange(start=0, step=1, length=10, dtype=int64) and the range of a Spin-1/2 Hilbert space is constructed as: .. code-block:: python >>> import netket as nk >>> n_max = 10 >>> nk.utils.StaticRange(start=-1, step=2, length=2) StaticRange(start=-1, step=2, length=2, dtype=int64) Args: start: Value of the first entry step: Step between the entries length: Length of this range dtype: The data type """ dtype = canonicalize_dtypes(start, step, dtype=dtype) self.start = np.array(start, dtype=dtype).item() self.step = np.array(step, dtype=dtype).item() self.length = int(length) self.dtype = dtype
@property def shape(self): """The shape of the range, if converted to an array. It's always (length,).""" return (self.length,) @property def ndim(self): """The number of dimensions of the range, if converted to an array. It's always 1.""" return 1
[docs] def astype(self, dtype: DType): """Returns a new StaticRange with a different dtype.""" return StaticRange(self.start, self.step, self.length, dtype=dtype)
def __len__(self): return self.length def __getitem__(self, i): if i >= self.length: raise IndexError return self.start + self.step * i
[docs] def states_to_numbers(self, x, dtype: DType = int): """Given an element in the range, returns it's index. Args: x: array of elements beloging to this range. No bounds checking is performed. dtype: Optional dtype to be used for the output. Returns: An array of integers, which can be. """ idx = (x - self.start) / self.step if dtype is not None: if not hasattr(idx, "astype"): idx = np.array(idx, dtype=dtype) else: idx = idx.astype(dtype) return idx
[docs] def numbers_to_states(self, i, dtype: DType = None): """Given an integer index, returns the i-th elements in the range. Args: x: indices to extract from the range. dtype: Optional dtype to be used for the output. Returns: An array of values from the range. The dtype by default is that of the range. """ if dtype is None: dtype = self.dtype npx = jnp if isinstance(i, jax.Array) else np start = npx.array(self.start, dtype=dtype) step = npx.array(self.step, dtype=dtype) return (start + step * i).astype(dtype)
[docs] def flip_state(self, state): """Only works if this range has length 2. Given a state, returns the other state.""" if not len(self) == 2: raise ValueError constant_sum = 2 * self.start + self.step return constant_sum - state
[docs] def all_states(self, dtype: DType = None): """Return all elements in the range. Equal to __array__ Args: dtype: Optional dtype to be used for the output. Returns: An array with all values from the range. The dtype by default is that of the range. """ return self.__array__(dtype=dtype)
@property def n_states(self): """The number of states in the range. Equal to length.""" return self.length @property def is_indexable(self): """If the range is indexable. Always True""" return True def __array__(self, dtype=None): if dtype is None: dtype = self.dtype states = self.start + np.arange(self.length, dtype=dtype) * self.step return states.astype(dtype) def __hash__(self): return hash(("StaticRange", self.start, self.step, self.length)) def __eq__(self, o): if isinstance(o, StaticRange): return ( self.start == o.start and self.step == o.step and self.length == o.length ) elif hasattr(o, "shape"): if self.shape == o.shape: return self.__array__() == o elif hasattr(o, "__array__"): if hasattr(o, "shape") and self.shape == o.shape: return self.__array__() == o return False def __repr__(self): return f"StaticRange(start={self.start}, step={self.step}, length={self.length}, dtype={self.dtype})"