Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include("aggregators/prioritytable.jl")
include("aggregators/directcr.jl")
include("aggregators/rssacr.jl")
include("aggregators/rdirect.jl")
include("aggregators/extrande.jl")
include("aggregators/coevolve.jl")

# spatial:
Expand Down Expand Up @@ -84,6 +85,7 @@ export Direct, DirectFW, SortingDirect, DirectCR
export BracketData, RSSA
export FRM, FRMFW, NRM
export RSSACR, RDirect
export Extrande
export Coevolve

export get_num_majumps, needs_depgraph, needs_vartojumps_map
Expand Down
13 changes: 12 additions & 1 deletion src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,18 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
"""
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

"""
The Extrande method for simulating variable rate jumps with user-defined bounds
on jumps rates and validity intervals via rejection.

Stochastic Simulation of Biomolecular Networks in Dynamic Environments, Voliotis
M, Thomas P, Grima R, Bowsher CG, PLOS Computational Biology 12(6): e1004923.
(2016); doi.org/10.1371/journal.pcbi.1004923
"""
struct Extrande <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve())
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), Extrande())

# For JumpProblem construction without an aggregator
struct NullAggregator <: AbstractAggregatorAlgorithm end
Expand All @@ -181,6 +191,7 @@ needs_vartojumps_map(aggregator::RSSACR) = true
# true if aggregator supports variable rates
supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false
supports_variablerates(aggregator::Coevolve) = true
supports_variablerates(aggregator::Extrande) = true

is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
Expand Down
117 changes: 117 additions & 0 deletions src/aggregators/extrande.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Define the aggregator.
struct Extrande <: AbstractAggregatorAlgorithm end

"""
Extrande sampling method for jumps with defined rate bounds.
"""

nullaffect!(integrator) = nothing
const NullAffectJump = ConstantRateJump((u, p, t) -> 0.0, nullaffect!)

mutable struct ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG} <:
AbstractSSAJumpAggregator
next_jump::Int
prev_jump::Int
next_jump_time::T
end_time::T
cur_rates::Vector{T}
sum_rate::T
ma_jumps::S
rate_bnds::F3
wds::F4
rates::F1
affects!::F2
save_positions::Tuple{Bool, Bool}
rng::RNG
end

function ExtrandeJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S,
rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG;
rate_bounds::F3, windows::F4,
kwargs...) where {T, S, F1, F2, F3, F4, RNG}
ExtrandeJumpAggregation{T, S, F1, F2, F3, F4, RNG}(nj, nj, njt, et, crs, sr, maj,
rate_bounds, windows, rs, affs!, sps,
rng)
end

############################# Required Functions ##############################
function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; variable_jumps = (), kwargs...)
ma_jumps_ = !isnothing(ma_jumps) ? ma_jumps : ()
rates, affects! = get_jump_info_fwrappers(u, p, t,
(constant_jumps..., variable_jumps..., ma_jumps_...,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense as MassActionJumps don't use the same interface as ConstantRateJumps or VariableRateJumps. So while you can merge the rate functions from the latter two you need to keep track of MassActionJumps separately and handle their rate calculation separately. See what we do in Direct for handling them.

NullAffectJump))
rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t,
(constant_jumps..., variable_jumps..., ma_jumps_...,
NullAffectJump))
build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds,
windows = wnds, kwargs...)
end

# set up a new simulation and calculate the first jump / jump time
function initialize!(p::ExtrandeJumpAggregation, integrator, u, params, t)
p.end_time = integrator.sol.prob.tspan[2]
generate_jumps!(p, integrator, u, params, t)
end

# execute one jump, changing the system state
@inline function execute_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t)
# execute jump
u = update_state!(p, integrator, u)
nothing
end

@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t)
ttnj = typemax(typeof(t))
nextrx = zero(Int)
Wmin = typemax(typeof(t))
Bmax = typemax(typeof(t))

# Calculate the total rate bound and the largest common validity window.
if !isempty(p.rate_bnds)
Bmax = typeof(t)(0.)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Bmax = typeof(t)(0.)
Bmax = zero(t)

@inbounds for i in 1:length(p.wds)
Wmin = min(Wmin, p.wds[i](u, params, t))
Bmax += p.rate_bnds[i](u, params, t)
end
end

# Rejection sampling.
if !isempty(p.rates)
nextrx = length(p.rates)
idx = 1
prop_ttnj = randexp(p.rng) / Bmax
if prop_ttnj < Wmin
fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...)

prev_rate = zero(t)
cur_rates = p.cur_rates
@inbounds for i in idx:length(cur_rates)
cur_rates[i] = cur_rates[i] + prev_rate
prev_rate = cur_rates[i]
end

UBmax = rand(p.rng) * Bmax
ttnj = prop_ttnj
if p.cur_rates[end] ≥ UBmax
nextrx = 1
@inbounds while p.cur_rates[nextrx] < UBmax
nextrx += 1
end
end
else
ttnj = Wmin
end
end

return nextrx, ttnj
end

function generate_jumps!(p::ExtrandeJumpAggregation, integrator, u, params, t)
nextexj, ttnexj = next_extrande_jump(p, u, params, t)
p.next_jump = nextexj
p.next_jump_time = t + ttnexj

nothing
end
24 changes: 24 additions & 0 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,27 @@ function get_jump_info_fwrappers(u, p, t, constant_jumps)

rates, affects!
end

##### helpers for splitting variable rate jumps with rate bounds and without #####

function rate_window_function(jump)
# Assumes that if no window is given the rate bound is valid for all times.
return !(jump.rateinterval isa Nothing) ? jump.rateinterval : (u, p, t) -> Inf
end

function get_va_jump_bound_info_fwrapper(u, p, t, jumps)
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p), typeof(t)}}

if (jumps !== nothing) && !isempty(jumps)
rates = [j isa VariableRateJump ? RateWrapper(j.urate) : RateWrapper(j.rate)
for j in jumps]
wnds = [j isa VariableRateJump ? RateWrapper(rate_window_function(j)) :
RateWrapper((u, p, t) -> Inf) for j in jumps]
else
rates = Vector{RateWrapper}()
wnds = Vector{RateWrapper}()
end

rates, wnds
end
64 changes: 64 additions & 0 deletions test/extrande.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using DiffEqBase, JumpProcesses, OrdinaryDiffEq, Test
using StableRNGs
using Statistics
rng = StableRNG(48572)

f = function (du, u, p, t)
du[1] = 0.0
end

rate = (u, p, t) -> t < 5.0 ? 1.0 : 0.0
rbound = (u, p, t) -> 1.0
rinterval = (u, p, t) -> Inf
affect! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1)
jump = VariableRateJump(rate, affect!; urate = rbound, rateinterval = rinterval)

prob = ODEProblem(f, [0.0], (0.0, 10.0))
jump_prob = JumpProblem(prob, Extrande(), jump; rng = rng)

# Test that process doesn't jump when rate switches to 0.
sol = solve(jump_prob, Tsit5())
@test sol(5.0)[1] == sol[end][1]

# Birth-death process with time-varying birth rates.
Nsims = 1000000
u0 = [10.0]

function runsimulations(jump_prob, testts)
Psamp = zeros(Int, length(testts), Nsims)
for i in 1:Nsims
sol_ = solve(jump_prob, Tsit5())
Psamp[:, i] = getindex.(sol_(testts).u, 1)
end
mean(Psamp, dims = 2)
end

# Variable rate birth jumps.
rateb = (u, p, t) -> (0.1 * sin(t) + 0.2)
ratebbound = (u, p, t) -> 0.3
ratebwindow = (u, p, t) -> Inf
affectb! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1)
jumpb = VariableRateJump(rateb, affectb!; urate = ratebbound, rateinterval = ratebwindow)

# Constant rate death jumps.
rated = (u, p, t) -> u[1] * 0.08
affectd! = (integrator) -> (integrator.u[1] = integrator.u[1] - 1)
jumpd = ConstantRateJump(rated, affectd!)

# Problem definition.
bd_prob = ODEProblem(f, u0, (0.0, 2pi))
jump_bd_prob = JumpProblem(bd_prob, Extrande(), jumpb, jumpd)

test_times = range(1.0, stop = 2pi, length = 3)
means = runsimulations(jump_bd_prob, test_times)

# ODE for the mean.
fu = function (du, u, p, t)
du[1] = (0.1 * sin(t) + 0.2) - (u[1] * 0.08)
end

ode_prob = ODEProblem(fu, u0, (0.0, 2 * pi))
ode_sol = solve(ode_prob, Tsit5())

# Test extrande against the ODE mean.
@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3))
4 changes: 2 additions & 2 deletions test/hawkes_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ h = [Float64[]]

Eλ, Varλ = expected_stats_hawkes_problem(p, tspan)

algs = (Direct(), Coevolve(), Coevolve())
algs = (Direct(), Coevolve(), Coevolve(), Extrande())
uselrate = zeros(Bool, length(algs))
uselrate[3] = true
Nsims = 250
Expand All @@ -122,7 +122,7 @@ for (i, alg) in enumerate(algs)
reset_history!(h)
sols[n] = solve(jump_prob, stepper)
end
if typeof(alg) <: Coevolve
if typeof(alg) <: Union{Coevolve, Extrande}
λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols))
else
cols = length(sols[1].u[1].u)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ using JumpProcesses, DiffEqBase, SafeTestsets
@time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end
@time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end
@time @safetestset "Ficticious Jump " begin include("extrande.jl") end
end