|
36 | 36 | from arviz.data.base import make_attrs
|
37 | 37 | from pytensor.graph.basic import Variable
|
38 | 38 | from rich.console import Console
|
| 39 | +from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn |
39 | 40 | from rich.theme import Theme
|
40 | 41 | from threadpoolctl import threadpool_limits
|
41 | 42 | from typing_extensions import Protocol
|
@@ -1075,16 +1076,28 @@ def _sample(
|
1075 | 1076 | )
|
1076 | 1077 | _pbar_data = {"chain": chain, "divergences": 0}
|
1077 | 1078 | _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
|
1078 |
| - with CustomProgress( |
1079 |
| - console=Console(theme=progressbar_theme), disable=not progressbar |
1080 |
| - ) as progress: |
| 1079 | + |
| 1080 | + progress = CustomProgress( |
| 1081 | + "[progress.description]{task.description}", |
| 1082 | + BarColumn(), |
| 1083 | + "[progress.percentage]{task.percentage:>3.0f}%", |
| 1084 | + TimeRemainingColumn(), |
| 1085 | + TextColumn("/"), |
| 1086 | + TimeElapsedColumn(), |
| 1087 | + console=Console(theme=progressbar_theme), |
| 1088 | + disable=not progressbar, |
| 1089 | + ) |
| 1090 | + |
| 1091 | + with progress: |
1081 | 1092 | try:
|
1082 |
| - task = progress.add_task(_desc.format(**_pbar_data), total=draws) |
| 1093 | + task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws) |
1083 | 1094 | for it, diverging in enumerate(sampling_gen):
|
1084 | 1095 | if it >= skip_first and diverging:
|
1085 | 1096 | _pbar_data["divergences"] += 1
|
1086 |
| - progress.update(task) |
1087 |
| - progress.update(task, refresh=True, advance=1, completed=True) |
| 1097 | + progress.update(task, description=_desc.format(**_pbar_data), completed=it) |
| 1098 | + progress.update( |
| 1099 | + task, description=_desc.format(**_pbar_data), completed=draws, refresh=True |
| 1100 | + ) |
1088 | 1101 | except KeyboardInterrupt:
|
1089 | 1102 | pass
|
1090 | 1103 |
|
|
0 commit comments