-
Notifications
You must be signed in to change notification settings - Fork 139
Open
Labels
Description
The fact that we wrap a graph in OpFromGraph, doesn't mean we also want to wrap it's gradient. It can also lead to troubles as in pymc-devs/pymc#7657
I suggest adding a wrap_grad_in_ofg: bool
kwarg that toggles this behavior. The easiest way to get this change is to keep the behavior as is, but call the inline_ofg
rewrite manually after calling it here:
pytensor/pytensor/compile/builders.py
Lines 612 to 614 in 2f1d25a
connected_input_grads = iter( | |
lop_op(*inputs, *connected_outputs, *connected_output_grads, **kwargs) | |
) |
And here:
pytensor/pytensor/compile/builders.py
Line 689 in 2f1d25a
connected_output_grads = iter(rop_op(*inputs, **kwargs)) |
jessegrabowski