-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
Before
No response
After
with pm.Model() as m:
...
m.check_start_vals(..., mode="JAX")
Context for the issue:
Sometimes we deploy pymc
in environments were we actually only care about GPU sampling via JAX, so there shouldn't be a need for a C compiler. This is all fine and good because we can sample_prior_predictive
, sample
and sample_posterior_predictive
using JAX or any other mode that doesn't need C. The problem comes when sample
calls model.check_start_vals
, this will use the default mode and in the absence of a compiler, the python code might take very long to run. It would be great to have a way to set the compilation mode for check_start_vals
as well.