diff --git a/conda-envs/environment-alternative-backends.yml b/conda-envs/environment-alternative-backends.yml index 5030e7bacf..fcf78c6991 100644 --- a/conda-envs/environment-alternative-backends.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -11,7 +11,7 @@ dependencies: - cloudpickle - zarr>=2.5.0,<3 - numba -- nutpie >= 0.13.4 +- nutpie >= 0.15.1 # Jaxlib version must not be greater than jax version! - blackjax>=1.2.2 - jax>=0.4.28 diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 76695d7603..d3a02b91b6 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -331,11 +331,7 @@ def _sample_external_nuts( "`idata_kwargs` are currently ignored by the nutpie sampler", UserWarning, ) - if var_names is not None: - warnings.warn( - "`var_names` are currently ignored by the nutpie sampler", - UserWarning, - ) + compile_kwargs = {} nuts_sampler_kwargs = nuts_sampler_kwargs.copy() for kwarg in ("backend", "gradient_backend"): @@ -343,6 +339,7 @@ def _sample_external_nuts( compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) compiled_model = nutpie.compile_pymc_model( model, + var_names=var_names, **compile_kwargs, ) t_start = time.time() diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 2d32277061..ed3fc09dd0 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -15,8 +15,9 @@ import numpy as np import numpy.testing as npt import pytest +import xarray as xr -from pymc import Data, Model, Normal, sample +from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample @pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) @@ -86,3 +87,55 @@ def test_step_args(): ) npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) + + +@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) +def test_sample_var_names(nuts_sampler): + seed = 1234 + kwargs = { + "chains": 1, + "tune": 100, + "draws": 100, + "random_seed": seed, + "progressbar": False, + "compute_convergence_checks": False, + } + + # Generate data + rng = np.random.default_rng(seed) + + group = rng.choice(list("ABCD"), size=100) + x = rng.normal(size=100) + y = rng.normal(size=100) + + group_values, group_idx = np.unique(group, return_inverse=True) + + coords = {"group": group_values} + + # Create model + with Model(coords=coords) as model: + b_group = Normal("b_group", dims="group") + b_x = Normal("b_x") + mu = Deterministic("mu", b_group[group_idx] + b_x * x) + sigma = HalfNormal("sigma") + Normal("y", mu=mu, sigma=sigma, observed=y) + + free_RVs = [var.name for var in model.free_RVs] + + with model: + # Sample with and without var_names, but always with the same seed + idata_1 = sample(nuts_sampler=nuts_sampler, **kwargs) + # Remove the last free RV from the sampling + idata_2 = sample(nuts_sampler=nuts_sampler, var_names=free_RVs[:-1], **kwargs) + + assert "mu" in idata_1.posterior + assert "mu" not in idata_2.posterior + + assert free_RVs[-1] in idata_1.posterior + assert free_RVs[-1] not in idata_2.posterior + + for var in free_RVs[:-1]: + assert var in idata_1.posterior + assert var in idata_2.posterior + + xr.testing.assert_allclose(idata_1.posterior[var], idata_2.posterior[var])