34
34
35
35
from arviz import InferenceData , dict_to_dataset
36
36
from arviz .data .base import make_attrs
37
- from fastprogress .fastprogress import progress_bar
38
37
from pytensor .graph .basic import Variable
38
+ from rich .console import Console
39
+ from rich .progress import Progress
40
+ from rich .theme import Theme
39
41
from typing_extensions import Protocol , TypeAlias
40
42
41
43
import pymc as pm
65
67
RandomSeed ,
66
68
RandomState ,
67
69
_get_seeds_per_chain ,
70
+ default_progress_theme ,
68
71
drop_warning_stat ,
69
72
get_untransformed_name ,
70
73
is_transformed_name ,
@@ -377,6 +380,7 @@ def sample(
377
380
cores : Optional [int ] = None ,
378
381
random_seed : RandomState = None ,
379
382
progressbar : bool = True ,
383
+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
380
384
step = None ,
381
385
var_names : Optional [Sequence [str ]] = None ,
382
386
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -406,6 +410,7 @@ def sample(
406
410
cores : Optional [int ] = None ,
407
411
random_seed : RandomState = None ,
408
412
progressbar : bool = True ,
413
+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
409
414
step = None ,
410
415
var_names : Optional [Sequence [str ]] = None ,
411
416
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -435,6 +440,7 @@ def sample(
435
440
cores : Optional [int ] = None ,
436
441
random_seed : RandomState = None ,
437
442
progressbar : bool = True ,
443
+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
438
444
step = None ,
439
445
var_names : Optional [Sequence [str ]] = None ,
440
446
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -761,6 +767,7 @@ def sample(
761
767
"tune" : tune ,
762
768
"var_names" : var_names ,
763
769
"progressbar" : progressbar ,
770
+ "progressbar_theme" : progressbar_theme ,
764
771
"model" : model ,
765
772
"cores" : cores ,
766
773
"callback" : callback ,
@@ -983,6 +990,7 @@ def _sample(
983
990
trace : IBaseTrace ,
984
991
tune : int ,
985
992
model : Optional [Model ] = None ,
993
+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
986
994
callback = None ,
987
995
** kwargs ,
988
996
) -> None :
@@ -1010,6 +1018,8 @@ def _sample(
1010
1018
tune : int
1011
1019
Number of iterations to tune.
1012
1020
model : Model (optional if in ``with`` context)
1021
+ progressbar_theme : Theme
1022
+ Optional custom theme for the progress bar.
1013
1023
"""
1014
1024
skip_first = kwargs .get ("skip_first" , 0 )
1015
1025
@@ -1026,19 +1036,16 @@ def _sample(
1026
1036
)
1027
1037
_pbar_data = {"chain" : chain , "divergences" : 0 }
1028
1038
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1029
- if progressbar :
1030
- sampling = progress_bar (sampling_gen , total = draws , display = progressbar )
1031
- sampling .comment = _desc .format (** _pbar_data )
1032
- else :
1033
- sampling = sampling_gen
1034
- try :
1035
- for it , diverging in enumerate (sampling ):
1036
- if it >= skip_first and diverging :
1037
- _pbar_data ["divergences" ] += 1
1038
- if progressbar :
1039
- sampling .comment = _desc .format (** _pbar_data )
1040
- except KeyboardInterrupt :
1041
- pass
1039
+ with Progress (console = Console (theme = progressbar_theme )) as progress :
1040
+ try :
1041
+ task = progress .add_task (_desc .format (** _pbar_data ), total = draws , visible = progressbar )
1042
+ for it , diverging in enumerate (sampling_gen ):
1043
+ if it >= skip_first and diverging :
1044
+ _pbar_data ["divergences" ] += 1
1045
+ progress .update (task , advance = 1 )
1046
+ progress .update (task , advance = 1 , completed = True )
1047
+ except KeyboardInterrupt :
1048
+ pass
1042
1049
1043
1050
1044
1051
def _iter_sample (
@@ -1131,6 +1138,7 @@ def _mp_sample(
1131
1138
random_seed : Sequence [RandomSeed ],
1132
1139
start : Sequence [PointType ],
1133
1140
progressbar : bool = True ,
1141
+ progressbar_theme : Optional [Theme ] = default_progress_theme ,
1134
1142
traces : Sequence [IBaseTrace ],
1135
1143
model : Optional [Model ] = None ,
1136
1144
callback : Optional [SamplingIteratorCallback ] = None ,
@@ -1158,6 +1166,8 @@ def _mp_sample(
1158
1166
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
1159
1167
progressbar : bool
1160
1168
Whether or not to display a progress bar in the command line.
1169
+ progressbar_theme : Theme
1170
+ Optional custom theme for the progress bar.
1161
1171
traces
1162
1172
Recording backends for each chain.
1163
1173
model : Model (optional if in ``with`` context)
@@ -1182,6 +1192,7 @@ def _mp_sample(
1182
1192
start_points = start ,
1183
1193
step_method = step ,
1184
1194
progressbar = progressbar ,
1195
+ progressbar_theme = progressbar_theme ,
1185
1196
mp_ctx = mp_ctx ,
1186
1197
)
1187
1198
try :
0 commit comments