Skip to content

Commit 68fc77f

Browse files
[Bug] Fix random halt problems on traditional pipelines (#147)
* [feat] Fix random halt problems on traditional pipelines * Documentation update * Fix flake * Flake due to kernel pca errors
1 parent 7bcde56 commit 68fc77f

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

autoPyTorch/evaluation/abstract_evaluator.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,23 @@
5454

5555

5656
class MyTraditionalTabularClassificationPipeline(BaseEstimator):
57+
"""
58+
A wrapper class that holds a pipeline for traditional classification.
59+
Estimators like CatBoost, and Random Forest are considered traditional machine
60+
learning models and are fitted before neural architecture search.
61+
62+
This class is an interface to fit a pipeline containing a traditional machine
63+
learning model, and is the final object that is stored for inference.
64+
65+
Attributes:
66+
dataset_properties (Dict[str, Any]):
67+
A dictionary containing dataset specific information
68+
random_state (Optional[Union[int, np.random.RandomState]]):
69+
Object that contains a seed and allows for reproducible results
70+
init_params (Optional[Dict]):
71+
An optional dictionary that is passed to the pipeline's steps. It complies
72+
a similar function as the kwargs
73+
"""
5774
def __init__(self, config: str,
5875
dataset_properties: Dict[str, Any],
5976
random_state: Optional[Union[int, np.random.RandomState]] = None,
@@ -98,6 +115,21 @@ def get_default_pipeline_options() -> Dict[str, Any]:
98115

99116

100117
class DummyClassificationPipeline(DummyClassifier):
118+
"""
119+
A wrapper class that holds a pipeline for dummy classification.
120+
121+
A wrapper over DummyClassifier of scikit learn. This estimator is considered the
122+
worst performing model. In case of failure, at least this model will be fitted.
123+
124+
Attributes:
125+
dataset_properties (Dict[str, Any]):
126+
A dictionary containing dataset specific information
127+
random_state (Optional[Union[int, np.random.RandomState]]):
128+
Object that contains a seed and allows for reproducible results
129+
init_params (Optional[Dict]):
130+
An optional dictionary that is passed to the pipeline's steps. It complies
131+
a similar function as the kwargs
132+
"""
101133
def __init__(self, config: Configuration,
102134
random_state: Optional[Union[int, np.random.RandomState]] = None,
103135
init_params: Optional[Dict] = None
@@ -148,6 +180,21 @@ def get_default_pipeline_options() -> Dict[str, Any]:
148180

149181

150182
class DummyRegressionPipeline(DummyRegressor):
183+
"""
184+
A wrapper class that holds a pipeline for dummy regression.
185+
186+
A wrapper over DummyRegressor of scikit learn. This estimator is considered the
187+
worst performing model. In case of failure, at least this model will be fitted.
188+
189+
Attributes:
190+
dataset_properties (Dict[str, Any]):
191+
A dictionary containing dataset specific information
192+
random_state (Optional[Union[int, np.random.RandomState]]):
193+
Object that contains a seed and allows for reproducible results
194+
init_params (Optional[Dict]):
195+
An optional dictionary that is passed to the pipeline's steps. It complies
196+
a similar function as the kwargs
197+
"""
151198
def __init__(self, config: Configuration,
152199
random_state: Optional[Union[int, np.random.RandomState]] = None,
153200
init_params: Optional[Dict] = None) -> None:
@@ -351,7 +398,7 @@ def _get_pipeline(self) -> BaseEstimator:
351398
if isinstance(self.configuration, int):
352399
pipeline = self.pipeline_class(config=self.configuration,
353400
random_state=np.random.RandomState(self.seed),
354-
init_params=self.fit_dictionary)
401+
init_params=self._init_params)
355402
elif isinstance(self.configuration, Configuration):
356403
pipeline = self.pipeline_class(config=self.configuration,
357404
dataset_properties=self.dataset_properties,
@@ -364,7 +411,7 @@ def _get_pipeline(self) -> BaseEstimator:
364411
pipeline = self.pipeline_class(config=self.configuration,
365412
dataset_properties=self.dataset_properties,
366413
random_state=np.random.RandomState(self.seed),
367-
init_params=self.fit_dictionary)
414+
init_params=self._init_params)
368415
else:
369416
raise ValueError("Invalid configuration entered")
370417
return pipeline

test/test_pipeline/components/preprocessing/test_feature_preprocessor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import flaky
2+
13
import numpy as np
24

35
import pytest
@@ -51,6 +53,7 @@ def test_feature_preprocessor(self, fit_dictionary_tabular, preprocessor):
5153
transformed = column_transformer.transform(X['X_train'])
5254
assert isinstance(transformed, np.ndarray)
5355

56+
@flaky.flaky(max_runs=3)
5457
def test_pipeline_fit_include(self, fit_dictionary_tabular, preprocessor):
5558
"""
5659
This test ensures that a tabular classification

0 commit comments

Comments
 (0)