|
13 | 13 | from smac.runhistory.runhistory import RunHistory, StatusType
|
14 | 14 |
|
15 | 15 | from autoPyTorch.api.base_task import BaseTask
|
16 |
| -from autoPyTorch.api.results_manager import ResultsManager, STATUS2MSG, cost2metric |
| 16 | +from autoPyTorch.api.results_manager import ResultsManager, SearchResults, STATUS2MSG, cost2metric |
17 | 17 | from autoPyTorch.metrics import accuracy, balanced_accuracy, log_loss
|
18 | 18 |
|
19 | 19 |
|
@@ -92,6 +92,23 @@ def _check_metric_dict(metric_dict, status_types):
|
92 | 92 | for s, isnan in zip(status_types, np.isnan(vals))])
|
93 | 93 |
|
94 | 94 |
|
| 95 | +def test_extract_results_from_run_history(): |
| 96 | + # test the raise error for the `status_msg is None` |
| 97 | + run_history = RunHistory() |
| 98 | + cs = ConfigurationSpace() |
| 99 | + config = Configuration(cs, {}) |
| 100 | + run_history.add( |
| 101 | + config=config, |
| 102 | + cost=0.0, |
| 103 | + time=1.0, |
| 104 | + status=StatusType.CAPPED, |
| 105 | + ) |
| 106 | + with pytest.raises(ValueError) as excinfo: |
| 107 | + SearchResults(metric=accuracy, scoring_functions=[], run_history=run_history) |
| 108 | + |
| 109 | + assert excinfo._excinfo[0] == ValueError |
| 110 | + |
| 111 | + |
95 | 112 | def test_search_results_sprint_statistics():
|
96 | 113 | api = BaseTask()
|
97 | 114 | for method in ['get_search_results', 'sprint_statistics', 'get_incumbent_results']:
|
|
0 commit comments