@@ -278,12 +278,18 @@ def fgraph_from_model(
278
278
return fgraph , memo
279
279
280
280
281
- def model_from_fgraph (fgraph : FunctionGraph ) -> Model :
281
+ def model_from_fgraph (fgraph : FunctionGraph , mutate_fgraph : bool = False ) -> Model :
282
282
"""Convert FunctionGraph to PyMC model.
283
283
284
- This requires nodes to be properly tagged with `ModelVar` dummy Ops.
284
+ Parameters
285
+ ----------
286
+ fgraph: FunctionGraph
287
+ fgraph representation of a PyMC model, with dummy `ModelVar` Ops.
288
+ See `fgraph_from_model` for more details.
285
289
286
- See: fgraph_from_model
290
+ mutate_fgraph: bool, default False
291
+ Whether the function is allowed to modify the fgraph (and it's variables) in place.
292
+ This is useful if these are not needed anymore after the model is created.
287
293
"""
288
294
289
295
def first_non_model_var (var ):
@@ -300,11 +306,12 @@ def first_non_model_var(var):
300
306
_coords = getattr (fgraph , "_coords" , {})
301
307
_dim_lengths = getattr (fgraph , "_dim_lengths" , {})
302
308
303
- fgraph , memo = fgraph .clone_get_equiv (check_integrity = False , attach_feature = False )
304
- # Shared dim lengths are not extracted from the fgraph representation,
305
- # so we need to update after we clone the fgraph
306
- # TODO: Consider representing/extracting them from the fgraph!
307
- _dim_lengths = {k : memo .get (v , v ) for k , v in _dim_lengths .items ()}
309
+ if not mutate_fgraph :
310
+ fgraph , memo = fgraph .clone_get_equiv (check_integrity = False , attach_feature = False )
311
+ # Shared dim lengths are not extracted from the fgraph representation,
312
+ # so we need to update after we clone the fgraph
313
+ # TODO: Consider representing/extracting them from the fgraph!
314
+ _dim_lengths = {k : memo .get (v , v ) for k , v in _dim_lengths .items ()}
308
315
309
316
model ._coords = _coords
310
317
model ._dim_lengths = _dim_lengths
@@ -385,7 +392,7 @@ def clone_model(model: Model) -> Model:
385
392
z = pm.Deterministic("z", clone_x + 1)
386
393
387
394
"""
388
- return model_from_fgraph (fgraph_from_model (model )[0 ])
395
+ return model_from_fgraph (fgraph_from_model (model )[0 ], mutate_fgraph = True )
389
396
390
397
391
398
def extract_dims (var ) -> tuple :
0 commit comments