Skip to content

Commit de748b6

Browse files
Always return dictionary from data_info
1 parent 07c6ab4 commit de748b6

File tree

2 files changed

+104
-38
lines changed

2 files changed

+104
-38
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ def __init__(
254254
mode=mode,
255255
)
256256

257+
self._needs_exog_data = exog_state_names is not None and len(exog_state_names) > 0
258+
257259
# Save counts of the number of parameters in each category
258260
self.param_counts = {
259261
"x0": k_states * (1 - self.stationary_initialization),
@@ -337,7 +339,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
337339

338340
@property
339341
def data_info(self) -> dict[str, dict[str, Any]]:
340-
info = None
342+
info = {}
341343

342344
if isinstance(self.exog_state_names, list):
343345
info = {

tests/statespace/models/test_VARMAX.py

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,14 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
188188
def test_impulse_response(parameters, varma_mod, idata, rng):
189189
irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters)
190190

191-
assert not np.any(np.isnan(irf.irf.values))
191+
assert np.isfinite(irf.irf.values).all()
192+
193+
194+
def test_forecast(varma_mod, idata, rng):
195+
forecast = varma_mod.forecast(idata.prior, periods=10, random_seed=rng)
196+
197+
assert np.isfinite(forecast.forecast_latent.values).all()
198+
assert np.isfinite(forecast.forecast_observed.values).all()
192199

193200

194201
class TestVARMAXWithExogenous:
@@ -436,42 +443,8 @@ def test_create_varmax_with_exogenous_raises_if_args_disagree(self, data):
436443
stationary_initialization=False,
437444
)
438445

439-
@pytest.mark.parametrize(
440-
"k_exog, exog_state_names",
441-
[
442-
(2, None),
443-
(None, ["foo", "bar"]),
444-
(None, {"y1": ["a", "b"], "y2": ["c"]}),
445-
],
446-
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
447-
)
448-
@pytest.mark.filterwarnings("ignore::UserWarning")
449-
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
450-
endog_names = ["y1", "y2", "y3"]
451-
n_obs = 50
452-
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
453-
454-
y = rng.normal(size=(n_obs, len(endog_names)))
455-
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
456-
457-
if isinstance(exog_state_names, dict):
458-
exog_data = {
459-
f"{name}_exogenous_data": pd.DataFrame(
460-
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
461-
columns=exog_names,
462-
index=time_idx,
463-
)
464-
for name, exog_names in exog_state_names.items()
465-
}
466-
else:
467-
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
468-
exog_data = {
469-
"exogenous_data": pd.DataFrame(
470-
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
471-
columns=exog_names,
472-
index=time_idx,
473-
)
474-
}
446+
def _build_varmax(self, df, k_exog, exog_state_names, exog_data):
447+
endog_names = df.columns.values.tolist()
475448

476449
mod = BayesianVARMAX(
477450
endog_names=endog_names,
@@ -512,6 +485,47 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
512485

513486
mod.build_statespace_graph(data=df)
514487

488+
return mod, m
489+
490+
@pytest.mark.parametrize(
491+
"k_exog, exog_state_names",
492+
[
493+
(2, None),
494+
(None, ["foo", "bar"]),
495+
(None, {"y1": ["a", "b"], "y2": ["c"]}),
496+
],
497+
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
498+
)
499+
@pytest.mark.filterwarnings("ignore::UserWarning")
500+
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
501+
endog_names = ["y1", "y2", "y3"]
502+
n_obs = 50
503+
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
504+
505+
y = rng.normal(size=(n_obs, len(endog_names)))
506+
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
507+
508+
if isinstance(exog_state_names, dict):
509+
exog_data = {
510+
f"{name}_exogenous_data": pd.DataFrame(
511+
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
512+
columns=exog_names,
513+
index=time_idx,
514+
)
515+
for name, exog_names in exog_state_names.items()
516+
}
517+
else:
518+
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
519+
exog_data = {
520+
"exogenous_data": pd.DataFrame(
521+
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
522+
columns=exog_names,
523+
index=time_idx,
524+
)
525+
}
526+
527+
mod, m = self._build_varmax(df, k_exog, exog_state_names, exog_data)
528+
515529
with freeze_dims_and_data(m):
516530
prior = pm.sample_prior_predictive(
517531
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
@@ -543,3 +557,53 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
543557
obs_intercept.append(np.zeros_like(obs_intercept[0]))
544558

545559
np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2)
560+
561+
@pytest.mark.filterwarnings("ignore::UserWarning")
562+
def test_forecast_with_exog(self, rng):
563+
endog_names = ["y1", "y2", "y3"]
564+
n_obs = 50
565+
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
566+
567+
y = rng.normal(size=(n_obs, len(endog_names)))
568+
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
569+
570+
mod, m = self._build_varmax(
571+
df,
572+
k_exog=2,
573+
exog_state_names=None,
574+
exog_data={
575+
"exogenous_data": pd.DataFrame(
576+
rng.normal(size=(n_obs, 2)).astype(floatX),
577+
columns=["exogenous_0", "exogenous_1"],
578+
index=time_idx,
579+
)
580+
},
581+
)
582+
583+
with freeze_dims_and_data(m):
584+
prior = pm.sample_prior_predictive(
585+
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
586+
)
587+
588+
with pytest.raises(
589+
ValueError,
590+
match="This model was fit using exogenous data. Forecasting cannot be performed "
591+
"without providing scenario data",
592+
):
593+
mod.forecast(prior.prior, periods=10, random_seed=rng)
594+
595+
forecast = mod.forecast(
596+
prior.prior,
597+
periods=10,
598+
random_seed=rng,
599+
scenario={
600+
"exogenous_data": pd.DataFrame(
601+
rng.normal(size=(10, 2)).astype(floatX),
602+
columns=["exogenous_0", "exogenous_1"],
603+
index=pd.date_range(start=df.index[-1], periods=10, freq="D"),
604+
)
605+
},
606+
)
607+
608+
assert np.isfinite(forecast.forecast_latent.values).all()
609+
assert np.isfinite(forecast.forecast_observed.values).all()

0 commit comments

Comments
 (0)