@@ -363,6 +363,11 @@ class BudgetOptimizer(BaseModel):
363
363
),
364
364
)
365
365
366
+ compile_kwargs : dict | None = Field (
367
+ default = None ,
368
+ description = "Keyword arguments for the model compilation. Specially usefull to pass compilation mode" ,
369
+ )
370
+
366
371
model_config = ConfigDict (arbitrary_types_allowed = True )
367
372
368
373
DEFAULT_MINIMIZE_KWARGS : ClassVar [dict ] = {
@@ -684,10 +689,24 @@ def _compile_objective_and_grad(self):
684
689
)
685
690
objective_grad = pt .grad (rewrite_pregrad (objective ), budgets_flat )
686
691
687
- objective_and_grad_func = function ([budgets_flat ], [objective , objective_grad ])
692
+ if self .compile_kwargs and (self .compile_kwargs ["mode" ]).lower () == "jax" :
693
+ # Use PyMC's JAX infrastructure for robust compilation
694
+ from pymc .sampling .jax import get_jaxified_graph
695
+
696
+ objective_and_grad_func = get_jaxified_graph (
697
+ inputs = [budgets_flat ],
698
+ outputs = [objective , objective_grad ],
699
+ ** {k : v for k , v in self .compile_kwargs .items () if k != "mode" } or {},
700
+ )
701
+ else :
702
+ # Standard PyTensor compilation
703
+ objective_and_grad_func = function (
704
+ [budgets_flat ], [objective , objective_grad ], ** self .compile_kwargs or {}
705
+ )
688
706
689
707
# Avoid repeated input validation for performance
690
- objective_and_grad_func .trust_input = True
708
+ if hasattr (objective_and_grad_func , "trust_input" ):
709
+ objective_and_grad_func .trust_input = True
691
710
692
711
self ._compiled_functions [self .utility_function ] = {
693
712
"objective_and_grad" : objective_and_grad_func ,
0 commit comments