Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
449 changes: 216 additions & 233 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ rayon = "1.9.0"
# Keep arrow in sync with nuts-rs requirements
arrow = { version = "52.0.0", default-features = false, features = ["ffi"] }
anyhow = "1.0.72"
itertools = "0.13.0"
bridgestan = "2.5.0"
itertools = "0.14.0"
bridgestan = "2.6.1"
rand_distr = "0.4.3"
smallvec = "1.11.0"
upon = { version = "0.8.1", default-features = false, features = [] }
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ name = "nutpie"
description = "Sample Stan or PyMC models"
authors = [{ name = "PyMC Developers", email = "[email protected]" }]
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.10,<3.13"
license = { text = "MIT" }
classifiers = [
"Programming Language :: Rust",
Expand All @@ -26,6 +26,7 @@ dependencies = [
"xarray >= 2023.06.0",
"arviz >= 0.15.0",
]
dynamic = ["version"]

[project.optional-dependencies]
stan = ["bridgestan >= 2.4.1"]
Expand Down
2 changes: 1 addition & 1 deletion python/nutpie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from nutpie.sample import sample

__version__: str = _lib.__version__
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]
__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"]
23 changes: 21 additions & 2 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import itertools
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from functools import wraps
from importlib.util import find_spec
Expand Down Expand Up @@ -69,7 +70,7 @@ def seeded_array_fn(seed: SeedType = None):
for name, shape in zip(names, shapes, strict=True):
initial_value = initial_value_dict[name]
n = int(np.prod(initial_value.shape))
if initial_value.shape != shape:
if tuple(initial_value.shape) != tuple(shape):
raise ValueError(
f"Size of initial value for {name} is {initial_value.shape}, "
f"expected {shape}"
Expand Down Expand Up @@ -218,6 +219,7 @@ def make_user_data(shared_vars, shared_data):
def _compile_pymc_model_numba(
model: "pm.Model",
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
var_names: Iterable[str] | None = None,
**kwargs,
) -> CompiledPyMCModel:
if find_spec("numba") is None:
Expand All @@ -242,6 +244,7 @@ def _compile_pymc_model_numba(
compute_grad=True,
join_expanded=True,
pymc_initial_point_fn=pymc_initial_point_fn,
var_names=var_names,
)

expand_fn = expand_fn_pt.vm.jit_fn
Expand Down Expand Up @@ -337,6 +340,7 @@ def _compile_pymc_model_jax(
*,
gradient_backend=None,
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
var_names: Iterable[str] | None = None,
**kwargs,
):
if find_spec("jax") is None:
Expand Down Expand Up @@ -366,6 +370,7 @@ def _compile_pymc_model_jax(
compute_grad=gradient_backend == "pytensor",
join_expanded=False,
pymc_initial_point_fn=pymc_initial_point_fn,
var_names=var_names,
)

logp_fn = logp_fn_pt.vm.jit_fn
Expand Down Expand Up @@ -441,6 +446,7 @@ def compile_pymc_model(
default_initialization_strategy: Literal[
"support_point", "prior"
] = "support_point",
var_names: Iterable[str] | None = None,
**kwargs,
) -> CompiledModel:
"""Compile necessary functions for sampling a pymc model.
Expand All @@ -464,6 +470,8 @@ def compile_pymc_model(
initial_points : dict
Initial value (strategies) to use instead of what's specified in
`Model.initial_values`.
var_names : list[str] | None
A list of variables to store in the trace. If None, store all variables.
Returns
-------
compiled_model : CompiledPyMCModel
Expand Down Expand Up @@ -493,13 +501,17 @@ def compile_pymc_model(
if gradient_backend == "jax":
raise ValueError("Gradient backend cannot be jax when using numba backend")
return _compile_pymc_model_numba(
model=model, pymc_initial_point_fn=initial_point_fn, **kwargs
model=model,
pymc_initial_point_fn=initial_point_fn,
var_names=var_names,
**kwargs,
)
elif backend.lower() == "jax":
return _compile_pymc_model_jax(
model=model,
gradient_backend=gradient_backend,
pymc_initial_point_fn=initial_point_fn,
var_names=var_names,
**kwargs,
)
else:
Expand Down Expand Up @@ -542,6 +554,7 @@ def _make_functions(
compute_grad: bool,
join_expanded: bool,
pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]],
var_names: Iterable[str] | None = None,
) -> tuple[
int,
int,
Expand All @@ -568,6 +581,8 @@ def _make_functions(
pymc_initial_point_fn: Callable
Initial point function created by
pymc.initial_point.make_initial_point_fn
var_names:
Names of variables to store in the trace. Defaults to all variables.

Returns
-------
Expand Down Expand Up @@ -673,6 +688,10 @@ def _make_functions(
var for var in model.unobserved_value_vars if var.name not in joined_names
]

if var_names is not None:
names = set(var_names)
remaining_rvs = [var for var in remaining_rvs if var.name in names]

all_names = joined_names + remaining_rvs

all_names = joined_names.copy()
Expand Down
38 changes: 19 additions & 19 deletions src/pyfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use arrow::{
array::{
Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder,
Int64Builder, ListBuilder, PrimitiveBuilder, StructBuilder,
Array, ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder,
LargeListBuilder, PrimitiveBuilder, StructBuilder,
},
datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type},
};
Expand All @@ -16,7 +16,7 @@ use pyo3::{
Bound, Py, PyAny, PyErr, Python,
};
use rand::Rng;
use rand_distr::{Distribution, StandardNormal, Uniform};
use rand_distr::{Distribution, Uniform};
use smallvec::SmallVec;
use thiserror::Error;

Expand All @@ -37,21 +37,21 @@ impl PyVariable {
ExpandDtype::Float64 {} => DataType::Float64,
ExpandDtype::Float32 {} => DataType::Float32,
ExpandDtype::Int64 {} => DataType::Int64,
ExpandDtype::BooleanArray { tensor_type } => {
ExpandDtype::BooleanArray { tensor_type: _ } => {
let field = Arc::new(Field::new("item", DataType::Boolean, false));
DataType::FixedSizeList(field, tensor_type.size() as i32)
DataType::LargeList(field)
}
ExpandDtype::ArrayFloat64 { tensor_type: _ } => {
let field = Arc::new(Field::new("item", DataType::Float64, true));
DataType::List(field)
DataType::LargeList(field)
}
ExpandDtype::ArrayFloat32 { tensor_type } => {
ExpandDtype::ArrayFloat32 { tensor_type: _ } => {
let field = Arc::new(Field::new("item", DataType::Float32, false));
DataType::FixedSizeList(field, tensor_type.size() as i32)
DataType::LargeList(field)
}
ExpandDtype::ArrayInt64 { tensor_type } => {
ExpandDtype::ArrayInt64 { tensor_type: _ } => {
let field = Arc::new(Field::new("item", DataType::Int64, false));
DataType::FixedSizeList(field, tensor_type.size() as i32)
DataType::LargeList(field)
}
}
}
Expand Down Expand Up @@ -368,10 +368,10 @@ impl DrawStorage for PyTrace {
)?;
builder.append_value(value.extract().expect("Return value from expand function could not be converted to int64"))
},
ExpandDtype::BooleanArray { tensor_type} => {
let builder: &mut FixedSizeListBuilder<Box<dyn ArrayBuilder>> =
ExpandDtype::BooleanArray { tensor_type } => {
let builder: &mut LargeListBuilder<Box<dyn ArrayBuilder>> =
self.builder.field_builder(i).context(
"Builder has incorrect type",
"Builder has incorrect type. Expected LargeListBuilder of Bool",
)?;
let value_builder = builder.values().as_any_mut().downcast_mut::<BooleanBuilder>().context("Could not downcast builder to boolean type")?;
let values: PyReadonlyArray1<bool> = value.extract().context("Could not convert object to array")?;
Expand All @@ -383,9 +383,9 @@ impl DrawStorage for PyTrace {
},
ExpandDtype::ArrayFloat64 { tensor_type } => {
//let builder: &mut FixedSizeListBuilder<Box<dyn ArrayBuilder>> =
let builder: &mut ListBuilder<Box<dyn ArrayBuilder>> =
let builder: &mut LargeListBuilder<Box<dyn ArrayBuilder>> =
self.builder.field_builder(i).context(
"Builder has incorrect type",
"Builder has incorrect type. Expected LargeListBuilder of Float64",
)?;
let value_builder = builder.values().as_any_mut().downcast_mut::<PrimitiveBuilder<Float64Type>>().context("Could not downcast builder to float64 type")?;
let values: PyReadonlyArray1<f64> = value.extract().context("Could not convert object to array")?;
Expand All @@ -396,9 +396,9 @@ impl DrawStorage for PyTrace {
builder.append(true);
},
ExpandDtype::ArrayFloat32 { tensor_type } => {
let builder: &mut FixedSizeListBuilder<Box<dyn ArrayBuilder>> =
let builder: &mut LargeListBuilder<Box<dyn ArrayBuilder>> =
self.builder.field_builder(i).context(
"Builder has incorrect type",
"Builder has incorrect type. Expected LargeListBuilder of Float32",
)?;
let value_builder = builder.values().as_any_mut().downcast_mut::<PrimitiveBuilder<Float32Type>>().context("Could not downcast builder to float32 type")?;
let values: PyReadonlyArray1<f32> = value.extract().context("Could not convert object to array")?;
Expand All @@ -409,9 +409,9 @@ impl DrawStorage for PyTrace {
builder.append(true);
},
ExpandDtype::ArrayInt64 {tensor_type} => {
let builder: &mut FixedSizeListBuilder<Box<dyn ArrayBuilder>> =
let builder: &mut LargeListBuilder<Box<dyn ArrayBuilder>> =
self.builder.field_builder(i).context(
"Builder has incorrect type",
"Builder has incorrect type. Expected LargeListBuilder of Int64",
)?;
let value_builder = builder.values().as_any_mut().downcast_mut::<PrimitiveBuilder<Int64Type>>().context("Could not downcast builder to i64 type")?;
let values: PyReadonlyArray1<i64> = value.extract().context("Could not convert object to array")?;
Expand Down
12 changes: 7 additions & 5 deletions src/pymc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::{ffi::c_void, fmt::Display, sync::Arc};

use anyhow::{bail, Context, Result};
use arrow::{
array::{Array, FixedSizeListArray, Float64Array, StructArray},
array::{Array, Float64Array, LargeListArray, StructArray},
buffer::OffsetBuffer,
datatypes::{DataType, Field, Fields},
};
use itertools::{izip, Itertools};
Expand All @@ -13,7 +14,6 @@ use pyo3::{
types::{PyAnyMethods, PyList},
Bound, Py, PyAny, PyObject, PyResult, Python,
};
use rand::{distributions::Uniform, prelude::Distribution};

use thiserror::Error;

Expand Down Expand Up @@ -170,11 +170,13 @@ impl<'model> DrawStorage for PyMcTrace<'model> {
fn finalize(self) -> Result<Arc<dyn Array>> {
let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes)
.map(|(data, name, size)| {
assert!(data.len() % size == 0);
let num_arrays = data.len() / size;
let data = Float64Array::from(data);
let item_field = Arc::new(Field::new("item", DataType::Float64, false));
let array =
FixedSizeListArray::new(item_field.clone(), size as _, Arc::new(data), None);
let field = Field::new(name, DataType::FixedSizeList(item_field, size as _), false);
let offsets = OffsetBuffer::from_lengths((0..num_arrays).into_iter().map(|_| size));
let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None);
let field = Field::new(name, DataType::LargeList(item_field), false);
(Arc::new(field), Arc::new(array) as Arc<dyn Array>)
})
.unzip();
Expand Down
47 changes: 47 additions & 0 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,53 @@ def test_pymc_model_shared(backend, gradient_backend):
nutpie.sample(compiled3, chains=1)


@parameterize_backends
def test_pymc_var_names(backend, gradient_backend):
with pm.Model() as model:
mu = pm.Data("mu", -0.1)
sigma = pm.Data("sigma", np.ones(3))
a = pm.Normal("a", mu=mu, sigma=sigma, shape=3)

b = pm.Deterministic("b", mu * a)
pm.Deterministic("c", mu * b)

compiled = nutpie.compile_pymc_model(
model,
backend=backend,
gradient_backend=gradient_backend,
var_names=None,
)
trace = nutpie.sample(compiled, chains=1, seed=1)

# Check that variables are stored
assert hasattr(trace.posterior, "b")
assert hasattr(trace.posterior, "c")

compiled = nutpie.compile_pymc_model(
model,
backend=backend,
gradient_backend=gradient_backend,
var_names=[],
)
trace = nutpie.sample(compiled, chains=1, seed=1)

# Check that variables are stored
assert not hasattr(trace.posterior, "b")
assert not hasattr(trace.posterior, "c")

compiled = nutpie.compile_pymc_model(
model,
backend=backend,
gradient_backend=gradient_backend,
var_names=["b"],
)
trace = nutpie.sample(compiled, chains=1, seed=1)

# Check that variables are stored
assert hasattr(trace.posterior, "b")
assert not hasattr(trace.posterior, "c")


@pytest.mark.parametrize(
("backend", "gradient_backend"),
[
Expand Down
Loading