Created
October 28, 2021 19:55
-
-
Save GuiMarthe/5048a0fdf415813577dc465457df26e1 to your computer and use it in GitHub Desktop.
This is an awesome function to get feature names from pipelines! Bluntly stolen from https://johaupt.github.io/scikit-learn/tutorial/python/data%20processing/ml%20pipeline/model%20interpretation/columnTransformer_feature_names.html
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_feature_names(column_transformer): | |
"""Get feature names from all transformers. | |
Returns | |
------- | |
feature_names : list of strings | |
Names of the features produced by transform. | |
""" | |
# Remove the internal helper function | |
#check_is_fitted(column_transformer) | |
# Turn loopkup into function for better handling with pipeline later | |
def get_names(trans): | |
# >> Original get_feature_names() method | |
if trans == 'drop' or ( | |
hasattr(column, '__len__') and not len(column)): | |
return [] | |
if trans == 'passthrough': | |
if hasattr(column_transformer, '_df_columns'): | |
if ((not isinstance(column, slice)) | |
and all(isinstance(col, str) for col in column)): | |
return column | |
else: | |
return column_transformer._df_columns[column] | |
else: | |
indices = np.arange(column_transformer._n_features) | |
return ['x%d' % i for i in indices[column]] | |
if not hasattr(trans, 'get_feature_names'): | |
# >>> Change: Return input column names if no method avaiable | |
# Turn error into a warning | |
warnings.warn("Transformer %s (type %s) does not " | |
"provide get_feature_names. " | |
"Will return input column names if available" | |
% (str(name), type(trans).__name__)) | |
# For transformers without a get_features_names method, use the input | |
# names to the column transformer | |
if column is None: | |
return [] | |
else: | |
return [name + "__" + f for f in column] | |
return [name + "__" + f for f in trans.get_feature_names()] | |
### Start of processing | |
feature_names = [] | |
# Allow transformers to be pipelines. Pipeline steps are named differently, so preprocessing is needed | |
if type(column_transformer) == sklearn.pipeline.Pipeline: | |
l_transformers = [(name, trans, None, None) for step, name, trans in column_transformer._iter()] | |
else: | |
# For column transformers, follow the original method | |
l_transformers = list(column_transformer._iter(fitted=True)) | |
for name, trans, column, _ in l_transformers: | |
if type(trans) == sklearn.pipeline.Pipeline: | |
# Recursive call on pipeline | |
_names = get_feature_names(trans) | |
# if pipeline has no transformer that returns names | |
if len(_names)==0: | |
_names = [name + "__" + f for f in column] | |
feature_names.extend(_names) | |
else: | |
feature_names.extend(get_names(trans)) | |
return feature_names |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment