NetKet’s tests are written using the PyTest
In particular, we make extensive use of
pytest.parametrize in order to run the same
test functions on different input types.
We also use often
to initialize expensive objects only once among different tests.
Test structure and common files.#
Test folder is a python module because every folder (aka, submodule) has an
Tests are grouped in submodules according to the NetKet’s submodule they tests.
If you add a new submodule, you should add a new submodule to the tests too.
If you add a new file to netket, if might be a good idea to split it’s tests into a new file in the relevant submodule, too.
Common functions and methods used throughout our testing infrastructure are defined in the file
test/common.py and every test is expected to use them if necessary.
Some common fixtures are also defined inside
test/conftest.py and are available to all tests.
Those do not need to be imported explicitly, as pytest will take care of it.
Tests and MPI#
Tests are expected to run with or without mpi4py installed, under MPI and not under MPI.
Therefore you should never import
mpi4jax in the global test module, but only
inside individual tests.
Tests not testing MPI-related functionality should be skipped when executed under MPI. To mark a whole module to be skipped under mpi, you can define the following variable
# test/hilbert/test_tensor.py from .. import common pytestmark = common.skipif_mpi def test_tensors(): ...
Alternatively, you can skip individual tests by decorating them with
@common.skipif_mpi def my_serial_test(): ...
To execute a test only when run with MPI, you can use the decorator
common.onlyif_mpi in the
same way as shown in the two examples above.
If, inside your tests, you need to run some NetKet functions with MPI disabled, for example to
check that the MPI code gives the same result as the non-MPI code, you can use the
netket_disable_mpi object as follows:
from .. import common @onlyif_mpi def test_matmul(): ... x_mpi = A@v with common.netket_disable_mpi(): x_serial = A@v np.testing.assert_allclose(x_mpi, x_serial)
For simplicity, you can also use the fixtures
_mpi_comm as inputs to your test
functions to get easily those information. See the example below:
# run this with mpi and not def test_mpi_things(_mpi_rank, _mpi_size): if _mpi_size == 1: # mpi disabled else: # mpi enabled