Skip to content

Commit 2fbe1c6

Browse files
committed
Fix mypy errors
1 parent aa01330 commit 2fbe1c6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+179
-173
lines changed

.pre-commit-config.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,22 @@ repos:
33
rev: v0.761
44
hooks:
55
- id: mypy
6-
args: [--show-error-codes]
7-
name: mypy AutoPyTorch
6+
args: [--show-error-codes,
7+
--warn-redundant-casts,
8+
--warn-return-any,
9+
--warn-unreachable,
10+
]
811
files: autoPyTorch/.*
12+
exclude: autoPyTorch/ensemble/
913
- repo: https://gitlab.com/pycqa/flake8
1014
rev: 3.8.3
1115
hooks:
1216
- id: flake8
13-
name: flake8 AutoPyTorch
14-
files: autoPyTorch/.*
1517
additional_dependencies:
1618
- flake8-print==3.1.4
1719
- flake8-import-order
20+
name: flake8 autoPyTorch
21+
files: autoPyTorch/.*
1822
- id: flake8
19-
name: flake8 tests
20-
files: test/.*
21-
additional_dependencies:
22-
- flake8-print==3.1.4
23-
- flake8-import-order
23+
name: flake8 test
24+
files: test/.*

autoPyTorch/api/base_task.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import unittest.mock
1313
import warnings
1414
from abc import abstractmethod
15-
from typing import Any, Callable, Dict, List, Optional, Union, cast
15+
from typing import Any, Callable, Dict, List, Optional, Union
1616

1717
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
1818

@@ -34,10 +34,10 @@
3434
STRING_TO_OUTPUT_TYPES,
3535
STRING_TO_TASK_TYPES,
3636
)
37+
from autoPyTorch.data.base_validator import BaseInputValidator
3738
from autoPyTorch.datasets.base_dataset import BaseDataset
3839
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
3940
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
40-
from autoPyTorch.ensemble.ensemble_selection import EnsembleSelection
4141
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
4242
from autoPyTorch.evaluation.abstract_evaluator import fit_and_suppress_warnings
4343
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
@@ -187,7 +187,7 @@ def __init__(
187187

188188
self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]
189189

190-
self._dask_client = None
190+
self._dask_client: Optional[dask.distributed.Client] = None
191191

192192
self.search_space_updates = search_space_updates
193193
if search_space_updates is not None:
@@ -196,6 +196,8 @@ def __init__(
196196
raise ValueError("Expected search space updates to be of instance"
197197
" HyperparameterSearchSpaceUpdates got {}".format(type(self.search_space_updates)))
198198

199+
self.InputValidator: Optional[BaseInputValidator] = None
200+
199201
@abstractmethod
200202
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline:
201203
"""
@@ -697,6 +699,7 @@ def _search(
697699
precision: int = 32,
698700
disable_file_output: List = [],
699701
load_models: bool = True,
702+
dask_client: Optional[dask.Distributed.Client] = None
700703
) -> 'BaseTask':
701704
"""
702705
Search for the best pipeline configuration for the given dataset.
@@ -828,10 +831,11 @@ def _search(
828831
# If no dask client was provided, we create one, so that we can
829832
# start a ensemble process in parallel to smbo optimize
830833
if (
831-
self._dask_client is None and (self.ensemble_size > 0 or self.n_jobs is not None and self.n_jobs > 1)
834+
dask_client is None and (self.ensemble_size > 0 or self.n_jobs > 1)
832835
):
833836
self._create_dask_client()
834837
else:
838+
self._dask_client = dask_client
835839
self._is_dask_client_internally_created = False
836840

837841
# Handle time resource allocation
@@ -1177,7 +1181,6 @@ def predict(
11771181

11781182
# Mypy assert
11791183
assert self.ensemble_ is not None, "Load models should error out if no ensemble"
1180-
self.ensemble_ = cast(Union[SingleBest, EnsembleSelection], self.ensemble_)
11811184

11821185
if isinstance(self.resampling_strategy, HoldoutValTypes):
11831186
models = self.models_
@@ -1266,15 +1269,17 @@ def get_models_with_weights(self) -> List:
12661269
self._load_models()
12671270

12681271
assert self.ensemble_ is not None
1269-
return self.ensemble_.get_models_with_weights(self.models_)
1272+
models_with_weights: List = self.ensemble_.get_models_with_weights(self.models_)
1273+
return models_with_weights
12701274

12711275
def show_models(self) -> str:
12721276
df = []
12731277
for weight, model in self.get_models_with_weights():
12741278
representation = model.get_pipeline_representation()
12751279
representation.update({'Weight': weight})
12761280
df.append(representation)
1277-
return pd.DataFrame(df).to_markdown()
1281+
models_markdown: str = pd.DataFrame(df).to_markdown()
1282+
return models_markdown
12781283

12791284
def _print_debug_info_to_log(self) -> None:
12801285
"""

autoPyTorch/data/base_target_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def fit(
9595
np.shape(y_test)
9696
))
9797
if isinstance(y_train, pd.DataFrame):
98-
y_train = typing.cast(pd.DataFrame, y_train)
9998
y_test = typing.cast(pd.DataFrame, y_test)
10099
if y_train.columns.tolist() != y_test.columns.tolist():
101100
raise ValueError(

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def transform(
145145
X = self.numpy_array_to_pandas(X)
146146

147147
if hasattr(X, "iloc") and not scipy.sparse.issparse(X):
148-
X = typing.cast(pd.DataFrame, X)
149148
if np.any(pd.isnull(X)):
150149
for column in X.columns:
151150
if X[column].isna().all():

autoPyTorch/data/tabular_target_validator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,9 @@ def _check_data(
194194
A set of features whose dimensionality and data type is going to be checked
195195
"""
196196

197-
if not isinstance(
198-
y, (np.ndarray, pd.DataFrame, list, pd.Series)) and not scipy.sparse.issparse(y):
197+
if not (isinstance( # type: ignore[misc]
198+
y, (np.ndarray, pd.DataFrame, typing.List, pd.Series))
199+
and scipy.sparse.issparse(y)):
199200
raise ValueError("AutoPyTorch only supports Numpy arrays, Pandas DataFrames,"
200201
" pd.Series, sparse data and Python Lists as targets, yet, "
201202
"the provided input is of type {}".format(

autoPyTorch/datasets/base_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
202202
return X, Y
203203

204204
def __len__(self) -> int:
205-
return self.train_tensors[0].shape[0]
205+
return int(self.train_tensors[0].shape[0])
206206

207207
def _get_indices(self) -> np.ndarray:
208208
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))

autoPyTorch/ensemble/ensemble_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
ensemble_nbest: int,
5959
max_models_on_disc: Union[float, int],
6060
seed: int,
61-
precision: int,
61+
precision: Union[int, str],
6262
max_iterations: Optional[int],
6363
read_at_most: int,
6464
ensemble_memory_limit: Optional[int],
@@ -270,7 +270,7 @@ def fit_and_return_ensemble(
270270
ensemble_nbest: int,
271271
max_models_on_disc: Union[float, int],
272272
seed: int,
273-
precision: int,
273+
precision: Union[int, str],
274274
memory_limit: Optional[int],
275275
read_at_most: int,
276276
random_state: int,

autoPyTorch/ensemble/ensemble_selection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ def _fit(
149149
if len(predictions) == 1:
150150
break
151151

152-
self.indices_ = order
153-
self.trajectory_ = trajectory
154-
self.train_loss_ = trajectory[-1]
152+
self.indices_: List[int] = order
153+
self.trajectory_: List[float] = trajectory
154+
self.train_loss_: float = trajectory[-1]
155155

156156
def _calculate_weights(self) -> None:
157157
"""

autoPyTorch/evaluation/tae.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def fit_predict_try_except_decorator(
3535
ta: typing.Callable,
36-
queue: multiprocessing.Queue, cost_for_crash: float, **kwargs: typing.Any) -> None:
36+
queue: multiprocessing.Queue, cost_for_crash: float, **kwargs: typing.Any) -> typing.Optional[typing.Any]:
3737
try:
3838
return ta(queue=queue, **kwargs)
3939
except Exception as e:
@@ -54,6 +54,7 @@ def fit_predict_try_except_decorator(
5454
'status': StatusType.CRASHED,
5555
'final_queue_element': True}, block=True)
5656
queue.close()
57+
return None
5758

5859

5960
def get_cost_of_crash(metric: autoPyTorchMetric) -> float:
@@ -146,13 +147,15 @@ def __init__(
146147
self.exclude = exclude
147148
self.disable_file_output = disable_file_output
148149
self.init_params = init_params
150+
151+
self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
152+
149153
self.pipeline_config: typing.Dict[str, typing.Union[int, str, float]] = dict()
150154
if pipeline_config is None:
151155
pipeline_config = replace_string_bool_to_bool(json.load(open(
152156
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
153157
self.pipeline_config.update(pipeline_config)
154158

155-
self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
156159
self.logger_port = logger_port
157160
if self.logger_port is None:
158161
self.logger: typing.Union[logging.Logger, PicklableClientLogger] = logging.getLogger("TAE")
@@ -236,7 +239,8 @@ def run_wrapper(
236239
run_info = run_info._replace(cutoff=int(np.ceil(run_info.cutoff)))
237240

238241
self.logger.info("Starting to evaluate configuration %s" % run_info.config.config_id)
239-
return super().run_wrapper(run_info=run_info)
242+
run_info, run_value = super().run_wrapper(run_info=run_info)
243+
return run_info, run_value
240244

241245
def run(
242246
self,

autoPyTorch/optimizer/smbo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,11 @@ def __init__(self,
195195

196196
self.search_space_updates = search_space_updates
197197

198-
dataset_name_ = "" if dataset_name is None else dataset_name
199198
if logger_port is None:
200199
self.logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
201200
else:
202201
self.logger_port = logger_port
203-
logger_name = '%s(%d):%s' % (self.__class__.__name__, self.seed, ":" + dataset_name_)
202+
logger_name = '%s(%d):%s' % (self.__class__.__name__, self.seed, ":" + self.dataset_name)
204203
self.logger = get_named_client_logger(name=logger_name,
205204
port=self.logger_port)
206205
self.logger.info("initialised {}".format(self.__class__.__name__))

0 commit comments

Comments
 (0)