@@ -158,8 +158,8 @@ def _fit(
158
158
if not X .select_dtypes (include = 'object' ).empty :
159
159
X = self .infer_objects (X )
160
160
161
- if len ( self .transformed_columns ) == 0 and self .feat_type is None :
162
- self . transformed_columns , self . feat_type = self . _get_columns_to_encode ( X )
161
+ self .transformed_columns , self .feat_type = self . _get_columns_to_encode ( X )
162
+
163
163
assert self .feat_type is not None
164
164
165
165
if len (self .transformed_columns ) > 0 :
@@ -326,8 +326,7 @@ def _check_data(
326
326
327
327
# Define the column to be encoded here as the feature validator is fitted once
328
328
# per estimator
329
- if len (self .transformed_columns ) == 0 and self .feat_type is None :
330
- self .transformed_columns , self .feat_type = self ._get_columns_to_encode (X )
329
+ self .transformed_columns , self .feat_type = self ._get_columns_to_encode (X )
331
330
332
331
column_order = [column for column in X .columns ]
333
332
if len (self .column_order ) > 0 :
@@ -369,6 +368,10 @@ def _get_columns_to_encode(
369
368
feat_type:
370
369
Type of each column numerical/categorical
371
370
"""
371
+
372
+ if len (self .transformed_columns ) > 0 and self .feat_type is not None :
373
+ return self .transformed_columns , self .feat_type
374
+
372
375
# Register if a column needs encoding
373
376
transformed_columns = []
374
377
0 commit comments