22
22
from sys import modules
23
23
from typing import (
24
24
TYPE_CHECKING ,
25
- Any ,
26
25
Literal ,
27
26
Optional ,
28
27
TypeVar ,
48
47
49
48
from pymc .blocking import DictToArrayBijection , RaveledVars
50
49
from pymc .data import GenTensorVariable , is_minibatch
51
- from pymc .distributions .transforms import _default_transform
50
+ from pymc .distributions .transforms import ChainedTransform , _default_transform
52
51
from pymc .exceptions import (
53
52
BlockModelAccessError ,
54
53
ImputationWarning ,
58
57
)
59
58
from pymc .initial_point import make_initial_point_fn
60
59
from pymc .logprob .basic import transformed_conditional_logp
60
+ from pymc .logprob .transforms import Transform
61
61
from pymc .logprob .utils import ParameterValueError , replace_rvs_by_values
62
62
from pymc .model_graph import model_to_graphviz
63
63
from pymc .pytensorf import (
@@ -1214,7 +1214,16 @@ def set_data(
1214
1214
shared_object .set_value (values )
1215
1215
1216
1216
def register_rv (
1217
- self , rv_var , name , observed = None , total_size = None , dims = None , transform = UNSET , initval = None
1217
+ self ,
1218
+ rv_var ,
1219
+ name ,
1220
+ * ,
1221
+ observed = None ,
1222
+ total_size = None ,
1223
+ dims = None ,
1224
+ default_transform = UNSET ,
1225
+ transform = UNSET ,
1226
+ initval = None ,
1218
1227
):
1219
1228
"""Register an (un)observed random variable with the model.
1220
1229
@@ -1229,8 +1238,10 @@ def register_rv(
1229
1238
upscales logp of variable with ``coef = total_size/var.shape[0]``
1230
1239
dims : tuple
1231
1240
Dimension names for the variable.
1241
+ default_transform
1242
+ A default transform for the random variable in log-likelihood space.
1232
1243
transform
1233
- A transform for the random variable in log-likelihood space .
1244
+ Additional transform which may be applied after default transform .
1234
1245
initval
1235
1246
The initial value of the random variable.
1236
1247
@@ -1255,7 +1266,7 @@ def register_rv(
1255
1266
if total_size is not None :
1256
1267
raise ValueError ("total_size can only be passed to observed RVs" )
1257
1268
self .free_RVs .append (rv_var )
1258
- self .create_value_var (rv_var , transform )
1269
+ self .create_value_var (rv_var , transform = transform , default_transform = default_transform )
1259
1270
self .add_named_variable (rv_var , dims )
1260
1271
self .set_initval (rv_var , initval )
1261
1272
else :
@@ -1278,7 +1289,9 @@ def register_rv(
1278
1289
1279
1290
# `rv_var` is potentially changed by `make_obs_var`,
1280
1291
# for example into a new graph for imputation of missing data.
1281
- rv_var = self .make_obs_var (rv_var , observed , dims , transform , total_size )
1292
+ rv_var = self .make_obs_var (
1293
+ rv_var , observed , dims , default_transform , transform , total_size
1294
+ )
1282
1295
1283
1296
return rv_var
1284
1297
@@ -1287,7 +1300,8 @@ def make_obs_var(
1287
1300
rv_var : TensorVariable ,
1288
1301
data : np .ndarray ,
1289
1302
dims ,
1290
- transform : Any | None ,
1303
+ default_transform : Transform | None ,
1304
+ transform : Transform | None ,
1291
1305
total_size : int | None ,
1292
1306
) -> TensorVariable :
1293
1307
"""Create a `TensorVariable` for an observed random variable.
@@ -1301,8 +1315,10 @@ def make_obs_var(
1301
1315
The observed data.
1302
1316
dims : tuple
1303
1317
Dimension names for the variable.
1304
- transform : int, optional
1318
+ default_transform
1305
1319
A transform for the random variable in log-likelihood space.
1320
+ transform
1321
+ Additional transform which may be applied after default transform.
1306
1322
1307
1323
Returns
1308
1324
-------
@@ -1339,12 +1355,19 @@ def make_obs_var(
1339
1355
1340
1356
# Register ObservedRV corresponding to observed component
1341
1357
observed_rv .name = f"{ name } _observed"
1342
- self .create_value_var (observed_rv , transform = None , value_var = observed_data )
1358
+ self .create_value_var (
1359
+ observed_rv , transform = transform , default_transform = None , value_var = observed_data
1360
+ )
1343
1361
self .add_named_variable (observed_rv )
1344
1362
self .observed_RVs .append (observed_rv )
1345
1363
1346
1364
# Register FreeRV corresponding to unobserved components
1347
- self .register_rv (unobserved_rv , f"{ name } _unobserved" , transform = transform )
1365
+ self .register_rv (
1366
+ unobserved_rv ,
1367
+ f"{ name } _unobserved" ,
1368
+ transform = transform ,
1369
+ default_transform = default_transform ,
1370
+ )
1348
1371
1349
1372
# Register Deterministic that combines observed and missing
1350
1373
# Note: This can widely increase memory consumption during sampling for large datasets
@@ -1363,14 +1386,21 @@ def make_obs_var(
1363
1386
rv_var .name = name
1364
1387
1365
1388
rv_var .tag .observations = data
1366
- self .create_value_var (rv_var , transform = None , value_var = data )
1389
+ self .create_value_var (
1390
+ rv_var , transform = transform , default_transform = None , value_var = data
1391
+ )
1367
1392
self .add_named_variable (rv_var , dims )
1368
1393
self .observed_RVs .append (rv_var )
1369
1394
1370
1395
return rv_var
1371
1396
1372
1397
def create_value_var (
1373
- self , rv_var : TensorVariable , transform : Any , value_var : Variable | None = None
1398
+ self ,
1399
+ rv_var : TensorVariable ,
1400
+ * ,
1401
+ default_transform : Transform ,
1402
+ transform : Transform ,
1403
+ value_var : Variable | None = None ,
1374
1404
) -> TensorVariable :
1375
1405
"""Create a ``TensorVariable`` that will be used as the random
1376
1406
variable's "value" in log-likelihood graphs.
@@ -1385,7 +1415,11 @@ def create_value_var(
1385
1415
----------
1386
1416
rv_var : TensorVariable
1387
1417
1388
- transform : Any
1418
+ default_transform: Transform
1419
+ A transform for the random variable in log-likelihood space.
1420
+
1421
+ transform: Transform
1422
+ Additional transform which may be applied after default transform.
1389
1423
1390
1424
value_var : Variable, optional
1391
1425
@@ -1396,11 +1430,25 @@ def create_value_var(
1396
1430
1397
1431
# Make the value variable a transformed value variable,
1398
1432
# if there's an applicable transform
1399
- if transform is UNSET :
1433
+ if transform is None and default_transform is UNSET :
1434
+ default_transform = None
1435
+ warnings .warn (
1436
+ "To disable default transform, please use default_transform=None"
1437
+ " instead of transform=None. Setting transform to None will"
1438
+ " not have any effect in future." ,
1439
+ UserWarning ,
1440
+ )
1441
+
1442
+ if default_transform is UNSET :
1400
1443
if rv_var .owner is None :
1401
- transform = None
1444
+ default_transform = None
1402
1445
else :
1403
- transform = _default_transform (rv_var .owner .op , rv_var )
1446
+ default_transform = _default_transform (rv_var .owner .op , rv_var )
1447
+
1448
+ if transform is UNSET :
1449
+ transform = default_transform
1450
+ elif transform is not None and default_transform is not None :
1451
+ transform = ChainedTransform ([default_transform , transform ])
1404
1452
1405
1453
if value_var is None :
1406
1454
if transform is None :
0 commit comments