netket.utils.struct.ShardedFieldSpec#
- class netket.utils.struct.ShardedFieldSpec[source]#
Bases:
objectSpecification of a sharded field.
Used to specify (for the time being) how to handle serialization/deserialization (using flax.serialize) of a field in a NetKet-style Pytree.
- Inheritance

- Attributes
- deserialization_function: ShardedDeserializationFunction | str | None = 'relaxed'#
Function to use to deserialize the data. Can be a callable with the signature:
def f(value_target: jax.Array, value_state: jax.Array, *, name: str = ".") -> jax.Array
or one of the following strings:
“fail”: Raise an error if the sharded data does not match the target data.
“strict”: Raise an error if the sharded data does not match the target data.
“relaxed”: Ignore extra data in the sharded data if the target data is smaller than the serialized data; error if the target data is larger.
“relaxed-ignore-errors”: Ignore extra data in the sharded data if the target data is smaller, and do nothing if the target data is larger.
“relaxed-rng-key”: Special case for RNG keys, where we can safely truncate the serialized data if the target data is larger.
The default is “relaxed”.