Skip to content

Commit e2ca250

Browse files
authored
GPU support in BudgetOptimizer (#1937)
* Introduce jax_backend boolean to enable jax-based optimizer computation * Jaxify _compiled_functions * Add extensive jax_backend tests * BudgetOptimizer now accepts compile_kwargs dict * Adjust wrapper to BudgetOptimizer signature * Update tests
1 parent a49cba6 commit e2ca250

File tree

3 files changed

+105
-17
lines changed

3 files changed

+105
-17
lines changed

pymc_marketing/mmm/budget_optimizer.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ class BudgetOptimizer(BaseModel):
363363
),
364364
)
365365

366+
compile_kwargs: dict | None = Field(
367+
default=None,
368+
description="Keyword arguments for the model compilation. Specially usefull to pass compilation mode",
369+
)
370+
366371
model_config = ConfigDict(arbitrary_types_allowed=True)
367372

368373
DEFAULT_MINIMIZE_KWARGS: ClassVar[dict] = {
@@ -684,10 +689,24 @@ def _compile_objective_and_grad(self):
684689
)
685690
objective_grad = pt.grad(rewrite_pregrad(objective), budgets_flat)
686691

687-
objective_and_grad_func = function([budgets_flat], [objective, objective_grad])
692+
if self.compile_kwargs and (self.compile_kwargs["mode"]).lower() == "jax":
693+
# Use PyMC's JAX infrastructure for robust compilation
694+
from pymc.sampling.jax import get_jaxified_graph
695+
696+
objective_and_grad_func = get_jaxified_graph(
697+
inputs=[budgets_flat],
698+
outputs=[objective, objective_grad],
699+
**{k: v for k, v in self.compile_kwargs.items() if k != "mode"} or {},
700+
)
701+
else:
702+
# Standard PyTensor compilation
703+
objective_and_grad_func = function(
704+
[budgets_flat], [objective, objective_grad], **self.compile_kwargs or {}
705+
)
688706

689707
# Avoid repeated input validation for performance
690-
objective_and_grad_func.trust_input = True
708+
if hasattr(objective_and_grad_func, "trust_input"):
709+
objective_and_grad_func.trust_input = True
691710

692711
self._compiled_functions[self.utility_function] = {
693712
"objective_and_grad": objective_and_grad_func,

pymc_marketing/mmm/multidimensional.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,13 @@ def create_sample_kwargs(
21462146
class MultiDimensionalBudgetOptimizerWrapper(OptimizerCompatibleModelWrapper):
21472147
"""Wrapper for the BudgetOptimizer to handle multi-dimensional model."""
21482148

2149-
def __init__(self, model: MMM, start_date: str, end_date: str):
2149+
def __init__(
2150+
self,
2151+
model: MMM,
2152+
start_date: str,
2153+
end_date: str,
2154+
compile_kwargs: dict | None = None,
2155+
):
21502156
self.model_class = model
21512157
self.start_date = start_date
21522158
self.end_date = end_date
@@ -2158,6 +2164,7 @@ def __init__(self, model: MMM, start_date: str, end_date: str):
21582164
include_carryover=False,
21592165
)
21602166
self.num_periods = len(self.zero_data[self.model_class.date_column].unique())
2167+
self.compile_kwargs = compile_kwargs
21612168
# Adding missing dependencies for compatibility with BudgetOptimizer
21622169
self._channel_scales = 1.0
21632170

@@ -2256,6 +2263,7 @@ def optimize_budget(
22562263
budgets_to_optimize=budgets_to_optimize,
22572264
budget_distribution_over_period=budget_distribution_over_period,
22582265
model=self, # Pass the wrapper instance itself to the BudgetOptimizer
2266+
compile_kwargs=self.compile_kwargs,
22592267
)
22602268

22612269
return allocator.allocate_budget(

0 commit comments

Comments
 (0)