Skip to content

Commit a11caf4

Browse files
ravinkohlifranchuteriveramlindauerfrank-hutter
authored
[ADD] Add column transformer (#305)
* Match paper libraries-versions * Update README.md * Update README.md * Update README.md * [FIX] master branch README (#209) * Enable github actions (#273) * Update README.md * Create CITATION.cff * Added column transformer, changed requirements and added tests * remove redundant lines * Remove unwanted change made * Fix bug in test api and dummy forward pass * Fix silly bugs * increase time to pass test * remove parallel capabilities of traditional learners to resolve bug in docs building * almost fixed * Add documentation for tabularfeaturevalidator * fix flake * fix silly bug * address comment from shuhei * rename enc_columns to transformed_columns in the remaining places * fix bug in test * fix mypy * add shuhei's suggestion Co-authored-by: chico <[email protected]> Co-authored-by: Marius Lindauer <[email protected]> Co-authored-by: Frank <[email protected]> Co-authored-by: Francisco Rivera Valverde <[email protected]>
1 parent 9002937 commit a11caf4

File tree

13 files changed

+188
-137
lines changed

13 files changed

+188
-137
lines changed

autoPyTorch/api/base_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,8 @@ def get_incumbent_results(
13221322
if not include_traditional:
13231323
# traditional classifiers have trainer_configuration in their additional info
13241324
run_history_data = dict(
1325-
filter(lambda elem: elem[1].additional_info is not None and elem[1].
1325+
filter(lambda elem: elem[1].status == StatusType.SUCCESS and elem[1].
1326+
additional_info is not None and elem[1].
13261327
additional_info['configuration_origin'] != 'traditional',
13271328
run_history_data.items()))
13281329
run_history_data = dict(

autoPyTorch/data/base_feature_validator.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ class BaseFeatureValidator(BaseEstimator):
3535
List of the column types found by this estimator during fit.
3636
data_type (str):
3737
Class name of the data type provided during fit.
38-
encoder (typing.Optional[BaseEstimator])
39-
Host a encoder object if the data requires transformation (for example,
40-
if provided a categorical column in a pandas DataFrame)
41-
enc_columns (typing.List[str])
42-
List of columns that were encoded.
4338
"""
4439
def __init__(self,
4540
logger: typing.Optional[typing.Union[PicklableClientLogger, logging.Logger
@@ -51,8 +46,8 @@ def __init__(self,
5146
self.dtypes = [] # type: typing.List[str]
5247
self.column_order = [] # type: typing.List[str]
5348

54-
self.encoder = None # type: typing.Optional[BaseEstimator]
55-
self.enc_columns = [] # type: typing.List[str]
49+
self.column_transformer = None # type: typing.Optional[BaseEstimator]
50+
self.transformed_columns = [] # type: typing.List[str]
5651

5752
self.logger: typing.Union[
5853
PicklableClientLogger, logging.Logger
@@ -61,6 +56,7 @@ def __init__(self,
6156
# Required for dataset properties
6257
self.num_features = None # type: typing.Optional[int]
6358
self.categories = [] # type: typing.List[typing.List[int]]
59+
6460
self.categorical_columns: typing.List[int] = []
6561
self.numerical_columns: typing.List[int] = []
6662

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 127 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import typing
3+
from typing import Dict, List
34

45
import numpy as np
56

@@ -13,11 +14,108 @@
1314
from sklearn.base import BaseEstimator
1415
from sklearn.compose import ColumnTransformer
1516
from sklearn.exceptions import NotFittedError
17+
from sklearn.impute import SimpleImputer
18+
from sklearn.pipeline import make_pipeline
1619

1720
from autoPyTorch.data.base_feature_validator import BaseFeatureValidator, SUPPORTED_FEAT_TYPES
1821

1922

23+
def _create_column_transformer(
24+
preprocessors: Dict[str, List[BaseEstimator]],
25+
categorical_columns: List[str],
26+
) -> ColumnTransformer:
27+
"""
28+
Given a dictionary of preprocessors, this function
29+
creates a sklearn column transformer with appropriate
30+
columns associated with their preprocessors.
31+
32+
Args:
33+
preprocessors (Dict[str, List[BaseEstimator]]):
34+
Dictionary containing list of numerical and categorical preprocessors.
35+
categorical_columns (List[str]):
36+
List of names of categorical columns
37+
38+
Returns:
39+
ColumnTransformer
40+
"""
41+
42+
categorical_pipeline = make_pipeline(*preprocessors['categorical'])
43+
44+
return ColumnTransformer([
45+
('categorical_pipeline', categorical_pipeline, categorical_columns)],
46+
remainder='passthrough'
47+
)
48+
49+
50+
def get_tabular_preprocessors() -> Dict[str, List[BaseEstimator]]:
51+
"""
52+
This function creates a Dictionary containing a list
53+
of numerical and categorical preprocessors
54+
55+
Returns:
56+
Dict[str, List[BaseEstimator]]
57+
"""
58+
preprocessors: Dict[str, List[BaseEstimator]] = dict()
59+
60+
# Categorical Preprocessors
61+
onehot_encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value',
62+
unknown_value=-1)
63+
categorical_imputer = SimpleImputer(strategy='constant', copy=False)
64+
65+
preprocessors['categorical'] = [categorical_imputer, onehot_encoder]
66+
67+
return preprocessors
68+
69+
2070
class TabularFeatureValidator(BaseFeatureValidator):
71+
"""
72+
A subclass of `BaseFeatureValidator` made for tabular data.
73+
It ensures that the dataset provided is of the expected format.
74+
Subsequently, it preprocesses the data by fitting a column
75+
transformer.
76+
77+
Attributes:
78+
categories (List[List[str]]):
79+
List for which an element at each index is a
80+
list containing the categories for the respective
81+
categorical column.
82+
transformed_columns (List[str])
83+
List of columns that were transformed.
84+
column_transformer (Optional[BaseEstimator])
85+
Hosts an imputer and an encoder object if the data
86+
requires transformation (for example, if provided a
87+
categorical column in a pandas DataFrame)
88+
column_order (List[str]):
89+
List of the features stored in the order that
90+
was fitted.
91+
numerical_columns (List[int]):
92+
List of indices of numerical columns
93+
categorical_columns (List[int]):
94+
List of indices of categorical columns
95+
"""
96+
@staticmethod
97+
def _comparator(cmp1: str, cmp2: str) -> int:
98+
"""Order so that categorical columns come left and numerical columns come right
99+
100+
Args:
101+
cmp1 (str): First variable to compare
102+
cmp2 (str): Second variable to compare
103+
104+
Raises:
105+
ValueError: if the values of the variables to compare
106+
are not in 'categorical' or 'numerical'
107+
108+
Returns:
109+
int: either [0, -1, 1]
110+
"""
111+
choices = ['categorical', 'numerical']
112+
if cmp1 not in choices or cmp2 not in choices:
113+
raise ValueError('The comparator for the column order only accepts {}, '
114+
'but got {} and {}'.format(choices, cmp1, cmp2))
115+
116+
idx1, idx2 = choices.index(cmp1), choices.index(cmp2)
117+
return idx1 - idx2
118+
21119
def _fit(
22120
self,
23121
X: SUPPORTED_FEAT_TYPES,
@@ -60,51 +158,38 @@ def _fit(
60158
if not X.select_dtypes(include='object').empty:
61159
X = self.infer_objects(X)
62160

63-
self.enc_columns, self.feat_type = self._get_columns_to_encode(X)
161+
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
64162

65-
if len(self.enc_columns) > 0:
66-
X = self.impute_nan_in_categories(X)
163+
assert self.feat_type is not None
67164

68-
self.encoder = ColumnTransformer(
69-
[
70-
("encoder",
71-
preprocessing.OrdinalEncoder(
72-
handle_unknown='use_encoded_value',
73-
unknown_value=-1,
74-
), self.enc_columns)],
75-
remainder="passthrough"
165+
if len(self.transformed_columns) > 0:
166+
167+
preprocessors = get_tabular_preprocessors()
168+
self.column_transformer = _create_column_transformer(
169+
preprocessors=preprocessors,
170+
categorical_columns=self.transformed_columns,
76171
)
77172

78173
# Mypy redefinition
79-
assert self.encoder is not None
80-
self.encoder.fit(X)
81-
82-
# The column transformer reoders the feature types - we therefore need to change
83-
# it as well
84-
# This means columns are shifted to the right
85-
def comparator(cmp1: str, cmp2: str) -> int:
86-
if (
87-
cmp1 == 'categorical' and cmp2 == 'categorical'
88-
or cmp1 == 'numerical' and cmp2 == 'numerical'
89-
):
90-
return 0
91-
elif cmp1 == 'categorical' and cmp2 == 'numerical':
92-
return -1
93-
elif cmp1 == 'numerical' and cmp2 == 'categorical':
94-
return 1
95-
else:
96-
raise ValueError((cmp1, cmp2))
174+
assert self.column_transformer is not None
175+
self.column_transformer.fit(X)
97176

177+
# The column transformer reorders the feature types
178+
# therefore, we need to change the order of columns as well
179+
# This means categorical columns are shifted to the left
98180
self.feat_type = sorted(
99181
self.feat_type,
100-
key=functools.cmp_to_key(comparator)
182+
key=functools.cmp_to_key(self._comparator)
101183
)
102184

185+
encoded_categories = self.column_transformer.\
186+
named_transformers_['categorical_pipeline'].\
187+
named_steps['ordinalencoder'].categories_
103188
self.categories = [
104189
# We fit an ordinal encoder, where all categorical
105190
# columns are shifted to the left
106191
list(range(len(cat)))
107-
for cat in self.encoder.transformers_[0][1].categories_
192+
for cat in encoded_categories
108193
]
109194

110195
for i, type_ in enumerate(self.feat_type):
@@ -158,7 +243,7 @@ def transform(
158243
self._check_data(X)
159244

160245
# Pandas related transformations
161-
if hasattr(X, "iloc") and self.encoder is not None:
246+
if hasattr(X, "iloc") and self.column_transformer is not None:
162247
if np.any(pd.isnull(X)):
163248
# After above check it means that if there is a NaN
164249
# the whole column must be NaN
@@ -167,11 +252,7 @@ def transform(
167252
if X[column].isna().all():
168253
X[column] = pd.to_numeric(X[column])
169254

170-
# We also need to fillna on the transformation
171-
# in case test data is provided
172-
X = self.impute_nan_in_categories(X)
173-
174-
X = self.encoder.transform(X)
255+
X = self.column_transformer.transform(X)
175256

176257
# Sparse related transformations
177258
# Not all sparse format support index sorting
@@ -245,7 +326,7 @@ def _check_data(
245326

246327
# Define the column to be encoded here as the feature validator is fitted once
247328
# per estimator
248-
enc_columns, _ = self._get_columns_to_encode(X)
329+
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
249330

250331
column_order = [column for column in X.columns]
251332
if len(self.column_order) > 0:
@@ -282,13 +363,17 @@ def _get_columns_to_encode(
282363
A set of features that are going to be validated (type and dimensionality
283364
checks) and a encoder fitted in the case the data needs encoding
284365
Returns:
285-
enc_columns (List[str]):
366+
transformed_columns (List[str]):
286367
Columns to encode, if any
287368
feat_type:
288369
Type of each column numerical/categorical
289370
"""
371+
372+
if len(self.transformed_columns) > 0 and self.feat_type is not None:
373+
return self.transformed_columns, self.feat_type
374+
290375
# Register if a column needs encoding
291-
enc_columns = []
376+
transformed_columns = []
292377

293378
# Also, register the feature types for the estimator
294379
feat_type = []
@@ -297,7 +382,7 @@ def _get_columns_to_encode(
297382
for i, column in enumerate(X.columns):
298383
if X[column].dtype.name in ['category', 'bool']:
299384

300-
enc_columns.append(column)
385+
transformed_columns.append(column)
301386
feat_type.append('categorical')
302387
# Move away from np.issubdtype as it causes
303388
# TypeError: data type not understood in certain pandas types
@@ -339,7 +424,7 @@ def _get_columns_to_encode(
339424
)
340425
else:
341426
feat_type.append('numerical')
342-
return enc_columns, feat_type
427+
return transformed_columns, feat_type
343428

344429
def list_to_dataframe(
345430
self,
@@ -429,59 +514,3 @@ def infer_objects(self, X: pd.DataFrame) -> pd.DataFrame:
429514
self.object_dtype_mapping = {column: X[column].dtype for column in X.columns}
430515
self.logger.debug(f"Infer Objects: {self.object_dtype_mapping}")
431516
return X
432-
433-
def impute_nan_in_categories(self, X: pd.DataFrame) -> pd.DataFrame:
434-
"""
435-
impute missing values before encoding,
436-
remove once sklearn natively supports
437-
it in ordinal encoding. Sklearn issue:
438-
"https://github.com/scikit-learn/scikit-learn/issues/17123)"
439-
440-
Arguments:
441-
X (pd.DataFrame):
442-
data to be interpreted.
443-
444-
Returns:
445-
pd.DataFrame
446-
"""
447-
448-
# To be on the safe side, map always to the same missing
449-
# value per column
450-
if not hasattr(self, 'dict_nancol_to_missing'):
451-
self.dict_missing_value_per_col: typing.Dict[str, typing.Any] = {}
452-
453-
# First make sure that we do not alter the type of the column which cause:
454-
# TypeError: '<' not supported between instances of 'int' and 'str'
455-
# in the encoding
456-
for column in self.enc_columns:
457-
if X[column].isna().any():
458-
if column not in self.dict_missing_value_per_col:
459-
try:
460-
float(X[column].dropna().values[0])
461-
can_cast_as_number = True
462-
except Exception:
463-
can_cast_as_number = False
464-
if can_cast_as_number:
465-
# In this case, we expect to have a number as category
466-
# it might be string, but its value represent a number
467-
missing_value: typing.Union[str, int] = '-1' if isinstance(X[column].dropna().values[0],
468-
str) else -1
469-
else:
470-
missing_value = 'Missing!'
471-
472-
# Make sure this missing value is not seen before
473-
# Do this check for categorical columns
474-
# else modify the value
475-
if hasattr(X[column], 'cat'):
476-
while missing_value in X[column].cat.categories:
477-
if isinstance(missing_value, str):
478-
missing_value += '0'
479-
else:
480-
missing_value += missing_value
481-
self.dict_missing_value_per_col[column] = missing_value
482-
483-
# Convert the frame in place
484-
X[column].cat.add_categories([self.dict_missing_value_per_col[column]],
485-
inplace=True)
486-
X.fillna({column: self.dict_missing_value_per_col[column]}, inplace=True)
487-
return X

autoPyTorch/pipeline/components/setup/network_backbone/InceptionTimeBackbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, n_res_inputs: int, n_outputs: int):
7979
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
8080
shortcut = self.shortcut(res)
8181
shortcut = self.bn(shortcut)
82-
x += shortcut
82+
x = x + shortcut
8383
return torch.relu(x)
8484

8585

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
{
2-
"n_estimators" : 300,
3-
"n_jobs" : -1
2+
"n_estimators" : 300
43
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
{
2-
"weights" : "uniform",
3-
"n_jobs" : -1
2+
"weights" : "uniform"
43
}

autoPyTorch/pipeline/components/setup/traditional_ml/estimator_configs/lgb.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,5 @@
55
"min_data_in_leaf" : 3,
66
"feature_fraction" : 0.9,
77
"boosting_type" : "gbdt",
8-
"learning_rate" : 0.03,
9-
"num_threads" : -1
8+
"learning_rate" : 0.03
109
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
{
2-
"n_estimators" : 300,
3-
"n_jobs" : -1
2+
"n_estimators" : 300
43
}
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
{
2-
"n_jobs" : -1
32
}

0 commit comments

Comments
 (0)