diff --git a/pymc/sampling/deterministic.py b/pymc/sampling/deterministic.py index b300d5ee97..b0b04f38ec 100644 --- a/pymc/sampling/deterministic.py +++ b/pymc/sampling/deterministic.py @@ -84,6 +84,7 @@ def compute_deterministics( if var_names is None: deterministics = model.deterministics + var_names = [det.name for det in deterministics] else: deterministics = [model[var_name] for var_name in var_names] if not set(deterministics).issubset(set(model.deterministics)): @@ -101,7 +102,7 @@ def compute_deterministics( new_dataset = apply_function_over_dataset( fn, dataset[[rv.name for rv in model.free_RVs]], - output_var_names=[det.name for det in model.deterministics], + output_var_names=var_names, dims=dims, coords=coords, sample_dims=sample_dims, diff --git a/tests/sampling/test_deterministic.py b/tests/sampling/test_deterministic.py index f693a788c5..d1d2b8474c 100644 --- a/tests/sampling/test_deterministic.py +++ b/tests/sampling/test_deterministic.py @@ -59,6 +59,11 @@ def test_compute_deterministics(): assert extended_with_mu["mu"].dims == ("chain", "draw", "group") assert_allclose(extended_with_mu["mu"], dataset["mu_raw"].cumsum("group")) + only_sigma = compute_deterministics(dataset, var_names=["sigma"], model=m, progressbar=False) + assert set(only_sigma.data_vars.variables) == {"sigma"} + assert only_sigma["sigma"].dims == ("chain", "draw") + assert_allclose(only_sigma["sigma"], np.exp(dataset["sigma_raw"])) + def test_docstring_example(): import pymc as pm