• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

quaquel / EMAworkbench / 18273881593

06 Oct 2025 07:52AM UTC coverage: 88.703% (+0.04%) from 88.664%
18273881593

Pull #422

github

web-flow
Merge be4b8d92c into 592d0cd98
Pull Request #422: ruff fixes

69 of 93 new or added lines in 23 files covered. (74.19%)

29 existing lines in 8 files now uncovered.

7852 of 8852 relevant lines covered (88.7%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.06
/ema_workbench/analysis/plotting_util.py
1
"""Plotting utility functions."""
2

3
import copy
1✔
4
import enum
1✔
5

6
import matplotlib as mpl
1✔
7
import matplotlib.gridspec as gridspec
1✔
8
import matplotlib.pyplot as plt
1✔
9
import numpy as np
1✔
10
import pandas as pd
1✔
11
import scipy.stats as stats
1✔
12
import seaborn as sns
1✔
13

14
from ..util import EMAError, get_module_logger
1✔
15

16
# .. codeauthor:: jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
17

18
__all__ = ["COLOR_LIST", "Density", "LegendEnum", "PlotType"]
1✔
19

20
_logger = get_module_logger(__name__)
1✔
21

22
COLOR_LIST = sns.color_palette()
1✔
23
"""Default color list"""
1✔
24
sns.set_palette(COLOR_LIST)
1✔
25

26
TIME = "TIME"
1✔
27
"""Default key for time"""
1✔
28

29

30
# ==============================================================================
31
# actual plotting functions
32
# ==============================================================================
33

34

35
class Density(enum.Enum):
1✔
36
    """Enum for different types of density plots."""
37

38
    KDE = "kde"
1✔
39
    """constant for plotting density as a kernel density estimate."""
1✔
40

41
    HIST = "hist"
1✔
42
    """constant for plotting density as a histogram."""
1✔
43

44
    BOXPLOT = "boxplot"
1✔
45
    """constant for plotting density as a boxplot."""
1✔
46

47
    VIOLIN = "violin"
1✔
48
    """constant for plotting density as a violin plot, which combines a
1✔
49
    Gaussian density estimate with a boxplot."""
50

51
    BOXENPLOT = "boxenplot"
1✔
52
    """constant for plotting density as a boxenplot."""
1✔
53

54

55
class LegendEnum(enum.Enum):
1✔
56
    """Enum for different styles of legends."""
57

58
    # used for legend
59
    LINE = "line"
1✔
60
    PATCH = "patch"
1✔
61
    SCATTER = "scatter"
1✔
62

63

64
class PlotType(enum.Enum):
1✔
65
    """Enum for different types of plots."""
66

67
    ENVELOPE = "envelope"
1✔
68
    """constant for plotting envelopes."""
1✔
69

70
    LINES = "lines"
1✔
71
    """constant for plotting lines."""
1✔
72

73
    ENV_LIN = "env_lin"
1✔
74
    """constant for plotting envelopes with lines."""
1✔
75

76

77
def plot_envelope(ax, j, time, value, fill=False):
1✔
78
    """Helper function, responsible for plotting an envelope.
79

80
    Parameters
81
    ----------
82
    ax : axes instance
83
    j : int
84
    time : ndarray
85
    value : ndarray
86
    fill : bool
87

88

89
    """
90
    # plot minima and maxima
91
    minimum = np.min(value, axis=0)
1✔
92
    maximum = np.max(value, axis=0)
1✔
93

94
    color = get_color(j)
1✔
95

96
    if fill:
1✔
97
        #        ax.plot(time, minimum, color=color, alpha=0.3)
98
        #        ax.plot(time, maximum, color=color, alpha=0.3)
99
        ax.fill_between(time, minimum, maximum, facecolor=color, alpha=0.3)
1✔
100
    else:
101
        ax.plot(time, minimum, c=color)
1✔
102
        ax.plot(time, maximum, c=color)
1✔
103

104

105
def plot_histogram(ax, values, log):
1✔
106
    """Helper function, responsible for plotting a histogram.
107

108
    Parameters
109
    ----------
110
    ax : axes instance
111
    values : ndarray
112
    log : bool
113

114

115
    """
116
    if isinstance(values, list):
1✔
117
        color = [get_color(i) for i in range(len(values))]
1✔
118
    else:
119
        color = get_color(0)
1✔
120
    a = ax.hist(
1✔
121
        values,
122
        bins=11,
123
        orientation="horizontal",
124
        histtype="bar",
125
        density=True,
126
        color=color,
127
        log=log,
128
    )
129
    if not log:
1✔
130
        ax.set_xticks([0, ax.get_xbound()[1]])
1✔
131
    return a
1✔
132

133

134
def plot_kde(ax, values, log):
1✔
135
    """Helper function, responsible for plotting a KDE.
136

137
    Parameters
138
    ----------
139
    ax : axes instance
140
    values : ndarray
141
    log : bool
142

143

144
    """
145
    for j, value in enumerate(values):
1✔
146
        color = get_color(j)
1✔
147
        kde_x, kde_y = determine_kde(value)
1✔
148
        ax.plot(kde_x, kde_y, c=color, ms=1, markevery=20)
1✔
149

150
        if log:
1✔
151
            ax.set_xscale("log")
1✔
152
        else:
153
            ax.set_xticks([0, ax.get_xaxis().get_view_interval()[1]])
1✔
154
            labels = [f"{0:.2g}", f"{ax.get_xlim()[1]:.2g}"]
1✔
155
            ax.set_xticklabels(labels)
1✔
156

157

158
def plot_boxplots(ax, values, log, group_labels=None):
1✔
159
    """Helper function for plotting a boxplot.
160

161
    Parameters
162
    ----------
163
    ax : axes instance
164
    values : ndarray
165
    log : bool
166
    group_labels : list of str, optional
167

168

169
    """
170
    # if log:
171
    #     _logger.warning("log option ignored for boxplot")
172
    #
173
    # ax.boxplot(values)
174
    # if group_labels:
175
    #     ax.set_xticklabels(group_labels, rotation='vertical')
176

177
    if log:
1✔
178
        _logger.warning("log option ignored for boxplot")
1✔
179
    if not group_labels:
1✔
180
        group_labels = [""]
1✔
181

182
    dfs = []
1✔
183
    for k, v in zip(group_labels, values):
1✔
184
        v = pd.DataFrame(v)  # noqa:PLW2901
1✔
185
        v["id_var"] = k
1✔
186
        dfs.append(v)
1✔
187
    data = pd.concat(dfs)
1✔
188

189
    sns.boxplot(x="id_var", y=0, data=data, order=group_labels, ax=ax)
1✔
190

191

192
def plot_violinplot(ax, values, log, group_labels=None):
1✔
193
    """Helper function for plotting violin plots on axes.
194

195
    Parameters
196
    ----------
197
    ax : axes instance
198
    values : ndarray
199
    log : bool
200
    group_labels : list of str, optional
201

202
    """
203
    if log:
1✔
204
        _logger.warning("log option ignored for violin plot")
1✔
205

206
    if not group_labels:
1✔
207
        group_labels = [""]
1✔
208

209
    a = dict(zip(group_labels, values))
1✔
210
    b = [pd.DataFrame({k: v}) for k, v in a.items()]
1✔
211
    c = [pd.melt(entry) for entry in b]
1✔
212
    data = pd.concat(c)
1✔
213

214
    sns.violinplot(x="variable", y="value", data=data, order=group_labels, ax=ax)
1✔
215

216

217
def plot_boxenplot(ax, values, log, group_labels=None):
1✔
218
    """Helper function for plotting boxenplot plots on axes.
219

220
    Parameters
221
    ----------
222
    ax : axes instance
223
    values : ndarray
224
    log : bool
225
    group_labels : list of str, optional
226

227
    """
228
    if log:
1✔
229
        _logger.warning("log option ignored for violin plot")
1✔
230
    if not group_labels:
1✔
231
        group_labels = [""]
1✔
232

233
    data = pd.DataFrame.from_records(dict(zip(group_labels, values)))
1✔
234
    data = pd.melt(data)
1✔
235

236
    sns.boxenplot(x="variable", y="value", data=data, order=group_labels, ax=ax)
1✔
237

238

239
def group_density(
1✔
240
    ax_d, density, outcomes, outcome_to_plot, group_labels, log=False, index=-1
241
):
242
    """Helper function for plotting densities in case of grouped data.
243

244
    Parameters
245
    ----------
246
    ax_d : axes instance
247
    density : {HIST, BOXPLOT, VIOLIN, KDE}
248
    outcomes :  dict
249
    outcome_to_plot : str
250
    group_labels : list of str
251
    log : bool, optional
252
    index : int, optional
253

254
    Raises
255
    ------
256
    EMAError
257
        if density is unknown
258

259
    """
260
    values = [outcomes[key][outcome_to_plot][:, index] for key in group_labels]
1✔
261

262
    if density == Density.HIST:
1✔
263
        plot_histogram(ax_d, values, log)
1✔
264
    elif density == Density.BOXPLOT:
1✔
265
        plot_boxplots(ax_d, values, log, group_labels=group_labels)
1✔
266
    elif density == Density.VIOLIN:
1✔
267
        plot_violinplot(ax_d, values, log, group_labels=group_labels)
1✔
268
    elif density == Density.KDE:
1✔
269
        plot_kde(ax_d, values, log)
1✔
270
    elif density == Density.BOXENPLOT:
1✔
271
        plot_boxenplot(ax_d, values, log, group_labels=group_labels)
1✔
272
    else:
273
        raise EMAError(f"Unknown density plot type: {density}")
×
274

275
    ax_d.set_xlabel("")
1✔
276
    ax_d.set_ylabel("")
1✔
277

278

279
def simple_density(density, value, ax_d, ax, log):
1✔
280
    """Helper function, responsible for producing a density plot.
281

282
    Parameters
283
    ----------
284
    density : {HIST, BOXPLOT, VIOLIN, KDE}
285
    value : ndarray
286
    ax_d : axes instance
287
    ax : axes instance
288
    log : bool
289

290
    """
291
    if density == Density.KDE:
1✔
292
        plot_kde(ax_d, [value[:, -1]], log)
1✔
293
    elif density == Density.HIST:
1✔
294
        plot_histogram(ax_d, value[:, -1], log)
1✔
295
    elif density == Density.BOXPLOT:
1✔
296
        plot_boxplots(ax_d, [value[:, -1]], log)
1✔
297
    elif density == Density.VIOLIN:
1✔
298
        plot_violinplot(ax_d, [value[:, -1]], log)
1✔
299
    elif density == Density.BOXENPLOT:
1✔
300
        plot_boxenplot(ax_d, [value[:, -1]], log)
1✔
301
    else:
302
        raise EMAError(f"Unknown density plot type: {density}")
1✔
303

304
    ax_d.get_yaxis().set_view_interval(
1✔
305
        ax.get_yaxis().get_view_interval()[0], ax.get_yaxis().get_view_interval()[1]
306
    )
307
    ax_d.set_ylim(
1✔
308
        bottom=ax.get_yaxis().get_view_interval()[0],
309
        top=ax.get_yaxis().get_view_interval()[1],
310
    )
311

312
    ax_d.set_xlabel("")
1✔
313
    ax_d.set_ylabel("")
1✔
314

315

316
def simple_kde(outcomes, outcomes_to_show, colormap, log, minima, maxima):
1✔
317
    """Helper function for generating a density heatmap over time.
318

319
    Parameters
320
    ----------
321
    outcomes : dict
322
    outcomes_to_show : list of str
323
    colormap : str
324
    log : bool
325
    minima : dict
326
    maxima : dict
327

328
    """
329
    size_kde = 100
1✔
330
    fig, axes = plt.subplots(len(outcomes_to_show), squeeze=False)
1✔
331
    axes = axes[:, 0]
1✔
332

333
    axes_dict = {}
1✔
334

335
    # do the plotting
336
    for outcome_to_plot, ax in zip(outcomes_to_show, axes):
1✔
337
        axes_dict[outcome_to_plot] = ax
1✔
338

339
        outcome = outcomes[outcome_to_plot]
1✔
340

341
        kde_over_time = np.zeros(shape=(size_kde, outcome.shape[1]))
1✔
342
        ymin = minima[outcome_to_plot]
1✔
343
        ymax = maxima[outcome_to_plot]
1✔
344

345
        # make kde over time
346
        for j in range(outcome.shape[1]):
1✔
347
            kde_x = determine_kde(outcome[:, j], size_kde, ymin, ymax)[0]
1✔
348
            kde_x = kde_x / np.max(kde_x)
1✔
349

350
            if log:
1✔
351
                kde_x = np.log(kde_x + 1)
1✔
352
            kde_over_time[:, j] = kde_x
1✔
353

354
        sns.heatmap(kde_over_time[::-1, :], ax=ax, cmap=colormap, cbar=True)
1✔
355
        ax.set_xticklabels([])
1✔
356
        ax.set_yticklabels([])
1✔
357
        ax.set_xlabel("time")
1✔
358
        ax.set_ylabel(outcome_to_plot)
1✔
359

360
    return fig, axes_dict
1✔
361

362

363
def make_legend(categories, ax, ncol=3, legend_type=LegendEnum.LINE, alpha=1):
1✔
364
    """Helper function responsible for making the legend.
365

366
    Parameters
367
    ----------
368
    categories : str or tuple
369
                 the categories in the legend
370
    ax : axes instance
371
         the axes with which the legend is associated
372
    ncol : int
373
           the number of columns to use
374
    legend_type : {LINES, SCATTER, PATCH}
375
                  whether the legend is linked to lines, patches, or scatter
376
                  plots
377
    alpha : float
378
            the alpha of the artists
379

380
    """
381
    some_identifiers = []
1✔
382
    labels = []
1✔
383
    for i, category in enumerate(categories):
1✔
384
        color = get_color(i)
1✔
385

386
        if legend_type == LegendEnum.LINE:
1✔
387
            artist = plt.Line2D([0, 1], [0, 1], color=color, alpha=alpha)  # TODO
1✔
388
        elif legend_type == LegendEnum.SCATTER:
1✔
389
            #             marker_obj = mpl.markers.MarkerStyle('o')
390
            #             path = marker_obj.get_path().transformed(
391
            #                              marker_obj.get_transform())
392
            #             artist  = mpl.collections.PathCollection((path,),
393
            #                                         sizes = [20],
394
            #                                         facecolors = COLOR_LIST[i],
395
            #                                         edgecolors = 'k',
396
            #                                         offsets = (0,0)
397
            #                                         )
398
            # TODO work around, should be a proper proxyartist for scatter
399
            # legends
400
            artist = mpl.lines.Line2D([0], [0], linestyle="none", c=color, marker="o")
1✔
401

402
        elif legend_type == LegendEnum.PATCH:
1✔
403
            artist = plt.Rectangle(
1✔
404
                (0, 0), 1, 1, edgecolor=color, facecolor=color, alpha=alpha
405
            )
406

407
        some_identifiers.append(artist)
1✔
408
        label = (
1✔
409
            "%.2f - %.2f" % category if isinstance(category, tuple) else category  # noqa: UP031
410
        )
411
        labels.append(str(label))
1✔
412

413
    ax.legend(
1✔
414
        some_identifiers,
415
        labels,
416
        ncol=ncol,
417
        loc=3,
418
        borderaxespad=0.1,
419
        mode="expand",
420
        bbox_to_anchor=(0.0, 1.1, 1.0, 0.102),
421
    )
422

423

424
def determine_kde(data, size_kde=1000, ymin=None, ymax=None):
1✔
425
    """Helper function responsible for performing a KDE.
426

427
    Parameters
428
    ----------
429
    data : ndarray
430
    size_kde : int, optional
431
    ymin : float, optional
432
    ymax : float, optional
433

434
    Returns
435
    -------
436
    ndarray
437
        x values for kde
438
    ndarray
439
        y values for kde
440

441
    ..note:: x and y values are based on rotation as used in density
442
             plots for end states.
443

444

445
    """
446
    if not ymin:
1✔
447
        ymin = np.min(data)
1✔
448
    if not ymax:
1✔
449
        ymax = np.max(data)
1✔
450

451
    kde_y = np.linspace(ymin, ymax, size_kde)
1✔
452

453
    try:
1✔
454
        kde_x = stats.gaussian_kde(data)
1✔
455
        kde_x = kde_x.evaluate(kde_y)
1✔
456
    #         grid = GridSearchCV(KernelDensity(kernel='gaussian'),
457
    #                             {'bandwidth': np.linspace(ymin, ymax, 20)},
458
    #                             cv=20)
459
    #         grid.fit(data[:, np.newaxis])
460
    #         best_kde = grid.best_estimator_
461
    #         kde_x = np.exp(best_kde.score_samples(kde_y[:, np.newaxis]))
UNCOV
462
    except Exception as e:
×
UNCOV
463
        _logger.warning(f"error in determine_kde: {e}")
×
464
        kde_x = np.zeros(kde_y.shape)
×
465

466
    return kde_x, kde_y
1✔
467

468

469
def filter_scalar_outcomes(outcomes):
1✔
470
    """Helper function that removes non time series outcomes from all the utcomes.
471

472
    Parameters
473
    ----------
474
    outcomes : dict
475

476
    Returns
477
    -------
478
    dict
479
        the filtered outcomes
480

481

482
    """
483
    temp = {}
1✔
484
    for key, value in outcomes.items():
1✔
485
        if value.ndim < 2:
1✔
486
            _logger.info(f"outcome {key} not shown because it is not time series data")
1✔
487
        else:
488
            temp[key] = value
1✔
489
    return temp
1✔
490

491

492
def determine_time_dimension(outcomes):
1✔
493
    """Helper function for determining or creating time dimension.
494

495
    Parameters
496
    ----------
497
    outcomes : dict
498

499
    Returns
500
    -------
501
    ndarray
502

503

504
    """
505
    time = None
1✔
506
    try:
1✔
507
        time = outcomes["TIME"]
1✔
508
        time = time[0, :]
1✔
509
        outcomes.pop("TIME")
1✔
UNCOV
510
    except KeyError:
×
UNCOV
511
        values = iter(outcomes.values())
×
512
        for value in values:
×
513
            if value.ndim == 2:
×
514
                time = np.arange(0, value.shape[1])
×
515
                break
×
516

517
    if time is None:
1✔
UNCOV
518
        _logger.info("no time dimension found in results")
×
519
    return time, outcomes
1✔
520

521

522
def group_results(
1✔
523
    experiments, outcomes, group_by, grouping_specifiers, grouping_labels
524
):
525
    """Helper function that takes the experiments and results and returns a list based on groupoing.
526

527
    Each element in the dictionary contains the experiments
528
    and results for a particular group, the key is the grouping specifier.
529

530
    Parameters
531
    ----------
532
    experiments : DataFrame
533
    outcomes : dict
534
    group_by : str
535
               The column in the experiments array to which the grouping
536
               specifiers apply. If the name is'index' it is assumed that the
537
               grouping specifiers are valid indices for numpy.ndarray.
538
    grouping_specifiers : iterable
539
                    An iterable of grouping specifiers. A grouping
540
                    specifier is a unique identifier in case of grouping by
541
                    categorical uncertainties. It is a tuple in case of
542
                    grouping by a parameter uncertainty. In this cose, the code
543
                    treats the tuples as half open intervals, apart from the
544
                    last entry, which is treated as closed on both sides.
545
                    In case of 'index', the iterable should be a dictionary
546
                    with the name for each group as key and the value being a
547
                    valid index for numpy.ndarray.
548

549
    Returns
550
    -------
551
    dict
552
        A dictionary with the experiments and results for each group, the
553
        grouping specifier is used as key
554

555
    ..note:: In case of grouping by parameter uncertainty, the list of
556
             grouping specifiers is sorted. The traversal assumes half open
557
             intervals, where the upper limit of each interval is open, except
558
             for the last interval which is closed.
559

560
    """
561
    groups = {}
1✔
562
    if group_by != "index":
1✔
563
        column_to_group_by = experiments.loc[:, group_by]
1✔
564

565
    for label, specifier in zip(grouping_labels, grouping_specifiers):
1✔
566
        if isinstance(specifier, tuple):
1✔
567
            # the grouping is a continuous uncertainty
568
            lower_limit, upper_limit = specifier
1✔
569

570
            # check whether it is the last grouping specifier
571
            if grouping_specifiers.index(specifier) == len(grouping_specifiers) - 1:
1✔
572
                # last case
573

574
                logical = (column_to_group_by >= lower_limit) & (
1✔
575
                    column_to_group_by <= upper_limit
576
                )
577
            else:
578
                logical = (column_to_group_by >= lower_limit) & (
1✔
579
                    column_to_group_by < upper_limit
580
                )
581
        elif group_by == "index":
1✔
582
            # the grouping is based on indices
583
            logical = specifier
1✔
584
        else:
585
            # the grouping is an integer or categorical uncertainty
586
            logical = column_to_group_by == specifier
1✔
587

588
        group_outcomes = {}
1✔
589
        for key, value in outcomes.items():
1✔
590
            value = value[logical]  # noqa: PLW2901
1✔
591
            group_outcomes[key] = value
1✔
592
        groups[label] = (experiments.loc[logical, :], group_outcomes)
1✔
593

594
    return groups
1✔
595

596

597
def make_continuous_grouping_specifiers(array, nr_of_groups=5):
1✔
598
    """Helper function for discretizing a continuous array.
599

600
    By default, the array is split into 5 equally wide intervals.
601

602
    Parameters
603
    ----------
604
    array : ndarray
605
            a 1-d array that is to be turned into discrete intervals.
606
    nr_of_groups : int, optional
607

608
    Returns
609
    -------
610
    list of tuples
611
        list of tuples with the lower and upper bound of the intervals.
612

613

614
    .. note:: this code only produces intervals. :func:`group_results` uses
615
              these intervals in half-open fashion, apart from the last
616
              interval: [a, b), [b,c), [c,d]. That is, both the end point
617
              and the start point of the range of the continuous array are
618
              included.
619

620
    """
621
    minimum = np.min(array)
1✔
622
    maximum = np.max(array)
1✔
623
    step = (maximum - minimum) / nr_of_groups
1✔
624
    a = [(minimum + step * x, minimum + step * (x + 1)) for x in range(nr_of_groups)]
1✔
625
    assert a[0][0] == minimum
1✔
626
    assert a[-1][1] == maximum
1✔
627
    return a
1✔
628

629

630
def prepare_pairs_data(
1✔
631
    experiments,
632
    outcomes,
633
    outcomes_to_show=None,
634
    group_by=None,
635
    grouping_specifiers=None,
636
    point_in_time=-1,
637
    filter_scalar=True,
638
):
639
    """Helper function to prepare the data for pairs plotting.
640

641
    Parameters
642
    ----------
643
    results : tuple
644
    outcomes_to_show : list of str, optional. Both None and an empty list indicate that all outcomes should be shown.
645
    group_by : str, optional
646
    grouping_specifiers : iterable, optional
647
    point_in_time : int, optional
648
    filter_scalar : bool, optional
649

650
    """
651
    if outcomes_to_show is not None:
1✔
UNCOV
652
        if not isinstance(outcomes_to_show, list):
×
UNCOV
653
            raise TypeError(
×
654
                f"For pair-wise plotting multiple outcomes need to be provided.\n"
655
                f"outcomes_to_show must be a list of strings or None, instead of a {type(outcomes_to_show)}"
656
            )
UNCOV
657
        elif len(outcomes_to_show) == 1:
×
UNCOV
658
            raise ValueError(
×
659
                f"Only {len(outcomes_to_show)} outcome provided, at least two are needed for pair-wise plotting."
660
            )
661

662
    experiments, outcomes, outcomes_to_show, time, grouping_labels = prepare_data(
1✔
663
        experiments,
664
        None,
665
        outcomes,
666
        outcomes_to_show,
667
        group_by,
668
        grouping_specifiers,
669
        filter_scalar,
670
    )
671

672
    def filter_outcomes(outcomes, point_in_time):
1✔
673
        """Helper function for filtering outcomes."""
674
        new_outcomes = {}
1✔
675
        for key, value in outcomes.items():
1✔
676
            if len(value.shape) == 2:
1✔
677
                new_outcomes[key] = value[:, point_in_time]
1✔
678
            else:
UNCOV
679
                new_outcomes[key] = value
×
680
        return new_outcomes
1✔
681

682
    if point_in_time:
1✔
683
        if point_in_time != -1:
1✔
UNCOV
684
            point_in_time = np.where(time == point_in_time)
×
685

686
        if group_by:
1✔
687
            new_outcomes = {}
1✔
688
            for key, value in outcomes.items():
1✔
689
                new_outcomes[key] = filter_outcomes(value, point_in_time)
1✔
690
            outcomes = new_outcomes
1✔
691
        else:
692
            outcomes = filter_outcomes(outcomes, point_in_time)
1✔
693
    return experiments, outcomes, outcomes_to_show, grouping_labels
1✔
694

695

696
def prepare_data(
1✔
697
    experiments,
698
    experiments_to_show,
699
    outcomes,
700
    outcomes_to_show=None,
701
    group_by=None,
702
    grouping_specifiers=None,
703
    filter_scalar=True,
704
):
705
    """Helper function for preparing datasets prior to plotting.
706

707
    Parameters
708
    ----------
709
    experiments : DataFrame
710
    experiments_to_show : ndarray
711
    outcomes : dict
712
    outcomes_to_show : list of str, optional
713
    group_by : str, optional
714
    grouping_specifiers : iterable, optional
715
    filter_scalar : bool, optional
716

717
    """
718
    experiments = experiments.copy()
1✔
719
    outcomes = copy.deepcopy(outcomes)
1✔
720

721
    if experiments_to_show is not None:
1✔
722
        experiments = experiments.loc[experiments_to_show, :]
1✔
723

724
        for k, v in outcomes.items():
1✔
725
            outcomes[k] = v[experiments_to_show]
1✔
726

727
    time, outcomes = determine_time_dimension(outcomes)
1✔
728

729
    # remove outcomes that are not to be shown
730
    if outcomes_to_show:
1✔
731
        temp_outcomes = {}
1✔
732
        if isinstance(outcomes_to_show, str):
1✔
733
            outcomes_to_show = [outcomes_to_show]
1✔
734

735
        for entry in outcomes_to_show:
1✔
736
            temp_outcomes[entry] = outcomes[entry]
1✔
737
        outcomes = temp_outcomes
1✔
738

739
    # filter the outcomes to exclude scalar values
740
    if filter_scalar:
1✔
741
        outcomes = filter_scalar_outcomes(outcomes)
1✔
742
    if not outcomes_to_show:
1✔
743
        outcomes_to_show = list(outcomes.keys())
1✔
744

745
    # group the data if desired
746
    if group_by:
1✔
747
        if not grouping_specifiers:
1✔
748
            # no grouping specifier, so infer from the data
749
            if group_by == "index":
1✔
UNCOV
750
                raise EMAError(
×
751
                    "No grouping specifiers provided while trying to group on index"
752
                )
753
            else:
754
                column_to_group_by = experiments[group_by]
1✔
755
                if column_to_group_by.dtype in (object, "category"):
1✔
756
                    grouping_specifiers = set(column_to_group_by)
1✔
757
                else:
UNCOV
758
                    grouping_specifiers = make_continuous_grouping_specifiers(
×
759
                        column_to_group_by, grouping_specifiers
760
                    )
761
            grouping_labels = grouping_specifiers = sorted(grouping_specifiers)
1✔
762
        else:
763
            if isinstance(grouping_specifiers, str):
1✔
764
                grouping_specifiers = [grouping_specifiers]
1✔
765
                grouping_labels = grouping_specifiers
1✔
766
            elif isinstance(grouping_specifiers, dict):
1✔
767
                grouping_labels = sorted(grouping_specifiers.keys())
1✔
768
                grouping_specifiers = [
1✔
769
                    grouping_specifiers[key] for key in grouping_labels
770
                ]
771
            else:
772
                grouping_labels = grouping_specifiers
1✔
773

774
        outcomes = group_results(
1✔
775
            experiments, outcomes, group_by, grouping_specifiers, grouping_labels
776
        )
777

778
        new_outcomes = {}
1✔
779
        for key, value in outcomes.items():
1✔
780
            new_outcomes[key] = value[1]
1✔
781
        outcomes = new_outcomes
1✔
782
    else:
783
        grouping_labels = []
1✔
784

785
    return experiments, outcomes, outcomes_to_show, time, grouping_labels
1✔
786

787

788
def do_titles(ax, titles, outcome):
1✔
789
    """Helper function for setting the title on an ax.
790

791
    Parameters
792
    ----------
793
    ax : axes instance
794
    titles : dict
795
             a dict which maps outcome names to titles
796
    outcome : str
797
              the outcome plotted in the ax.
798

799
    """
800
    if isinstance(titles, dict):
1✔
801
        if not titles:
1✔
802
            ax.set_title(outcome)
1✔
803
        else:
804
            try:
1✔
805
                ax.set_title(titles[outcome])
1✔
806
            except KeyError:
1✔
807
                _logger.warning(
1✔
808
                    f"KeyError in do_titles, no title provided for outcome `{outcome}`"
809
                )
810
                ax.set_title(outcome)
1✔
811

812

813
def do_ylabels(ax, ylabels, outcome):
1✔
814
    """Helper function for setting the y labels on an ax.
815

816
    Parameters
817
    ----------
818
    ax : axes instance
819
    titles : dict
820
             a dict which maps outcome names to y labels
821
    outcome : str
822
              the outcome plotted in the ax.
823

824
    """
825
    if isinstance(ylabels, dict):
1✔
826
        if not ylabels:
1✔
827
            ax.set_ylabel(outcome)
1✔
828
        else:
829
            try:
1✔
830
                ax.set_ylabel(ylabels[outcome])
1✔
831
            except KeyError:
1✔
832
                _logger.warning(
1✔
833
                    f"KeyError in do_ylabels, no ylabel provided for outcome `{outcome}`"
834
                )
835
                ax.set_ylabel(outcome)
1✔
836

837

838
def make_grid(outcomes_to_show, density=False):
1✔
839
    """Helper function for making the grid that specifies the size and location of all axes.
840

841
    Parameters
842
    ----------
843
    outcomes_to_show : list of str
844
                       the list of outcomes to show
845
    density: boolean : bool, optional
846

847
    """
848
    # make the plotting grid
849
    if density:
1✔
850
        grid = gridspec.GridSpec(len(outcomes_to_show), 2, width_ratios=[4, 1])
1✔
851
    else:
852
        grid = gridspec.GridSpec(len(outcomes_to_show), 1)
1✔
853
    grid.update(wspace=0.1, hspace=0.4)
1✔
854

855
    figure = plt.figure()
1✔
856
    return figure, grid
1✔
857

858

859
def get_color(index):
1✔
860
    """Helper function for cycling over color list.
861

862
    Useful if the number of items is higher than the length of the color list.
863
    """
864
    corrected_index = index % len(COLOR_LIST)
1✔
865
    return COLOR_LIST[corrected_index]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc