Skip to content
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_build_from_yml_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3105,7 +3105,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.13"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_case_study.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9892,7 +9892,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
"version": "3.12.11"
}
},
"nbformat": 4,
Expand Down
1,636 changes: 932 additions & 704 deletions docs/source/notebooks/mmm/mmm_components.ipynb

Large diffs are not rendered by default.

211 changes: 193 additions & 18 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:

import warnings
from collections.abc import Sequence
from typing import Any, ClassVar, Protocol, runtime_checkable
from typing import Any, ClassVar, Protocol, cast, runtime_checkable

import numpy as np
import pytensor.tensor as pt
Expand All @@ -225,7 +225,7 @@ def _set_predictors_for_optimization(self, num_periods: int) -> pm.Model:
compile_constraints_for_scipy,
)
from pymc_marketing.mmm.utility import UtilityFunctionType, average_response
from pymc_marketing.pytensor_utils import extract_response_distribution
from pymc_marketing.pytensor_utils import extract_response_distribution, merge_models


def optimizer_xarray_builder(value, **kwargs):
Expand Down Expand Up @@ -394,24 +394,36 @@ def __init__(self, **data):
self._budget_shape = tuple(len(coord) for coord in self._budget_coords.values())

# 4. Ensure that we only optmize over non-zero channels
if self.budgets_to_optimize is None:
# If no mask is provided, we optimize all channels
self.budgets_to_optimize = (
self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)
def _should_enforce_mask_validation() -> bool:
# Duck-typing: allow wrappers to opt-in by providing this attribute
return bool(
getattr(self.mmm_model, "enforce_budget_mask_validation", False)
)
else:
# If a mask is provided, ensure it has the correct shape
expected_mask = self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)

# Check if we are asking to optimize over channels that are not present in the model
if np.any(self.budgets_to_optimize.values > expected_mask.values):
raise ValueError(
"budgets_to_optimize mask contains True values at coordinates where the model has no "
"information."
if _should_enforce_mask_validation():
if self.budgets_to_optimize is None:
self.budgets_to_optimize = (
self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)
)
else:
expected_mask = (
self.mmm_model.idata.posterior.channel_contribution.mean(
("chain", "draw", "date")
).astype(bool)
)
if np.any(self.budgets_to_optimize.values > expected_mask.values):
raise ValueError(
"budgets_to_optimize mask contains True values at coordinates "
"where the model has no information."
)
else:
if self.budgets_to_optimize is None:
self.budgets_to_optimize = xr.DataArray(
np.ones(self._budget_shape, dtype=bool),
dims=self._budget_dims,
coords=self._budget_coords,
)

size_budgets = self.budgets_to_optimize.sum().item()
Expand Down Expand Up @@ -893,3 +905,166 @@ def track_progress(xk):

else:
raise MinimizeException(f"Optimization failed: {result.message}")


class MultiModelWrapper(OptimizerCompatibleModelWrapper):
"""Wrapper that merges multiple OptimizerCompatibleModelWrapper models.

- Keeps a persistent merged model for optimization via `_set_predictors_for_optimization`.
- Provides a dynamic merged `model` property for inspection (non-persistent), if needed.
"""

def __init__(
self,
models: list[OptimizerCompatibleModelWrapper],
prefixes: list[str] | None = None,
merge_on: str | None = "channel_data",
use_every_n_draw: int = 1,
) -> None:
if len(models) < 1:
raise ValueError("Need at least 1 model")

self._channel_scales = 1.0
self.models = models
self.num_models = len(models)

# Auto-generate prefixes if not provided - ALL models get prefixes
if prefixes is None:
self.prefixes = [f"model{i + 1}" for i in range(self.num_models)]
else:
if len(prefixes) != len(models):
raise ValueError(
f"Number of prefixes ({len(prefixes)}) must match number of models ({len(models)})"
)
self.prefixes = prefixes

self.merge_on = merge_on
self.use_every_n_draw = use_every_n_draw

# Use first model as primary for attributes
self.primary_model = models[0]
self.num_periods = getattr(self.primary_model, "num_periods", None)

# Merge idata from all models with appropriate prefixes
self._merge_idata()

if hasattr(self.primary_model, "adstock"):
self.adstock = self.primary_model.adstock

# Signal to BudgetOptimizer to enforce mask validation
self.enforce_budget_mask_validation = False

# Persistent merged model used for optimization
self._persistent_merged_model: Model | None = None
self._persistent_num_periods: int | None = None

def _merge_idata(self) -> None:
if self.num_models == 1:
idata = self.models[0].idata.isel(
draw=slice(None, None, self.use_every_n_draw)
)
if self.prefixes[0]:
idata = self._prefix_idata(idata, self.prefixes[0])
self.idata = idata
return

merged_idata = None
for i, model in enumerate(self.models):
prefix = self.prefixes[i]
idata_i = model.idata.isel(
draw=slice(None, None, self.use_every_n_draw)
).copy()
if prefix:
idata_i = self._prefix_idata(idata_i, prefix)

if merged_idata is None:
merged_idata = idata_i
else:
for group in ("posterior", "constant_data", "observed_data"):
if group in idata_i:
if group in merged_idata:
merged_idata[group] = xr.merge(
[merged_idata[group], idata_i[group]]
)
else:
merged_idata[group] = idata_i[group]

self.idata = merged_idata

def _prefix_idata(self, idata, prefix: str):
shared_vars = {"chain", "draw", "__obs__"}
if self.merge_on:
shared_vars.add(self.merge_on)

shared_dims = set(shared_vars)
if (
self.merge_on
and "constant_data" in idata
and self.merge_on in idata.constant_data
):
merge_dims = list(idata.constant_data[self.merge_on].dims)
shared_dims.update(merge_dims)

prefixed_idata = idata.copy()
for group in ("posterior", "constant_data", "observed_data"):
if group in prefixed_idata:
rename_dict = {}
for var in prefixed_idata[group].data_vars:
if var not in shared_vars and not var.startswith(f"{prefix}_"):
rename_dict[var] = f"{prefix}_{var}"
for dim in prefixed_idata[group].dims:
if dim not in shared_dims and not dim.startswith(f"{prefix}_"):
rename_dict[dim] = f"{prefix}_{dim}"
if rename_dict:
prefixed_idata[group] = prefixed_idata[group].rename(rename_dict)

return prefixed_idata

def _set_predictors_for_optimization(self, num_periods: int) -> Model:
# If we already built a persistent model for this horizon, reuse it
if (
self._persistent_merged_model is not None
and self._persistent_num_periods == int(num_periods)
):
return self._persistent_merged_model

# Build per-model optimization models
pymc_models = [
m._set_predictors_for_optimization(num_periods=num_periods)
for m in self.models
]
if self.num_models == 1:
self._persistent_merged_model = freeze_dims_and_data(pymc_models[0])
else:
self._persistent_merged_model = merge_models(
models=pymc_models, prefixes=self.prefixes, merge_on=self.merge_on
)

self._persistent_num_periods = int(num_periods)
return self._persistent_merged_model

@property
def model(self) -> Model:
"""Return the merged PyMC model.

If a persistent optimization model exists, return it. Otherwise, try to lazily
construct it using the known number of periods. As a fallback, merge the
underlying training models from each wrapper (non-persistent).
"""
# If a persistent optimization model exists, expose it for mutation
if self._persistent_merged_model is not None:
return self._persistent_merged_model

# If we know the number of periods, lazily build the persistent model now
if self.num_periods is not None:
return self._set_predictors_for_optimization(int(self.num_periods))

# Fallback: dynamic merged training models (non-persistent)
# Obtain each wrapper's training model dynamically; not all wrappers statically expose `.model`.
# Cast to Any first to avoid mypy attr-defined errors for Protocol wrappers.
pymc_models = [cast(Any, model).model for model in self.models]
if self.num_models == 1:
return pymc_models[0]
return merge_models(
models=pymc_models, prefixes=self.prefixes, merge_on=self.merge_on
)
18 changes: 18 additions & 0 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
)
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.model_graph import deterministics_to_flat
from pymc_marketing.pytensor_utils import MaskedDist

PYMC_MARKETING_ISSUE = "https://github.com/pymc-labs/pymc-marketing/issues/new"
warning_msg = (
Expand Down Expand Up @@ -1489,6 +1490,20 @@ def _posterior_predictive_data_transformation(
X, include_last_observations
)

# If the model likelihood was masked during training, require that OOS inputs
# start at the same minimum date as the training data. Supplying only a test
# horizon (i.e., a later min date) is not compatible with masked likelihoods.
likelihood_cfg = self.model_config.get("likelihood")
if isinstance(likelihood_cfg, MaskedDist):
training_min = pd.to_datetime(self.model_coords["date"]).min()
input_min = pd.to_datetime(X[self.date_column]).min()
if pd.Timestamp(input_min) != pd.Timestamp(training_min):
raise ValueError(
"Out-of-sample with masked likelihood requires X to start at the training min date; "
f"got {pd.Timestamp(input_min).date()} != {pd.Timestamp(training_min).date()}. "
"Provide full X from training start or use an unmasked likelihood."
)

dataarrays = []
if include_last_observations:
last_obs = self.xarray_dataset.isel(date=slice(-self.adstock.l_max, None))
Expand Down Expand Up @@ -2146,6 +2161,9 @@ def create_sample_kwargs(
class MultiDimensionalBudgetOptimizerWrapper(OptimizerCompatibleModelWrapper):
"""Wrapper for the BudgetOptimizer to handle multi-dimensional model."""

# Signal to BudgetOptimizer that this wrapper should enforce budget mask validation
enforce_budget_mask_validation = True

def __init__(self, model: MMM, start_date: str, end_date: str):
self.model_class = model
self.start_date = start_date
Expand Down
Loading