From 35b02f3e351b853b15342d99f08feef2aef31ece Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 16 Apr 2024 17:35:08 +0200 Subject: [PATCH] Fix gradient bug in models with max operation --- pymc/logprob/rewriting.py | 8 ++++++-- tests/logprob/test_order.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index 055516d197..e5395e21d1 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -365,8 +365,6 @@ def incsubtensor_rv_replace(fgraph, node): "local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic" ) logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic") -# Split max_and_argmax -logprob_rewrites_db.register("local_max_and_argmax", out2in(local_max_and_argmax), "basic") # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require @@ -376,6 +374,12 @@ def incsubtensor_rv_replace(fgraph, node): logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic") +# Split max_and_argmax +# We only register this in the measurable IR db because max does not have a grad implemented +# And running this on any MaxAndArgmax would lead to issues: https://github.com/pymc-devs/pymc/issues/7251 +# This special registering can be removed after https://github.com/pymc-devs/pytensor/issues/334 is fixed +measurable_ir_rewrites_db.register("local_max_and_argmax", local_max_and_argmax, "basic") + # These rewrites push random/measurable variables "down", making them closer to # (or eventually) the graph outputs. Often this is done by lifting other `Op`s # "up" through the random/measurable variables and into their inputs. diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 4d15240375..54299fdcb4 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -45,6 +45,7 @@ import pymc as pm from pymc import logp +from pymc.logprob import conditional_logp from pymc.testing import assert_no_rvs @@ -293,3 +294,18 @@ def test_min_max_bernoulli(): min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value)) np.testing.assert_allclose(min_logp_fn(1), np.log(p**n)) np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n)) + + +def test_non_measurable_max_grad(): + # Regression test for https://github.com/pymc-devs/pytensor/issues/711 + x = pt.random.normal(0, 1, size=(3,)) + max_x = x.max() + y = pt.random.normal(max_x, 1) + + x_vv = x.type() + y_vv = y.type() + logp_terms = conditional_logp({x: x_vv, y: y_vv}).values() + joint_logp = pt.sum([term.sum() for term in logp_terms]) + + # Test that calling gradient does not raise a NotImplementedError + assert pt.grad(joint_logp, x_vv)