12
12
import unittest .mock
13
13
import warnings
14
14
from abc import abstractmethod
15
- from typing import Any , Callable , Dict , List , Optional , Union , cast
15
+ from typing import Any , Callable , Dict , List , Optional , Union
16
16
17
17
from ConfigSpace .configuration_space import Configuration , ConfigurationSpace
18
18
34
34
STRING_TO_OUTPUT_TYPES ,
35
35
STRING_TO_TASK_TYPES ,
36
36
)
37
+ from autoPyTorch .data .base_validator import BaseInputValidator
37
38
from autoPyTorch .datasets .base_dataset import BaseDataset
38
39
from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutValTypes
39
40
from autoPyTorch .ensemble .ensemble_builder import EnsembleBuilderManager
40
- from autoPyTorch .ensemble .ensemble_selection import EnsembleSelection
41
41
from autoPyTorch .ensemble .singlebest_ensemble import SingleBest
42
42
from autoPyTorch .evaluation .abstract_evaluator import fit_and_suppress_warnings
43
43
from autoPyTorch .evaluation .tae import ExecuteTaFuncWithQueue , get_cost_of_crash
@@ -187,7 +187,7 @@ def __init__(
187
187
188
188
self .stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
189
189
190
- self ._dask_client = None
190
+ self ._dask_client : Optional [ dask . distributed . Client ] = None
191
191
192
192
self .search_space_updates = search_space_updates
193
193
if search_space_updates is not None :
@@ -196,6 +196,8 @@ def __init__(
196
196
raise ValueError ("Expected search space updates to be of instance"
197
197
" HyperparameterSearchSpaceUpdates got {}" .format (type (self .search_space_updates )))
198
198
199
+ self .InputValidator : Optional [BaseInputValidator ] = None
200
+
199
201
@abstractmethod
200
202
def build_pipeline (self , dataset_properties : Dict [str , Any ]) -> BasePipeline :
201
203
"""
@@ -697,6 +699,7 @@ def _search(
697
699
precision : int = 32 ,
698
700
disable_file_output : List = [],
699
701
load_models : bool = True ,
702
+ dask_client : Optional [dask .Distributed .Client ] = None
700
703
) -> 'BaseTask' :
701
704
"""
702
705
Search for the best pipeline configuration for the given dataset.
@@ -828,10 +831,11 @@ def _search(
828
831
# If no dask client was provided, we create one, so that we can
829
832
# start a ensemble process in parallel to smbo optimize
830
833
if (
831
- self . _dask_client is None and (self .ensemble_size > 0 or self . n_jobs is not None and self .n_jobs > 1 )
834
+ dask_client is None and (self .ensemble_size > 0 or self .n_jobs > 1 )
832
835
):
833
836
self ._create_dask_client ()
834
837
else :
838
+ self ._dask_client = dask_client
835
839
self ._is_dask_client_internally_created = False
836
840
837
841
# Handle time resource allocation
@@ -1177,7 +1181,6 @@ def predict(
1177
1181
1178
1182
# Mypy assert
1179
1183
assert self .ensemble_ is not None , "Load models should error out if no ensemble"
1180
- self .ensemble_ = cast (Union [SingleBest , EnsembleSelection ], self .ensemble_ )
1181
1184
1182
1185
if isinstance (self .resampling_strategy , HoldoutValTypes ):
1183
1186
models = self .models_
@@ -1266,15 +1269,17 @@ def get_models_with_weights(self) -> List:
1266
1269
self ._load_models ()
1267
1270
1268
1271
assert self .ensemble_ is not None
1269
- return self .ensemble_ .get_models_with_weights (self .models_ )
1272
+ models_with_weights : List = self .ensemble_ .get_models_with_weights (self .models_ )
1273
+ return models_with_weights
1270
1274
1271
1275
def show_models (self ) -> str :
1272
1276
df = []
1273
1277
for weight , model in self .get_models_with_weights ():
1274
1278
representation = model .get_pipeline_representation ()
1275
1279
representation .update ({'Weight' : weight })
1276
1280
df .append (representation )
1277
- return pd .DataFrame (df ).to_markdown ()
1281
+ models_markdown : str = pd .DataFrame (df ).to_markdown ()
1282
+ return models_markdown
1278
1283
1279
1284
def _print_debug_info_to_log (self ) -> None :
1280
1285
"""
0 commit comments