|
22 | 22 | import pytensor
|
23 | 23 | import pytensor.tensor as pt
|
24 | 24 |
|
| 25 | +from arviz import InferenceData |
25 | 26 | from numpy import random as nr
|
26 | 27 | from numpy import testing as npt
|
27 | 28 | from pytensor.compile.mode import Mode
|
@@ -982,3 +983,115 @@ def assert_no_rvs(vars: Sequence[Variable]) -> None:
|
982 | 983 | rvs = rvs_in_graph(vars)
|
983 | 984 | if rvs:
|
984 | 985 | raise AssertionError(f"RV found in graph: {rvs}")
|
| 986 | + |
| 987 | + |
| 988 | +def mock_sample(draws: int = 10, **kwargs): |
| 989 | + """Mock :func:`pymc.sample` with :func:`pymc.sample_prior_predictive`. |
| 990 | +
|
| 991 | + Useful for testing models that use pm.sample without running MCMC sampling. |
| 992 | +
|
| 993 | + Examples |
| 994 | + -------- |
| 995 | + Using mock_sample with pytest |
| 996 | +
|
| 997 | + .. note:: |
| 998 | +
|
| 999 | + Use :func:`pymc.testing.mock_sample_setup_and_teardown` directly for pytest fixtures. |
| 1000 | +
|
| 1001 | + .. code-block:: python |
| 1002 | +
|
| 1003 | + import pytest |
| 1004 | +
|
| 1005 | + import pymc as pm |
| 1006 | + from pymc.testing import mock_sample |
| 1007 | +
|
| 1008 | +
|
| 1009 | + @pytest.fixture(scope="module") |
| 1010 | + def mock_pymc_sample(): |
| 1011 | + original_sample = pm.sample |
| 1012 | + pm.sample = mock_sample |
| 1013 | +
|
| 1014 | + yield |
| 1015 | +
|
| 1016 | + pm.sample = original_sample |
| 1017 | +
|
| 1018 | + """ |
| 1019 | + random_seed = kwargs.get("random_seed", None) |
| 1020 | + model = kwargs.get("model", None) |
| 1021 | + draws = kwargs.get("draws", draws) |
| 1022 | + n_chains = kwargs.get("chains", 1) |
| 1023 | + idata: InferenceData = pm.sample_prior_predictive( |
| 1024 | + model=model, |
| 1025 | + random_seed=random_seed, |
| 1026 | + draws=draws, |
| 1027 | + ) |
| 1028 | + |
| 1029 | + idata.add_groups( |
| 1030 | + posterior=( |
| 1031 | + idata["prior"] |
| 1032 | + .isel(chain=0) |
| 1033 | + .expand_dims({"chain": range(n_chains)}) |
| 1034 | + .transpose("chain", "draw", ...) |
| 1035 | + ) |
| 1036 | + ) |
| 1037 | + del idata["prior"] |
| 1038 | + if "prior_predictive" in idata: |
| 1039 | + del idata["prior_predictive"] |
| 1040 | + return idata |
| 1041 | + |
| 1042 | + |
| 1043 | +def mock_sample_setup_and_teardown(): |
| 1044 | + """Set up and tear down mocking of PyMC sampling functions for testing. |
| 1045 | +
|
| 1046 | + This function is designed to be used with pytest fixtures to temporarily replace |
| 1047 | + PyMC's sampling functionality with faster alternatives for testing purposes. |
| 1048 | +
|
| 1049 | + Effects during the fixture's active period: |
| 1050 | +
|
| 1051 | + * Replaces :func:`pymc.sample` with :func:`pymc.testing.mock_sample`, which uses |
| 1052 | + prior predictive sampling instead of MCMC |
| 1053 | + * Replaces distributions: |
| 1054 | + * :class:`pymc.Flat` with :class:`pymc.Normal` |
| 1055 | + * :class:`pymc.HalfFlat` with :class:`pymc.HalfNormal` |
| 1056 | + * Automatically restores all original functions and distributions after the test completes |
| 1057 | +
|
| 1058 | + Examples |
| 1059 | + -------- |
| 1060 | + Use with `pytest` to mock actual PyMC sampling in test suite. |
| 1061 | +
|
| 1062 | + .. code-block:: python |
| 1063 | +
|
| 1064 | + # tests/conftest.py |
| 1065 | + import pytest |
| 1066 | + import pymc as pm |
| 1067 | + from pymc.testing import mock_sample_setup_and_teardown |
| 1068 | +
|
| 1069 | + # Register as a pytest fixture |
| 1070 | + mock_pymc_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown) |
| 1071 | +
|
| 1072 | +
|
| 1073 | + # tests/test_model.py |
| 1074 | + # Use in a test function |
| 1075 | + def test_model_inference(mock_pymc_sample): |
| 1076 | + with pm.Model() as model: |
| 1077 | + x = pm.Normal("x", 0, 1) |
| 1078 | + # This will use mock_sample instead of actual MCMC |
| 1079 | + idata = pm.sample() |
| 1080 | + # Test with the inference data... |
| 1081 | +
|
| 1082 | + """ |
| 1083 | + import pymc as pm |
| 1084 | + |
| 1085 | + original_flat = pm.Flat |
| 1086 | + original_half_flat = pm.HalfFlat |
| 1087 | + original_sample = pm.sample |
| 1088 | + |
| 1089 | + pm.sample = mock_sample |
| 1090 | + pm.Flat = pm.Normal |
| 1091 | + pm.HalfFlat = pm.HalfNormal |
| 1092 | + |
| 1093 | + yield |
| 1094 | + |
| 1095 | + pm.sample = original_sample |
| 1096 | + pm.Flat = original_flat |
| 1097 | + pm.HalfFlat = original_half_flat |
0 commit comments