Skip to content

Commit a04ba08

Browse files
committed
add shuhei's suggestion
1 parent efde1d5 commit a04ba08

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

autoPyTorch/data/tabular_feature_validator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def _fit(
158158
if not X.select_dtypes(include='object').empty:
159159
X = self.infer_objects(X)
160160

161-
if len(self.transformed_columns) == 0 and self.feat_type is None:
162-
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
161+
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
162+
163163
assert self.feat_type is not None
164164

165165
if len(self.transformed_columns) > 0:
@@ -326,8 +326,7 @@ def _check_data(
326326

327327
# Define the column to be encoded here as the feature validator is fitted once
328328
# per estimator
329-
if len(self.transformed_columns) == 0 and self.feat_type is None:
330-
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
329+
self.transformed_columns, self.feat_type = self._get_columns_to_encode(X)
331330

332331
column_order = [column for column in X.columns]
333332
if len(self.column_order) > 0:
@@ -369,6 +368,10 @@ def _get_columns_to_encode(
369368
feat_type:
370369
Type of each column numerical/categorical
371370
"""
371+
372+
if len(self.transformed_columns) > 0 and self.feat_type is not None:
373+
return self.transformed_columns, self.feat_type
374+
372375
# Register if a column needs encoding
373376
transformed_columns = []
374377

0 commit comments

Comments
 (0)