netket.utils.struct.Pytree

netket.utils.struct.Pytree#

class netket.utils.struct.Pytree[source]#

Bases: object

Astract Base class for jaw-aware PyTree classes.

A class inheriting from PyTree can be passed to a jax function as a standard argument, and can contain both static and dynamic fields. Those will be correctly handled when flattening the PyTree.

Static fields must be specified as class attributes, by specifying the nk.utils.struct.field(pytree_node=False)().

Example

Construct a PyTree with a ‘constant’ value

>>> from netket.utils.struct import field, Pytree
>>> import jax
>>>
>>> class MyPyTree(Pytree):
...     a: int = field(pytree_node=False)
...     b: jax.Array
...     def __init__(self, a, b):
...         self.a = a
...         self.b = b
...     def __repr__(self):
...         return f"MyPyTree(a={self.a}, b={self.b})"
>>>
>>> my_pytree = MyPyTree(1, jax.numpy.ones(2))
>>> jax.jit(lambda x: print(x))(my_pytree)  
    MyPyTree(a=1, b=Traced...

PyTree classes by default are not mutable, therefore they behave similarly to frozen dataclasses. If you want to make a PyTree mutable, you can specify the mutable=True argument in the class definition.

Example

>>> from netket.utils.struct import field, Pytree
>>>
>>> class MyPyTree(Pytree, mutable=True):
...     a: int = field(pytree_node=False)
...     ...

By default only the fields declared as class attributes can be set and/or modified after initialization. If you want to allow the creation of new fields after initialization, you can specify the dynamic_nodes=True argument in the class definition.

PyTree classes can also be inherited by a netket dataclass, in which case the dataclass will be initialized with the fields of the PyTree. However, this behaviour is deprecated and will be removed in the future. We suggest you to remove the @nk.utils.struct.dataclass decorator and simply define an __init__ method.

Inheritance
__init__()#
Methods
replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)