@@ -188,7 +188,14 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
188
188
def test_impulse_response (parameters , varma_mod , idata , rng ):
189
189
irf = varma_mod .impulse_response_function (idata .prior , random_seed = rng , ** parameters )
190
190
191
- assert not np .any (np .isnan (irf .irf .values ))
191
+ assert np .isfinite (irf .irf .values ).all ()
192
+
193
+
194
+ def test_forecast (varma_mod , idata , rng ):
195
+ forecast = varma_mod .forecast (idata .prior , periods = 10 , random_seed = rng )
196
+
197
+ assert np .isfinite (forecast .forecast_latent .values ).all ()
198
+ assert np .isfinite (forecast .forecast_observed .values ).all ()
192
199
193
200
194
201
class TestVARMAXWithExogenous :
@@ -436,42 +443,8 @@ def test_create_varmax_with_exogenous_raises_if_args_disagree(self, data):
436
443
stationary_initialization = False ,
437
444
)
438
445
439
- @pytest .mark .parametrize (
440
- "k_exog, exog_state_names" ,
441
- [
442
- (2 , None ),
443
- (None , ["foo" , "bar" ]),
444
- (None , {"y1" : ["a" , "b" ], "y2" : ["c" ]}),
445
- ],
446
- ids = ["k_exog_int" , "exog_state_names_list" , "exog_state_names_dict" ],
447
- )
448
- @pytest .mark .filterwarnings ("ignore::UserWarning" )
449
- def test_varmax_with_exog (self , rng , k_exog , exog_state_names ):
450
- endog_names = ["y1" , "y2" , "y3" ]
451
- n_obs = 50
452
- time_idx = pd .date_range (start = "2020-01-01" , periods = n_obs , freq = "D" )
453
-
454
- y = rng .normal (size = (n_obs , len (endog_names )))
455
- df = pd .DataFrame (y , columns = endog_names , index = time_idx ).astype (floatX )
456
-
457
- if isinstance (exog_state_names , dict ):
458
- exog_data = {
459
- f"{ name } _exogenous_data" : pd .DataFrame (
460
- rng .normal (size = (n_obs , len (exog_names ))).astype (floatX ),
461
- columns = exog_names ,
462
- index = time_idx ,
463
- )
464
- for name , exog_names in exog_state_names .items ()
465
- }
466
- else :
467
- exog_names = exog_state_names or [f"exogenous_{ i } " for i in range (k_exog )]
468
- exog_data = {
469
- "exogenous_data" : pd .DataFrame (
470
- rng .normal (size = (n_obs , k_exog or len (exog_state_names ))).astype (floatX ),
471
- columns = exog_names ,
472
- index = time_idx ,
473
- )
474
- }
446
+ def _build_varmax (self , df , k_exog , exog_state_names , exog_data ):
447
+ endog_names = df .columns .values .tolist ()
475
448
476
449
mod = BayesianVARMAX (
477
450
endog_names = endog_names ,
@@ -512,6 +485,47 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
512
485
513
486
mod .build_statespace_graph (data = df )
514
487
488
+ return mod , m
489
+
490
+ @pytest .mark .parametrize (
491
+ "k_exog, exog_state_names" ,
492
+ [
493
+ (2 , None ),
494
+ (None , ["foo" , "bar" ]),
495
+ (None , {"y1" : ["a" , "b" ], "y2" : ["c" ]}),
496
+ ],
497
+ ids = ["k_exog_int" , "exog_state_names_list" , "exog_state_names_dict" ],
498
+ )
499
+ @pytest .mark .filterwarnings ("ignore::UserWarning" )
500
+ def test_varmax_with_exog (self , rng , k_exog , exog_state_names ):
501
+ endog_names = ["y1" , "y2" , "y3" ]
502
+ n_obs = 50
503
+ time_idx = pd .date_range (start = "2020-01-01" , periods = n_obs , freq = "D" )
504
+
505
+ y = rng .normal (size = (n_obs , len (endog_names )))
506
+ df = pd .DataFrame (y , columns = endog_names , index = time_idx ).astype (floatX )
507
+
508
+ if isinstance (exog_state_names , dict ):
509
+ exog_data = {
510
+ f"{ name } _exogenous_data" : pd .DataFrame (
511
+ rng .normal (size = (n_obs , len (exog_names ))).astype (floatX ),
512
+ columns = exog_names ,
513
+ index = time_idx ,
514
+ )
515
+ for name , exog_names in exog_state_names .items ()
516
+ }
517
+ else :
518
+ exog_names = exog_state_names or [f"exogenous_{ i } " for i in range (k_exog )]
519
+ exog_data = {
520
+ "exogenous_data" : pd .DataFrame (
521
+ rng .normal (size = (n_obs , k_exog or len (exog_state_names ))).astype (floatX ),
522
+ columns = exog_names ,
523
+ index = time_idx ,
524
+ )
525
+ }
526
+
527
+ mod , m = self ._build_varmax (df , k_exog , exog_state_names , exog_data )
528
+
515
529
with freeze_dims_and_data (m ):
516
530
prior = pm .sample_prior_predictive (
517
531
draws = 10 , random_seed = rng , compile_kwargs = {"mode" : "JAX" }
@@ -543,3 +557,53 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
543
557
obs_intercept .append (np .zeros_like (obs_intercept [0 ]))
544
558
545
559
np .testing .assert_allclose (beta_dot_data , np .stack (obs_intercept , axis = - 1 ), atol = 1e-2 )
560
+
561
+ @pytest .mark .filterwarnings ("ignore::UserWarning" )
562
+ def test_forecast_with_exog (self , rng ):
563
+ endog_names = ["y1" , "y2" , "y3" ]
564
+ n_obs = 50
565
+ time_idx = pd .date_range (start = "2020-01-01" , periods = n_obs , freq = "D" )
566
+
567
+ y = rng .normal (size = (n_obs , len (endog_names )))
568
+ df = pd .DataFrame (y , columns = endog_names , index = time_idx ).astype (floatX )
569
+
570
+ mod , m = self ._build_varmax (
571
+ df ,
572
+ k_exog = 2 ,
573
+ exog_state_names = None ,
574
+ exog_data = {
575
+ "exogenous_data" : pd .DataFrame (
576
+ rng .normal (size = (n_obs , 2 )).astype (floatX ),
577
+ columns = ["exogenous_0" , "exogenous_1" ],
578
+ index = time_idx ,
579
+ )
580
+ },
581
+ )
582
+
583
+ with freeze_dims_and_data (m ):
584
+ prior = pm .sample_prior_predictive (
585
+ draws = 10 , random_seed = rng , compile_kwargs = {"mode" : "JAX" }
586
+ )
587
+
588
+ with pytest .raises (
589
+ ValueError ,
590
+ match = "This model was fit using exogenous data. Forecasting cannot be performed "
591
+ "without providing scenario data" ,
592
+ ):
593
+ mod .forecast (prior .prior , periods = 10 , random_seed = rng )
594
+
595
+ forecast = mod .forecast (
596
+ prior .prior ,
597
+ periods = 10 ,
598
+ random_seed = rng ,
599
+ scenario = {
600
+ "exogenous_data" : pd .DataFrame (
601
+ rng .normal (size = (10 , 2 )).astype (floatX ),
602
+ columns = ["exogenous_0" , "exogenous_1" ],
603
+ index = pd .date_range (start = df .index [- 1 ], periods = 10 , freq = "D" ),
604
+ )
605
+ },
606
+ )
607
+
608
+ assert np .isfinite (forecast .forecast_latent .values ).all ()
609
+ assert np .isfinite (forecast .forecast_observed .values ).all ()
0 commit comments