Column Transformer¶
Real-world datasets often contain a mix of numerical and categorical column types. sorix provides a tool to apply different preprocessing steps to different columns of the same dataset in a single operation.
1. ColumnTransformer API¶
The ColumnTransformer class takes a list of transformers, where each entry is a tuple containing:
- A descriptive name for the transformation.
- The transformer object (scaler, encoder, etc.).
- The list of columns to which the transformation should be applied.
Practical Example¶
Let's assume we have a dataset with numerical (e.g., age, income) and categorical (e.g., city, gender) columns.
# Uncomment the next line and run this cell to install sorix
#!pip install 'sorix @ git+https://github.com/Mitchell-Mirano/sorix.git@main'
import numpy as np
import pandas as pd
import sorix
from sorix.preprocessing import ColumnTransformer, StandardScaler, OneHotEncoder
# Create sample data with mixed types
data = {
'age': [25, 30, 45, 50, 22, 35],
'income': [50000, 60000, 75000, 80000, 45000, 70000],
'city': ['NYC', 'SF', 'NYC', 'LA', 'SF', 'LA'],
'gender': ['M', 'F', 'M', 'F', 'NB', 'M']
}
X = pd.DataFrame(data)
X
# Define column transformations
ct = ColumnTransformer([
('num_scal', StandardScaler(), ['age', 'income']),
('cat_enc', OneHotEncoder(), ['city', 'gender'])
])
# Apply the transformations
X_processed = ct.fit_transform(X)
print("Processed data size: ", X_processed.shape)
print("\nFeatures names after ColumnTransformer:\n", ct.get_features_names())
Processed data size: (6, 8) Features names after ColumnTransformer: ['num_scal_age', 'num_scal_income', 'cat_enc_city_LA', 'cat_enc_city_NYC', 'cat_enc_city_SF', 'cat_enc_gender_F', 'cat_enc_gender_M', 'cat_enc_gender_NB']
2. Persistence: Saving and Loading¶
Saving the ColumnTransformer state is crucial because it encapsulates the state of ALL internal scalers and encoders. If any component changes its internal mapping, the entire pipeline will fail during inference.
In sorix, we recommend the use of .sor extension for all saved artifacts.
A. Direct Saving with sorix.save (Recommended)¶
You can save the entire pipeline in a single file with sorix.save.
# Save the entire ColumnTransformer pipeline using .sor extension
sorix.save(ct, 'full_pipeline.sor')
# Load it back
loaded_ct = sorix.load('full_pipeline.sor')
# Verify with original data
assert np.allclose(ct.transform(X), loaded_ct.transform(X))
print("ColumnTransformer pipeline successfully saved and reloaded (.sor)!")
ColumnTransformer pipeline successfully saved and reloaded (.sor)!
B. Recursive State Dictionary Pattern¶
The state_dict() and load_state_dict() methods handle the saving and loading of internal transformer states recursively. You can save the whole dictionary in a .sor file too.
# 1. Extract the recursive state dictionary
params_dict = ct.state_dict()
# 2. Save the dictionary with sorix.save (.sor extension)
sorix.save(params_dict, 'pipeline_params.sor')
# 3. Load the dictionary back
loaded_params = sorix.load('pipeline_params.sor')
# 4. Apply reloaded state into a fresh pipeline with the same architecture
new_ct = ColumnTransformer([
('num_scal', StandardScaler(), ['age', 'income']),
('cat_enc', OneHotEncoder(), ['city', 'gender'])
])
new_ct.load_state_dict(loaded_params)
# 5. Verify results
assert np.allclose(ct.transform(X), new_ct.transform(X))
print("Full ColumnTransformer state successfully saved and reloaded!")
Full ColumnTransformer state successfully saved and reloaded!