• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

idanmoradarthas / DataScienceUtils / 21105664715

18 Jan 2026 03:58AM UTC coverage: 86.972% (-9.8%) from 96.765%
21105664715

push

github

idanmoradarthas
linting

2 of 2 new or added lines in 1 file covered. (100.0%)

102 existing lines in 2 files now uncovered.

721 of 829 relevant lines covered (86.97%)

10.44 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

75.23
/ds_utils/preprocess.py
1
"""Data preprocessing utilities."""
2

3
from typing import Callable, List, Optional, Union
12✔
4
import warnings
12✔
5

6
from matplotlib import axes, dates, pyplot as plt, ticker
12✔
7
import numpy as np
12✔
8
from numpy.random import RandomState
12✔
9
import pandas as pd
12✔
10
from scipy.cluster.hierarchy import dendrogram, linkage
12✔
11
from scipy.spatial.distance import squareform
12✔
12
import seaborn as sns
12✔
13
from sklearn.base import TransformerMixin
12✔
14
from sklearn.compose import ColumnTransformer
12✔
15
from sklearn.feature_selection import mutual_info_classif
12✔
16
from sklearn.impute import SimpleImputer
12✔
17
from sklearn.pipeline import Pipeline
12✔
18
from sklearn.preprocessing import OrdinalEncoder
12✔
19

20
from ds_utils.math_utils import safe_percentile
12✔
21

22

23
def _plot_clean_violin_distribution(
12✔
24
    series: pd.Series, include_outliers: bool, outlier_iqr_multiplier: float, ax: Optional[axes.Axes] = None, **kwargs
25
) -> axes.Axes:
26
    """Plot a violin distribution for a numeric series with optional outlier trimming.
27

28
    When ``include_outliers`` is False, values outside the IQR fence are removed
29
    before plotting. The fence is defined as
30
    [Q1 - k * IQR, Q3 + k * IQR], where ``k`` is ``outlier_iqr_multiplier``, and
31
    the bounds are clipped to the observed min/max of the series.
32

33
    :param series: Numeric series to visualize. NA handling is expected upstream.
34
    :param include_outliers: Whether to include values outside the IQR fence.
35
    :param outlier_iqr_multiplier: Multiplier ``k`` used to compute the IQR fence.
36
    :param ax: Matplotlib Axes to draw on. If None, callers should provide one upstream.
37
    :param kwargs: Additional keyword arguments passed to ``seaborn.violinplot``.
38
    :return: The Axes object with the violin plot.
39
    """
40
    if include_outliers:
12✔
41
        series_plot = series.copy()
12✔
42
    else:
43
        q1 = series.quantile(0.25)
12✔
44
        q3 = series.quantile(0.75)
12✔
45
        min_series_value = series.min()
12✔
46
        max_series_value = series.max()
12✔
47
        iqr = q3 - q1
12✔
48
        lower_bound = max(min_series_value, q1 - outlier_iqr_multiplier * iqr)
12✔
49
        upper_bound = min(max_series_value, q3 + outlier_iqr_multiplier * iqr)
12✔
50
        series_plot = series[(series >= lower_bound) & (series <= upper_bound)].copy()
12✔
51

52
    sns.violinplot(y=series_plot, hue=None, legend=False, ax=ax, **kwargs)
12✔
53

54
    ax.set_xticks([])
12✔
55
    ax.set_ylabel("Values")
12✔
56
    ax.grid(axis="y", linestyle="--", alpha=0.7)
12✔
57

58
    return ax
12✔
59

60

61
def _plot_datetime_heatmap(feature_series: pd.Series, first_day_of_week: str, ax: axes.Axes, **kwargs) -> axes.Axes:
12✔
62
    """Plot a 2D heatmap for datetime features showing day-of-week vs year-week patterns.
63

64
    :param feature_series: The datetime series to visualize.
65
    :param first_day_of_week: First day of the week for the heatmap X-axis.
66
    :param ax: Matplotlib Axes to draw on.
67
    :param kwargs: Additional keyword arguments passed to seaborn's heatmap function.
68
    :return: The Axes object with the heatmap.
69
    """
70
    # Validate first_day_of_week parameter
71
    valid_days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
12✔
72
    if first_day_of_week not in valid_days:
12✔
73
        raise ValueError(f"first_day_of_week must be one of {valid_days}, got '{first_day_of_week}'")
12✔
74

75
    # Create day of week order starting with first_day_of_week
76
    day_index = valid_days.index(first_day_of_week)
12✔
77
    day_order = valid_days[day_index:] + valid_days[:day_index]
12✔
78

79
    # Create DataFrame with date, day of week, year, and week number
80
    df = (
12✔
81
        feature_series.to_frame("date")
82
        .assign(
83
            day_of_week=lambda x: x["date"].dt.day_name(),
84
            year=lambda x: x["date"].dt.year,
85
            week_number=lambda x: x["date"].dt.isocalendar().week,
86
        )
87
        .assign(year_week=lambda x: x["year"].astype(str) + "-W" + x["week_number"].astype(str).str.zfill(2))
88
        .groupby(["year_week", "day_of_week"])
89
        .size()
90
        .unstack(fill_value=0)
91
    )
92

93
    # Ensure all days of the week are present as columns, reordered according to day_order
94
    for day in day_order:
12✔
95
        if day not in df.columns:
12✔
96
            df[day] = 0
12✔
97

98
    # Reorder columns to match day_order (columns = day of week, rows = year-week)
99
    df = df.reindex(columns=day_order)
12✔
100

101
    # Create heatmap with annotations to show numbers in cells
102
    sns.heatmap(df, cmap="Blues", ax=ax, annot=True, fmt="d", **kwargs)
12✔
103
    ax.set_xlabel("Day of Week")
12✔
104
    ax.set_ylabel("Year-Week")
12✔
105

106
    return ax
12✔
107

108

109
def _copy_series_or_keep_top_10(series: pd.Series) -> pd.Series:
12✔
110
    if pd.api.types.is_bool_dtype(series):
12✔
111
        return series.map({True: "True", False: "False"})
12✔
112
    if len(series.unique()) > 10:
12✔
113
        top10 = series.value_counts().nlargest(10).index
12✔
114
        return series.map(lambda x: x if x in top10 else "Other values")
12✔
115
    return series
12✔
116

117

118
def _plot_count_bar(
12✔
119
    value_counts: pd.Series, order: Optional[Union[List[str], str]], show_counts: bool, ax: axes.Axes, **kwargs
120
) -> axes.Axes:
121
    """Plot a bar chart for categorical data with optional ordering and count labels.
122

123
    :param value_counts: Series containing value counts to plot
124
    :param order: Order specification for categories (None, string, or list)
125
    :param show_counts: Whether to display count values on top of bars
126
    :param ax: Axes to draw on
127
    :param kwargs: Additional arguments passed to ax.bar
128
    :return: The Axes object with the bar plot
129
    """
130
    # Apply ordering based on the order parameter
131
    if order is None:
12✔
132
        value_counts = value_counts.sort_index()
12✔
133
    elif isinstance(order, str):
12✔
134
        if order == "count_desc":
12✔
135
            value_counts = value_counts.sort_values(ascending=False)
12✔
136
        elif order == "count_asc":
12✔
137
            value_counts = value_counts.sort_values(ascending=True)
12✔
138
        elif order == "alpha_asc":
12✔
139
            value_counts = value_counts.sort_index(ascending=True)
12✔
140
        elif order == "alpha_desc":
12✔
141
            value_counts = value_counts.sort_index(ascending=False)
12✔
142
        else:
143
            raise ValueError(
12✔
144
                f"Invalid order string: '{order}'. Must be one of: 'count_desc', 'count_asc', 'alpha_asc', 'alpha_desc'"
145
            )
146
    elif isinstance(order, list):
12✔
147
        # Filter to only include categories present in the data
148
        valid_order = [cat for cat in order if cat in value_counts.index]
12✔
149
        # Add any missing categories from value_counts that weren't in order
150
        missing_cats = [cat for cat in value_counts.index if cat not in valid_order]
12✔
151
        full_order = valid_order + missing_cats
12✔
152
        value_counts = value_counts.reindex(full_order)
12✔
153

154
    # Create bar plot using matplotlib
155
    bars = ax.bar(range(len(value_counts)), value_counts.values, **kwargs)
12✔
156
    ax.set_xticks(range(len(value_counts)))
12✔
157
    ax.set_xticklabels(value_counts.index)
12✔
158
    ax.set_ylabel("Count")
12✔
159

160
    # Add count labels if requested
161
    if show_counts:
12✔
162
        for bar in bars:
12✔
163
            height = bar.get_height()
12✔
164
            ax.text(
12✔
165
                bar.get_x() + bar.get_width() / 2.0,
166
                height,
167
                f"{int(height):,}",
168
                ha="center",
169
                va="bottom",
170
                fontweight="bold",
171
            )
172

173
    return ax
12✔
174

175

176
def visualize_feature(
12✔
177
    series: pd.Series,
178
    remove_na: bool = False,
179
    *,
180
    include_outliers: bool = True,
181
    outlier_iqr_multiplier: float = 1.5,
182
    first_day_of_week: str = "Monday",
183
    show_counts: bool = True,
184
    order: Optional[Union[List[str], str]] = None,
185
    ax: Optional[axes.Axes] = None,
186
    **kwargs,
187
) -> axes.Axes:
188
    """Visualize a pandas Series using an appropriate plot based on dtype.
189

190
    Behavior by dtype:
191
    - Float: draw a violin distribution. If ``include_outliers`` is False, values
192
      outside the IQR fence [Q1 - k*IQR, Q3 + k*IQR] with ``k=outlier_iqr_multiplier``
193
      are trimmed prior to plotting.
194
    - Datetime: draw a 2D heatmap showing day-of-week vs year-week patterns. The heatmap
195
      displays counts of records for each day of the week (X-axis) and year-week combination
196
      (Y-axis), making weekly and yearly patterns immediately visible.
197
    - Object/categorical/bool/int: draw a count plot. Extremely high-cardinality
198
      series may be reduced to their top categories internally.
199

200
    :param series: The data series to visualize.
201
    :param remove_na: If True, plot with NA values removed; otherwise include them.
202
    :param include_outliers: Whether to include outliers for float features.
203
    :param outlier_iqr_multiplier: IQR multiplier used to trim outliers for float features.
204
    :param first_day_of_week: First day of the week for the heatmap X-axis. Must be one of
205
                              "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday".
206
                              Default is "Monday".
207
    :param show_counts: If True, display count values on top of bars in count plots. Default is True.
208
    :param order: Order to plot categorical levels in count plots. Can be:
209

210
                  - None: Use default sorting (index order after value_counts)
211
                  - "count_desc": Sort by count in descending order (most frequent first)
212
                  - "count_asc": Sort by count in ascending order (least frequent first)
213
                  - "alpha_asc": Sort alphabetically in ascending order
214
                  - "alpha_desc": Sort alphabetically in descending order
215
                  - List: Explicit list of category names in desired order
216

217
                  Only applies to categorical/object/bool/int features.
218
    :param ax: Axes in which to draw the plot. If None, a new one is created.
219
    :param kwargs: Extra keyword arguments forwarded to the underlying plotting function
220
                   (``seaborn.violinplot``, ``seaborn.heatmap``, or ``matplotlib.pyplot.bar``).
221
    :return: The Axes object with the plot drawn onto it.
222
    """
223
    if ax is None:
12✔
224
        _, ax = plt.subplots()
12✔
225

226
    feature_series = series.dropna() if remove_na else series
12✔
227

228
    if pd.api.types.is_float_dtype(feature_series):
12✔
229
        ax = _plot_clean_violin_distribution(feature_series, include_outliers, outlier_iqr_multiplier, ax, **kwargs)
12✔
230
    elif pd.api.types.is_datetime64_any_dtype(feature_series):
12✔
231
        ax = _plot_datetime_heatmap(feature_series, first_day_of_week, ax, **kwargs)
12✔
232
        labels = ax.get_xticklabels()
12✔
233
    else:
234
        series_to_plot = _copy_series_or_keep_top_10(feature_series)
12✔
235
        value_counts = series_to_plot.value_counts(dropna=remove_na).sort_index()
12✔
236
        ax = _plot_count_bar(value_counts, order, show_counts, ax, **kwargs)
12✔
237
        labels = ax.get_xticklabels()
12✔
238

239
    if not ax.get_title():
12✔
240
        ax.set_title(f"{feature_series.name} ({feature_series.dtype})")
12✔
241
        # Only set empty xlabel for non-datetime plots
242
        if not pd.api.types.is_datetime64_any_dtype(feature_series):
12✔
243
            ax.set_xlabel("")
12✔
244

245
    # Skip tick relabeling for float (violin) plots where x-ticks are hidden
246
    # Also skip for datetime plots as they handle their own labels
247
    if not pd.api.types.is_float_dtype(feature_series) and not pd.api.types.is_datetime64_any_dtype(feature_series):
12✔
248
        ticks_loc = ax.get_xticks()
12✔
249
        ax.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
12✔
250
        ax.set_xticklabels(labels, rotation=45, ha="right")
12✔
251

252
    return ax
12✔
253

254

255
def get_correlated_features(
12✔
256
    correlation_matrix: pd.DataFrame, features: List[str], target_feature: str, threshold: float = 0.95
257
) -> pd.DataFrame:
258
    """Calculate features correlated above a threshold with target correlations.
259

260
    Calculate features correlated above a threshold and extract a DataFrame with correlations and correlation
261
    to the target feature.
262

263
    :param correlation_matrix: The correlation matrix.
264
    :param features: List of feature names to analyze.
265
    :param target_feature: Name of the target feature.
266
    :param threshold: Correlation threshold (default 0.95).
267
    :return: DataFrame with correlations and correlation to the target feature.
268
    """
269
    target_corr = correlation_matrix[target_feature]
12✔
270
    features_corr = correlation_matrix.loc[features, features]
12✔
271
    corr_matrix = features_corr.where(np.triu(np.ones(features_corr.shape), k=1).astype(bool))
12✔
272
    corr_matrix = corr_matrix[~np.isnan(corr_matrix)].stack().reset_index()
12✔
273
    corr_matrix = corr_matrix[corr_matrix[0].abs() >= threshold]
12✔
274

275
    if corr_matrix.empty:
12✔
276
        warnings.warn(f"Correlation threshold {threshold} was too high. An empty frame was returned", UserWarning)
12✔
277
        return pd.DataFrame(
12✔
278
            columns=["level_0", "level_1", "level_0_level_1_corr", "level_0_target_corr", "level_1_target_corr"]
279
        )
280

281
    corr_matrix["level_0_target_corr"] = target_corr[corr_matrix["level_0"]].values
12✔
282
    corr_matrix["level_1_target_corr"] = target_corr[corr_matrix["level_1"]].values
12✔
283
    corr_matrix = corr_matrix.rename({0: "level_0_level_1_corr"}, axis=1).reset_index(drop=True)
12✔
284
    return corr_matrix
12✔
285

286

287
def visualize_correlations(correlation_matrix: pd.DataFrame, *, ax: Optional[axes.Axes] = None, **kwargs) -> axes.Axes:
12✔
288
    """Compute and visualize pairwise correlations of columns, excluding NA/null values.
289

290
    `Original Seaborn code <https://seaborn.pydata.org/examples/many_pairwise_correlations.html>`_.
291

292
    :param correlation_matrix: The correlation matrix.
293
    :param ax: Axes in which to draw the plot. If None, use the currently active Axes.
294
    :param kwargs: Additional keyword arguments passed to seaborn's heatmap function.
295
    :return: The Axes object with the plot drawn onto it.
296
    """
297
    if ax is None:
12✔
298
        _, ax = plt.subplots()
12✔
299

300
    mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
12✔
301
    sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt=".3f", ax=ax, **kwargs)
12✔
302
    return ax
12✔
303

304

305
def plot_correlation_dendrogram(
12✔
306
    correlation_matrix: pd.DataFrame,
307
    cluster_distance_method: Union[str, Callable] = "average",
308
    *,
309
    ax: Optional[axes.Axes] = None,
310
    **kwargs,
311
) -> axes.Axes:
312
    """Plot a dendrogram of the correlation matrix, showing hierarchically the most correlated variables.
313

314
    `Original XAI code <https://github.com/EthicalML/XAI>`_.
315

316
    :param correlation_matrix: The correlation matrix.
317
    :param cluster_distance_method: Method for calculating the distance between newly formed clusters.
318
                                    `Read more here <https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html>`_
319
    :param ax: Axes in which to draw the plot. If None, use the currently active Axes.
320
    :param kwargs: Additional keyword arguments passed to the dendrogram function.
321
    :return: The Axes object with the plot drawn onto it.
322
    """
323
    if ax is None:
12✔
324
        _, ax = plt.subplots()
12✔
325

326
    corr_condensed = squareform(1 - correlation_matrix)
12✔
327
    z = linkage(corr_condensed, method=cluster_distance_method)
12✔
328
    ax.set(**kwargs)
12✔
329
    dendrogram(z, labels=correlation_matrix.columns.tolist(), orientation="left", ax=ax)
12✔
330
    return ax
12✔
331

332

333
def plot_features_interaction(
12✔
334
    data: pd.DataFrame,
335
    feature_1: str,
336
    feature_2: str,
337
    *,
338
    remove_na: bool = False,
339
    include_outliers: bool = True,
340
    outlier_iqr_multiplier: float = 1.5,
341
    show_ratios: bool = False,
342
    ax: Optional[axes.Axes] = None,
343
    **kwargs,
344
) -> axes.Axes:
345
    """Plot the joint distribution between two features using type-aware defaults.
346

347
    Behavior by dtypes of ``feature_1`` and ``feature_2``:
348
    - If both are numeric: scatter plot.
349
    - If one is datetime and the other numeric: line/scatter over time.
350
    - If both are datetime: scatter plot with complete cases.
351
    - If both are categorical-like: overlaid histograms per category.
352
    - If one is categorical-like and the other numeric: violin plot by category.
353

354
    For the categorical-vs-numeric case, you can optionally trim outliers from the
355
    numeric feature using an IQR fence [Q1 - k*IQR, Q3 + k*IQR], where ``k`` is
356
    controlled by ``outlier_iqr_multiplier``.
357

358
    When ``remove_na`` is False, missing values are visualized:
359
    - Numeric vs Numeric: marginal rug plots showing missing values
360
    - Numeric vs Datetime: missing numeric values shown as markers on x-axis,
361
      missing datetime values shown as rug plot on right margin
362
    - Datetime vs Datetime: complete cases shown as scatter plot, missing values
363
      shown as rug plots on margins (x-axis for missing feature_2, y-axis for missing feature_1)
364
    - Categorical vs Numeric: missing numeric values shown with rug plots per category
365
    - Categorical vs Categorical: missing values included as "Missing" category
366
    - Categorical/Boolean vs Datetime: missing categorical values added as "Missing" category,
367
      missing datetime values shown as a separate violin at the edge of the plot
368

369
    :param data: The input DataFrame where each feature is a column.
370
    :param feature_1: Name of the first feature.
371
    :param feature_2: Name of the second feature.
372
    :param remove_na: If False (default), keep all data and visualize missingness patterns.
373
                      If True, remove rows where either feature is NA before plotting.
374
    :param include_outliers: Whether to include values outside the IQR fence for
375
                             categorical-vs-numeric violin plots (default True).
376
    :param outlier_iqr_multiplier: Multiplier ``k`` for the IQR fence when trimming
377
                                   outliers in categorical-vs-numeric plots (default 1.5).
378
    :param show_ratios: If True, display ratios (proportions) instead of absolute counts
379
                        for categorical vs categorical plots. Only applies when both
380
                        features are categorical-like (default False).
381
    :param ax: Axes in which to draw the plot. If None, a new one is created.
382
    :param kwargs: Additional keyword arguments forwarded to the underlying plotting
383
                   functions (e.g., ``seaborn.violinplot``, ``Axes.scatter``, ``Axes.plot``).
384
    :return: The Axes object with the plot drawn onto it.
385
    """
386
    if ax is None:
12✔
387
        _, ax = plt.subplots()
12✔
388

389
    if remove_na:
12✔
UNCOV
390
        plot_data = data[[feature_1, feature_2]].dropna()
×
391
    else:
392
        plot_data = data[[feature_1, feature_2]].copy()
12✔
393

394
    dtype1 = data[feature_1].dtype
12✔
395
    dtype2 = data[feature_2].dtype
12✔
396

397
    if _is_categorical_like(dtype1):
12✔
398
        ax = _plot_categorical_feature1(
12✔
399
            feature_1,
400
            feature_2,
401
            plot_data,
402
            dtype2,
403
            include_outliers,
404
            outlier_iqr_multiplier,
405
            show_ratios,
406
            remove_na,
407
            ax,
408
            **kwargs,
409
        )
410
    elif pd.api.types.is_datetime64_any_dtype(dtype1):
12✔
411
        ax = _plot_datetime_feature1(feature_1, feature_2, plot_data, dtype2, remove_na, ax, **kwargs)
12✔
412
    elif _is_categorical_like(dtype2):
12✔
413
        ax = _plot_categorical_vs_numeric(
12✔
414
            feature_2, feature_1, plot_data, outlier_iqr_multiplier, include_outliers, remove_na, ax, **kwargs
415
        )
416
    elif pd.api.types.is_datetime64_any_dtype(dtype2):
12✔
417
        ax = _plot_datetime_vs_numeric(feature_2, feature_1, plot_data, remove_na, ax, **kwargs)
12✔
418
    else:
419
        ax = _plot_numeric_features(feature_1, feature_2, plot_data, remove_na, ax, **kwargs)
12✔
420

421
    return ax
12✔
422

423

424
def _is_categorical_like(dtype):
12✔
425
    """Check if the dtype is categorical-like (categorical, boolean, or object)."""
426
    return (
12✔
427
        isinstance(dtype, pd.CategoricalDtype)
428
        or pd.api.types.is_bool_dtype(dtype)
429
        or pd.api.types.is_object_dtype(dtype)
430
    )
431

432

433
def _plot_categorical_feature1(
12✔
434
    categorical_feature,
435
    feature_2,
436
    data,
437
    dtype2,
438
    include_outliers,
439
    outlier_iqr_multiplier,
440
    show_ratios,
441
    remove_na,
442
    ax,
443
    **kwargs,
444
):
445
    """Plot when the first feature is categorical-like."""
446
    if _is_categorical_like(dtype2):
12✔
447
        ax = _plot_categorical_vs_categorical(
12✔
448
            categorical_feature, feature_2, data, show_ratios, remove_na, ax, **kwargs
449
        )
450
    elif pd.api.types.is_datetime64_any_dtype(dtype2):
12✔
451
        ax = _plot_categorical_vs_datetime(categorical_feature, feature_2, data, remove_na, ax, **kwargs)
12✔
452
    else:
453
        ax = _plot_categorical_vs_numeric(
12✔
454
            categorical_feature,
455
            feature_2,
456
            data,
457
            outlier_iqr_multiplier,
458
            include_outliers,
459
            remove_na,
460
            ax,
461
            **kwargs,
462
        )
463
    return ax
12✔
464

465

466
def _plot_datetime_vs_numeric(datetime_feature, other_feature, data, remove_na, ax, **kwargs):
12✔
467
    """Plot datetime vs numeric feature.
468

469
    When remove_na is False, missing values are handled as follows:
470
    - Missing numeric values: shown as markers on the x-axis (y=0 or bottom of plot)
471
    - Missing datetime values: shown as a rug plot on the right margin
472
    - Different colors/markers distinguish the two types of missingness
473
    """
474
    # Get complete cases for main plot
475
    complete_data = data.dropna(subset=[datetime_feature, other_feature])
12✔
476

477
    if len(complete_data) > 0:
12✔
478
        ax.plot(complete_data[datetime_feature], complete_data[other_feature], **kwargs)
12✔
479

480
    ax.set_xlabel(datetime_feature)
12✔
481
    ax.set_ylabel(other_feature)
12✔
482

483
    # Handle missing values if not removed
484
    # Skip missing value visualization if both features are the same column
485
    if not remove_na and datetime_feature != other_feature:
12✔
486
        has_plotted_missing = False
12✔
487

488
        # Cases where datetime is present but numeric is missing
489
        missing_numeric = data[data[other_feature].isna() & data[datetime_feature].notna()]
12✔
490
        if len(missing_numeric) > 0:
12✔
491
            # Filter out any rows where datetime_feature is also NaN (shouldn't happen due to filter, but be safe)
UNCOV
492
            missing_numeric_clean = missing_numeric[missing_numeric[datetime_feature].notna()]
×
UNCOV
493
            if len(missing_numeric_clean) > 0:
×
UNCOV
494
                y_min = ax.get_ylim()[0] if len(complete_data) > 0 else 0
×
UNCOV
495
                ax.scatter(
×
496
                    missing_numeric_clean[datetime_feature],
497
                    [y_min] * len(missing_numeric_clean),
498
                    marker="|",
499
                    s=100,
500
                    alpha=0.6,
501
                    color="red",
502
                    label=f"{other_feature} missing",
503
                    zorder=5,
504
                )
UNCOV
505
                has_plotted_missing = True
×
506

507
        # Cases where numeric is present but datetime is missing
508
        missing_datetime = data[data[datetime_feature].isna() & data[other_feature].notna()]
12✔
509
        if len(missing_datetime) > 0:
12✔
510
            # Get the x-axis range to place rug marks at the right edge
UNCOV
511
            if len(complete_data) > 0:
×
UNCOV
512
                x_min, x_max = ax.get_xlim()
×
UNCOV
513
                y_min, y_max = ax.get_ylim()
×
514
            else:
515
                # If no complete data, set reasonable defaults
UNCOV
516
                x_min = dates.date2num(pd.Timestamp.now() - pd.Timedelta(days=30))
×
UNCOV
517
                x_max = dates.date2num(pd.Timestamp.now())
×
UNCOV
518
                y_min = missing_datetime[other_feature].min()
×
UNCOV
519
                y_max = missing_datetime[other_feature].max()
×
UNCOV
520
                if y_min == y_max:
×
UNCOV
521
                    y_min -= 1
×
UNCOV
522
                    y_max += 1
×
UNCOV
523
                ax.set_xlim(x_min, x_max)
×
UNCOV
524
                ax.set_ylim(y_min, y_max)
×
525

526
            # Plot rug marks for missing datetime values at the right edge
527
            # Use a small offset from the right edge to make them visible
UNCOV
528
            x_range = x_max - x_min
×
UNCOV
529
            x_rug = x_max + x_range * 0.02  # 2% offset from right edge
×
UNCOV
530
            ax.scatter(
×
531
                [x_rug] * len(missing_datetime),
532
                missing_datetime[other_feature],
533
                marker="_",
534
                s=100,
535
                alpha=0.6,
536
                color="orange",
537
                label=f"{datetime_feature} missing",
538
                zorder=5,
539
            )
540

541
            # Extend xlim slightly to accommodate the rug plot
UNCOV
542
            ax.set_xlim(x_min, x_max + x_range * 0.05)
×
UNCOV
543
            has_plotted_missing = True
×
544

545
        # Add legend if we plotted any missing values
546
        if has_plotted_missing:
12✔
UNCOV
547
            ax.legend(loc="best", framealpha=0.9)
×
548

549
    return ax
12✔
550

551

552
def _plot_datetime_vs_datetime(datetime_feature_1, datetime_feature_2, data, remove_na, ax, **kwargs):
12✔
553
    """Plot when both features are datetime.
554

555
    When remove_na is False, missing values are handled as follows:
556
    - Complete cases: shown as line/scatter plot
557
    - Missing datetime_feature_2 values: shown as rug plot on x-axis (bottom margin)
558
    - Missing datetime_feature_1 values: shown as rug plot on y-axis (left margin)
559
    - Different colors/markers distinguish the two types of missingness
560
    """
561
    # Get complete cases for main plot
562
    complete_data = data.dropna(subset=[datetime_feature_1, datetime_feature_2])
12✔
563

564
    if len(complete_data) > 0:
12✔
565
        # Use scatter plot for datetime vs datetime (can also use line plot)
566
        ax.scatter(complete_data[datetime_feature_1], complete_data[datetime_feature_2], **kwargs)
12✔
567

568
    ax.set_xlabel(datetime_feature_1)
12✔
569
    ax.set_ylabel(datetime_feature_2)
12✔
570

571
    # Format both axes as datetime
572
    ax.xaxis.set_major_formatter(dates.DateFormatter("%Y-%m-%d"))
12✔
573
    ax.yaxis.set_major_formatter(dates.DateFormatter("%Y-%m-%d"))
12✔
574
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha="right")
12✔
575
    plt.setp(ax.yaxis.get_majorticklabels(), rotation=0)
12✔
576

577
    # Handle missing values if not removed
578
    if not remove_na and datetime_feature_1 != datetime_feature_2:
12✔
UNCOV
579
        has_plotted_missing = False
×
580

581
        # Cases where datetime_feature_1 is present but datetime_feature_2 is missing
UNCOV
582
        missing_f2 = data[data[datetime_feature_2].isna() & data[datetime_feature_1].notna()]
×
UNCOV
583
        if len(missing_f2) > 0:
×
584
            # Get y-axis limits to place rug marks at the bottom
UNCOV
585
            if len(complete_data) > 0:
×
UNCOV
586
                y_min, y_max = ax.get_ylim()
×
587
            else:
588
                # If no complete data, use range from available datetime_feature_2 values
UNCOV
589
                available_f2 = data[datetime_feature_2].dropna()
×
UNCOV
590
                if len(available_f2) > 0:
×
UNCOV
591
                    y_min = dates.date2num(available_f2.min())
×
UNCOV
592
                    y_max = dates.date2num(available_f2.max())
×
UNCOV
593
                    if y_min == y_max:
×
UNCOV
594
                        y_min -= 1
×
UNCOV
595
                        y_max += 1
×
596
                else:
597
                    # Fallback to default range if no datetime_feature_2 values available
UNCOV
598
                    y_min = dates.date2num(pd.Timestamp.now() - pd.Timedelta(days=30))
×
UNCOV
599
                    y_max = dates.date2num(pd.Timestamp.now())
×
UNCOV
600
                ax.set_ylim(y_min, y_max)
×
601

602
            # Place rug marks slightly below the bottom of the plot
UNCOV
603
            y_range = y_max - y_min
×
UNCOV
604
            y_rug = y_min - y_range * 0.02  # 2% offset below bottom
×
605

UNCOV
606
            ax.scatter(
×
607
                missing_f2[datetime_feature_1],
608
                [y_rug] * len(missing_f2),
609
                marker="|",
610
                s=100,
611
                alpha=0.6,
612
                color="red",
613
                label=f"{datetime_feature_2} missing",
614
                zorder=5,
615
            )
616
            # Extend ylim slightly to accommodate the rug plot
UNCOV
617
            ax.set_ylim(y_min - y_range * 0.05, y_max)
×
UNCOV
618
            has_plotted_missing = True
×
619

620
        # Cases where datetime_feature_2 is present but datetime_feature_1 is missing
UNCOV
621
        missing_f1 = data[data[datetime_feature_1].isna() & data[datetime_feature_2].notna()]
×
UNCOV
622
        if len(missing_f1) > 0:
×
623
            # Get x-axis limits to place rug marks on the left
UNCOV
624
            if len(complete_data) > 0:
×
UNCOV
625
                x_min, x_max = ax.get_xlim()
×
626
            else:
627
                # If no complete data, use range from available datetime_feature_1 values
UNCOV
628
                available_f1 = data[datetime_feature_1].dropna()
×
UNCOV
629
                if len(available_f1) > 0:
×
UNCOV
630
                    x_min = dates.date2num(available_f1.min())
×
631
                    x_max = dates.date2num(available_f1.max())
×
UNCOV
632
                    if x_min == x_max:
×
UNCOV
633
                        x_min -= 1
×
634
                        x_max += 1
×
635
                else:
636
                    # Fallback to default range if no datetime_feature_1 values available
UNCOV
637
                    x_min = dates.date2num(pd.Timestamp.now() - pd.Timedelta(days=30))
×
638
                    x_max = dates.date2num(pd.Timestamp.now())
×
639
                ax.set_xlim(x_min, x_max)
×
640

641
            # Place rug marks slightly to the left of the plot
UNCOV
642
            x_range = x_max - x_min
×
UNCOV
643
            x_rug = x_min - x_range * 0.02  # 2% offset to the left
×
644

UNCOV
645
            ax.scatter(
×
646
                [x_rug] * len(missing_f1),
647
                missing_f1[datetime_feature_2],
648
                marker="_",
649
                s=100,
650
                alpha=0.6,
651
                color="orange",
652
                label=f"{datetime_feature_1} missing",
653
                zorder=5,
654
            )
655
            # Extend xlim slightly to accommodate the rug plot
UNCOV
656
            ax.set_xlim(x_min - x_range * 0.05, x_max)
×
UNCOV
657
            has_plotted_missing = True
×
658

659
        # Add legend if we plotted any missing values
660
        if has_plotted_missing:
×
661
            ax.legend(loc="best", framealpha=0.9)
×
662

663
    return ax
12✔
664

665

666
def _plot_datetime_feature1(datetime_feature, feature_2, data, dtype2, remove_na, ax, **kwargs):
12✔
667
    """Plot when the first feature is datetime."""
668
    if _is_categorical_like(dtype2):
12✔
669
        ax = _plot_categorical_vs_datetime(feature_2, datetime_feature, data, remove_na, ax, **kwargs)
12✔
670
    elif pd.api.types.is_datetime64_any_dtype(dtype2):
12✔
671
        # Both features are datetime - use specialized datetime vs datetime plot
672
        ax = _plot_datetime_vs_datetime(datetime_feature, feature_2, data, remove_na, ax, **kwargs)
12✔
673
    else:
674
        ax = _plot_datetime_vs_numeric(datetime_feature, feature_2, data, remove_na, ax, **kwargs)
12✔
675
    return ax
12✔
676

677

678
def _plot_numeric_features(feature_1, feature_2, data, remove_na, ax, **kwargs):
12✔
679
    """Plot when both features are numeric.
680

681
    If remove_na is False, adds marginal rug plots showing where missing values occur.
682
    """
683
    # Get complete cases for main scatter plot
684
    complete_data = data.dropna(subset=[feature_1, feature_2])
12✔
685

686
    ax.scatter(complete_data[feature_1], complete_data[feature_2], **kwargs)
12✔
687
    ax.set_xlabel(feature_1)
12✔
688
    ax.set_ylabel(feature_2)
12✔
689

690
    # Add marginal rug plots for missing values if not removed
691
    if not remove_na:
12✔
692
        # Cases where feature_1 is present but feature_2 is missing
693
        missing_f2 = data[data[feature_2].isna() & data[feature_1].notna()]
12✔
694
        if len(missing_f2) > 0:
12✔
UNCOV
695
            y_min = ax.get_ylim()[0]
×
UNCOV
696
            ax.scatter(
×
697
                missing_f2[feature_1],
698
                [y_min] * len(missing_f2),
699
                marker="|",
700
                s=100,
701
                alpha=0.5,
702
                color="red",
703
                label=f"{feature_2} missing",
704
            )
705

706
        # Cases where feature_2 is present but feature_1 is missing
707
        missing_f1 = data[data[feature_1].isna() & data[feature_2].notna()]
12✔
708
        if len(missing_f1) > 0:
12✔
UNCOV
709
            x_min = ax.get_xlim()[0]
×
UNCOV
710
            ax.scatter(
×
711
                [x_min] * len(missing_f1),
712
                missing_f1[feature_2],
713
                marker="_",
714
                s=100,
715
                alpha=0.5,
716
                color="orange",
717
                label=f"{feature_1} missing",
718
            )
719

720
        # Add legend if there are any missing values
721
        if len(missing_f2) > 0 or len(missing_f1) > 0:
12✔
UNCOV
722
            ax.legend(loc="best", framealpha=0.9)
×
723

724
    return ax
12✔
725

726

727
def _plot_categorical_vs_categorical(feature_1, feature_2, data, show_ratios, remove_na, ax, **kwargs):
12✔
728
    """Plot when both features are categorical-like.
729

730
    When remove_na is False, missing values are handled by:
731
    - Adding a "Missing" category for any NaN values in either feature
732
    - Including these in the crosstab/heatmap display
733

734
    When remove_na is True, rows with missing values in either feature are excluded.
735
    """
736
    dup_df = pd.DataFrame()
12✔
737
    dup_df[feature_1] = _copy_series_or_keep_top_10(data[feature_1])
12✔
738
    dup_df[feature_2] = _copy_series_or_keep_top_10(data[feature_2])
12✔
739

740
    # Handle missing values based on remove_na parameter
741
    if not remove_na:
12✔
742
        # Replace NaN with "Missing" category for both features
743
        if dup_df[feature_1].isna().any():
12✔
UNCOV
744
            dup_df[feature_1] = dup_df[feature_1].fillna("Missing")
×
745
        if dup_df[feature_2].isna().any():
12✔
UNCOV
746
            dup_df[feature_2] = dup_df[feature_2].fillna("Missing")
×
747

748
        # Create crosstab with all values (including "Missing" categories)
749
        crosstab = pd.crosstab(dup_df[feature_1], dup_df[feature_2], dropna=False)
12✔
750
    else:
751
        # Remove rows where either feature is missing
UNCOV
752
        dup_df = dup_df.dropna(subset=[feature_1, feature_2])
×
UNCOV
753
        crosstab = pd.crosstab(dup_df[feature_1], dup_df[feature_2], dropna=True)
×
754

755
    if show_ratios:
12✔
756
        total = crosstab.sum().sum()
12✔
757
        crosstab_display = crosstab / total
12✔
758
        fmt = ".3f"
12✔
759
    else:
760
        crosstab_display = crosstab
12✔
761
        fmt = "d"
12✔
762

763
    sns.heatmap(crosstab_display, annot=True, fmt=fmt, ax=ax, **kwargs)
12✔
764
    ax.set_xlabel(feature_2)
12✔
765
    ax.set_ylabel(feature_1)
12✔
766

767
    if show_ratios:
12✔
768
        ax.set_title(f"{feature_1} vs {feature_2} (Proportions)")
12✔
769

770
    return ax
12✔
771

772

773
def _plot_categorical_vs_datetime(categorical_feature, datetime_feature, data, remove_na, ax, **kwargs):
12✔
774
    """Plot when one feature is categorical-like and the other is datetime.
775

776
    When remove_na is False, missing values are handled as follows:
777
    - Missing categorical values: added as "Missing" category (creates an extra violin)
778
    - Missing datetime values: shown as a separate violin at the edge of the plot
779
    """
780
    dup_df = pd.DataFrame()
12✔
781
    dup_df[categorical_feature] = _copy_series_or_keep_top_10(data[categorical_feature])
12✔
782

783
    # Handle missing categorical values by adding "Missing" category
784
    if not remove_na and dup_df[categorical_feature].isna().any():
12✔
UNCOV
785
        dup_df[categorical_feature] = dup_df[categorical_feature].fillna("Missing")
×
786

787
    # Initialize variables for missing datetime handling
788
    missing_datetime_value = None
12✔
789
    has_missing_datetime = False
12✔
790

791
    # Convert datetime to numeric, handling missing values
792
    if not remove_na:
12✔
793
        # For missing datetime values, we'll use a special marker value
794
        # First, convert non-missing datetimes to numeric
795
        datetime_numeric = data[datetime_feature].apply(lambda x: dates.date2num(x) if pd.notna(x) else np.nan)
12✔
796

797
        # Get the range of valid datetime values to place "Missing" at the edge
798
        valid_datetime_numeric = datetime_numeric.dropna()
12✔
799
        has_missing_datetime = datetime_numeric.isna().any()
12✔
800

801
        if len(valid_datetime_numeric) > 0:
12✔
802
            datetime_min = valid_datetime_numeric.min()
12✔
803
            datetime_max = valid_datetime_numeric.max()
12✔
804
            datetime_range = datetime_max - datetime_min
12✔
805
            # Place "Missing" at the right edge, slightly offset (at least 1 day or 10% of range)
806
            missing_datetime_value = datetime_max + max(datetime_range * 0.1, 1.0)
12✔
UNCOV
807
        elif has_missing_datetime:
×
808
            # If no valid datetimes but we have missing ones, use a default value
UNCOV
809
            missing_datetime_value = dates.date2num(pd.Timestamp.now())
×
810

811
        # Replace NaN datetime values with the special marker if we have missing values
812
        if has_missing_datetime and missing_datetime_value is not None:
12✔
UNCOV
813
            dup_df[datetime_feature] = datetime_numeric.fillna(missing_datetime_value)
×
814
        else:
815
            dup_df[datetime_feature] = datetime_numeric
12✔
816
    else:
817
        # Remove rows where either feature is missing
UNCOV
818
        dup_df = dup_df.dropna(subset=[categorical_feature])
×
UNCOV
819
        datetime_numeric = data[datetime_feature].apply(dates.date2num)
×
UNCOV
820
        dup_df[datetime_feature] = datetime_numeric
×
UNCOV
821
        dup_df = dup_df.dropna(subset=[datetime_feature])
×
822

823
    # Create violin plot with all data (complete + missing datetime)
824
    chart = sns.violinplot(x=datetime_feature, y=categorical_feature, data=dup_df, ax=ax, **kwargs)
12✔
825

826
    # Format x-axis ticks for datetime
827
    ticks_loc = chart.get_xticks()
12✔
828

829
    if not remove_na and has_missing_datetime and missing_datetime_value is not None:
12✔
830
        # Check if we have missing datetime data
UNCOV
831
        missing_datetime_data = dup_df[dup_df[datetime_feature] == missing_datetime_value]
×
UNCOV
832
        if len(missing_datetime_data) > 0:
×
833
            # Separate valid datetime ticks from the missing datetime position
834
            # Use a threshold to identify the missing datetime position
UNCOV
835
            valid_ticks = [t for t in ticks_loc if abs(t - missing_datetime_value) > 0.1]
×
836
            # Add the missing datetime position if it's not already in ticks
UNCOV
837
            if not any(abs(t - missing_datetime_value) < 0.1 for t in ticks_loc):
×
UNCOV
838
                valid_ticks.append(missing_datetime_value)
×
UNCOV
839
            valid_ticks = sorted(valid_ticks)
×
840

UNCOV
841
            chart.xaxis.set_major_locator(ticker.FixedLocator(valid_ticks))
×
UNCOV
842
            tick_labels = [
×
843
                dates.num2date(t).strftime("%Y-%m-%d %H:%M") if abs(t - missing_datetime_value) > 0.1 else "Missing"
844
                for t in valid_ticks
845
            ]
UNCOV
846
            chart.set_xticklabels(tick_labels, rotation=45, ha="right")
×
847
        else:
848
            # No missing datetime data, use standard formatting
UNCOV
849
            chart.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
×
UNCOV
850
            chart.set_xticklabels(chart.get_xticklabels(), rotation=45, ha="right")
×
UNCOV
851
            ax.xaxis.set_major_formatter(_convert_numbers_to_dates)
×
852
    else:
853
        # Standard datetime formatting
854
        chart.xaxis.set_major_locator(ticker.FixedLocator(ticks_loc))
12✔
855
        chart.set_xticklabels(chart.get_xticklabels(), rotation=45, ha="right")
12✔
856
        ax.xaxis.set_major_formatter(_convert_numbers_to_dates)
12✔
857

858
    ax.set_xlabel(datetime_feature)
12✔
859
    ax.set_ylabel(categorical_feature)
12✔
860

861
    return ax
12✔
862

863

864
def _plot_categorical_vs_numeric(
12✔
865
    categorical_feature,
866
    numeric_feature,
867
    data,
868
    outlier_iqr_multiplier,
869
    include_outliers,
870
    remove_na,
871
    ax,
872
    **kwargs,
873
):
874
    """Plot when the first feature is categorical-like and the second is numeric.
875

876
    Renders a violin plot of the numeric feature for each category. When
877
    ``include_outliers`` is False, numeric values outside the IQR fence
878
    [Q1 - k*IQR, Q3 + k*IQR] are trimmed, where ``k`` is ``outlier_iqr_multiplier``.
879

880
    When ``remove_na`` is False, missing values are handled as follows:
881
    - Missing categorical values get a "Missing" category
882
    - Missing numeric values are shown with rug plots at the bottom of each category
883
    """
884
    dup_df = pd.DataFrame()
12✔
885
    dup_df[categorical_feature] = _copy_series_or_keep_top_10(data[categorical_feature])
12✔
886
    dup_df[numeric_feature] = data[numeric_feature]
12✔
887

888
    # Handle missing categorical values by adding "Missing" category
889
    if not remove_na and dup_df[categorical_feature].isna().any():
12✔
UNCOV
890
        dup_df[categorical_feature] = dup_df[categorical_feature].fillna("Missing")
×
891

892
    # Apply outlier filtering if requested
893
    if include_outliers:
12✔
894
        df_plot = dup_df.copy()
12✔
895
    else:
896
        q1 = dup_df[numeric_feature].quantile(0.25)
12✔
897
        q3 = dup_df[numeric_feature].quantile(0.75)
12✔
898
        min_series_value = dup_df[numeric_feature].min()
12✔
899
        max_series_value = dup_df[numeric_feature].max()
12✔
900
        iqr = q3 - q1
12✔
901
        lower_bound = max(min_series_value, q1 - outlier_iqr_multiplier * iqr)
12✔
902
        upper_bound = min(max_series_value, q3 + outlier_iqr_multiplier * iqr)
12✔
903
        df_plot = dup_df[(dup_df[numeric_feature] >= lower_bound) & (dup_df[numeric_feature] <= upper_bound)].copy()
12✔
904

905
    # Create main violin plot (only with non-NA numeric values)
906
    df_plot_complete = df_plot.dropna(subset=[numeric_feature])
12✔
907
    sns.violinplot(
12✔
908
        x=categorical_feature, y=numeric_feature, hue=categorical_feature, data=df_plot_complete, ax=ax, **kwargs
909
    )
910

911
    # If remove_na is False, add rug plots for missing numeric values
912
    if not remove_na:
12✔
913
        missing_numeric = df_plot[df_plot[numeric_feature].isna()]
12✔
914
        if len(missing_numeric) > 0:
12✔
915
            # Get the y-axis limits to place rug marks at the bottom
UNCOV
916
            y_min = ax.get_ylim()[0]
×
917

918
            # Get unique categories and their x-axis positions
UNCOV
919
            categories = df_plot_complete[categorical_feature].unique()
×
UNCOV
920
            cat_to_pos = {cat: i for i, cat in enumerate(categories)}
×
921

922
            # Plot rug marks for each category that has missing numeric values
UNCOV
923
            for cat in missing_numeric[categorical_feature].unique():
×
UNCOV
924
                if cat in cat_to_pos:
×
UNCOV
925
                    count = len(missing_numeric[missing_numeric[categorical_feature] == cat])
×
UNCOV
926
                    x_pos = cat_to_pos[cat]
×
927

928
                    # Add small horizontal jitter for visibility when there are multiple missing values
UNCOV
929
                    jitter = np.random.uniform(-0.1, 0.1, count)
×
930

UNCOV
931
                    ax.scatter(
×
932
                        [x_pos] * count + jitter,
933
                        [y_min] * count,
934
                        marker="|",
935
                        s=100,
936
                        alpha=0.6,
937
                        color="red",
938
                        linewidths=2,
939
                        label=f"{numeric_feature} missing"
940
                        if cat == missing_numeric[categorical_feature].unique()[0]
941
                        else "",
942
                    )
943

944
            # Add legend if we plotted any missing values
UNCOV
945
            if len(missing_numeric) > 0:
×
UNCOV
946
                ax.legend(loc="best", framealpha=0.9)
×
947

948
    ax.set_xlabel(categorical_feature.replace("_", " ").title())
12✔
949
    ax.set_ylabel(numeric_feature.replace("_", " ").title())
12✔
950
    ax.grid(axis="y", linestyle="--", alpha=0.7)
12✔
951
    return ax
12✔
952

953

954
@plt.FuncFormatter
12✔
955
def _convert_numbers_to_dates(x, pos):
12✔
956
    return dates.num2date(x).strftime("%Y-%m-%d %H:%M")
12✔
957

958

959
def extract_statistics_dataframe_per_label(df: pd.DataFrame, feature_name: str, label_name: str) -> pd.DataFrame:
12✔
960
    """Calculate comprehensive statistical metrics for a specified feature grouped by label.
961

962
    This method computes various statistical measures for a given numerical feature, broken down by unique
963
    values in the specified label column. The statistics include count, null count,
964
    mean, standard deviation, min/max values and multiple percentiles.
965

966
    :param df: Input pandas DataFrame containing the data
967
    :param feature_name: Name of the column to calculate statistics on
968
    :param label_name: Name of the column to group by
969
    :return: DataFrame with statistical metrics for each unique label value, with columns:
970
            - count: Number of non-null observations
971
            - null_count: Number of null values
972
            - mean: Average value
973
            - min: Minimum value
974
            - 1_percentile: 1st percentile
975
            - 5_percentile: 5th percentile
976
            - 25_percentile: 25th percentile
977
            - median: 50th percentile
978
            - 75_percentile: 75th percentile
979
            - 95_percentile: 95th percentile
980
            - 99_percentile: 99th percentile
981
            - max: Maximum value
982

983
    :raises KeyError: If feature_name or label_name is not found in DataFrame
984
    :raises TypeError: If feature_name column is not numeric
985
    """
986
    if feature_name not in df.columns:
12✔
987
        raise KeyError(f"Feature column '{feature_name}' not found in DataFrame")
12✔
988
    if label_name not in df.columns:
12✔
989
        raise KeyError(f"Label column '{label_name}' not found in DataFrame")
12✔
990
    if not pd.api.types.is_numeric_dtype(df[feature_name]):
12✔
991
        raise TypeError(f"Feature column '{feature_name}' must be numeric")
12✔
992

993
    # Define percentile functions with consistent naming
994

995
    def percentile_1(x):
12✔
996
        return safe_percentile(x, 1)
12✔
997

998
    def percentile_5(x):
12✔
999
        return safe_percentile(x, 5)
12✔
1000

1001
    def percentile_25(x):
12✔
1002
        return safe_percentile(x, 25)
12✔
1003

1004
    def percentile_75(x):
12✔
1005
        return safe_percentile(x, 75)
12✔
1006

1007
    def percentile_95(x):
12✔
1008
        return safe_percentile(x, 95)
12✔
1009

1010
    def percentile_99(x):
12✔
1011
        return safe_percentile(x, 99)
12✔
1012

1013
    return df.groupby([label_name], observed=True)[feature_name].agg(
12✔
1014
        [
1015
            ("count", "count"),
1016
            ("null_count", lambda x: x.isnull().sum()),
1017
            ("mean", "mean"),
1018
            ("min", "min"),
1019
            ("1_percentile", percentile_1),
1020
            ("5_percentile", percentile_5),
1021
            ("25_percentile", percentile_25),
1022
            ("median", "median"),
1023
            ("75_percentile", percentile_75),
1024
            ("95_percentile", percentile_95),
1025
            ("99_percentile", percentile_99),
1026
            ("max", "max"),
1027
        ]
1028
    )
1029

1030

1031
def compute_mutual_information(
12✔
1032
    df: pd.DataFrame,
1033
    features: List[str],
1034
    label_col: str,
1035
    *,
1036
    n_neighbors: int = 3,
1037
    random_state: Optional[Union[int, RandomState]] = None,
1038
    n_jobs: Optional[int] = None,
1039
    numerical_imputer: TransformerMixin = SimpleImputer(strategy="mean"),
1040
    discrete_imputer: TransformerMixin = SimpleImputer(strategy="most_frequent"),
1041
    discrete_encoder: TransformerMixin = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1),
1042
) -> pd.DataFrame:
1043
    """Compute mutual information scores between features and a target label.
1044

1045
    This function calculates mutual information scores for specified features with respect to a target
1046
    label column. Features are automatically categorized as numerical or discrete (boolean/categorical)
1047
    and preprocessed accordingly before computing mutual information.
1048

1049
    Mutual information measures the mutual dependence between two variables - higher scores indicate
1050
    stronger relationships between the feature and the target label.
1051

1052
    :param df: Input pandas DataFrame containing the features and label
1053
    :param features: List of column names to compute mutual information for
1054
    :param label_col: Name of the target label column
1055
    :param n_neighbors: Number of neighbors to use for MI estimation for continuous variables. Higher values
1056
                        reduce variance of the estimation, but could introduce a bias.
1057
    :param random_state: Random state for reproducible results. Can be int or RandomState instance
1058
    :param n_jobs: The number of jobs to use for computing the mutual information. The parallelization is done
1059
                   on the columns. `None` means 1 unless in a `joblib.parallel_backend` context. ``-1`` means
1060
                   using all processors.
1061
    :param numerical_imputer: Sklearn-compatible transformer for numerical features (default: mean imputation)
1062
    :param discrete_imputer: Sklearn-compatible transformer for discrete features (default: most frequent imputation)
1063
    :param discrete_encoder: Sklearn-compatible transformer for encoding discrete features (default: ordinal encoding
1064
                            with unknown value handling)
1065
    :return: DataFrame with columns 'feature_name' and 'mi_score', sorted by MI score (descending)
1066

1067
    :raises KeyError: If any feature or label_col is not found in DataFrame
1068
    :raises ValueError: If features list is empty or label_col contains non-finite values
1069
    """
1070
    # Input validation
1071
    if not features:
12✔
1072
        raise ValueError("features list cannot be empty")
12✔
1073

1074
    if label_col not in df.columns:
12✔
1075
        raise KeyError(f"Label column '{label_col}' not found in DataFrame")
12✔
1076

1077
    missing_features = [f for f in features if f not in df.columns]
12✔
1078
    if missing_features:
12✔
1079
        raise KeyError(f"Features not found in DataFrame: {missing_features}")
12✔
1080

1081
    if df[label_col].isnull().all():
12✔
1082
        raise ValueError(f"Label column '{label_col}' contains only null values")
12✔
1083

1084
    # Identify feature types
1085
    numerical_features = df[features].select_dtypes(include=[np.number]).columns.tolist()
12✔
1086
    boolean_features = df[features].select_dtypes(include=[bool]).columns.tolist()
12✔
1087
    categorical_features = df[features].select_dtypes(include=["object", "category"]).columns.tolist()
12✔
1088

1089
    # Create preprocessing pipelines
1090
    numerical_transformer = Pipeline(steps=[("imputer", numerical_imputer)], memory=None, verbose=False)
12✔
1091
    discrete_transformer = Pipeline(
12✔
1092
        steps=[("imputer", discrete_imputer), ("encoder", discrete_encoder)], memory=None, verbose=False
1093
    )
1094

1095
    # Setup column transformer
1096
    transformers = []
12✔
1097
    if numerical_features:
12✔
1098
        transformers.append(("num", numerical_transformer, numerical_features))
12✔
1099
    if boolean_features or categorical_features:
12✔
1100
        transformers.append(("discrete", discrete_transformer, boolean_features + categorical_features))
12✔
1101

1102
    preprocessor = ColumnTransformer(
12✔
1103
        transformers=transformers,
1104
        remainder="drop",  # Drop any features not explicitly handled
1105
        sparse_threshold=0,
1106
        n_jobs=n_jobs,
1107
        transformer_weights=None,
1108
        verbose=False,
1109
        verbose_feature_names_out=True,
1110
    )
1111

1112
    # Create discrete features mask for mutual_info_classif
1113
    discrete_features_mask = [False] * len(numerical_features) + [True] * (
12✔
1114
        len(boolean_features) + len(categorical_features)
1115
    )
1116

1117
    # Create ordered feature names list matching the preprocessed data
1118
    ordered_feature_names = numerical_features + boolean_features + categorical_features
12✔
1119

1120
    # Apply preprocessing
1121
    x_preprocessed = preprocessor.fit_transform(df[ordered_feature_names])
12✔
1122
    y = df[label_col]
12✔
1123

1124
    # Compute mutual information scores
1125
    mi_scores = mutual_info_classif(
12✔
1126
        X=x_preprocessed,
1127
        y=y,
1128
        n_neighbors=n_neighbors,
1129
        copy=True,
1130
        random_state=random_state,
1131
        n_jobs=n_jobs,
1132
        discrete_features=discrete_features_mask,
1133
    )
1134

1135
    # Create results DataFrame
1136
    mi_df = pd.DataFrame({"feature_name": ordered_feature_names, "mi_score": mi_scores})
12✔
1137

1138
    return mi_df.sort_values(by="mi_score", ascending=False).reset_index(drop=True)
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc