Skip to content
2 changes: 1 addition & 1 deletion conda-envs/environment-alternative-backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- cloudpickle
- zarr>=2.5.0,<3
- numba
- nutpie >= 0.13.4
- nutpie >= 0.15.1
# Jaxlib version must not be greater than jax version!
- blackjax>=1.2.2
- jax>=0.4.28
Expand Down
7 changes: 2 additions & 5 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,15 @@ def _sample_external_nuts(
"`idata_kwargs` are currently ignored by the nutpie sampler",
UserWarning,
)
if var_names is not None:
warnings.warn(
"`var_names` are currently ignored by the nutpie sampler",
UserWarning,
)

compile_kwargs = {}
nuts_sampler_kwargs = nuts_sampler_kwargs.copy()
for kwarg in ("backend", "gradient_backend"):
if kwarg in nuts_sampler_kwargs:
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
compiled_model = nutpie.compile_pymc_model(
model,
var_names=var_names,
**compile_kwargs,
)
t_start = time.time()
Expand Down
55 changes: 54 additions & 1 deletion tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import numpy as np
import numpy.testing as npt
import pytest
import xarray as xr

from pymc import Data, Model, Normal, sample
from pymc import Data, Deterministic, HalfNormal, Model, Normal, sample


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
Expand Down Expand Up @@ -86,3 +87,55 @@ def test_step_args():
)

npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1)


@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"])
def test_sample_var_names(nuts_sampler):
seed = 1234
kwargs = {
"chains": 1,
"tune": 100,
"draws": 100,
"random_seed": seed,
"progressbar": False,
"compute_convergence_checks": False,
}

# Generate data
rng = np.random.default_rng(seed)

group = rng.choice(list("ABCD"), size=100)
x = rng.normal(size=100)
y = rng.normal(size=100)

group_values, group_idx = np.unique(group, return_inverse=True)

coords = {"group": group_values}

# Create model
with Model(coords=coords) as model:
b_group = Normal("b_group", dims="group")
b_x = Normal("b_x")
mu = Deterministic("mu", b_group[group_idx] + b_x * x)
sigma = HalfNormal("sigma")
Normal("y", mu=mu, sigma=sigma, observed=y)

free_RVs = [var.name for var in model.free_RVs]

with model:
# Sample with and without var_names, but always with the same seed
idata_1 = sample(nuts_sampler=nuts_sampler, **kwargs)
# Remove the last free RV from the sampling
idata_2 = sample(nuts_sampler=nuts_sampler, var_names=free_RVs[:-1], **kwargs)

assert "mu" in idata_1.posterior
assert "mu" not in idata_2.posterior

assert free_RVs[-1] in idata_1.posterior
assert free_RVs[-1] not in idata_2.posterior

for var in free_RVs[:-1]:
assert var in idata_1.posterior
assert var in idata_2.posterior

xr.testing.assert_allclose(idata_1.posterior[var], idata_2.posterior[var])