Skip to content

Commit 1329880

Browse files
committed
refactor(training): extract utility functions for problem type and ID detection
Moved common logic for detecting problem types and ID columns into a new utils module. This improves code organization and reusability across preprocessing and EDA components. Modified files (3): - backend/training/eda.py: Refactored to use utility functions - backend/training/preprocessor.py: Updated to utilize shared utilities - backend/training/utils.py: Added new utility functions for detection logic
1 parent 83f13ac commit 1329880

3 files changed

Lines changed: 207 additions & 140 deletions

File tree

backend/training/eda.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pandas as pd
22
import numpy as np
33
from typing import List, Tuple
4-
import re
4+
5+
# Import shared utilities
6+
from .utils import detect_problem_type, is_id_column
57

68

79
def generate_eda_report(df: pd.DataFrame, target_column: str, output_path: str):
@@ -33,13 +35,6 @@ def generate_eda_report(df: pd.DataFrame, target_column: str, output_path: str):
3335
class EDAReportGenerator:
3436
"""Generate comprehensive EDA report with CSS-only visualizations"""
3537

36-
# Common ID column patterns
37-
ID_PATTERNS = [
38-
r'^id$', r'_id$', r'^id_', r'^uuid$', r'^guid$',
39-
r'order.*id', r'customer.*id', r'user.*id', r'transaction.*id',
40-
r'^index$', r'^row.*num', r'^serial', r'^record.*id',
41-
]
42-
4338
def __init__(self, df: pd.DataFrame, target_column: str):
4439
self.df = df
4540
self.target_column = target_column
@@ -53,38 +48,12 @@ def __init__(self, df: pd.DataFrame, target_column: str):
5348
self._analyze_columns()
5449

5550
def _detect_problem_type(self) -> str:
56-
"""Detect if classification or regression"""
57-
# Guard against empty target
58-
if len(self.target) == 0:
59-
return 'classification' # Default fallback
60-
61-
if pd.api.types.is_numeric_dtype(self.target):
62-
unique_ratio = self.target.nunique() / len(self.target)
63-
if unique_ratio < 0.05 or self.target.nunique() < 20:
64-
return 'classification'
65-
return 'regression'
66-
return 'classification'
51+
"""Detect if classification or regression using shared utility."""
52+
return detect_problem_type(self.target)
6753

6854
def _is_id_column(self, col_name: str, series: pd.Series) -> bool:
69-
"""Check if column is likely an ID"""
70-
col_lower = col_name.lower().strip()
71-
for pattern in self.ID_PATTERNS:
72-
if re.search(pattern, col_lower):
73-
return True
74-
75-
# Check if all unique and sequential
76-
if pd.api.types.is_numeric_dtype(series):
77-
if series.nunique() == len(series):
78-
sorted_vals = series.sort_values()
79-
if (sorted_vals.diff().dropna() == 1).all():
80-
return True
81-
82-
# High cardinality string column
83-
if series.dtype == 'object':
84-
if series.nunique() / len(series) > 0.95:
85-
return True
86-
87-
return False
55+
"""Check if column is likely an ID using shared utility."""
56+
return is_id_column(col_name, series)
8857

8958
def _analyze_columns(self):
9059
"""Analyze and categorize columns"""

backend/training/preprocessor.py

Lines changed: 13 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,23 @@
11
import pandas as pd
22
import numpy as np
3-
import re
43
from sklearn.model_selection import train_test_split
54
from sklearn.preprocessing import LabelEncoder, StandardScaler
65
from typing import Tuple, List
76

87
# Feature-engine for robust feature selection
98
from feature_engine.selection import DropConstantFeatures, DropDuplicateFeatures
109

10+
# Import shared utilities
11+
from .utils import (
12+
detect_problem_type,
13+
is_id_column,
14+
is_high_cardinality_categorical,
15+
)
16+
1117

1218
class AutoPreprocessor:
1319
"""Automatic data preprocessing for AutoML"""
1420

15-
# Common patterns for ID/identifier columns (case insensitive)
16-
ID_PATTERNS = [
17-
r'^id$',
18-
r'_id$',
19-
r'^id_',
20-
r'_id_',
21-
r'^uuid$',
22-
r'^guid$',
23-
r'order.*id',
24-
r'customer.*id',
25-
r'user.*id',
26-
r'transaction.*id',
27-
r'product.*id',
28-
r'session.*id',
29-
r'^index$',
30-
r'^row.*num',
31-
r'^serial',
32-
r'^record.*id',
33-
]
34-
3521
def __init__(self, target_column: str):
3622
self.target_column = target_column
3723
self.label_encoders = {}
@@ -41,65 +27,6 @@ def __init__(self, target_column: str):
4127
self.categorical_columns = []
4228
self.dropped_columns = [] # Track dropped columns for reporting
4329

44-
def detect_id_column(self, col_name: str, series: pd.Series) -> bool:
45-
"""
46-
Detect if a column is likely an ID/identifier column.
47-
Uses both name patterns and data characteristics.
48-
"""
49-
col_lower = col_name.lower().strip()
50-
51-
# Check name patterns
52-
for pattern in self.ID_PATTERNS:
53-
if re.search(pattern, col_lower):
54-
return True
55-
56-
# Check data characteristics for numeric columns
57-
if pd.api.types.is_numeric_dtype(series):
58-
n_unique = series.nunique()
59-
n_total = len(series)
60-
61-
# If all values are unique and sequential, likely an ID
62-
if n_unique == n_total:
63-
# Check if values are sequential integers
64-
if series.dtype in ['int64', 'int32', 'int']:
65-
sorted_vals = series.sort_values()
66-
is_sequential = (sorted_vals.diff().dropna() == 1).all()
67-
if is_sequential:
68-
return True
69-
70-
# Check for string columns that look like IDs (high cardinality)
71-
if series.dtype == 'object':
72-
n_unique = series.nunique()
73-
n_total = len(series)
74-
75-
# If almost all values are unique, likely an ID
76-
if n_unique / n_total > 0.95:
77-
# Additional check: IDs often have consistent format
78-
sample = series.dropna().head(100)
79-
# Check if values look like codes/IDs (alphanumeric patterns)
80-
if sample.apply(lambda x: bool(re.match(r'^[A-Za-z0-9\-_]+$', str(x)))).mean() > 0.9:
81-
return True
82-
83-
return False
84-
85-
def detect_constant_column(self, series: pd.Series) -> bool:
86-
"""Detect if a column has only one unique value (constant)"""
87-
return series.nunique() <= 1
88-
89-
def detect_high_cardinality_categorical(self, series: pd.Series, threshold: float = 0.5) -> bool:
90-
"""
91-
Detect categorical columns with too many unique values.
92-
These often don't generalize well and can cause overfitting.
93-
"""
94-
if series.dtype != 'object':
95-
return False
96-
97-
n_unique = series.nunique()
98-
n_total = len(series)
99-
100-
# If more than 50% unique values, too high cardinality
101-
return n_unique / n_total > threshold
102-
10330
def detect_useless_columns_with_feature_engine(self, df: pd.DataFrame) -> Tuple[List[str], dict]:
10431
"""
10532
Use feature-engine to detect constant and duplicate columns.
@@ -160,14 +87,14 @@ def detect_useless_columns(self, df: pd.DataFrame) -> List[str]:
16087
series = df[col]
16188

16289
# Check for ID columns (name patterns + data characteristics)
163-
if self.detect_id_column(col, series):
90+
if is_id_column(col, series):
16491
useless_cols.append(col)
16592
reasons[col] = "identifier/ID column"
16693
continue
16794

16895
# Check for high cardinality categorical
16996
if series.dtype == 'object':
170-
if self.detect_high_cardinality_categorical(series, threshold=0.5):
97+
if is_high_cardinality_categorical(series, threshold=0.5):
17198
useless_cols.append(col)
17299
reasons[col] = f"high cardinality categorical ({series.nunique()} unique values)"
173100
continue
@@ -182,25 +109,9 @@ def detect_useless_columns(self, df: pd.DataFrame) -> List[str]:
182109
self.dropped_columns = useless_cols
183110
return useless_cols
184111

185-
def detect_problem_type(self, y: pd.Series) -> str:
186-
"""Detect if problem is classification or regression"""
187-
# Guard against empty target
188-
if len(y) == 0:
189-
return 'classification' # Default fallback
190-
191-
# Check if target is numeric
192-
if pd.api.types.is_numeric_dtype(y):
193-
# If numeric, check unique values ratio
194-
unique_ratio = y.nunique() / len(y)
195-
196-
# If less than 5% unique values or less than 20 unique values, likely classification
197-
if unique_ratio < 0.05 or y.nunique() < 20:
198-
return 'classification'
199-
else:
200-
return 'regression'
201-
else:
202-
# Non-numeric target is classification
203-
return 'classification'
112+
def _detect_problem_type(self, y: pd.Series) -> str:
113+
"""Detect if problem is classification or regression using shared utility."""
114+
return detect_problem_type(y)
204115

205116
def handle_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
206117
"""Handle missing values in the dataset"""
@@ -265,7 +176,7 @@ def preprocess(
265176
print(f"✂️ Removed {len(cols_to_drop)} column(s): {cols_to_drop}")
266177

267178
# Detect problem type
268-
problem_type = self.detect_problem_type(y)
179+
problem_type = self._detect_problem_type(y)
269180
print(f"Detected problem type: {problem_type}")
270181

271182
# Handle missing values

0 commit comments

Comments
 (0)