• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

rafaelpadilla / 3W / 24912462866

24 Apr 2026 09:21PM UTC coverage: 76.362% (-3.1%) from 79.464%
24912462866

push

github

web-flow
Merge pull request #73 from rafaelpadilla/eduardo/refactor_data_operations

Refactor of data operations, trainers and models.

244 of 339 branches covered (71.98%)

Branch coverage included in aggregate %.

1317 of 1706 new or added lines in 50 files covered. (77.2%)

28 existing lines in 5 files now uncovered.

2124 of 2762 relevant lines covered (76.9%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.94
/toolkit/ThreeWToolkit/preprocessing/impute_missing.py
1
import pandas as pd
1✔
2
from typing import Literal
1✔
3
from pydantic import Field, ValidationInfo, field_validator, PrivateAttr
1✔
4

5
from ..core.base_dataset import BaseDataset
1✔
6
from ..core.base_preprocessing import BasePreprocessing, BasePreprocessingConfig
1✔
7
from ..core.dataset_outputs import DatasetOutputs
1✔
8

9

10
class ImputeMissingConfig(BasePreprocessingConfig):
1✔
11
    """Configuration for the ImputeMissing preprocessing step."""
12

13
    strategy: Literal["constant", "mean", "ffill", "bfill", "interpolate"] = Field(
1✔
14
        default="constant",
15
        description="Imputation strategy to use for filling missing values. Options include:\n"
16
        "- 'constant': Fill missing values with a specified constant value (requires `fill_value`).\n"
17
        "- 'mean': Fill missing values with the mean of the column (computed across all events during fit).\n"
18
        "- 'ffill': Forward-fill missing values using the last valid observation (applied per-event).\n"
19
        "- 'bfill': Backward-fill missing values using the next valid observation (applied per-event).\n"
20
        "- 'interpolate': Fill missing values using interpolation (requires `interpolate_method`, applied per-event).",
21
    )
22
    fill_value: float | None = Field(
1✔
23
        default=0.0,
24
        description="The constant value to use for filling missing values when strategy='constant'.\
25
                     This field is required if strategy is set to 'constant'.",
26
    )
27
    # columns: list[str] | None = None
28
    interpolate_method: Literal["linear", "nearest", "zero"] | None = Field(
1✔
29
        default=None,
30
        description="The interpolation method to use when strategy='interpolate'.\
31
                     This field is required if strategy is set to 'interpolate'. Options include:\n"
32
        "- 'linear': Linear interpolation (default)\n"
33
        "- 'nearest': Nearest-neighbor interpolation\n"
34
        "- 'zero': Step-wise interpolation (previous value)",
35
    )
36
    _target: type = PrivateAttr(default_factory=lambda: ImputeMissing)
1✔
37

38
    @field_validator("fill_value")
1✔
39
    def check_fill_value_for_constant(
1✔
40
        cls, fill_value: float | None, info: ValidationInfo
41
    ) -> float | None:
42
        strategy = info.data.get("strategy")
1✔
43
        if strategy == "constant" and fill_value is None:
1✔
NEW
44
            raise ValueError("You must provide `fill_value` when strategy='constant'")
×
45
        return fill_value
1✔
46

47
    @field_validator("interpolate_method")
1✔
48
    def check_interpolate_method(
1✔
49
        cls,
50
        interpolate_method: Literal["linear", "nearest", "zero"] | None,
51
        info: ValidationInfo,
52
    ) -> Literal["linear", "nearest", "zero"] | None:
53
        strategy = info.data.get("strategy")
1✔
54
        if strategy == "interpolate" and interpolate_method is None:
1✔
NEW
55
            raise ValueError(
×
56
                "You must provide `interpolate_method` when strategy='interpolate'"
57
            )
58
        return interpolate_method
1✔
59

60

61
class ImputeMissing(BasePreprocessing):
1✔
62
    """
63
    A data processing step that handles missing values in signal columns using various imputation strategies.
64

65
    Supports global strategies (mean, constant) with statistics collected across events,
66
    and time-series strategies (ffill, bfill, interpolate) applied per-event.
67

68
    Attributes:
69
        config (ImputeMissingConfig): Configuration object containing imputation parameters
70
    """
71

72
    def __init__(
1✔
73
        self,
74
        config: ImputeMissingConfig,
75
    ):
76
        """
77
        Initialize the ImputeMissing step with the provided configuration.
78

79
        Args:
80
            config (ImputeMissingConfig): Configuration containing strategy, columns, and fill_value
81
        """
82
        self.config: ImputeMissingConfig = config
1✔
83
        self.global_average: pd.Series | None = None
1✔
84

85
    def fit(self, data: BaseDataset) -> None:
1✔
86
        """
87
        Collect event statistics needed for imputation.
88

89
        Only for mean, median, constant: accumulates sum and count, or collects values.
90
        For ffill, bfill, interpolate: no collection needed (applied per-event).
91

92
        Args:
93
            data (dict): Input event data containing 'signal' DataFrame
94
        """
95
        # Verify if dataset passes nan threshold check and determine columns to drop based on all-NaN fraction
96
        if self.config.strategy in ["constant", "ffill", "bfill", "interpolate"]:
1✔
97
            return  # No global collection needed for time-series strategies
1✔
98

99
        self._compute_global_average(data)
1✔
100

101
    def transform(self, data: DatasetOutputs) -> DatasetOutputs:
1✔
102
        """
103
        Execute the missing value imputation on the specified columns.
104

105
        Also drops events (rows) where the label column is NaN.
106

107
        For time-series strategies (ffill, bfill, interpolate): apply per-event.
108
        For global strategies (mean, median, constant): use pre-computed values.
109

110
        Args:
111
            data (dict): Input event data containing 'signal' DataFrame
112

113
        Returns:
114
            dict: Event data with imputed 'signal' DataFrame
115
        """
116

117
        signal = data.signal.copy().astype(float)
1✔
118
        if self.config.strategy == "constant":
1✔
119
            signal = signal.fillna(self.config.fill_value)
1✔
120
        elif self.config.strategy == "mean":
1✔
121
            if self.global_average is None:
1✔
NEW
122
                raise ValueError("Global average not computed. Call fit() first.")
×
123
            signal = signal.fillna(self.global_average)
1✔
124

125
        elif (
1✔
126
            self.config.strategy == "interpolate"
127
            and self.config.interpolate_method is not None
128
        ):
129
            signal = (
1✔
130
                signal.interpolate(method=self.config.interpolate_method)
131
                .bfill()
132
                .ffill()
133
            )  # interpolate
134
            # then fill any remaining NaNs
135
        elif self.config.strategy == "ffill":
1✔
136
            signal = (
1✔
137
                signal.ffill().bfill()
138
            )  # forward-fill then backward-fill to handle leading NaNs
139
        else:  # self.config.strategy == "bfill":
140
            signal = (
1✔
141
                signal.bfill().ffill()
142
            )  # backward-fill then forward-fill to handle trailing NaNs
143

144
        # if post-imputation there are still missing values, print a warning
145
        if signal.isna().all().any():  # type: ignore
1✔
NEW
146
            raise RuntimeError(
×
147
                "Imputation failed: some columns still contain all NaN values after imputation. Check your data and imputation strategy."
148
            )
149

150
        return DatasetOutputs(signal=signal, label=data.label, metadata=data.metadata)
1✔
151

152
    def _compute_global_average(self, data: BaseDataset) -> None:
1✔
153
        """Compute the global average (mean) for each signal column across all events."""
154

155
        _sums = []
1✔
156
        _counts = []
1✔
157
        for event in data:
1✔
158
            _sums.append(event.signal.sum())
1✔
159
            _counts.append(event.signal.count())
1✔
160
        # compute weighted average of the sums
161
        sums = pd.concat(_sums, axis=1).transpose()
1✔
162
        counts = pd.concat(_counts, axis=1).transpose()
1✔
163

164
        self.global_average = sums.mean() / counts.mean()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc