Source code for netket.utils.struct.fields

# Copyright 2023 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

Common utilities to define fields and properties of netket dataclasses
and Pytrees.

from functools import partial

import dataclasses
from dataclasses import MISSING

import jax

def _cache_name(property_name):
    return "__" + property_name + "_cache"

def _raw_cache_name(cache_name):
    # removes leading __ and ending _cache
    return cache_name[2:-6]

## Our stuff
class _Uninitialized:
    Sentinel value used to signal uninitialized values

    def __repr__(self):
        return "Uninitialized"

Uninitialized = _Uninitialized()

    _Uninitialized, lambda x: ((), Uninitialized), lambda *args: Uninitialized

[docs] def field(pytree_node=True, serialize=True, cache=False, **kwargs): """Mark a field of a dataclass or PyTree to be: Args: pytree_node: a leaf node in the pytree representation of this dataclass. If False this must be hashable serialize: If True the node is included in the serialization. In general you should not specify this. cache: If True this node is a cache and will be reset every time fields are modified. """ return dataclasses.field( metadata={"pytree_node": pytree_node, "serialize": serialize, "cache": cache}, **kwargs, )
def static_field(**kwargs): return field(pytree_node=False, **kwargs) class CachedProperty: """Sentinel attribute wrapper to signal that a method is a property but must be cached. """ def __init__(self, method, pytree_node=False): = method.__name__ self.cache_name = _cache_name( self.method = method self.pytree_node = pytree_node self.type = method.__annotations__.get("return", MISSING) self.doc = method.__doc__ if self.type is MISSING: raise TypeError( f"Cached property {method} requires a return type annotation." ) def __get__(self, obj, objtype=None): val = getattr(obj, self.cache_name, Uninitialized) if val is Uninitialized: val = self.method(obj) setattr(obj, self.cache_name, val) return val def __repr__(self): return ( f"CachedProperty(name={}, " f"type={self.type}, pytree_node={self.pytree_node})" ) def property_cached(fun=None, pytree_node=False): """Decorator to make the method behave as a property but cache the resulting value and clears it upon replace. Args: pytree_node: a leaf node in the pytree representation of this dataclass. If False this must be hashable """ if fun is None: return partial(property_cached, pytree_node=pytree_node) return CachedProperty(fun, pytree_node=pytree_node)