@@ -436,7 +436,6 @@ def __init__(self, *args, ar_order, constant_term, **kwargs):
436
436
437
437
def update (self , node : Node ):
438
438
"""Return the update mapping for the noise RV."""
439
- # Since noise is a shared variable it shows up as the last node input
440
439
return {node .inputs [- 1 ]: node .outputs [0 ]}
441
440
442
441
@@ -658,13 +657,13 @@ def step(*args):
658
657
ar_ = pt .concatenate ([init_ , innov_ .T ], axis = - 1 )
659
658
660
659
ar_op = AutoRegressiveRV (
661
- inputs = [rhos_ , sigma_ , init_ , steps_ ],
660
+ inputs = [rhos_ , sigma_ , init_ , steps_ , noise_rng ],
662
661
outputs = [noise_next_rng , ar_ ],
663
662
ar_order = ar_order ,
664
663
constant_term = constant_term ,
665
664
)
666
665
667
- ar = ar_op (rhos , sigma , init_dist , steps )
666
+ ar = ar_op (rhos , sigma , init_dist , steps , noise_rng )
668
667
return ar
669
668
670
669
@@ -731,7 +730,6 @@ class GARCH11RV(SymbolicRandomVariable):
731
730
732
731
def update (self , node : Node ):
733
732
"""Return the update mapping for the noise RV."""
734
- # Since noise is a shared variable it shows up as the last node input
735
733
return {node .inputs [- 1 ]: node .outputs [0 ]}
736
734
737
735
@@ -797,7 +795,6 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
797
795
# In this case the size of the init_dist depends on the parameters shape
798
796
batch_size = pt .broadcast_shape (omega , alpha_1 , beta_1 , initial_vol )
799
797
init_dist = change_dist_size (init_dist , batch_size )
800
- # initial_vol = initial_vol * pt.ones(batch_size)
801
798
802
799
# Create OpFromGraph representing random draws from GARCH11 process
803
800
# Variables with underscore suffix are dummy inputs into the OpFromGraph
@@ -819,7 +816,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
819
816
820
817
(y_t , _ ), innov_updates_ = pytensor .scan (
821
818
fn = step ,
822
- outputs_info = [init_ , initial_vol_ * pt .ones ( batch_size )],
819
+ outputs_info = [init_ , pt .broadcast_to ( initial_vol_ . astype ( "floatX" ), init_ . shape )],
823
820
non_sequences = [omega_ , alpha_1_ , beta_1_ , noise_rng ],
824
821
n_steps = steps_ ,
825
822
strict = True ,
@@ -831,11 +828,11 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
831
828
)
832
829
833
830
garch11_op = GARCH11RV (
834
- inputs = [omega_ , alpha_1_ , beta_1_ , initial_vol_ , init_ , steps_ ],
831
+ inputs = [omega_ , alpha_1_ , beta_1_ , initial_vol_ , init_ , steps_ , noise_rng ],
835
832
outputs = [noise_next_rng , garch11_ ],
836
833
)
837
834
838
- garch11 = garch11_op (omega , alpha_1 , beta_1 , initial_vol , init_dist , steps )
835
+ garch11 = garch11_op (omega , alpha_1 , beta_1 , initial_vol , init_dist , steps , noise_rng )
839
836
return garch11
840
837
841
838
@@ -891,14 +888,13 @@ class EulerMaruyamaRV(SymbolicRandomVariable):
891
888
ndim_supp = 1
892
889
_print_name = ("EulerMaruyama" , "\\ operatorname{EulerMaruyama}" )
893
890
894
- def __init__ (self , * args , dt , sde_fn , ** kwargs ):
891
+ def __init__ (self , * args , dt : float , sde_fn : Callable , ** kwargs ):
895
892
self .dt = dt
896
893
self .sde_fn = sde_fn
897
894
super ().__init__ (* args , ** kwargs )
898
895
899
896
def update (self , node : Node ):
900
897
"""Return the update mapping for the noise RV."""
901
- # Since noise is a shared variable it shows up as the last node input
902
898
return {node .inputs [- 1 ]: node .outputs [0 ]}
903
899
904
900
@@ -1010,14 +1006,14 @@ def step(*prev_args):
1010
1006
)
1011
1007
1012
1008
eulermaruyama_op = EulerMaruyamaRV (
1013
- inputs = [init_ , steps_ , * sde_pars_ ],
1009
+ inputs = [init_ , steps_ , * sde_pars_ , noise_rng ],
1014
1010
outputs = [noise_next_rng , sde_out_ ],
1015
1011
dt = dt ,
1016
1012
sde_fn = sde_fn ,
1017
1013
signature = f"(),(s),{ ',' .join ('()' for _ in sde_pars_ )} ->(),(t)" ,
1018
1014
)
1019
1015
1020
- eulermaruyama = eulermaruyama_op (init_dist , steps , * sde_pars )
1016
+ eulermaruyama = eulermaruyama_op (init_dist , steps , * sde_pars , noise_rng )
1021
1017
return eulermaruyama
1022
1018
1023
1019
0 commit comments