Writing Tests#

NetKet’s tests are written using the PyTest framework. In particular, we make extensive use of pytest.parametrize in order to run the same test functions on different input types. We also use often pytest.fixture to initialize expensive objects only once among different tests.

Test structure and common files.#

NetKet’s Test folder is a python module because every folder (aka, submodule) has an __init__.py file. Tests are grouped in submodules according to the NetKet’s submodule they tests. If you add a new submodule, you should add a new submodule to the tests too.

If you add a new file to netket, if might be a good idea to split it’s tests into a new file in the relevant submodule, too.

Common functions and methods used throughout our testing infrastructure are defined in the file test/common.py and every test is expected to use them if necessary. Some common fixtures are also defined inside test/conftest.py and are available to all tests. Those do not need to be imported explicitly, as pytest will take care of it.

Tests and MPI#

Tests are expected to run with or without mpi4py installed, under MPI and not under MPI. Therefore you should never import mpi4py or mpi4jax in the global test module, but only inside individual tests.

Tests not testing MPI-related functionality should be skipped when executed under MPI. To mark a whole module to be skipped under mpi, you can define the following variable

# test/hilbert/test_tensor.py

from .. import common

pytestmark = common.skipif_mpi

def test_tensors():

Alternatively, you can skip individual tests by decorating them with common.skipif_mpi.

def my_serial_test():

To execute a test only when run with MPI, you can use the decorator common.onlyif_mpi in the same way as shown in the two examples above.

If, inside your tests, you need to run some NetKet functions with MPI disabled, for example to check that the MPI code gives the same result as the non-MPI code, you can use the netket_disable_mpi object as follows:

from .. import common

def test_matmul():
	x_mpi = A@v

	with common.netket_disable_mpi():
		x_serial = A@v

		np.testing.assert_allclose(x_mpi, x_serial)

For simplicity, you can also use the fixtures _mpi_size, _mpi_rank and _mpi_comm as inputs to your test functions to get easily those information. See the example below:

# run this with mpi and not
def test_mpi_things(_mpi_rank, _mpi_size):
	if _mpi_size == 1:
		# mpi disabled
		# mpi enabled