Timing and Profiling in NetKet#
NetKet provides a built-in timing system in netket.utils that helps profile your code and understand where time is being spent. This is particularly useful when optimizing performance or debugging slow computations.
The timing system is designed to work seamlessly with JAX and provides hierarchical timing information with a beautiful output format.
Basic Usage#
Context Managers#
The simplest way to time code is using the Timer class as a context manager:
import netket as nk
from netket.utils import timing
import time
import jax
import jax.numpy as jnp
# Basic timing with Timer
with timing.Timer() as timer:
time.sleep(0.1) # Simulate some work
# Nested timing with timed_scope
with timing.timed_scope("matrix multiplication"):
a = jnp.ones((100, 100))
b = jnp.ones((100, 100))
result = a @ b
# Important for JAX: block until computation is done
timer.block_until_ready(result)
with timing.timed_scope("more work"):
time.sleep(0.05)
print(timer)
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.230 โ
โ โโโ (30.5%) | matrix multiplication : 0.070 s โ
โ โโโ (23.9%) | more work : 0.055 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Hierarchical Timing#
You can create nested timing structures by combining multiple timing scopes:
# Nested timing example
with timing.Timer() as timer:
with timing.timed_scope("setup"):
time.sleep(0.02)
with timing.timed_scope("computation"):
with timing.timed_scope("part 1"):
time.sleep(0.01)
with timing.timed_scope("part 2"):
time.sleep(0.01)
print(timer)
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.050 โ
โ โโโ (50.0%) | setup : 0.025 s โ
โ โโโ (50.0%) | computation : 0.025 s โ
โ โโโ (50.0%) | part 1 : 0.013 s โ
โ โโโ (50.0%) | part 2 : 0.013 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Using timed_scope with Force#
The timed_scope() context manager is perfect for timing specific sections of code within a larger timing context:
# Using timed_scope with force=True to enable timing even without a parent timer
with timing.timed_scope("main computation", force=True) as timer:
# Some initial setup
key = jax.random.key(42)
data = jax.random.normal(key, (1000, 1000))
timer.block_until_ready(data)
with timing.timed_scope("eigenvalue decomposition"):
eigenvals = jnp.linalg.eigvals(data)
timer.block_until_ready(eigenvals)
with timing.timed_scope("statistical analysis"):
mean_val = jnp.mean(eigenvals)
std_val = jnp.std(eigenvals)
timer.block_until_ready((mean_val, std_val))
print(timer)
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.538 โ
โ โโโ (67.2%) | eigenvalue decomposition : 0.362 s โ
โ โโโ (11.9%) | statistical analysis : 0.064 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Using the @timed Decorator#
The timed() decorator allows you to automatically time function calls. This is especially useful for timing functions that are called multiple times:
@timing.timed(name="expensive_computation")
def expensive_function(x):
"""A function that does some expensive computation."""
result = jnp.sin(x) * jnp.cos(x) + jnp.exp(-x**2)
return jnp.sum(result)
@timing.timed(name="data_processing")
def process_data(data):
"""Process some data."""
return jnp.fft.fft(data)
# The decorated functions will only be timed when inside a timing context
with timing.Timer() as timer:
# Call the functions multiple times
for i in range(3):
x = jnp.linspace(0, 10, 1000)
result1 = expensive_function(x)
key = jax.random.key(i)
data = jax.random.normal(key, (512,))
result2 = process_data(data)
# Block until JAX computations are complete
timer.block_until_ready((result1, result2))
print(timer)
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.269 โ
โ โโโ (53.7%) | expensive_computation : 0.144 s โ
โ โโโ (16.2%) | data_processing : 0.043 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
JAX Gotchas: JIT and Block-Until-Ready#
JIT Compilation#
Timing decorators and context managers will not work inside JIT-compiled functions. JAX compilation strips out Python operations that arenโt compatible with compilation, so timing calls inside JIT functions will not report any information and may cause errors.
# DON'T: Use timing inside a JIT function - this won't work
@jax.jit
def bad_jitted_function(x):
with timing.timed_scope("inside jit"): # This will be ignored/cause errors
return jnp.sum(x**2)
# DO: Time the JIT function from outside
@timing.timed(name="jitted_computation")
@jax.jit
def good_jitted_function(x):
return jnp.sum(x**2)
# Or time the call to the JIT function
@jax.jit
def my_jitted_function(x):
return jnp.sum(x**2)
x = jnp.ones((1000, 1000))
# This will show no timing information from inside the JIT function
print("Bad example (no timing info):")
with timing.Timer() as timer_bad:
result = bad_jitted_function(x)
timer_bad.block_until_ready(result)
print(timer_bad)
# This will properly time the JIT function
print("\nGood example (proper timing):")
with timing.Timer() as timer_good:
with timing.timed_scope("jitted function call"):
result = my_jitted_function(x)
timer_good.block_until_ready(result)
print(timer_good)
Bad example (no timing info):
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.030 โ
โ โโโ (1.5%) | inside jit : 0.000 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Good example (proper timing):
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.031 โ
โ โโโ (99.9%) | jitted function call : 0.031 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Block-Until-Ready#
When timing JAX functions, itโs crucial to use timer.block_until_ready() to get accurate timing results. JAX uses lazy evaluation, so computations might not happen immediately when the function is called.
# Demonstration of why block_until_ready is important
print("Without block_until_ready (inaccurate timing):")
with timing.Timer() as timer_bad:
with timing.timed_scope("jax computation"):
large_matrix = jnp.ones((2000, 2000))
result = jnp.linalg.inv(large_matrix)
# Not blocking - timing will be inaccurate!
print(timer_bad)
print("\nWith block_until_ready (accurate timing):")
with timing.Timer() as timer_good:
with timing.timed_scope("jax computation"):
large_matrix = jnp.ones((2000, 2000))
result = jnp.linalg.inv(large_matrix)
# Properly blocking until computation is done
timer_good.block_until_ready(result)
print(timer_good)
Without block_until_ready (inaccurate timing):
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.068 โ
โ โโโ (100.0%) | jax computation : 0.068 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
With block_until_ready (accurate timing):
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.205 โ
โ โโโ (100.0%) | jax computation : 0.205 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Interaction with NetKet Drivers#
Real-World Example: Timing a NetKet Simulation#
Hereโs how the timing system is used in actual NetKet simulations:
# Create a simple quantum system
L = 8
g = nk.graph.Chain(length=L, pbc=True)
hi = nk.hilbert.Spin(s=0.5, N=g.n_nodes)
# Define the Hamiltonian
ha = nk.operator.Ising(hilbert=hi, graph=g, h=1.0)
# Define the variational ansatz
model = nk.models.RBM(alpha=1)
sampler = nk.sampler.MetropolisLocal(hi)
optimizer = nk.optimizer.Sgd(learning_rate=0.1)
# Create the variational state
vs = nk.vqs.MCState(sampler, model, n_samples=1000)
# Time the creation and execution of a VMC driver
with timing.Timer() as timer:
with timing.timed_scope("driver setup"):
driver = nk.driver.VMC(ha, optimizer, variational_state=vs)
with timing.timed_scope("optimization"):
# Run a few optimization steps with timing enabled
driver.run(n_iter=5, timeit=True)
print("\nTotal timing breakdown:")
print(timer)
/Users/filippo.vicentini/Nextcloud/Codes/Python/netket/netket/vqs/mc/mc_state/state.py:299: UserWarning: n_samples=1000 (250 per JAX device) does not divide n_chains=64, increased to 1024 (256 per JAX device)
self.n_samples = n_samples
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.728 โ
โ โโโ (85.1%) | VMC._forward_and_backward : 0.619 s โ
โ โโโ (97.6%) | MCState.expect_and_grad : 0.604 s โ
โ โโโ (69.3%) | MCState.sample : 0.419 s โ
โ โโโ (46.2%) | sampling n_discarded samples : 0.193 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Total timing breakdown:
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.843 โ
โ โโโ (100.0%) | optimization : 0.843 s โ
โ โโโ (86.3%) | โ
โ /Users/filippo.vicentini/Nextcloud/Codes/Python/netket/netket/driver/abstract_variational_driver.py:336 โ
โ : 0.728 s โ
โ โโโ (85.1%) | VMC._forward_and_backward : 0.619 s โ
โ โโโ (97.6%) | MCState.expect_and_grad : 0.604 s โ
โ โโโ (69.3%) | MCState.sample : 0.419 s โ
โ โโโ (46.2%) | sampling n_discarded samples : 0.193 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
Timing Custom Observable Calculations#
You can also use the timing system to profile custom observable calculations within your NetKet workflows:
# Example of how timing is used internally (simplified version)
@timing.timed(name="estimate observables")
def estimate_observables(state, observables):
"""This mimics how NetKet drivers time observable estimation."""
results = {}
for name, obs in observables.items():
with timing.timed_scope(f"observable: {name}"):
results[name] = state.expect(obs)
return results
# Demonstrate the pattern
observables = {"energy": ha, "magnetization": nk.operator.spin.sigmax(hi, 0)}
with timing.Timer() as timer:
for i in range(3):
with timing.timed_scope(f"iteration {i}"):
estimates = estimate_observables(vs, observables)
print(timer)
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.516 โ
โ โโโ (99.1%) | iteration 0 : 0.511 s โ
โ โโโ (100.0%) | estimate observables : 0.511 s โ
โ โโโ (19.0%) | observable: energy : 0.097 s โ
โ โ โโโ (100.0%) | MCState.expect : 0.097 s โ
โ โ โโโ (3.2%) | MCState.sample : 0.003 s โ
โ โ โโโ (26.0%) | sampling n_discarded samples : 0.001 s โ
โ โโโ (81.0%) | observable: magnetization : 0.414 s โ
โ โโโ (100.0%) | MCState.expect : 0.414 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ
# The driver.run() method with timeit=True will automatically show timing information
# This is implemented using the same timing system we've been exploring
# Example of how timing is used internally (simplified version)
@timing.timed(name="estimate observables")
def estimate_observables(state, observables):
"""This mimics how NetKet drivers time observable estimation."""
results = {}
for name, obs in observables.items():
with timing.timed_scope(f"observable: {name}"):
results[name] = state.expect(obs)
return results
# Demonstrate the pattern
observables = {"energy": ha, "magnetization": nk.operator.spin.sigmax(hi, 0)}
with timing.Timer() as timer:
for i in range(3):
with timing.timed_scope(f"iteration {i}"):
estimates = estimate_observables(vs, observables)
print(timer)
โญโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Timing Information โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ Total: 0.007 โ
โ โโโ (39.0%) | iteration 0 : 0.003 s โ
โ โ โโโ (97.9%) | estimate observables : 0.003 s โ
โ โ โโโ (42.3%) | observable: energy : 0.001 s โ
โ โ โ โโโ (99.3%) | MCState.expect : 0.001 s โ
โ โ โโโ (55.7%) | observable: magnetization : 0.002 s โ
โ โ โโโ (99.4%) | MCState.expect : 0.002 s โ
โ โโโ (29.4%) | iteration 1 : 0.002 s โ
โ โ โโโ (97.7%) | estimate observables : 0.002 s โ
โ โ โโโ (53.0%) | observable: energy : 0.001 s โ
โ โ โ โโโ (99.4%) | MCState.expect : 0.001 s โ
โ โ โโโ (44.9%) | observable: magnetization : 0.001 s โ
โ โ โโโ (99.0%) | MCState.expect : 0.001 s โ
โ โโโ (31.0%) | iteration 2 : 0.002 s โ
โ โโโ (97.5%) | estimate observables : 0.002 s โ
โ โโโ (50.7%) | observable: energy : 0.001 s โ
โ โ โโโ (99.2%) | MCState.expect : 0.001 s โ
โ โโโ (46.2%) | observable: magnetization : 0.001 s โ
โ โโโ (98.6%) | MCState.expect : 0.001 s โ
โฐโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฏ