Source code for netket.utils.array

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
import jax

from .types import Array, DType, Shape


[docs] class HashableArray: """ This class wraps a numpy or jax array in order to make it hashable and equality comparable (which is necessary since a well-defined hashable object needs to satisfy :code:`obj1 == obj2` whenever :code:`hash(obj1) == hash(obj2)`. The underlying array can also be accessed using :code:`numpy.asarray(self)`. """
[docs] def __init__(self, wrapped: Array): """ Wraps an array into an object that is hashable, and that can be converted again into an array. Forces all arrays to numpy and sets them to readonly. They can be converted back to jax later or a writeable numpy copy can be created by using `np.array(...)` The hash is computed by hashing the whole content of the array. Args: wrapped: array to be wrapped """ if isinstance(wrapped, HashableArray): wrapped = wrapped.wrapped else: if isinstance(wrapped, jax.Array): # __array__ only works if it's a numpy array. wrapped = np.array(wrapped) else: wrapped = wrapped.copy() if isinstance(wrapped, np.ndarray): wrapped.flags.writeable = False self._wrapped: np.array = wrapped self._hash: Optional[int] = None
@property def wrapped(self): """The read-only wrapped array.""" return self._wrapped def __hash__(self): if self._hash is None: self._hash = hash(self.wrapped.tobytes()) return self._hash def __eq__(self, other): return ( type(other) is HashableArray and self.shape == other.shape and self.dtype == other.dtype and hash(self) == hash(other) ) def __array__(self, dtype: DType = None): if dtype is None: dtype = self.wrapped.dtype return self.wrapped.__array__(dtype) @property def dtype(self) -> DType: return self.wrapped.dtype @property def size(self) -> int: return self.wrapped.size @property def ndim(self) -> int: return self.wrapped.ndim @property def shape(self) -> Shape: return self.wrapped.shape def __repr__(self) -> str: return f"HashableArray({self.wrapped},\n shape={self.shape}, dtype={self.dtype}, hash={hash(self)})" def __str__(self) -> str: return ( f"HashableArray(shape={self.shape}, dtype={self.dtype}, hash={hash(self)})" )
def array_in(x, ys): """ Interpret ys as a list of arrays, and test if x is equal to any y in ys, with exactly the same shape but not exactly the same dtype. Note: In numpy, :code:`x in ys` is equivalent to :code:`any(x == ys)`, which is usually not what we intend when :code:`x.size > 1`. JAX arrays will raise an error rather than silently compute it. """ x = x.reshape(1, -1) ys = ys.reshape(ys.shape[0], -1) return (x == ys).all(axis=1).any()