Skip to content

ENH: Add a way to specify mode for check_start_vals #7481

@lucianopaz

Description

@lucianopaz

Before

No response

After

with pm.Model() as m:
    ...
m.check_start_vals(..., mode="JAX")

Context for the issue:

Sometimes we deploy pymc in environments were we actually only care about GPU sampling via JAX, so there shouldn't be a need for a C compiler. This is all fine and good because we can sample_prior_predictive, sample and sample_posterior_predictive using JAX or any other mode that doesn't need C. The problem comes when sample calls model.check_start_vals, this will use the default mode and in the absence of a compiler, the python code might take very long to run. It would be great to have a way to set the compilation mode for check_start_vals as well.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions