Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,6 @@ def incsubtensor_rv_replace(fgraph, node):
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
)
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
# Split max_and_argmax
logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic")

# These rewrites convert un-measurable variables into their measurable forms,
# but they need to be reapplied, because some of the measurable forms require
Expand All @@ -376,6 +374,12 @@ def incsubtensor_rv_replace(fgraph, node):

logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic")

# Split max_and_argmax
# We only register this in the measurable IR db because max does not have a grad implemented
# And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251
# This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed
measurable_ir_rewrites_db.register("local_max_and_argmax", local_max_and_argmax, "basic")

# These rewrites push random/measurable variables "down", making them closer to
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
Expand Down
16 changes: 16 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import pymc as pm

from pymc import logp
from pymc.logprob import conditional_logp
from pymc.testing import assert_no_rvs


Expand Down Expand Up @@ -293,3 +294,18 @@ def test_min_max_bernoulli():
min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value))
np.testing.assert_allclose(min_logp_fn(1), np.log(p**n))
np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))


def test_non_measurable_max_grad():
# Regression test for https://github.com/pymc-devs/pytensor/issues/711
x = pt.random.normal(0, 1, size=(3,))
max_x = x.max()
y = pt.random.normal(max_x, 1)

x_vv = x.type()
y_vv = y.type()
logp_terms = conditional_logp({x: x_vv, y: y_vv}).values()
joint_logp = pt.sum([term.sum() for term in logp_terms])

# Test that calling gradient does not raise a NotImplementedError
assert pt.grad(joint_logp, x_vv)