Skip to content

Commit dcb68e6

Browse files
committed
Fix independent_rvs determination in vectorize_over_posterior
1 parent 0960323 commit dcb68e6

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

pymc/sampling/forward.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,10 +1062,8 @@ def vectorize_over_posterior(
10621062
if rv in all_rvs
10631063
]:
10641064
rv_ancestors = ancestors([rv], blockers=[*needed_rvs, *independent_rvs, *outputs])
1065-
if (
1066-
rv not in needed_rvs
1067-
and not ({*outputs, *independent_rvs} & set(rv_ancestors))
1068-
and {var for var in rv_ancestors if var in all_rvs} <= {rv, *needed_rvs}
1065+
if rv not in needed_rvs and not (
1066+
{*outputs, *needed_rvs, *independent_rvs} & set(rv_ancestors)
10691067
):
10701068
independent_rvs.append(rv)
10711069
for rv in independent_rvs:

tests/sampling/test_forward.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample():
19581958
atol=0.6 / np.sqrt(10000),
19591959
)
19601960
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)
1961+
1962+
1963+
def test_vectorize_over_posterior_with_intermediate_rvs():
1964+
with pm.Model() as model:
1965+
a = pm.Normal("a")
1966+
b = pm.Normal.dist(a)
1967+
c = b + 1
1968+
d = pm.Normal.dist(c)
1969+
idata = pm.sample_prior_predictive(100, var_names=["a"])
1970+
idata.add_groups({"posterior": idata.prior})
1971+
_, _, vectorized_no_intermediate = vectorize_over_posterior(
1972+
outputs=[b, c, d],
1973+
posterior=idata.posterior,
1974+
input_rvs=[a],
1975+
allow_rvs_in_graph=True,
1976+
)
1977+
[vectorized_intermediate_rvs] = vectorize_over_posterior(
1978+
outputs=[d],
1979+
posterior=idata.posterior,
1980+
input_rvs=[a],
1981+
allow_rvs_in_graph=True,
1982+
)
1983+
assert vectorized_no_intermediate.type.shape == (1, 100)
1984+
assert vectorized_no_intermediate.type.shape == vectorized_intermediate_rvs.type.shape
1985+
a_ancestor1 = get_var_by_name([vectorized_no_intermediate], "a")[0]
1986+
a_ancestor2 = get_var_by_name([vectorized_intermediate_rvs], "a")[0]
1987+
assert isinstance(a_ancestor1, TensorConstant)
1988+
assert np.array_equiv(a_ancestor1.eval(), idata.posterior.a.data)
1989+
assert isinstance(a_ancestor2, TensorConstant)
1990+
assert np.array_equiv(a_ancestor2.eval(), idata.posterior.a.data)

0 commit comments

Comments
 (0)