• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

idanmoradarthas / DataScienceUtils / 18126146414

30 Sep 2025 10:01AM UTC coverage: 98.63% (-1.4%) from 100.0%
18126146414

push

github

idanmoradarthas
Update images

576 of 584 relevant lines covered (98.63%)

11.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.17
/ds_utils/preprocess.py
1
"""Data preprocessing utilities."""
2

3
from typing import Callable, List, Optional, Union
12✔
4
import warnings
12✔
5

6
from matplotlib import axes, dates, pyplot as plt, ticker
12✔
7
import numpy as np
12✔
8
from numpy.random import RandomState
12✔
9
import pandas as pd
12✔
10
from scipy.cluster.hierarchy import dendrogram, linkage
12✔
11
from scipy.spatial.distance import squareform
12✔
12
import seaborn as sns
12✔
13
from sklearn.base import TransformerMixin
12✔
14
from sklearn.compose import ColumnTransformer
12✔
15
from sklearn.feature_selection import mutual_info_classif
12✔
16
from sklearn.impute import SimpleImputer
12✔
17
from sklearn.pipeline import Pipeline
12✔
18
from sklearn.preprocessing import OrdinalEncoder
12✔
19

20
from ds_utils.math_utils import safe_percentile
12✔
21

22

23
def _plot_clean_violin_distribution(
12✔
24
    series: pd.Series, include_outliers: bool, outlier_iqr_multiplier: float, ax: Optional[axes.Axes] = None, **kwargs
25
) -> axes.Axes:
26
    """Plot a violin distribution for a numeric series with optional outlier trimming.
27

28
    When ``include_outliers`` is False, values outside the IQR fence are removed
29
    before plotting. The fence is defined as
30
    [Q1 - k * IQR, Q3 + k * IQR], where ``k`` is ``outlier_iqr_multiplier``, and
31
    the bounds are clipped to the observed min/max of the series.
32

33
    :param series: Numeric series to visualize. NA handling is expected upstream.
34
    :param include_outliers: Whether to include values outside the IQR fence.
35
    :param outlier_iqr_multiplier: Multiplier ``k`` used to compute the IQR fence.
36
    :param ax: Matplotlib Axes to draw on. If None, callers should provide one upstream.
37
    :param kwargs: Additional keyword arguments passed to ``seaborn.violinplot``.
38
    :return: The Axes object with the violin plot.
39
    """
40
    if include_outliers:
12✔
41
        series_plot = series.copy()
12✔
42
    else:
43
        q1 = series.quantile(0.25)
12✔
44
        q3 = series.quantile(0.75)
12✔
45
        min_series_value = series.min()
12✔
46
        max_series_value = series.max()
12✔
47
        iqr = q3 - q1
12✔
48
        lower_bound = max(min_series_value, q1 - outlier_iqr_multiplier * iqr)
12✔
49
        upper_bound = min(max_series_value, q3 + outlier_iqr_multiplier * iqr)
12✔
50
        series_plot = series[(series >= lower_bound) & (series <= upper_bound)].copy()
12✔
51

52
    sns.violinplot(y=series_plot, hue=None, legend=False, ax=ax, **kwargs)
12✔
53

54
    ax.set_xticks([])
12✔
55
    ax.set_ylabel("Values")
12✔
56
    ax.grid(axis="y", linestyle="--", alpha=0.7)
12✔
57

58
    return ax
12✔
59

60

61
def visualize_feature(
12✔
62
    series: pd.Series,
63
    remove_na: bool = False,
64
    *,
65
    include_outliers: bool = True,
66
    outlier_iqr_multiplier: float = 1.5,
67
    ax: Optional[axes.Axes] = None,
68
    **kwargs,
69
) -> axes.Axes:
70
    """Visualize a pandas Series using an appropriate plot based on dtype.
71

72
    Behavior by dtype:
73
    - Float: draw a violin distribution. If ``include_outliers`` is False, values
74
      outside the IQR fence [Q1 - k*IQR, Q3 + k*IQR] with ``k=outlier_iqr_multiplier``
75
      are trimmed prior to plotting.
76
    - Datetime: draw a line plot of value counts over time (sorted by index).
77
    - Object/categorical/bool/int: draw a count plot. Extremely high-cardinality
78
      series may be reduced to their top categories internally.
79

80
    :param series: The data series to visualize.
81
    :param remove_na: If True, plot with NA values removed; otherwise include them.
82
    :param include_outliers: Whether to include outliers for float features.
83
    :param outlier_iqr_multiplier: IQR multiplier used to trim outliers for float features.
84
    :param ax: Axes in which to draw the plot. If None, a new one is created.
85
    :param kwargs: Extra keyword arguments forwarded to the underlying plotting function
86
                   (``seaborn.violinplot``, ``Series.plot``, or ``seaborn.countplot``).
87
    :return: The Axes object with the plot drawn onto it.
88
    """
89
    if ax is None:
12✔
90
        _, ax = plt.subplots()
12✔
91

92
    feature_series = series.dropna() if remove_na else series
12✔
93

94
    if pd.api.types.is_float_dtype(feature_series):
12✔
95
        ax = _plot_clean_violin_distribution(feature_series, include_outliers, outlier_iqr_multiplier, ax, **kwargs)
12✔
96
    elif pd.api.types.is_datetime64_any_dtype(feature_series):
12✔
97
        feature_series.value_counts().sort_index().plot(kind="line", ax=ax, **kwargs)
12✔
98
        labels = ax.get_xticks()
12✔
99
    else:
100
        sns.countplot(x=_copy_series_or_keep_top_10(feature_series), ax=ax, **kwargs)
12✔
101
        labels = ax.get_xticklabels()
12✔
102

103
    if not ax.get_title():
12✔
104
        ax.set_title(f"{feature_series.name} ({feature_series.dtype})")
12✔
105
        ax.set_xlabel("")
12✔
106

107
    # Skip tick relabeling for float (violin) plots where x-ticks are hidden
108
    if not pd.api.types.is_float_dtype(feature_series):
12✔
109
        ticks_loc = ax.get_xticks()
12✔
110
        ax.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
12✔
111
        ax.set_xticklabels(labels, rotation=45, ha="right")
12✔
112

113
    if pd.api.types.is_datetime64_any_dtype(feature_series):
12✔
114
        ax.xaxis.set_major_formatter(_convert_numbers_to_dates)
12✔
115

116
    return ax
12✔
117

118

119
def get_correlated_features(
12✔
120
    correlation_matrix: pd.DataFrame, features: List[str], target_feature: str, threshold: float = 0.95
121
) -> pd.DataFrame:
122
    """Calculate features correlated above a threshold with target correlations.
123

124
    Calculate features correlated above a threshold and extract a DataFrame with correlations and correlation
125
    to the target feature.
126

127
    :param correlation_matrix: The correlation matrix.
128
    :param features: List of feature names to analyze.
129
    :param target_feature: Name of the target feature.
130
    :param threshold: Correlation threshold (default 0.95).
131
    :return: DataFrame with correlations and correlation to the target feature.
132
    """
133
    target_corr = correlation_matrix[target_feature]
12✔
134
    features_corr = correlation_matrix.loc[features, features]
12✔
135
    corr_matrix = features_corr.where(np.triu(np.ones(features_corr.shape), k=1).astype(bool))
12✔
136
    corr_matrix = corr_matrix[~np.isnan(corr_matrix)].stack().reset_index()
12✔
137
    corr_matrix = corr_matrix[corr_matrix[0].abs() >= threshold]
12✔
138

139
    if corr_matrix.empty:
12✔
140
        warnings.warn(f"Correlation threshold {threshold} was too high. An empty frame was returned", UserWarning)
12✔
141
        return pd.DataFrame(
12✔
142
            columns=["level_0", "level_1", "level_0_level_1_corr", "level_0_target_corr", "level_1_target_corr"]
143
        )
144

145
    corr_matrix["level_0_target_corr"] = target_corr[corr_matrix["level_0"]].values
12✔
146
    corr_matrix["level_1_target_corr"] = target_corr[corr_matrix["level_1"]].values
12✔
147
    corr_matrix = corr_matrix.rename({0: "level_0_level_1_corr"}, axis=1).reset_index(drop=True)
12✔
148
    return corr_matrix
12✔
149

150

151
def visualize_correlations(correlation_matrix: pd.DataFrame, *, ax: Optional[axes.Axes] = None, **kwargs) -> axes.Axes:
12✔
152
    """Compute and visualize pairwise correlations of columns, excluding NA/null values.
153

154
    `Original Seaborn code <https://seaborn.pydata.org/examples/many_pairwise_correlations.html>`_.
155

156
    :param correlation_matrix: The correlation matrix.
157
    :param ax: Axes in which to draw the plot. If None, use the currently active Axes.
158
    :param kwargs: Additional keyword arguments passed to seaborn's heatmap function.
159
    :return: The Axes object with the plot drawn onto it.
160
    """
161
    if ax is None:
12✔
162
        _, ax = plt.subplots()
12✔
163

164
    mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
12✔
165
    sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt=".3f", ax=ax, **kwargs)
12✔
166
    return ax
12✔
167

168

169
def plot_correlation_dendrogram(
12✔
170
    correlation_matrix: pd.DataFrame,
171
    cluster_distance_method: Union[str, Callable] = "average",
172
    *,
173
    ax: Optional[axes.Axes] = None,
174
    **kwargs,
175
) -> axes.Axes:
176
    """Plot a dendrogram of the correlation matrix, showing hierarchically the most correlated variables.
177

178
    `Original XAI code <https://github.com/EthicalML/XAI>`_.
179

180
    :param correlation_matrix: The correlation matrix.
181
    :param cluster_distance_method: Method for calculating the distance between newly formed clusters.
182
                                    `Read more here <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html>`_
183
    :param ax: Axes in which to draw the plot. If None, use the currently active Axes.
184
    :param kwargs: Additional keyword arguments passed to the dendrogram function.
185
    :return: The Axes object with the plot drawn onto it.
186
    """
187
    if ax is None:
12✔
188
        _, ax = plt.subplots()
12✔
189

190
    corr_condensed = squareform(1 - correlation_matrix)
12✔
191
    z = linkage(corr_condensed, method=cluster_distance_method)
12✔
192
    ax.set(**kwargs)
12✔
193
    dendrogram(z, labels=correlation_matrix.columns.tolist(), orientation="left", ax=ax)
12✔
194
    return ax
12✔
195

196

197
def plot_features_interaction(
12✔
198
    data: pd.DataFrame,
199
    feature_1: str,
200
    feature_2: str,
201
    *,
202
    include_outliers: bool = True,
203
    outlier_iqr_multiplier: float = 1.5,
204
    ax: Optional[axes.Axes] = None,
205
    **kwargs,
206
) -> axes.Axes:
207
    """Plot the joint distribution between two features using type-aware defaults.
208

209
    Behavior by dtypes of ``feature_1`` and ``feature_2``:
210
    - If both are numeric: scatter plot.
211
    - If one is datetime and the other numeric: line/scatter over time.
212
    - If both are categorical-like: overlaid histograms per category.
213
    - If one is categorical-like and the other numeric: violin plot by category.
214

215
    For the categorical-vs-numeric case, you can optionally trim outliers from the
216
    numeric feature using an IQR fence [Q1 - k*IQR, Q3 + k*IQR], where ``k`` is
217
    controlled by ``outlier_iqr_multiplier``.
218

219
    :param data: The input DataFrame where each feature is a column.
220
    :param feature_1: Name of the first feature.
221
    :param feature_2: Name of the second feature.
222
    :param include_outliers: Whether to include values outside the IQR fence for
223
                             categorical-vs-numeric violin plots (default True).
224
    :param outlier_iqr_multiplier: Multiplier ``k`` for the IQR fence when trimming
225
                                   outliers in categorical-vs-numeric plots (default 1.5).
226
    :param ax: Axes in which to draw the plot. If None, a new one is created.
227
    :param kwargs: Additional keyword arguments forwarded to the underlying plotting
228
                   functions (e.g., ``seaborn.violinplot``, ``Axes.scatter``, ``Axes.plot``).
229
    :return: The Axes object with the plot drawn onto it.
230
    """
231
    if ax is None:
12✔
232
        _, ax = plt.subplots()
12✔
233

234
    dtype1 = data[feature_1].dtype
12✔
235
    dtype2 = data[feature_2].dtype
12✔
236

237
    if _is_categorical_like(dtype1):
12✔
238
        _plot_categorical_feature1(
12✔
239
            feature_1,
240
            feature_2,
241
            data,
242
            dtype2,
243
            include_outliers,
244
            outlier_iqr_multiplier,
245
            ax,
246
            **kwargs,
247
        )
248
    elif pd.api.types.is_datetime64_any_dtype(dtype1):
12✔
249
        _plot_datetime_feature1(feature_1, feature_2, data, dtype2, ax, **kwargs)
12✔
250
    elif _is_categorical_like(dtype2):
12✔
251
        _plot_categorical_vs_numeric(feature_2, feature_1, data, outlier_iqr_multiplier, include_outliers, ax, **kwargs)
12✔
252
    elif pd.api.types.is_datetime64_any_dtype(dtype2):
12✔
253
        _plot_xy(feature_2, feature_1, data, ax, **kwargs)
12✔
254
    else:
255
        _plot_numeric_features(feature_1, feature_2, data, ax, **kwargs)
12✔
256

257
    return ax
12✔
258

259

260
def _is_categorical_like(dtype):
12✔
261
    """Check if the dtype is categorical-like (categorical, boolean, or object)."""
262
    return (
12✔
263
        isinstance(dtype, pd.CategoricalDtype)
264
        or pd.api.types.is_bool_dtype(dtype)
265
        or pd.api.types.is_object_dtype(dtype)
266
    )
267

268

269
def _plot_categorical_feature1(
12✔
270
    categorical_feature,
271
    feature_2,
272
    data,
273
    dtype2,
274
    include_outliers,
275
    outlier_iqr_multiplier,
276
    ax,
277
    **kwargs,
278
):
279
    """Plot when the first feature is categorical-like."""
280
    if _is_categorical_like(dtype2):
12✔
281
        _plot_categorical_vs_categorical(categorical_feature, feature_2, data, ax, **kwargs)
12✔
282
    elif pd.api.types.is_datetime64_any_dtype(dtype2):
12✔
283
        _plot_categorical_vs_datetime(categorical_feature, feature_2, data, ax, **kwargs)
12✔
284
    else:
285
        _plot_categorical_vs_numeric(
12✔
286
            categorical_feature,
287
            feature_2,
288
            data,
289
            outlier_iqr_multiplier,
290
            include_outliers,
291
            ax,
292
            **kwargs,
293
        )
294

295

296
def _plot_xy(datetime_feature, other_feature, data, ax, **kwargs):
12✔
297
    ax.plot(data[datetime_feature], data[other_feature], **kwargs)
12✔
298
    ax.set_xlabel(datetime_feature)
12✔
299
    ax.set_ylabel(other_feature)
12✔
300

301

302
def _plot_datetime_feature1(datetime_feature, feature_2, data, dtype2, ax, **kwargs):
12✔
303
    """Plot when the first feature is datetime."""
304
    if _is_categorical_like(dtype2):
12✔
305
        _plot_categorical_vs_datetime(feature_2, datetime_feature, data, ax, **kwargs)
12✔
306
    else:
307
        _plot_xy(datetime_feature, feature_2, data, ax, **kwargs)
12✔
308

309

310
def _plot_numeric_features(feature_1, feature_2, data, ax, **kwargs):
12✔
311
    """Plot when both features are numeric."""
312
    ax.scatter(data[feature_1], data[feature_2], **kwargs)
12✔
313
    ax.set_xlabel(feature_1)
12✔
314
    ax.set_ylabel(feature_2)
12✔
315

316

317
def _plot_categorical_vs_categorical(feature_1, feature_2, data, ax, **kwargs):
12✔
318
    """Plot when both features are categorical-like."""
319
    dup_df = pd.DataFrame()
12✔
320
    dup_df[feature_1] = _copy_series_or_keep_top_10(data[feature_1])
12✔
321
    dup_df[feature_2] = _copy_series_or_keep_top_10(data[feature_2])
12✔
322
    group_feature_1 = dup_df[feature_1].unique().tolist()
12✔
323
    ax.hist(
12✔
324
        [dup_df.loc[dup_df[feature_1] == value, feature_2] for value in group_feature_1],
325
        label=group_feature_1,
326
        **kwargs,
327
    )
328
    ax.set_xlabel(feature_1)
12✔
329
    ax.legend(title=feature_2)
12✔
330

331

332
def _plot_categorical_vs_datetime(categorical_feature, datetime_feature, data, ax, **kwargs):
12✔
333
    """Plot when one feature is categorical-like and the other is datetime.
334

335
    Draws a violin plot across time buckets on the x-axis with categories on the
336
    y-axis. This unified function expects the categorical feature name first and
337
    the datetime feature name second.
338
    """
339
    dup_df = pd.DataFrame()
12✔
340
    dup_df[datetime_feature] = data[datetime_feature].apply(dates.date2num)
12✔
341
    dup_df[categorical_feature] = _copy_series_or_keep_top_10(data[categorical_feature])
12✔
342
    chart = sns.violinplot(x=datetime_feature, y=categorical_feature, data=dup_df, ax=ax, **kwargs)
12✔
343
    ticks_loc = chart.get_xticks()
12✔
344
    chart.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
12✔
345
    chart.set_xticklabels(chart.get_xticklabels(), rotation=45, ha="right")
12✔
346
    ax.xaxis.set_major_formatter(_convert_numbers_to_dates)
12✔
347

348

349
def _plot_categorical_vs_numeric(
12✔
350
    categorical_feature,
351
    numeric_feature,
352
    data,
353
    outlier_iqr_multiplier,
354
    include_outliers,
355
    ax,
356
    **kwargs,
357
):
358
    """Plot when the first feature is categorical-like and the second is numeric.
359

360
    Renders a violin plot of the numeric feature for each category. When
361
    ``include_outliers`` is False, numeric values outside the IQR fence
362
    [Q1 - k*IQR, Q3 + k*IQR] are trimmed, where ``k`` is ``outlier_iqr_multiplier``.
363
    """
364
    dup_df = pd.DataFrame()
12✔
365
    dup_df[categorical_feature] = _copy_series_or_keep_top_10(data[categorical_feature])
12✔
366
    dup_df[numeric_feature] = data[numeric_feature]
12✔
367

368
    if include_outliers:
12✔
369
        df_plot = dup_df.copy()
12✔
370
    else:
371
        q1 = dup_df[numeric_feature].quantile(0.25)
×
372
        q3 = dup_df[numeric_feature].quantile(0.75)
×
373
        min_series_value = dup_df[numeric_feature].min()
×
374
        max_series_value = dup_df[numeric_feature].max()
×
375
        iqr = q3 - q1
×
376
        lower_bound = max(min_series_value, q1 - outlier_iqr_multiplier * iqr)
×
377
        upper_bound = min(max_series_value, q3 + outlier_iqr_multiplier * iqr)
×
378
        df_plot = dup_df[(dup_df[numeric_feature] >= lower_bound) & (dup_df[numeric_feature] <= upper_bound)].copy()
×
379

380
    sns.violinplot(x=categorical_feature, y=numeric_feature, hue=categorical_feature, data=df_plot, ax=ax, **kwargs)
12✔
381

382
    ax.set_xlabel(categorical_feature.replace("_", " ").title())
12✔
383
    ax.set_ylabel(numeric_feature.replace("_", " ").title())
12✔
384
    ax.grid(axis="y", linestyle="--", alpha=0.7)
12✔
385
    return ax
12✔
386

387

388
def _copy_series_or_keep_top_10(series: pd.Series) -> pd.Series:
12✔
389
    if series.dtype == bool:
12✔
390
        return series.map({True: "True", False: "False"})
12✔
391
    if len(series.unique()) > 10:
12✔
392
        top10 = series.value_counts().nlargest(10).index
12✔
393
        return series.map(lambda x: x if x in top10 else "Other values")
12✔
394
    return series
12✔
395

396

397
@plt.FuncFormatter
12✔
398
def _convert_numbers_to_dates(x, pos):
12✔
399
    return dates.num2date(x).strftime("%Y-%m-%d %H:%M")
12✔
400

401

402
def extract_statistics_dataframe_per_label(df: pd.DataFrame, feature_name: str, label_name: str) -> pd.DataFrame:
12✔
403
    """Calculate comprehensive statistical metrics for a specified feature grouped by label.
404

405
    This method computes various statistical measures for a given numerical feature, broken down by unique
406
    values in the specified label column. The statistics include count, null count,
407
    mean, standard deviation, min/max values and multiple percentiles.
408

409
    :param df: Input pandas DataFrame containing the data
410
    :param feature_name: Name of the column to calculate statistics on
411
    :param label_name: Name of the column to group by
412
    :return: DataFrame with statistical metrics for each unique label value, with columns:
413
            - count: Number of non-null observations
414
            - null_count: Number of null values
415
            - mean: Average value
416
            - min: Minimum value
417
            - 1_percentile: 1st percentile
418
            - 5_percentile: 5th percentile
419
            - 25_percentile: 25th percentile
420
            - median: 50th percentile
421
            - 75_percentile: 75th percentile
422
            - 95_percentile: 95th percentile
423
            - 99_percentile: 99th percentile
424
            - max: Maximum value
425

426
    :raises KeyError: If feature_name or label_name is not found in DataFrame
427
    :raises TypeError: If feature_name column is not numeric
428
    """
429
    if feature_name not in df.columns:
12✔
430
        raise KeyError(f"Feature column '{feature_name}' not found in DataFrame")
12✔
431
    if label_name not in df.columns:
12✔
432
        raise KeyError(f"Label column '{label_name}' not found in DataFrame")
12✔
433
    if not pd.api.types.is_numeric_dtype(df[feature_name]):
12✔
434
        raise TypeError(f"Feature column '{feature_name}' must be numeric")
12✔
435

436
    # Define percentile functions with consistent naming
437

438
    def percentile_1(x):
12✔
439
        return safe_percentile(x, 1)
12✔
440

441
    def percentile_5(x):
12✔
442
        return safe_percentile(x, 5)
12✔
443

444
    def percentile_25(x):
12✔
445
        return safe_percentile(x, 25)
12✔
446

447
    def percentile_75(x):
12✔
448
        return safe_percentile(x, 75)
12✔
449

450
    def percentile_95(x):
12✔
451
        return safe_percentile(x, 95)
12✔
452

453
    def percentile_99(x):
12✔
454
        return safe_percentile(x, 99)
12✔
455

456
    return df.groupby([label_name], observed=True)[feature_name].agg(
12✔
457
        [
458
            ("count", "count"),
459
            ("null_count", lambda x: x.isnull().sum()),
460
            ("mean", "mean"),
461
            ("min", "min"),
462
            ("1_percentile", percentile_1),
463
            ("5_percentile", percentile_5),
464
            ("25_percentile", percentile_25),
465
            ("median", "median"),
466
            ("75_percentile", percentile_75),
467
            ("95_percentile", percentile_95),
468
            ("99_percentile", percentile_99),
469
            ("max", "max"),
470
        ]
471
    )
472

473

474
def compute_mutual_information(
12✔
475
    df: pd.DataFrame,
476
    features: List[str],
477
    label_col: str,
478
    *,
479
    n_neighbors: int = 3,
480
    random_state: Optional[Union[int, RandomState]] = None,
481
    n_jobs: Optional[int] = None,
482
    numerical_imputer: TransformerMixin = SimpleImputer(strategy="mean"),
483
    discrete_imputer: TransformerMixin = SimpleImputer(strategy="most_frequent"),
484
    discrete_encoder: TransformerMixin = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1),
485
) -> pd.DataFrame:
486
    """Compute mutual information scores between features and a target label.
487

488
    This function calculates mutual information scores for specified features with respect to a target
489
    label column. Features are automatically categorized as numerical or discrete (boolean/categorical)
490
    and preprocessed accordingly before computing mutual information.
491

492
    Mutual information measures the mutual dependence between two variables - higher scores indicate
493
    stronger relationships between the feature and the target label.
494

495
    :param df: Input pandas DataFrame containing the features and label
496
    :param features: List of column names to compute mutual information for
497
    :param label_col: Name of the target label column
498
    :param n_neighbors: Number of neighbors to use for MI estimation for continuous variables. Higher values
499
                        reduce variance of the estimation, but could introduce a bias.
500
    :param random_state: Random state for reproducible results. Can be int or RandomState instance
501
    :param n_jobs: The number of jobs to use for computing the mutual information. The parallelization is done
502
                   on the columns. `None` means 1 unless in a `joblib.parallel_backend` context. ``-1`` means
503
                   using all processors.
504
    :param numerical_imputer: Sklearn-compatible transformer for numerical features (default: mean imputation)
505
    :param discrete_imputer: Sklearn-compatible transformer for discrete features (default: most frequent imputation)
506
    :param discrete_encoder: Sklearn-compatible transformer for encoding discrete features (default: ordinal encoding
507
                            with unknown value handling)
508
    :return: DataFrame with columns 'feature_name' and 'mi_score', sorted by MI score (descending)
509

510
    :raises KeyError: If any feature or label_col is not found in DataFrame
511
    :raises ValueError: If features list is empty or label_col contains non-finite values
512
    """
513
    # Input validation
514
    if not features:
12✔
515
        raise ValueError("features list cannot be empty")
12✔
516

517
    if label_col not in df.columns:
12✔
518
        raise KeyError(f"Label column '{label_col}' not found in DataFrame")
12✔
519

520
    missing_features = [f for f in features if f not in df.columns]
12✔
521
    if missing_features:
12✔
522
        raise KeyError(f"Features not found in DataFrame: {missing_features}")
12✔
523

524
    if df[label_col].isnull().all():
12✔
525
        raise ValueError(f"Label column '{label_col}' contains only null values")
12✔
526

527
    # Identify feature types
528
    numerical_features = df[features].select_dtypes(include=[np.number]).columns.tolist()
12✔
529
    boolean_features = df[features].select_dtypes(include=[bool]).columns.tolist()
12✔
530
    categorical_features = df[features].select_dtypes(include=["object", "category"]).columns.tolist()
12✔
531

532
    # Create preprocessing pipelines
533
    numerical_transformer = Pipeline(steps=[("imputer", numerical_imputer)], memory=None, verbose=False)
12✔
534
    discrete_transformer = Pipeline(
12✔
535
        steps=[("imputer", discrete_imputer), ("encoder", discrete_encoder)], memory=None, verbose=False
536
    )
537

538
    # Setup column transformer
539
    transformers = []
12✔
540
    if numerical_features:
12✔
541
        transformers.append(("num", numerical_transformer, numerical_features))
12✔
542
    if boolean_features or categorical_features:
12✔
543
        transformers.append(("discrete", discrete_transformer, boolean_features + categorical_features))
12✔
544

545
    preprocessor = ColumnTransformer(
12✔
546
        transformers=transformers,
547
        remainder="drop",  # Drop any features not explicitly handled
548
        sparse_threshold=0,
549
        n_jobs=n_jobs,
550
        transformer_weights=None,
551
        verbose=False,
552
        verbose_feature_names_out=True,
553
    )
554

555
    # Create discrete features mask for mutual_info_classif
556
    discrete_features_mask = [False] * len(numerical_features) + [True] * (
12✔
557
        len(boolean_features) + len(categorical_features)
558
    )
559

560
    # Create ordered feature names list matching the preprocessed data
561
    ordered_feature_names = numerical_features + boolean_features + categorical_features
12✔
562

563
    # Apply preprocessing
564
    x_preprocessed = preprocessor.fit_transform(df[ordered_feature_names])
12✔
565
    y = df[label_col]
12✔
566

567
    # Compute mutual information scores
568
    mi_scores = mutual_info_classif(
12✔
569
        X=x_preprocessed,
570
        y=y,
571
        n_neighbors=n_neighbors,
572
        copy=True,
573
        random_state=random_state,
574
        n_jobs=n_jobs,
575
        discrete_features=discrete_features_mask,
576
    )
577

578
    # Create results DataFrame
579
    mi_df = pd.DataFrame({"feature_name": ordered_feature_names, "mi_score": mi_scores})
12✔
580

581
    return mi_df.sort_values(by="mi_score", ascending=False).reset_index(drop=True)
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc