• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

martineberlein / debugging-benchmark / 8171647283

06 Mar 2024 11:57AM UTC coverage: 72.671% (+2.0%) from 70.662%
8171647283

Pull #26

github

web-flow
Merge 08a98172c into f838bc853
Pull Request #26: Release 0.2.0

376 of 465 new or added lines in 22 files covered. (80.86%)

1 existing line in 1 file now uncovered.

1646 of 2265 relevant lines covered (72.67%)

0.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/debugging_framework/feature_extractor.py
1
import logging
×
2
from typing import List, Set, Type, Optional, Any, Tuple
×
3
from abc import ABC, abstractmethod
×
4
import warnings
×
5

6
import numpy as np
×
7
from pandas import DataFrame
×
8
from grammar_graph.gg import GrammarGraph
×
9
import shap
×
10

11
from lightgbm import LGBMClassifier
×
12
from sklearn import preprocessing
×
13
from sklearn.ensemble import RandomForestClassifier
×
14
from sklearn.tree import DecisionTreeClassifier
×
15

NEW
16
from debugging_framework.feature_collector import (
×
17
    Feature,
18
    FeatureFactory,
19
    DEFAULT_FEATURE_TYPES,
20
)
21
from debugging_framework.input import Input
×
22
from debugging_framework.oracle import OracleResult
×
NEW
23
from debugging_framework.types import Grammar
×
24

25
# Suppress the specific SHAP warning
26
warnings.filterwarnings(
×
27
    "ignore",
28
    "LightGBM binary classifier with TreeExplainer shap values output has changed to a list of ndarray",
29
)
30
warnings.filterwarnings(
×
31
    "ignore", "No further splits with positive gain, best gain: -inf"
32
)
33

34

35
class RelevantFeatureLearner(ABC):
×
36
    def __init__(
×
37
        self,
38
        grammar: Grammar,
39
        feature_types: Optional[List[Type[Feature]]] = None,
40
        top_n: int = 3,
41
        threshold: float = 0.01,
42
        prune_parent_correlation: bool = True,
43
    ):
44
        self.grammar = grammar
×
45
        self.features = self.construct_features(feature_types or DEFAULT_FEATURE_TYPES)
×
46
        self.top_n = top_n
×
47
        self.threshold = threshold
×
48
        self.graph = GrammarGraph.from_grammar(grammar)
×
49
        self.prune_parent_correlation = prune_parent_correlation
×
50

51
    def construct_features(self, feature_types: List[Type[Feature]]) -> List[Feature]:
×
52
        return FeatureFactory(self.grammar).build(feature_types)
×
53

54
    def learn(
×
55
        self, test_input: Set[Input]
56
    ) -> Tuple[Set[Feature], Set[Feature], Set[Feature]]:
57
        if not test_input:
×
58
            raise ValueError(
×
59
                "Input set for learning relevant features must not be empty."
60
            )
61

62
        x_train, y_train = self.get_learning_data(test_input)
×
63
        primary_features = set(self.get_relevant_features(test_input, x_train, y_train))
×
64
        logging.info(f"Determined {primary_features} as most relevant.")
×
65
        correlated_features = self.find_correlated_features(x_train, primary_features)
×
66

67
        return (
×
68
            primary_features,
69
            correlated_features - primary_features,
70
            set(self.features) - primary_features.union(correlated_features),
71
        )
72

73
    def find_correlated_features(
×
74
        self, x_train: DataFrame, primary_features: Set[Feature]
75
    ) -> Set[Feature]:
76
        correlation_matrix = x_train.corr(method="spearman")
×
77

78
        correlated_features = {
×
79
            feature
80
            for primary in primary_features
81
            for feature, value in correlation_matrix[primary].items()
82
            if abs(value) > 0.7
83
            and self.determine_correlating_parent_non_terminal(primary, feature)
84
        }
85
        logging.info(f"Added Features: {correlated_features} due to high correlation.")
×
86
        return correlated_features
×
87

88
    def determine_correlating_parent_non_terminal(
×
89
        self, primary_feature: Feature, correlating_feature: Feature
90
    ) -> bool:
91
        if (
×
92
            self.prune_parent_correlation
93
            and self.graph.reachable(
94
                primary_feature.non_terminal, correlating_feature.non_terminal
95
            )
96
            and not (
97
                self.graph.reachable(
98
                    correlating_feature.non_terminal, primary_feature.non_terminal
99
                )
100
            )
101
            and not correlating_feature.non_terminal == "<start>"
102
        ):
103
            return False
×
104
        return True
×
105

106
    @abstractmethod
×
107
    def get_relevant_features(
×
108
        self, test_inputs: Set[Input], x_train: DataFrame, y_train: List[int]
109
    ) -> List[Feature]:
110
        raise NotImplementedError()
×
111

112
    @staticmethod
×
113
    def map_result(result: OracleResult) -> int:
×
114
        match result:
×
115
            case OracleResult.PASSING:
×
116
                return 0
×
117
            case OracleResult.FAILING:
×
118
                return 1
×
119
            case _:
×
120
                return -1
×
121

122
    def get_learning_data(self, test_inputs: Set[Input]) -> Tuple[DataFrame, List[int]]:
×
123
        records = [
×
124
            {
125
                feature: inp.features.get_feature_value(feature)
126
                for feature in self.features
127
            }
128
            for inp in test_inputs
129
            if inp.oracle != OracleResult.UNDEFINED  #
130
        ]
131

132
        df = DataFrame.from_records(records).replace(-np.inf, -(2**32))
×
133
        labels = [
×
134
            self.map_result(inp.oracle)
135
            for inp in test_inputs
136
            if inp.oracle != OracleResult.UNDEFINED
137
        ]
138

139
        return df.drop(columns=df.columns[df.nunique() == 1]), labels
×
140

141

142
class SKLearFeatureRelevanceLearner(RelevantFeatureLearner, ABC):
×
143
    def get_features(self, x_train: DataFrame, classifier) -> List[Feature]:
×
144
        features_with_importance = list(
×
145
            zip(x_train.columns, classifier.feature_importances_)
146
        )
147

148
        sorted_features = sorted(
×
149
            features_with_importance, key=lambda x: x[1], reverse=True
150
        )
151
        important_features = [
×
152
            feature
153
            for feature, importance in sorted_features
154
            if importance >= self.threshold
155
        ][: self.top_n]
156

157
        return important_features
×
158

159
    @abstractmethod
×
160
    def fit(self, x_train: DataFrame, y_train: List[int]) -> Any:
×
161
        raise NotImplementedError()
×
162

163
    def get_relevant_features(
×
164
        self, test_inputs: Set[Input], x_train: DataFrame, y_train: List[int]
165
    ) -> List[Feature]:
166
        classifier = self.fit(x_train, y_train)
×
167
        return self.get_features(x_train, classifier)
×
168

169

170
class DecisionTreeRelevanceLearner(SKLearFeatureRelevanceLearner):
×
171
    def fit(self, x_train: DataFrame, y_train: List[int]) -> Any:
×
172
        classifier = DecisionTreeClassifier(random_state=0)
×
173
        classifier.fit(x_train, y_train)
×
174
        return classifier
×
175

176

177
class RandomForestRelevanceLearner(SKLearFeatureRelevanceLearner):
×
178
    def fit(self, x_train: DataFrame, y_train: List[int]) -> Any:
×
179
        classifier = RandomForestClassifier(n_estimators=10, random_state=0)
×
180
        classifier.fit(x_train, y_train)
×
181
        return classifier
×
182

183

184
class GradientBoostingTreeRelevanceLearner(SKLearFeatureRelevanceLearner):
×
185
    def fit(self, x_train: DataFrame, y_train: List[int]) -> Any:
×
186
        classifier = LGBMClassifier(max_depth=5, n_estimators=1000, objective="binary")
×
187
        classifier.fit(x_train, y_train)
×
188
        return classifier
×
189

190

191
class SHAPRelevanceLearner(RelevantFeatureLearner):
×
192
    def __init__(
×
193
        self,
194
        grammar: Grammar,
195
        top_n: int = 3,
196
        feature_types: Optional[List[Type[Feature]]] = None,
197
        classifier_type: Optional[
198
            Type[SKLearFeatureRelevanceLearner]
199
        ] = GradientBoostingTreeRelevanceLearner,
200
        normalize_data: bool = False,
201
        show_beeswarm_plot: bool = False,
202
    ):
203
        super().__init__(grammar, top_n=top_n, feature_types=feature_types)
×
204
        self.classifier = classifier_type(self.grammar)
×
205
        self.show_beeswarm_plot = show_beeswarm_plot
×
206
        self.normalize_data = normalize_data
×
207

208
    def get_relevant_features(
×
209
        self, test_inputs: Set[Input], x_train: DataFrame, y_train: List[int]
210
    ) -> List[Feature]:
211
        x_train_normalized = self.normalize_learning_data(x_train)
×
212
        classifier = self.classifier.fit(x_train_normalized, y_train)
×
213
        shap_values = self.get_shap_values(classifier, x_train)
×
214
        if self.show_beeswarm_plot:
×
215
            self.display_beeswarm_plot(shap_values, x_train)
×
216
        return self.get_sorted_features_by_importance(shap_values, x_train)[
×
217
            : self.top_n
218
        ]
219

220
    def normalize_learning_data(self, data: DataFrame):
×
221
        if self.normalize_data:
×
222
            normalized = preprocessing.MinMaxScaler().fit_transform(data)
×
223
            return DataFrame(normalized, columns=data.columns)
×
224
        else:
225
            return data
×
226

227
    @staticmethod
×
228
    def get_shap_values(classifier, x_train):
×
229
        explainer = shap.TreeExplainer(classifier)
×
230
        return explainer.shap_values(x_train)
×
231

232
    @staticmethod
×
233
    def get_sorted_features_by_importance(
×
234
        shap_values, x_train: DataFrame
235
    ) -> List[Feature]:
236
        mean_shap_values = np.abs(shap_values[1]).mean(axis=0)
×
237
        sorted_indices = mean_shap_values.argsort()[::-1]
×
238
        return x_train.columns[sorted_indices].tolist()
×
239

240
    @staticmethod
×
241
    def display_beeswarm_plot(shap_values, x_train):
×
242
        return shap.summary_plot(shap_values[1], x_train.astype("float"))
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc