diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index f1d69c9282..572241e49a 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -567,10 +567,8 @@ def __init__(self, f): def __call__(self, state): return self.f(**state) - def __getattr__(self, item): - """Allow access to the original function attributes.""" - # This is only reached if `__getattribute__` fails. - return getattr(self.f, item) + def dprint(self, **kwrags): + return self.f.dprint(**kwrags) class CallableTensor: diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index f7efd5f6d4..de7f7f1cbe 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -746,7 +746,7 @@ def test_hessian_sign_change_warning(func): assert equal_computations([res_neg], [-res]) -def test_point_func(): +def test_point_func(capsys): x, y = pt.vectors("x", "y") outs = x * 2 + y**2 f = compile([x, y], outs) @@ -758,3 +758,30 @@ def test_point_func(): dprint_res = point_f.dprint(file="str") expected_dprint_res = point_f.f.dprint(file="str") assert dprint_res == expected_dprint_res + + point_f.dprint(print_shape=True) + captured = capsys.readouterr() + + # The shape=(?,) arises because the inputs are dvector. This checks that the dprint works, and the print_shape + # kwargs was correctly forwarded + assert "shape=(?,)" in captured.out + + +def test_pickle_point_func(): + """ + Regression test for https://github.com/pymc-devs/pymc/issues/7857 + """ + import cloudpickle + + x, y = pt.vectors("x", "y") + outs = x * 2 + y**2 + f = compile([x, y], outs) + + point_f = PointFunc(f) + point_f_pickled = cloudpickle.dumps(point_f) + point_f_unpickled = cloudpickle.loads(point_f_pickled) + + # Check that the function survived the round-trip + np.testing.assert_allclose( + point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]}) + )