Skip to content
Merged
52 changes: 52 additions & 0 deletions pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from collections.abc import Callable, Sequence
from typing import Any

from arviz import InferenceData
from xarray import DataArray
import numpy as np
import pytensor
import pytensor.tensor as pt
Expand Down Expand Up @@ -982,3 +984,53 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
rvs = rvs_in_graph(vars)
if rvs:
raise AssertionError(f"RV found in graph: {rvs}")


def mock_sample(*args, **kwargs):
"""Mock the pm.sample function by returning prior predictive samples as posterior.

Useful for testing models that use pm.sample without running MCMC sampling.

Examples
--------
Using mock_sample with pytest

.. code-block:: python

import pytest

import pymc as pm
from pymc.testing import mock_sample


@pytest.fixture(scope="module")
def mock_pymc_sample():
original_sample = pm.sample
pm.sample = mock_sample

yield

pm.sample = original_sample

"""
random_seed = kwargs.get("random_seed", None)
model = kwargs.get("model", None)
draws = kwargs.get("draws", 10)
n_chains = kwargs.get("chains", 1)
idata: InferenceData = pm.sample_prior_predictive(
model=model,
random_seed=random_seed,
draws=draws,
)

expanded_chains = DataArray(
np.ones(n_chains),
coords={"chain": np.arange(n_chains)},
)
idata.add_groups(
posterior=(idata.prior.mean("chain") * expanded_chains).transpose("chain", "draw", ...)
)
del idata.prior
if "prior_predictive" in idata:
del idata.prior_predictive
return idata
Loading