Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytensor
import pytensor.tensor as pt

from arviz import InferenceData
from numpy import random as nr
from numpy import testing as npt
from pytensor.compile.mode import Mode
Expand All @@ -31,6 +32,7 @@
from pytensor.tensor.random.op import RandomVariable
from scipy import special as sp
from scipy import stats as st
from xarray import DataArray

import pymc as pm

Expand Down Expand Up @@ -982,3 +984,104 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
rvs = rvs_in_graph(vars)
if rvs:
raise AssertionError(f"RV found in graph: {rvs}")


def mock_sample(draws: int = 10, **kwargs):
"""Mock the pm.sample function by returning prior predictive samples as posterior.

Useful for testing models that use pm.sample without running MCMC sampling.

Examples
--------
Using mock_sample with pytest

.. code-block:: python

import pytest

import pymc as pm
from pymc.testing import mock_sample


@pytest.fixture(scope="module")
def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample

yield

pm.sample = original_sample

"""
random_seed = kwargs.get("random_seed", None)
model = kwargs.get("model", None)
draws = kwargs.get("draws", draws)
n_chains = kwargs.get("chains", 1)
idata: InferenceData = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=draws,
)

expanded_chains = DataArray(
np.ones(n_chains),
coords={"chain": np.arange(n_chains)},
)
idata.add_groups(
posterior=(idata["prior"].mean("chain") * expanded_chains).transpose("chain", "draw", ...),
)
del idata["prior"]
if "prior_predictive" in idata:
del idata["prior_predictive"]
return idata


def mock_sample_setup_and_teardown():
"""Set up and tear down mocking of PyMC sampling functions for testing.

This function is designed to be used with pytest fixtures to temporarily replace
PyMC's sampling functionality with faster alternatives for testing purposes.

Effects during the fixture's active period:
* Replaces pm.sample with mock_sample, which uses prior predictive sampling
instead of MCMC
* Replaces pm.Flat with pm.Normal to avoid issues with unbounded priors
* Replaces pm.HalfFlat with pm.HalfNormal to avoid issues with semi-bounded priors
* Automatically restores all original functions after the test completes

Examples
--------
.. code-block:: python

import pytest
import pymc as pm
from pymc.testing import mock_sample_setup_and_teardown

# Register as a pytest fixture
mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)


# Use in a test function
def test_model_inference(mock_pymc_sample):
with pm.Model() as model:
x = pm.Normal("x", 0, 1)
# This will use mock_sample instead of actual MCMC
idata = pm.sample()
# Test with the inference data...

"""
import pymc as pm

original_flat = pm.Flat
original_half_flat = pm.HalfFlat
original_sample = pm.sample

pm.sample = mock_sample
pm.Flat = pm.Normal
pm.HalfFlat = pm.HalfNormal

yield

pm.sample = original_sample
pm.Flat = original_flat
pm.HalfFlat = original_half_flat
49 changes: 48 additions & 1 deletion tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

import pytest

from pymc.testing import Domain
import pymc as pm

from pymc.testing import Domain, mock_sample, mock_sample_setup_and_teardown
from tests.models import simple_normal


@pytest.mark.parametrize(
Expand All @@ -32,3 +35,47 @@
def test_domain(values, edges, expectation):
with expectation:
Domain(values, edges=edges)


@pytest.mark.parametrize(
"args, kwargs, expected_draws",
[
pytest.param((), {}, 10, id="default"),
pytest.param((100,), {}, 100, id="positional-draws"),
pytest.param((), {"draws": 100}, 100, id="keyword-draws"),
],
)
def test_mock_sample(args, kwargs, expected_draws) -> None:
_, model, _ = simple_normal(bounded_prior=True)

with model:
idata = mock_sample(*args, **kwargs)

assert "posterior" in idata
assert "observed_data" in idata
assert "prior" not in idata
assert "posterior_predictive" not in idata
assert "sample_stats" not in idata

assert idata.posterior.sizes == {"chain": 1, "draw": expected_draws}


mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)


@pytest.fixture(scope="function")
def dummy_model() -> pm.Model:
with pm.Model() as model:
pm.Flat("flat")
pm.HalfFlat("half_flat")

return model


def test_fixture(mock_pymc_sample, dummy_model) -> None:
with dummy_model:
idata = pm.sample()

posterior = idata.posterior
assert posterior.sizes == {"chain": 1, "draw": 10}
assert (posterior["half_flat"] >= 0).all()