Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions pymc_marketing/mmm/additive_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ def set_data(self, mmm, model, X):
- In `set_data`, update the data variables when dates/dims change.
"""

from typing import Any, Protocol
from abc import ABC, abstractmethod
from typing import Annotated, Any, Protocol

import pandas as pd
import pymc as pm
import xarray as xr
from pydantic import BaseModel, InstanceOf
from pydantic import BaseModel, Field, InstanceOf, PlainValidator, WithJsonSchema
from pymc_extras.prior import create_dim_handler
from pytensor import tensor as pt

Expand All @@ -131,35 +132,31 @@ def model(self) -> pm.Model:
"""The PyMC model."""


class MuEffect(Protocol):
"""Protocol for arbitrary additive mu effect."""
class MuEffect(ABC, BaseModel):
"""Abstract base class for arbitrary additive mu effects.

All mu_effects must inherit from this Pydantic BaseModel to ensure proper
serialization and deserialization when saving/loading MMM models.
"""

@abstractmethod
def create_data(self, mmm: Model) -> None:
"""Create the required data in the model."""

@abstractmethod
def create_effect(self, mmm: Model) -> pt.TensorVariable:
"""Create the additive effect in the model."""

@abstractmethod
def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
"""Set the data for new predictions."""


class FourierEffect:
class FourierEffect(MuEffect):
"""Fourier seasonality additive effect for MMM."""

def __init__(self, fourier: FourierBase, date_dim_name: str = "date"):
"""Initialize the Fourier effect.

Parameters
----------
fourier : FourierBase
The FourierBase instance to use for the effect.
date_dim_name : str, optional
The name of the date dimension in the model, by default "date".

"""
self.fourier = fourier
self.date_dim_name: str = date_dim_name
fourier: InstanceOf[FourierBase]
date_dim_name: str = Field("date")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to supply the date_dim_name here, as the mutltidimensional MMM has the date_column attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, because MuEffect is just an interface that has nothing to do with an MMM class.
So either we have this explicit date_dim_name or we need to introduce coupling with the MMM class.


def create_data(self, mmm: Model) -> None:
"""Create the required data in the model.
Expand Down Expand Up @@ -256,7 +253,14 @@ def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
pm.set_data(new_data=new_data, model=model)


class LinearTrendEffect:
_Timestamp = Annotated[
pd.Timestamp,
PlainValidator(lambda x: pd.Timestamp(x)),
WithJsonSchema({"type": "date-time"}),
]


class LinearTrendEffect(MuEffect):
"""Wrapper for LinearTrend to use with MMM's MuEffect protocol.

This class adapts the LinearTrend component to be used as an additive effect
Expand All @@ -268,6 +272,8 @@ class LinearTrendEffect:
The LinearTrend instance to wrap.
prefix : str
The prefix to use for variables in the model.
date_dim_name : str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same date_dim_name question as with FourierEffect.

The name of the date dimension in the model.

Examples
--------
Expand Down Expand Up @@ -357,11 +363,10 @@ class MockMMM:

"""

def __init__(self, trend: LinearTrend, prefix: str, date_dim_name: str = "date"):
self.trend = trend
self.prefix = prefix
self.linear_trend_first_date: pd.Timestamp
self.date_dim_name: str = date_dim_name
trend: InstanceOf[LinearTrend]
prefix: str
date_dim_name: str = Field("date")
linear_trend_first_date: _Timestamp | None = Field(None, init=False)

def create_data(self, mmm: Model) -> None:
"""Create the required data in the model.
Expand Down Expand Up @@ -439,7 +444,7 @@ def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None:
pm.set_data({f"{self.prefix}_t": t}, model=model)


class EventAdditiveEffect(BaseModel):
class EventAdditiveEffect(MuEffect):
"""Event effect class for the MMM.

Parameters
Expand Down
6 changes: 3 additions & 3 deletions tests/mmm/test_additive_effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_fourier_effect(
dims,
coords,
) -> None:
effect = FourierEffect(fourier)
effect = FourierEffect(fourier=fourier)

mmm = create_mock_mmm(
dims=dims,
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_fourier_effect_multidimensional(
prefix = "weekly"
prior = Prior("Laplace", mu=0, b=0.1, dims=prior_dims)
fourier = WeeklyFourier(n_order=10, prefix=prefix, prior=prior)
fourier_effect = FourierEffect(fourier)
fourier_effect = FourierEffect(fourier=fourier)

with mmm.model:
fourier_effect.create_data(mmm)
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_linear_trend_effect(
) -> None:
prefix = "linear_trend"
effect = LinearTrendEffect(
LinearTrend(priors=priors, dims=linear_trend_dims),
trend=LinearTrend(priors=priors, dims=linear_trend_dims),
prefix=prefix,
)

Expand Down
Loading