Source code for netket.utils.struct.pytree

import dataclasses
import inspect
import typing as tp
from abc import ABCMeta
from copy import copy
from functools import partial
from types import MappingProxyType

import jax

from .fields import CachedProperty, _cache_name, _raw_cache_name, Uninitialized
from netket.utils import config

P = tp.TypeVar("P", bound="Pytree")

DATACLASS_USER_INIT_N_ARGS = "_pytree_n_args_max"
"""
variable name used by dataclasses inheriting from a pytree to
store the topmost non-dataclass class in a mro.
"""


class PytreeMeta(ABCMeta):
    """
    Metaclass for PyTrees, takes care of initializing and turning
    frozen PyTrees to immutable after __init__.
    """

    def __call__(cls: type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
        obj: P = cls.__new__(cls, *args, **kwargs)
        obj.__dict__["_pytree__initializing"] = True
        try:
            obj.__init__(*args, **kwargs)
        finally:
            del obj.__dict__["_pytree__initializing"]

        vars_dict = vars(obj)
        if obj._pytree__class_dynamic_nodes:
            vars_dict["_pytree__node_fields"] = tuple(
                sorted(
                    field
                    for field in vars_dict
                    if field not in cls._pytree__static_fields
                )
            )
        else:
            vars_dict["_pytree__node_fields"] = cls._pytree__data_fields

        for field in obj._pytree__cachedprop_fields:
            vars_dict[field] = Uninitialized

        return obj


[docs] class Pytree(metaclass=PytreeMeta): """ Astract Base class for jaw-aware PyTree classes. A class inheriting from PyTree can be passed to a jax function as a standard argument, and can contain both static and dynamic fields. Those will be correctly handled when flattening the PyTree. Static fields must be specified as class attributes, by specifying the :func:`nk.utils.struct.field(pytree_node=False)`. Example: Construct a PyTree with a 'constant' value >>> from netket.utils.struct import field, Pytree >>> import jax >>> >>> class MyPyTree(Pytree): ... a: int = field(pytree_node=False) ... b: jax.Array ... def __init__(self, a, b): ... self.a = a ... self.b = b ... def __repr__(self): ... return f"MyPyTree(a={self.a}, b={self.b})" >>> >>> my_pytree = MyPyTree(1, jax.numpy.ones(2)) >>> jax.jit(lambda x: print(x))(my_pytree) # doctest:+ELLIPSIS MyPyTree(a=1, b=Traced... PyTree classes by default are not mutable, therefore they behave similarly to frozen dataclasses. If you want to make a PyTree mutable, you can specify the `mutable=True` argument in the class definition. Example: >>> from netket.utils.struct import field, Pytree >>> >>> class MyPyTree(Pytree, mutable=True): ... a: int = field(pytree_node=False) ... ... By default only the fields declared as class attributes can be set and/or modified after initialization. If you want to allow the creation of new fields after initialization, you can specify the `dynamic_nodes=True` argument in the class definition. PyTree classes can also be inherited by a netket dataclass, in which case the dataclass will be initialized with the fields of the PyTree. However, this behaviour is deprecated and will be removed in the future. We suggest you to remove the `@nk.utils.struct.dataclass` decorator and simply define an `__init__` method. """ _pytree__initializing: bool _pytree__class_is_mutable: bool _pytree__static_fields: tuple[str, ...] _pytree__node_fields: tuple[str, ...] _pytree__setter_descriptors: frozenset[str] _pytree__cachedprop_fields: tuple[str, ...] def __init_subclass__(cls, mutable: bool = False, dynamic_nodes: bool = False): super().__init_subclass__() # gather class info class_vars = vars(cls) setter_descriptors = set() static_fields = _inherited_static_fields(cls) # add special static fields static_fields.add("_pytree__node_fields") # new data_fields = _inherited_data_fields(cls) cached_prop_fields = set() for field, value in class_vars.items(): if isinstance(value, dataclasses.Field) and not value.metadata.get( "pytree_node", True ): static_fields.add(field) elif isinstance(value, CachedProperty): cached_prop_fields.add(field) elif isinstance(value, dataclasses.Field) and value.metadata.get( "pytree_node", True ): data_fields.add(field) # add setter descriptors if hasattr(value, "__set__"): setter_descriptors.add(field) for field in cached_prop_fields: # setattr(cls, _cache_name(field), Uninitialized) if class_vars[field].pytree_node: data_fields.add(_cache_name(field)) else: static_fields.add(_cache_name(field)) cached_prop_fields = cached_prop_fields.union( _inherited_cachedproperty_fields(cls) ) # If no annotations in this class, skip, otherwise we'd process # parent's annotations twice if "__annotations__" in cls.__dict__: # fields that are only type annotations, feed them forward for field, _ in cls.__annotations__.items(): if field not in static_fields and field not in data_fields: data_fields.add(field) if mutable and len(cached_prop_fields) != 0: raise ValueError("cannot use cached properties with " "mutable pytrees.") if config.netket_sphinx_build: for k in static_fields: try: delattr(cls, k) except AttributeError: pass for k in data_fields: try: delattr(cls, k) except AttributeError: pass # new init_fields = tuple(sorted(data_fields.union(static_fields))) data_fields = tuple(sorted(data_fields)) cached_prop_fields = tuple(sorted(cached_prop_fields)) cached_prop_fields = tuple(_cache_name(f) for f in cached_prop_fields) static_fields = tuple(sorted(static_fields)) # init class variables cls._pytree__initializing = False cls._pytree__class_is_mutable = mutable cls._pytree__static_fields = static_fields cls._pytree__setter_descriptors = frozenset(setter_descriptors) # new cls._pytree__class_dynamic_nodes = dynamic_nodes cls._pytree__data_fields = data_fields cls._pytree__cachedprop_fields = cached_prop_fields cls._pytree__init_fields = init_fields # TODO: clean up this in the future once minimal supported version is 0.4.7 if ( "flatten_func" in inspect.signature(jax.tree_util.register_pytree_with_keys).parameters ): jax.tree_util.register_pytree_with_keys( cls, partial( cls._pytree__flatten, with_key_paths=True, ), cls._pytree__unflatten, flatten_func=partial( cls._pytree__flatten, with_key_paths=False, ), ) else: jax.tree_util.register_pytree_with_keys( cls, partial( cls._pytree__flatten, with_key_paths=True, ), cls._pytree__unflatten, ) # flax serialization support from flax import serialization serialization.register_serialization_state( cls, partial(cls._to_flax_state_dict, cls._pytree__static_fields), partial(cls._from_flax_state_dict, cls._pytree__static_fields), ) def __pre_init__(self, *args, **kwargs): # Default implementation of __pre_init__, used by netket's # dataclasses for preinitialisation shuffling of parameters. # # This is necessary for PyTrees that are subclassed by a dataclass # (like a user-implemented sampler using legacy logic). # # This class takes out all arguments and kw-arguments that are # directed to the PyTree from a processing and 'hides' them # in a proprietary kwargument for later manipulation. # # This is necessary so we call the dataclass init only with # the arguments that it needs. # process keyword arguments kwargs_dataclass = {} kwargs_pytree = {} for k, v in kwargs.items(): if k in self.__dataclass_fields__.keys(): kwargs_dataclass[k] = v else: kwargs_pytree[k] = v # process positional args. Identify max positional arguments of the # topmost user defined init method max_pytree_args = getattr(self, DATACLASS_USER_INIT_N_ARGS, len(args)) n_args_pytree = min(len(args), max_pytree_args) # First n args are for the pytree initialiser (lower) and later # positional arguments are for the dataclass initializer args_pytree = args[:n_args_pytree] args_dataclass = args[n_args_pytree:] signature_pytree = (args_pytree, kwargs_pytree) kwargs_dataclass["__base_init_args"] = signature_pytree return args_dataclass, kwargs_dataclass def __post_init__(self): pass @classmethod def _pytree__flatten( cls, pytree: "Pytree", *, with_key_paths: bool, ) -> tuple[tuple[tp.Any, ...], tp.Mapping[str, tp.Any],]: all_vars = vars(pytree).copy() static = {k: all_vars.pop(k) for k in pytree._pytree__static_fields} if with_key_paths: node_values = tuple( (jax.tree_util.GetAttrKey(field), all_vars.pop(field)) for field in pytree._pytree__node_fields ) else: node_values = tuple( all_vars.pop(field) for field in pytree._pytree__node_fields ) if all_vars: raise ValueError( f"Unexpected fields in {cls.__name__}: {', '.join(all_vars.keys())}." "You cannot add new fields to a Pytree after it has been initialized." ) return node_values, MappingProxyType(static) @classmethod def _pytree__unflatten( cls: type[P], static_fields: tp.Mapping[str, tp.Any], node_values: tuple[tp.Any, ...], ) -> P: pytree = object.__new__(cls) pytree.__dict__.update(zip(static_fields["_pytree__node_fields"], node_values)) pytree.__dict__.update(static_fields) return pytree @classmethod def _to_flax_state_dict( cls, static_field_names: tuple[str, ...], pytree: "Pytree" ) -> dict[str, tp.Any]: from flax import serialization state_dict = { name: serialization.to_state_dict(getattr(pytree, name)) for name in pytree.__dict__ if name not in static_field_names } return state_dict @classmethod def _from_flax_state_dict( cls, static_field_names: tuple[str, ...], pytree: P, state: dict[str, tp.Any], ) -> P: """Restore the state of a data class.""" from flax import serialization state = state.copy() # copy the state so we can pop the restored fields. updates = {} for name in pytree.__dict__: if name in static_field_names: continue if name not in state: raise ValueError( f"Missing field {name} in state dict while restoring" f" an instance of {type(pytree).__name__}," f" at path {serialization.current_path()}" ) value = getattr(pytree, name) value_state = state.pop(name) updates[name] = serialization.from_state_dict(value, value_state, name=name) if state: names = ",".join(state.keys()) raise ValueError( f'Unknown field(s) "{names}" in state dict while' f" restoring an instance of {type(pytree).__name__}" f" at path {serialization.current_path()}" ) return pytree.replace(**updates)
[docs] def replace(self: P, **kwargs: tp.Any) -> P: """ Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, `dataclasses.replace` will be used. Otherwise, a new object will be created with the same type as the original object. """ if dataclasses.is_dataclass(self): pytree = dataclasses.replace(self, **kwargs) else: unknown_keys = set(kwargs) - set(vars(self)) if unknown_keys: raise ValueError( f"Trying to replace unknown fields {unknown_keys} " f"for '{type(self).__name__}'" ) pytree = copy(self) pytree.__dict__.update(kwargs) # Reset cached properties for fname in pytree._pytree__cachedprop_fields: setattr(pytree, fname, Uninitialized) return pytree
if not tp.TYPE_CHECKING: def __setattr__(self: P, field: str, value: tp.Any): if self._pytree__initializing: if self._pytree__class_dynamic_nodes: pass elif field not in self._pytree__init_fields: raise AttributeError( f"Cannot set field {field} in init that was not described " "as a class attribute above." ) else: if field in self._pytree__setter_descriptors: pass elif field in self._pytree__cachedprop_fields: pass elif not hasattr(self, field): raise AttributeError( f"Cannot add new fields to {type(self)} after initialization" ) elif not self._pytree__class_is_mutable: raise AttributeError( f"{type(self)} is immutable, trying to update field {field}" ) object.__setattr__(self, field, value)
def _inherited_static_fields(cls: type) -> set[str]: """ Returns the set of static fields of base classes of the input class """ static_fields = set() for parent_class in cls.mro(): if parent_class is not cls and parent_class is not Pytree: if issubclass(parent_class, Pytree): static_fields.update(parent_class._pytree__static_fields) elif dataclasses.is_dataclass(parent_class): for field in dataclasses.fields(parent_class): if not field.metadata.get("pytree_node", True): static_fields.add(field.name) return static_fields def _inherited_data_fields(cls: type) -> set[str]: """ Returns the set of data fields of base classes of the input class. """ data_fields = set() for parent_class in cls.mro(): if parent_class is not cls and parent_class is not Pytree: if issubclass(parent_class, Pytree): data_fields.update(parent_class._pytree__data_fields) elif dataclasses.is_dataclass(parent_class): for field in dataclasses.fields(parent_class): if field.metadata.get("pytree_node", True): data_fields.add(field.name) return data_fields def _inherited_cachedproperty_fields(cls: type) -> set[str]: """ Returns the set of cached properties of base classes of the input class. """ cachedproperty_fields = set() for parent_class in cls.mro(): if parent_class is not cls and parent_class is not Pytree: if issubclass(parent_class, Pytree): fields = tuple( _raw_cache_name(f) for f in parent_class._pytree__cachedprop_fields ) cachedproperty_fields.update(fields) elif dataclasses.is_dataclass(parent_class): for field in dataclasses.fields(parent_class): if isinstance(field, CachedProperty): cachedproperty_fields.add(field.name) return cachedproperty_fields