|
34 | 34 | from aesara import scalar
|
35 | 35 | from aesara.compile.mode import Mode, get_mode
|
36 | 36 | from aesara.gradient import grad
|
37 |
| -from aesara.graph import node_rewriter |
| 37 | +from aesara.graph import node_rewriter, rewrite_graph |
38 | 38 | from aesara.graph.basic import (
|
39 | 39 | Apply,
|
40 | 40 | Constant,
|
|
55 | 55 | RandomGeneratorSharedVariable,
|
56 | 56 | RandomStateSharedVariable,
|
57 | 57 | )
|
| 58 | +from aesara.tensor.rewriting.basic import topo_constant_folding |
| 59 | +from aesara.tensor.rewriting.shape import ShapeFeature |
58 | 60 | from aesara.tensor.sharedvar import SharedVariable
|
59 | 61 | from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
|
60 | 62 | from aesara.tensor.var import TensorConstant, TensorVariable
|
61 | 63 |
|
| 64 | +from pymc.exceptions import NotConstantValueError |
62 | 65 | from pymc.vartypes import continuous_types, isgenerator, typefilter
|
63 | 66 |
|
64 | 67 | PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable]
|
|
82 | 85 | "at_rng",
|
83 | 86 | "convert_observed_data",
|
84 | 87 | "compile_pymc",
|
| 88 | + "constant_fold", |
85 | 89 | ]
|
86 | 90 |
|
87 | 91 |
|
@@ -971,3 +975,30 @@ def compile_pymc(
|
971 | 975 | **kwargs,
|
972 | 976 | )
|
973 | 977 | return aesara_function
|
| 978 | + |
| 979 | + |
| 980 | +def constant_fold( |
| 981 | + xs: Sequence[TensorVariable], raise_not_constant: bool = True |
| 982 | +) -> Tuple[np.ndarray, ...]: |
| 983 | + """Use constant folding to get constant values of a graph. |
| 984 | +
|
| 985 | + Parameters |
| 986 | + ---------- |
| 987 | + xs: Sequence of TensorVariable |
| 988 | + The variables that are to be constant folded |
| 989 | + raise_not_constant: bool, default True |
| 990 | + Raises NotConstantValueError if any of the variables cannot be constant folded. |
| 991 | + This should only be disabled with care, as the graphs are cloned before |
| 992 | + attempting constant folding, and any old non-shared inputs will not work with |
| 993 | + the returned outputs |
| 994 | + """ |
| 995 | + fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True) |
| 996 | + |
| 997 | + folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs |
| 998 | + |
| 999 | + if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs): |
| 1000 | + raise NotConstantValueError |
| 1001 | + |
| 1002 | + return tuple( |
| 1003 | + folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs |
| 1004 | + ) |
0 commit comments