netket.errors.NetKetPyTreeUndeclaredAttributeAssignmentError

netket.errors.NetKetPyTreeUndeclaredAttributeAssignmentError#

exception netket.errors.NetKetPyTreeUndeclaredAttributeAssignmentError[source]#

Error thrown when trying to assign an undeclared attribute to a NetKet-style Pytree.

This error is thrown when you try to assign an attribute to a NetKet-style Pytree ( a class inheriting from netket.utils.struct.Pytree) that was not declared in the class definition.

This error is thrown to prevent you from accidentally creating a new attribute that is not part of the Pytree structure, which would lead to unexpected behaviour.

To fix this error, you should declare the attribute in the class definition, as shown in the example below:

from netket.utils import struct

class MyPytree(struct.Pytree):
    # This line below was probably missing in your class definition
    my_attribute: Any

    def __init__(self, my_attribute):
        self.my_attribute = my_attribute

Note that if the field is not a jax-Array or another Pytree, you should instead declare it as a static or non-node field, which will trigger recompilation when passed to jax functions, as shown below:

from netket.utils import struct
import jax

class MyPytree(struct.Pytree):
    my_dynamic_attribute: jax.Array
    my_static_attribute : int = struct.field(pytree_node=False)

    def __init__(self, dyn_val, static_val):
        self.my_dynamic_attribute = dyn_val
        self.my_static_attribute = static_val

leafs, structure = jax.tree.flatten(MyPytree(1, 2))
print(leafs)
# [1]
print(structure)
# PyTreeDef(CustomNode(MyPytree[{'_pytree__node_fields': ('my_dynamic_attribute',), 'my_static_attribute': 2}], [*]))

From which you should see that the list of leafs contains only the dynamic attribute, while the structure, which is static information, holds the static attribute.

Note

You can also declare the class to have dynamic nodes, in which case the attributes will be inferred automatically during the class initialization. However, this is not recomended as it is impossible to declare nodes that are not jax-Arrays or Pytrees.

Regardless, if you wish to do so you should declare the class as follows:

from netket.utils import struct

class MyPytree(struct.Pytree, dynamic_nodes=True):
    def __init__(self, my_attribute):
        self.my_attribute = my_attribute