12
12
import unittest .mock
13
13
import warnings
14
14
from abc import abstractmethod
15
- from typing import Any , Callable , Dict , List , Optional , Tuple , Union , cast
15
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
16
16
17
17
from ConfigSpace .configuration_space import Configuration , ConfigurationSpace
18
18
19
19
import dask
20
+ import dask .distributed
20
21
21
22
import joblib
22
23
38
39
from autoPyTorch .datasets .base_dataset import BaseDataset
39
40
from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
40
41
from autoPyTorch .ensemble .ensemble_builder import EnsembleBuilderManager
41
- from autoPyTorch .ensemble .ensemble_selection import EnsembleSelection
42
42
from autoPyTorch .ensemble .singlebest_ensemble import SingleBest
43
43
from autoPyTorch .evaluation .abstract_evaluator import fit_and_suppress_warnings
44
44
from autoPyTorch .evaluation .tae import ExecuteTaFuncWithQueue , get_cost_of_crash
45
45
from autoPyTorch .optimizer .smbo import AutoMLSMBO
46
46
from autoPyTorch .pipeline .base_pipeline import BasePipeline
47
- from autoPyTorch .pipeline .components .setup .traditional_ml .classifier_models import get_available_classifiers
47
+ from autoPyTorch .pipeline .components .setup .traditional_ml .traditional_learner import get_available_traditional_learners
48
48
from autoPyTorch .pipeline .components .training .metrics .base import autoPyTorchMetric
49
49
from autoPyTorch .pipeline .components .training .metrics .utils import calculate_score , get_metrics
50
50
from autoPyTorch .utils .common import FitRequirement , replace_string_bool_to_bool
@@ -198,7 +198,7 @@ def __init__(
198
198
# examples. Nevertheless, multi-process runs
199
199
# have spawn as requirement to reduce the
200
200
# possibility of a deadlock
201
- self ._dask_client = None
201
+ self ._dask_client : Optional [ dask . distributed . Client ] = None
202
202
self ._multiprocessing_context = 'forkserver'
203
203
if self .n_jobs == 1 :
204
204
self ._multiprocessing_context = 'fork'
@@ -590,7 +590,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs:
590
590
memory_limit = self ._memory_limit
591
591
if memory_limit is not None :
592
592
memory_limit = int (math .ceil (memory_limit ))
593
- available_classifiers = get_available_classifiers ()
593
+ available_classifiers = get_available_traditional_learners ()
594
594
dask_futures = []
595
595
596
596
total_number_classifiers = len (available_classifiers )
@@ -711,7 +711,8 @@ def _search(
711
711
precision : int = 32 ,
712
712
disable_file_output : List = [],
713
713
load_models : bool = True ,
714
- portfolio_selection : Optional [str ] = None
714
+ portfolio_selection : Optional [str ] = None ,
715
+ dask_client : Optional [dask .distributed .Client ] = None
715
716
) -> 'BaseTask' :
716
717
"""
717
718
Search for the best pipeline configuration for the given dataset.
@@ -838,6 +839,8 @@ def _search(
838
839
self ._metric = get_metrics (
839
840
names = [optimize_metric ], dataset_properties = dataset_properties )[0 ]
840
841
842
+ self .pipeline_options ['optimize_metric' ] = optimize_metric
843
+
841
844
self .search_space = self .get_search_space (dataset )
842
845
843
846
budget_config : Dict [str , Union [float , str ]] = {}
@@ -855,10 +858,11 @@ def _search(
855
858
# If no dask client was provided, we create one, so that we can
856
859
# start a ensemble process in parallel to smbo optimize
857
860
if (
858
- self . _dask_client is None and (self .ensemble_size > 0 or self . n_jobs is not None and self .n_jobs > 1 )
861
+ dask_client is None and (self .ensemble_size > 0 or self .n_jobs > 1 )
859
862
):
860
863
self ._create_dask_client ()
861
864
else :
865
+ self ._dask_client = dask_client
862
866
self ._is_dask_client_internally_created = False
863
867
864
868
# Handle time resource allocation
@@ -892,21 +896,18 @@ def _search(
892
896
# ============> Run traditional ml
893
897
894
898
if enable_traditional_pipeline :
895
- if STRING_TO_TASK_TYPES [self .task_type ] in REGRESSION_TASKS :
896
- self ._logger .warning ("Traditional Pipeline is not enabled for regression. Skipping..." )
897
- else :
898
- traditional_task_name = 'runTraditional'
899
- self ._stopwatch .start_task (traditional_task_name )
900
- elapsed_time = self ._stopwatch .wall_elapsed (self .dataset_name )
901
- # We want time for at least 1 Neural network in SMAC
902
- time_for_traditional = int (
903
- self ._time_for_task - elapsed_time - func_eval_time_limit_secs
904
- )
905
- self ._do_traditional_prediction (
906
- func_eval_time_limit_secs = func_eval_time_limit_secs ,
907
- time_left = time_for_traditional ,
908
- )
909
- self ._stopwatch .stop_task (traditional_task_name )
899
+ traditional_task_name = 'runTraditional'
900
+ self ._stopwatch .start_task (traditional_task_name )
901
+ elapsed_time = self ._stopwatch .wall_elapsed (self .dataset_name )
902
+ # We want time for at least 1 Neural network in SMAC
903
+ time_for_traditional = int (
904
+ self ._time_for_task - elapsed_time - func_eval_time_limit_secs
905
+ )
906
+ self ._do_traditional_prediction (
907
+ func_eval_time_limit_secs = func_eval_time_limit_secs ,
908
+ time_left = time_for_traditional ,
909
+ )
910
+ self ._stopwatch .stop_task (traditional_task_name )
910
911
911
912
# ============> Starting ensemble
912
913
elapsed_time = self ._stopwatch .wall_elapsed (self .dataset_name )
@@ -1207,7 +1208,6 @@ def predict(
1207
1208
1208
1209
# Mypy assert
1209
1210
assert self .ensemble_ is not None , "Load models should error out if no ensemble"
1210
- self .ensemble_ = cast (Union [SingleBest , EnsembleSelection ], self .ensemble_ )
1211
1211
1212
1212
if isinstance (self .resampling_strategy , HoldoutValTypes ):
1213
1213
models = self .models_
@@ -1316,15 +1316,17 @@ def get_models_with_weights(self) -> List:
1316
1316
self ._load_models ()
1317
1317
1318
1318
assert self .ensemble_ is not None
1319
- return self .ensemble_ .get_models_with_weights (self .models_ )
1319
+ models_with_weights : List [Tuple [float , BasePipeline ]] = self .ensemble_ .get_models_with_weights (self .models_ )
1320
+ return models_with_weights
1320
1321
1321
1322
def show_models (self ) -> str :
1322
1323
df = []
1323
1324
for weight , model in self .get_models_with_weights ():
1324
1325
representation = model .get_pipeline_representation ()
1325
1326
representation .update ({'Weight' : weight })
1326
1327
df .append (representation )
1327
- return pd .DataFrame (df ).to_markdown ()
1328
+ models_markdown : str = pd .DataFrame (df ).to_markdown ()
1329
+ return models_markdown
1328
1330
1329
1331
def _print_debug_info_to_log (self ) -> None :
1330
1332
"""
0 commit comments