• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

quaquel / EMAworkbench / 18214982978

03 Oct 2025 06:39AM UTC coverage: 88.703% (+0.04%) from 88.664%
18214982978

Pull #422

github

web-flow
Merge fe026872f into 592d0cd98
Pull Request #422: ruff fixes

53 of 73 new or added lines in 16 files covered. (72.6%)

2 existing lines in 2 files now uncovered.

7852 of 8852 relevant lines covered (88.7%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.14
/ema_workbench/analysis/cart.py
1
"""A scenario discovery oriented implementation of CART.
2

3
It essentially is a wrapper around scikit-learn's version of CART, with some scneario
4
discovery specific functionality.
5

6
"""
7

8
import contextlib
1✔
9
import io
1✔
10
import math
1✔
11
from io import StringIO
1✔
12

13
import matplotlib.image as mpimg
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
from sklearn import tree
1✔
18

19
from ema_workbench.util.ema_exceptions import EMAError
1✔
20

21
from ..util import get_module_logger
1✔
22
from . import scenario_discovery_util as sdutil
1✔
23

24
# Created on May 22, 2015
25
#
26
# .. codeauthor:: jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
27

28

29
__all__ = ["CART"]
1✔
30
_logger = get_module_logger(__name__)
1✔
31

32

33
class CART(sdutil.OutputFormatterMixin):
1✔
34
    """CART algorithm.
35

36
    Can be used in a manner similar to PRIM. It provides access
37
    to the underlying tree, but it can also show the boxes described by the
38
    tree in a table or graph form similar to prim.
39

40
    Parameters
41
    ----------
42
    x : DataFrame
43
    y : 1D ndarray
44
    mass_min : float, optional
45
               a value between 0 and 1 indicating the minimum fraction
46
               of data points in a terminal leaf. Defaults to 0.05,
47
               identical to prim.
48
    mode : {BINARY, CLASSIFICATION, REGRESSION}
49
           indicates the mode in which CART is used. Binary indicates
50
           binary classification, classification is multiclass, and regression
51
           is regression.
52

53
    Attributes
54
    ----------
55
    boxes : list
56
            list of DataFrame box lims
57
    stats : list
58
            list of dicts with stats
59

60
    Notes
61
    -----
62
    This class is a wrapper around scikit-learn's CART algorithm. It provides
63
    an interface to CART that is more oriented towards scenario discovery, and
64
    shared some methods with PRIM
65

66
    See Also
67
    --------
68
    :mod:`prim`
69

70
    """
71

72
    sep = "!?!"
1✔
73

74
    def __init__(self, x, y, mass_min=0.05, mode=sdutil.RuleInductionType.BINARY):
1✔
75
        """Init."""
76
        with contextlib.suppress(KeyError):
1✔
77
            x = x.drop(["scenario"], axis=1)
1✔
78

79
        self.x = x
1✔
80
        self.y = y
1✔
81
        self.mass_min = mass_min
1✔
82
        self.mode = mode
1✔
83

84
        # we need to transform the DataFrame into a ndarray
85
        # we use dummy variables for each category in case of categorical
86
        # variables. Integers are treated as floats
87
        dummies = pd.get_dummies(self.x, prefix_sep=self.sep)
1✔
88

89
        self.dummiesmap = {}
1✔
90
        for column, values in x.select_dtypes(exclude=np.number).items():
1✔
91
            mapping = {str(entry): entry for entry in values.unique()}
1✔
92
            self.dummiesmap[column] = mapping
1✔
93

94
        self.feature_names = dummies.columns.values.tolist()
1✔
95
        self._x = dummies.values
1✔
96
        self._boxes = None
1✔
97
        self._stats = None
1✔
98

99
    @property
1✔
100
    def boxes(self):
1✔
101
        """Return a list with the box limits for each terminal leaf.
102

103
        Returns
104
        -------
105
        list with boxlims for each terminal leaf
106

107
        """
108
        if self._boxes:
1✔
109
            return self._boxes
1✔
110

111
        # based on
112
        # http://stackoverflow.com/questions/20224526/how-to-extract-the-
113
        # decision-rules-from-scikit-learn-decision-tree
114
        assert self.clf
1✔
115

116
        left = self.clf.tree_.children_left
1✔
117
        right = self.clf.tree_.children_right
1✔
118
        threshold = self.clf.tree_.threshold
1✔
119
        features = [self.feature_names[i] for i in self.clf.tree_.feature]
1✔
120

121
        # get ids of leaf nodes
122
        leafs = np.argwhere(left == -1)[:, 0]
1✔
123

124
        def recurse(left, right, child, lineage=None):
1✔
125
            if lineage is None:
1✔
126
                # lineage = [self.clf.tree_.value[child]]
127
                lineage = []
1✔
128

129
            if child in left:
1✔
130
                parent = np.where(left == child)[0].item()
1✔
131
                split = "l"
1✔
132
            else:
133
                parent = np.where(right == child)[0].item()
1✔
134
                split = "r"
1✔
135

136
            lineage.append((parent, split, threshold[parent], features[parent]))
1✔
137

138
            if parent == 0:
1✔
139
                lineage.reverse()
1✔
140
                return lineage
1✔
141
            else:
142
                return recurse(left, right, parent, lineage)
1✔
143

144
        box_init = sdutil._make_box(self.x)
1✔
145
        boxes = []
1✔
146
        for leaf in leafs:
1✔
147
            branch = recurse(left, right, leaf)
1✔
148
            box = box_init.copy()
1✔
149
            for node in branch:
1✔
150
                direction = node[1]
1✔
151
                value = node[2]
1✔
152
                unc = node[3]
1✔
153

154
                if direction == "l":
1✔
155
                    if unc in box_init.columns:
1✔
156
                        box.loc[1, unc] = value
1✔
157
                    else:
158
                        unc, cat = unc.split(self.sep)
×
159
                        cats = box.loc[0, unc]
×
160
                        # TODO:: cat is a str needs casting?
161
                        # what about a lookup table mapping
162
                        # each str cat to the associate actual cat
163
                        # object
164
                        # can be created when making the dummy variables
165

166
                        cats.discard(self.dummiesmap[unc][cat])
×
167
                        box.loc[:, unc] = [set(cats), set(cats)]
×
168
                else:
169
                    if unc in box_init.columns:
1✔
170
                        if box[unc].dtype == np.int32:
1✔
171
                            value = math.ceil(value)
×
172
                        box.loc[0, unc] = value
1✔
173

174
            boxes.append(box)
1✔
175
        self._boxes = boxes
1✔
176
        return self._boxes
1✔
177

178
    @property
1✔
179
    def stats(self):
1✔
180
        """Returns list with the scenario discovery statistics for each terminal leaf.
181

182
        Returns
183
        -------
184
        list with scenario discovery statistics for each terminal leaf
185

186
        """
187
        if self._stats:
1✔
188
            return self._stats
1✔
189

190
        boxes = self.boxes
1✔
191

192
        box_init = sdutil._make_box(self.x)
1✔
193

194
        self._stats = []
1✔
195
        for box in boxes:
1✔
196
            boxstats = self._boxstat_methods[self.mode](self, box, box_init)
1✔
197
            self._stats.append(boxstats)
1✔
198
        return self._stats
1✔
199

200
    def _binary_stats(self, box, box_init):
1✔
201
        indices = sdutil._in_box(self.x, box)
1✔
202

203
        y_in_box = self.y[indices]
1✔
204
        box_coi = np.sum(y_in_box)
1✔
205

206
        boxstats = {
1✔
207
            "coverage": box_coi / np.sum(self.y),
208
            "density": box_coi / y_in_box.shape[0],
209
            "res dim": sdutil._determine_nr_restricted_dims(box, box_init),
210
            "mass": y_in_box.shape[0] / self.y.shape[0],
211
        }
212
        return boxstats
1✔
213

214
    def _regression_stats(self, box, box_init):
1✔
215
        indices = sdutil._in_box(self.x, box)
1✔
216

217
        y_in_box = self.y[indices]
1✔
218

219
        boxstats = {
1✔
220
            "mean": np.mean(y_in_box),
221
            "mass": y_in_box.shape[0] / self.y.shape[0],
222
            "res dim": sdutil._determine_nr_restricted_dims(box, box_init),
223
        }
224
        return boxstats
1✔
225

226
    def _classification_stats(self, box, box_init):
1✔
227
        indices = sdutil._in_box(self.x, box)
1✔
228

229
        y_in_box = self.y[indices]
1✔
230
        classes = set(self.y)
1✔
231
        classes = sorted(classes)
1✔
232

233
        counts = [y_in_box[y_in_box == ci].shape[0] for ci in classes]
1✔
234

235
        total_gini = 0
1✔
236
        for count in counts:
1✔
237
            total_gini += (count / y_in_box.shape[0]) ** 2
1✔
238
        gini = 1 - total_gini
1✔
239

240
        boxstats = {
1✔
241
            "gini": gini,
242
            "mass": y_in_box.shape[0] / self.y.shape[0],
243
            "box_composition": counts,
244
            "res dim": sdutil._determine_nr_restricted_dims(box, box_init),
245
        }
246

247
        return boxstats
1✔
248

249
    _boxstat_methods = {
1✔
250
        sdutil.RuleInductionType.BINARY: _binary_stats,
251
        sdutil.RuleInductionType.REGRESSION: _regression_stats,
252
        sdutil.RuleInductionType.CLASSIFICATION: _classification_stats,
253
    }
254

255
    def build_tree(self):
1✔
256
        """Train CART on the data."""
257
        min_samples = int(self.mass_min * self.x.shape[0])
1✔
258

259
        if self.mode == sdutil.RuleInductionType.REGRESSION:
1✔
260
            self.clf = tree.DecisionTreeRegressor(min_samples_leaf=min_samples)
1✔
261
        else:
262
            self.clf = tree.DecisionTreeClassifier(min_samples_leaf=min_samples)
1✔
263
        self.clf.fit(self._x, self.y)
1✔
264

265
    def show_tree(self, mplfig=True, format="png"):
1✔
266
        """Return a png (defaults) or svg of the tree.
267

268
        On Windows, graphviz needs to be installed with conda.
269

270
        Parameters
271
        ----------
272
        mplfig : bool, optional
273
                 if true (default) returns a matplotlib figure with the tree,
274
                 otherwise, it returns the output as bytes
275
        format : {'png', 'svg'}, default 'png'
276
                 Gives a format of the output.
277

278
        """
279
        assert self.clf
×
NEW
280
        import pydot  # noqa: PLC0415 dirty hack for read the docs
×
281

282
        dot_data = StringIO()
×
283
        tree.export_graphviz(
×
284
            self.clf, out_file=dot_data, feature_names=self.feature_names
285
        )
286
        dot_data = dot_data.getvalue()  # .encode('ascii') # @UndefinedVariable
×
287
        graphs = pydot.graph_from_dot_data(dot_data)
×
288

289
        # FIXME:: pydot now always returns a list, used to be either a
290
        # singleton or a list. This is a stopgap which might be sufficient
291
        # but just in case, we raise an error if assumption of len==1 does
292
        # not hold
293
        if len(graphs) > 1:
×
294
            raise EMAError(
×
295
                f"Expected a single tree for visualization, but found {len(graphs)} trees."
296
            )
297

298
        graph = graphs[0]
×
299

300
        if format == "png":
×
301
            img = graph.create_png()
×
302
            if mplfig:
×
303
                fig, ax = plt.subplots()
×
304
                ax.imshow(mpimg.imread(io.BytesIO(img)))
×
305
                ax.axis("off")
×
306
                return fig
×
307
        elif format == "svg":
×
308
            img = graph.create_svg()
×
309
        else:
310
            raise TypeError(f"format must be 'png' or 'svg' (instead of {format}).")
×
311

312
        return img
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc