• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

quaquel / EMAworkbench / 17440226021

03 Sep 2025 04:46PM UTC coverage: 83.291% (+3.0%) from 80.3%
17440226021

push

github

web-flow
Update ci.yml (#406)

7238 of 8690 relevant lines covered (83.29%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.06
/ema_workbench/analysis/plotting_util.py
1
"""Plotting utility functions."""
2

3
import copy
1✔
4
import enum
1✔
5

6
import matplotlib as mpl
1✔
7
import matplotlib.gridspec as gridspec
1✔
8
import matplotlib.pyplot as plt
1✔
9
import numpy as np
1✔
10
import pandas as pd
1✔
11
import scipy.stats as stats
1✔
12
import seaborn as sns
1✔
13

14
from ..util import EMAError, get_module_logger
1✔
15

16
# .. codeauthor:: jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
17

18
__all__ = ["COLOR_LIST", "Density", "LegendEnum", "PlotType"]
1✔
19

20
_logger = get_module_logger(__name__)
1✔
21

22
COLOR_LIST = sns.color_palette()
1✔
23
"""Default color list"""
1✔
24
sns.set_palette(COLOR_LIST)
1✔
25

26
TIME = "TIME"
1✔
27
"""Default key for time"""
1✔
28

29

30
# ==============================================================================
31
# actual plotting functions
32
# ==============================================================================
33

34

35
class Density(enum.Enum):
1✔
36
    """Enum for different types of density plots."""
37

38
    KDE = "kde"
1✔
39
    """constant for plotting density as a kernel density estimate."""
1✔
40

41
    HIST = "hist"
1✔
42
    """constant for plotting density as a histogram."""
1✔
43

44
    BOXPLOT = "boxplot"
1✔
45
    """constant for plotting density as a boxplot."""
1✔
46

47
    VIOLIN = "violin"
1✔
48
    """constant for plotting density as a violin plot, which combines a
1✔
49
    Gaussian density estimate with a boxplot."""
50

51
    BOXENPLOT = "boxenplot"
1✔
52
    """constant for plotting density as a boxenplot."""
1✔
53

54

55
class LegendEnum(enum.Enum):
1✔
56
    """Enum for different styles of legends."""
57

58
    # used for legend
59
    LINE = "line"
1✔
60
    PATCH = "patch"
1✔
61
    SCATTER = "scatter"
1✔
62

63

64
class PlotType(enum.Enum):
1✔
65
    """Enum for different types of plots."""
66

67
    ENVELOPE = "envelope"
1✔
68
    """constant for plotting envelopes."""
1✔
69

70
    LINES = "lines"
1✔
71
    """constant for plotting lines."""
1✔
72

73
    ENV_LIN = "env_lin"
1✔
74
    """constant for plotting envelopes with lines."""
1✔
75

76

77
def plot_envelope(ax, j, time, value, fill=False):
1✔
78
    """Helper function, responsible for plotting an envelope.
79

80
    Parameters
81
    ----------
82
    ax : axes instance
83
    j : int
84
    time : ndarray
85
    value : ndarray
86
    fill : bool
87

88

89
    """
90
    # plot minima and maxima
91
    minimum = np.min(value, axis=0)
1✔
92
    maximum = np.max(value, axis=0)
1✔
93

94
    color = get_color(j)
1✔
95

96
    if fill:
1✔
97
        #        ax.plot(time, minimum, color=color, alpha=0.3)
98
        #        ax.plot(time, maximum, color=color, alpha=0.3)
99
        ax.fill_between(time, minimum, maximum, facecolor=color, alpha=0.3)
1✔
100
    else:
101
        ax.plot(time, minimum, c=color)
1✔
102
        ax.plot(time, maximum, c=color)
1✔
103

104

105
def plot_histogram(ax, values, log):
1✔
106
    """Helper function, responsible for plotting a histogram.
107

108
    Parameters
109
    ----------
110
    ax : axes instance
111
    values : ndarray
112
    log : bool
113

114

115
    """
116
    if isinstance(values, list):
1✔
117
        color = [get_color(i) for i in range(len(values))]
1✔
118
    else:
119
        color = get_color(0)
1✔
120
    a = ax.hist(
1✔
121
        values,
122
        bins=11,
123
        orientation="horizontal",
124
        histtype="bar",
125
        density=True,
126
        color=color,
127
        log=log,
128
    )
129
    if not log:
1✔
130
        ax.set_xticks([0, ax.get_xbound()[1]])
1✔
131
    return a
1✔
132

133

134
def plot_kde(ax, values, log):
1✔
135
    """Helper function, responsible for plotting a KDE.
136

137
    Parameters
138
    ----------
139
    ax : axes instance
140
    values : ndarray
141
    log : bool
142

143

144
    """
145
    for j, value in enumerate(values):
1✔
146
        color = get_color(j)
1✔
147
        kde_x, kde_y = determine_kde(value)
1✔
148
        ax.plot(kde_x, kde_y, c=color, ms=1, markevery=20)
1✔
149

150
        if log:
1✔
151
            ax.set_xscale("log")
1✔
152
        else:
153
            ax.set_xticks([0, ax.get_xaxis().get_view_interval()[1]])
1✔
154
            labels = [f"{0:.2g}", f"{ax.get_xlim()[1]:.2g}"]
1✔
155
            ax.set_xticklabels(labels)
1✔
156

157

158
def plot_boxplots(ax, values, log, group_labels=None):
1✔
159
    """Helper function for plotting a boxplot.
160

161
    Parameters
162
    ----------
163
    ax : axes instance
164
    values : ndarray
165
    log : bool
166
    group_labels : list of str, optional
167

168

169
    """
170
    # if log:
171
    #     _logger.warning("log option ignored for boxplot")
172
    #
173
    # ax.boxplot(values)
174
    # if group_labels:
175
    #     ax.set_xticklabels(group_labels, rotation='vertical')
176

177
    if log:
1✔
178
        _logger.warning("log option ignored for boxplot")
1✔
179
    if not group_labels:
1✔
180
        group_labels = [""]
1✔
181

182
    dfs = []
1✔
183
    for k, v in zip(group_labels, values):
1✔
184
        v = pd.DataFrame(v) # noqa:PLW2901
1✔
185
        v["id_var"] = k
1✔
186
        dfs.append(v)
1✔
187
    data = pd.concat(dfs)
1✔
188

189
    sns.boxplot(x="id_var", y=0, data=data, order=group_labels, ax=ax)
1✔
190

191

192
def plot_violinplot(ax, values, log, group_labels=None):
1✔
193
    """Helper function for plotting violin plots on axes.
194

195
    Parameters
196
    ----------
197
    ax : axes instance
198
    values : ndarray
199
    log : bool
200
    group_labels : list of str, optional
201

202
    """
203
    if log:
1✔
204
        _logger.warning("log option ignored for violin plot")
1✔
205

206
    if not group_labels:
1✔
207
        group_labels = [""]
1✔
208

209
    a = dict(zip(group_labels, values))
1✔
210
    b = [pd.DataFrame({k: v}) for k, v in a.items()]
1✔
211
    c = [pd.melt(entry) for entry in b]
1✔
212
    data = pd.concat(c)
1✔
213

214
    sns.violinplot(x="variable", y="value", data=data, order=group_labels, ax=ax)
1✔
215

216

217
def plot_boxenplot(ax, values, log, group_labels=None):
1✔
218
    """Helper function for plotting boxenplot plots on axes.
219

220
    Parameters
221
    ----------
222
    ax : axes instance
223
    values : ndarray
224
    log : bool
225
    group_labels : list of str, optional
226

227
    """
228
    if log:
1✔
229
        _logger.warning("log option ignored for violin plot")
1✔
230
    if not group_labels:
1✔
231
        group_labels = [""]
1✔
232

233
    data = pd.DataFrame.from_records(dict(zip(group_labels, values)))
1✔
234
    data = pd.melt(data)
1✔
235

236
    sns.boxenplot(x="variable", y="value", data=data, order=group_labels, ax=ax)
1✔
237

238

239
def group_density(
1✔
240
    ax_d, density, outcomes, outcome_to_plot, group_labels, log=False, index=-1
241
):
242
    """Helper function for plotting densities in case of grouped data.
243

244
    Parameters
245
    ----------
246
    ax_d : axes instance
247
    density : {HIST, BOXPLOT, VIOLIN, KDE}
248
    outcomes :  dict
249
    outcome_to_plot : str
250
    group_labels : list of str
251
    log : bool, optional
252
    index : int, optional
253

254
    Raises:
255
    ------
256
    EMAError
257
        if density is unknown
258

259
    """
260
    values = [outcomes[key][outcome_to_plot][:, index] for key in group_labels]
1✔
261

262
    if density == Density.HIST:
1✔
263
        plot_histogram(ax_d, values, log)
1✔
264
    elif density == Density.BOXPLOT:
1✔
265
        plot_boxplots(ax_d, values, log, group_labels=group_labels)
1✔
266
    elif density == Density.VIOLIN:
1✔
267
        plot_violinplot(ax_d, values, log, group_labels=group_labels)
1✔
268
    elif density == Density.KDE:
1✔
269
        plot_kde(ax_d, values, log)
1✔
270
    elif density == Density.BOXENPLOT:
1✔
271
        plot_boxenplot(ax_d, values, log, group_labels=group_labels)
1✔
272
    else:
273
        raise EMAError(f"Unknown density plot type: {density}")
×
274

275
    ax_d.set_xlabel("")
1✔
276
    ax_d.set_ylabel("")
1✔
277

278

279
def simple_density(density, value, ax_d, ax, log):
1✔
280
    """Helper function, responsible for producing a density plot.
281

282
    Parameters
283
    ----------
284
    density : {HIST, BOXPLOT, VIOLIN, KDE}
285
    value : ndarray
286
    ax_d : axes instance
287
    ax : axes instance
288
    log : bool
289

290
    """
291
    if density == Density.KDE:
1✔
292
        plot_kde(ax_d, [value[:, -1]], log)
1✔
293
    elif density == Density.HIST:
1✔
294
        plot_histogram(ax_d, value[:, -1], log)
1✔
295
    elif density == Density.BOXPLOT:
1✔
296
        plot_boxplots(ax_d, [value[:, -1]], log)
1✔
297
    elif density == Density.VIOLIN:
1✔
298
        plot_violinplot(ax_d, [value[:, -1]], log)
1✔
299
    elif density == Density.BOXENPLOT:
1✔
300
        plot_boxenplot(ax_d, [value[:, -1]], log)
1✔
301
    else:
302
        raise EMAError(f"Unknown density plot type: {density}")
1✔
303

304
    ax_d.get_yaxis().set_view_interval(
1✔
305
        ax.get_yaxis().get_view_interval()[0], ax.get_yaxis().get_view_interval()[1]
306
    )
307
    ax_d.set_ylim(
1✔
308
        bottom=ax.get_yaxis().get_view_interval()[0],
309
        top=ax.get_yaxis().get_view_interval()[1],
310
    )
311

312
    ax_d.set_xlabel("")
1✔
313
    ax_d.set_ylabel("")
1✔
314

315

316
def simple_kde(outcomes, outcomes_to_show, colormap, log, minima, maxima):
1✔
317
    """Helper function for generating a density heatmap over time.
318

319
    Parameters
320
    ----------
321
    outcomes : dict
322
    outcomes_to_show : list of str
323
    colormap : str
324
    log : bool
325
    minima : dict
326
    maxima : dict
327

328
    """
329
    size_kde = 100
1✔
330
    fig, axes = plt.subplots(len(outcomes_to_show), squeeze=False)
1✔
331
    axes = axes[:, 0]
1✔
332

333
    axes_dict = {}
1✔
334

335
    # do the plotting
336
    for outcome_to_plot, ax in zip(outcomes_to_show, axes):
1✔
337
        axes_dict[outcome_to_plot] = ax
1✔
338

339
        outcome = outcomes[outcome_to_plot]
1✔
340

341
        kde_over_time = np.zeros(shape=(size_kde, outcome.shape[1]))
1✔
342
        ymin = minima[outcome_to_plot]
1✔
343
        ymax = maxima[outcome_to_plot]
1✔
344

345
        # make kde over time
346
        for j in range(outcome.shape[1]):
1✔
347
            kde_x = determine_kde(outcome[:, j], size_kde, ymin, ymax)[0]
1✔
348
            kde_x = kde_x / np.max(kde_x)
1✔
349

350
            if log:
1✔
351
                kde_x = np.log(kde_x + 1)
1✔
352
            kde_over_time[:, j] = kde_x
1✔
353

354
        sns.heatmap(kde_over_time[::-1, :], ax=ax, cmap=colormap, cbar=True)
1✔
355
        ax.set_xticklabels([])
1✔
356
        ax.set_yticklabels([])
1✔
357
        ax.set_xlabel("time")
1✔
358
        ax.set_ylabel(outcome_to_plot)
1✔
359

360
    return fig, axes_dict
1✔
361

362

363
def make_legend(categories, ax, ncol=3, legend_type=LegendEnum.LINE, alpha=1):
1✔
364
    """Helper function responsible for making the legend.
365

366
    Parameters
367
    ----------
368
    categories : str or tuple
369
                 the categories in the legend
370
    ax : axes instance
371
         the axes with which the legend is associated
372
    ncol : int
373
           the number of columns to use
374
    legend_type : {LINES, SCATTER, PATCH}
375
                  whether the legend is linked to lines, patches, or scatter
376
                  plots
377
    alpha : float
378
            the alpha of the artists
379

380
    """
381
    some_identifiers = []
1✔
382
    labels = []
1✔
383
    for i, category in enumerate(categories):
1✔
384
        color = get_color(i)
1✔
385

386
        if legend_type == LegendEnum.LINE:
1✔
387
            artist = plt.Line2D([0, 1], [0, 1], color=color, alpha=alpha)  # TODO
1✔
388
        elif legend_type == LegendEnum.SCATTER:
1✔
389
            #             marker_obj = mpl.markers.MarkerStyle('o')
390
            #             path = marker_obj.get_path().transformed(
391
            #                              marker_obj.get_transform())
392
            #             artist  = mpl.collections.PathCollection((path,),
393
            #                                         sizes = [20],
394
            #                                         facecolors = COLOR_LIST[i],
395
            #                                         edgecolors = 'k',
396
            #                                         offsets = (0,0)
397
            #                                         )
398
            # TODO work around, should be a proper proxyartist for scatter
399
            # legends
400
            artist = mpl.lines.Line2D([0], [0], linestyle="none", c=color, marker="o")
1✔
401

402
        elif legend_type == LegendEnum.PATCH:
1✔
403
            artist = plt.Rectangle(
1✔
404
                (0, 0), 1, 1, edgecolor=color, facecolor=color, alpha=alpha
405
            )
406

407
        some_identifiers.append(artist)
1✔
408
        label = "%.2f - %.2f" % category if isinstance(category, tuple) else category # noqa: UP031
1✔
409
        labels.append(str(label))
1✔
410

411
    ax.legend(
1✔
412
        some_identifiers,
413
        labels,
414
        ncol=ncol,
415
        loc=3,
416
        borderaxespad=0.1,
417
        mode="expand",
418
        bbox_to_anchor=(0.0, 1.1, 1.0, 0.102),
419
    )
420

421

422
def determine_kde(data, size_kde=1000, ymin=None, ymax=None):
1✔
423
    """Helper function responsible for performing a KDE.
424

425
    Parameters
426
    ----------
427
    data : ndarray
428
    size_kde : int, optional
429
    ymin : float, optional
430
    ymax : float, optional
431

432
    Returns:
433
    -------
434
    ndarray
435
        x values for kde
436
    ndarray
437
        y values for kde
438

439
    ..note:: x and y values are based on rotation as used in density
440
             plots for end states.
441

442

443
    """
444
    if not ymin:
1✔
445
        ymin = np.min(data)
1✔
446
    if not ymax:
1✔
447
        ymax = np.max(data)
1✔
448

449
    kde_y = np.linspace(ymin, ymax, size_kde)
1✔
450

451
    try:
1✔
452
        kde_x = stats.gaussian_kde(data)
1✔
453
        kde_x = kde_x.evaluate(kde_y)
1✔
454
    #         grid = GridSearchCV(KernelDensity(kernel='gaussian'),
455
    #                             {'bandwidth': np.linspace(ymin, ymax, 20)},
456
    #                             cv=20)
457
    #         grid.fit(data[:, np.newaxis])
458
    #         best_kde = grid.best_estimator_
459
    #         kde_x = np.exp(best_kde.score_samples(kde_y[:, np.newaxis]))
460
    except Exception as e:
×
461
        _logger.warning(f"error in determine_kde: {e}")
×
462
        kde_x = np.zeros(kde_y.shape)
×
463

464
    return kde_x, kde_y
1✔
465

466

467
def filter_scalar_outcomes(outcomes):
1✔
468
    """Helper function that removes non time series outcomes from all the utcomes.
469

470
    Parameters
471
    ----------
472
    outcomes : dict
473

474
    Returns:
475
    -------
476
    dict
477
        the filtered outcomes
478

479

480
    """
481
    temp = {}
1✔
482
    for key, value in outcomes.items():
1✔
483
        if value.ndim < 2:
1✔
484
            _logger.info(f"outcome {key} not shown because it is not time series data")
1✔
485
        else:
486
            temp[key] = value
1✔
487
    return temp
1✔
488

489

490
def determine_time_dimension(outcomes):
1✔
491
    """Helper function for determining or creating time dimension.
492

493
    Parameters
494
    ----------
495
    outcomes : dict
496

497
    Returns:
498
    -------
499
    ndarray
500

501

502
    """
503
    time = None
1✔
504
    try:
1✔
505
        time = outcomes["TIME"]
1✔
506
        time = time[0, :]
1✔
507
        outcomes.pop("TIME")
1✔
508
    except KeyError:
×
509
        values = iter(outcomes.values())
×
510
        for value in values:
×
511
            if value.ndim == 2:
×
512
                time = np.arange(0, value.shape[1])
×
513
                break
×
514

515
    if time is None:
1✔
516
        _logger.info("no time dimension found in results")
×
517
    return time, outcomes
1✔
518

519

520
def group_results(
1✔
521
    experiments, outcomes, group_by, grouping_specifiers, grouping_labels
522
):
523
    """Helper function that takes the experiments and results and returns a list based on groupoing.
524

525
    Each element in the dictionary contains the experiments
526
    and results for a particular group, the key is the grouping specifier.
527

528
    Parameters
529
    ----------
530
    experiments : DataFrame
531
    outcomes : dict
532
    group_by : str
533
               The column in the experiments array to which the grouping
534
               specifiers apply. If the name is'index' it is assumed that the
535
               grouping specifiers are valid indices for numpy.ndarray.
536
    grouping_specifiers : iterable
537
                    An iterable of grouping specifiers. A grouping
538
                    specifier is a unique identifier in case of grouping by
539
                    categorical uncertainties. It is a tuple in case of
540
                    grouping by a parameter uncertainty. In this cose, the code
541
                    treats the tuples as half open intervals, apart from the
542
                    last entry, which is treated as closed on both sides.
543
                    In case of 'index', the iterable should be a dictionary
544
                    with the name for each group as key and the value being a
545
                    valid index for numpy.ndarray.
546

547
    Returns:
548
    -------
549
    dict
550
        A dictionary with the experiments and results for each group, the
551
        grouping specifier is used as key
552

553
    ..note:: In case of grouping by parameter uncertainty, the list of
554
             grouping specifiers is sorted. The traversal assumes half open
555
             intervals, where the upper limit of each interval is open, except
556
             for the last interval which is closed.
557

558
    """
559
    groups = {}
1✔
560
    if group_by != "index":
1✔
561
        column_to_group_by = experiments.loc[:, group_by]
1✔
562

563
    for label, specifier in zip(grouping_labels, grouping_specifiers):
1✔
564
        if isinstance(specifier, tuple):
1✔
565
            # the grouping is a continuous uncertainty
566
            lower_limit, upper_limit = specifier
1✔
567

568
            # check whether it is the last grouping specifier
569
            if grouping_specifiers.index(specifier) == len(grouping_specifiers) - 1:
1✔
570
                # last case
571

572
                logical = (column_to_group_by >= lower_limit) & (
1✔
573
                    column_to_group_by <= upper_limit
574
                )
575
            else:
576
                logical = (column_to_group_by >= lower_limit) & (
1✔
577
                    column_to_group_by < upper_limit
578
                )
579
        elif group_by == "index":
1✔
580
            # the grouping is based on indices
581
            logical = specifier
1✔
582
        else:
583
            # the grouping is an integer or categorical uncertainty
584
            logical = column_to_group_by == specifier
1✔
585

586
        group_outcomes = {}
1✔
587
        for key, value in outcomes.items():
1✔
588
            value = value[logical] # noqa: PLW2901
1✔
589
            group_outcomes[key] = value
1✔
590
        groups[label] = (experiments.loc[logical, :], group_outcomes)
1✔
591

592
    return groups
1✔
593

594

595
def make_continuous_grouping_specifiers(array, nr_of_groups=5):
1✔
596
    """Helper function for discretizing a continuous array.
597

598
    By default, the array is split into 5 equally wide intervals.
599

600
    Parameters
601
    ----------
602
    array : ndarray
603
            a 1-d array that is to be turned into discrete intervals.
604
    nr_of_groups : int, optional
605

606
    Returns:
607
    -------
608
    list of tuples
609
        list of tuples with the lower and upper bound of the intervals.
610

611

612
    .. note:: this code only produces intervals. :func:`group_results` uses
613
              these intervals in half-open fashion, apart from the last
614
              interval: [a, b), [b,c), [c,d]. That is, both the end point
615
              and the start point of the range of the continuous array are
616
              included.
617

618
    """
619
    minimum = np.min(array)
1✔
620
    maximum = np.max(array)
1✔
621
    step = (maximum - minimum) / nr_of_groups
1✔
622
    a = [(minimum + step * x, minimum + step * (x + 1)) for x in range(nr_of_groups)]
1✔
623
    assert a[0][0] == minimum
1✔
624
    assert a[-1][1] == maximum
1✔
625
    return a
1✔
626

627

628
def prepare_pairs_data(
1✔
629
    experiments,
630
    outcomes,
631
    outcomes_to_show=None,
632
    group_by=None,
633
    grouping_specifiers=None,
634
    point_in_time=-1,
635
    filter_scalar=True,
636
):
637
    """Helper function to prepare the data for pairs plotting.
638

639
    Parameters
640
    ----------
641
    results : tuple
642
    outcomes_to_show : list of str, optional. Both None and an empty list indicate that all outcomes should be shown.
643
    group_by : str, optional
644
    grouping_specifiers : iterable, optional
645
    point_in_time : int, optional
646
    filter_scalar : bool, optional
647

648
    """
649
    if outcomes_to_show is not None:
1✔
650
        if not isinstance(outcomes_to_show, list):
×
651
            raise TypeError(
×
652
                f"For pair-wise plotting multiple outcomes need to be provided.\n"
653
                f"outcomes_to_show must be a list of strings or None, instead of a {type(outcomes_to_show)}"
654
            )
655
        elif len(outcomes_to_show) == 1:
×
656
            raise ValueError(
×
657
                f"Only {len(outcomes_to_show)} outcome provided, at least two are needed for pair-wise plotting."
658
            )
659

660
    experiments, outcomes, outcomes_to_show, time, grouping_labels = prepare_data(
1✔
661
        experiments,
662
        None,
663
        outcomes,
664
        outcomes_to_show,
665
        group_by,
666
        grouping_specifiers,
667
        filter_scalar,
668
    )
669

670
    def filter_outcomes(outcomes, point_in_time):
1✔
671
        """Helper function for filtering outcomes."""
672
        new_outcomes = {}
1✔
673
        for key, value in outcomes.items():
1✔
674
            if len(value.shape) == 2:
1✔
675
                new_outcomes[key] = value[:, point_in_time]
1✔
676
            else:
677
                new_outcomes[key] = value
×
678
        return new_outcomes
1✔
679

680
    if point_in_time:
1✔
681
        if point_in_time != -1:
1✔
682
            point_in_time = np.where(time == point_in_time)
×
683

684
        if group_by:
1✔
685
            new_outcomes = {}
1✔
686
            for key, value in outcomes.items():
1✔
687
                new_outcomes[key] = filter_outcomes(value, point_in_time)
1✔
688
            outcomes = new_outcomes
1✔
689
        else:
690
            outcomes = filter_outcomes(outcomes, point_in_time)
1✔
691
    return experiments, outcomes, outcomes_to_show, grouping_labels
1✔
692

693

694
def prepare_data(
1✔
695
    experiments,
696
    experiments_to_show,
697
    outcomes,
698
    outcomes_to_show=None,
699
    group_by=None,
700
    grouping_specifiers=None,
701
    filter_scalar=True,
702
):
703
    """Helper function for preparing datasets prior to plotting.
704

705
    Parameters
706
    ----------
707
    experiments : DataFrame
708
    experiments_to_show : ndarray
709
    outcomes : dict
710
    outcomes_to_show : list of str, optional
711
    group_by : str, optional
712
    grouping_specifiers : iterable, optional
713
    filter_scalar : bool, optional
714

715
    """
716
    experiments = experiments.copy()
1✔
717
    outcomes = copy.deepcopy(outcomes)
1✔
718

719
    if experiments_to_show is not None:
1✔
720
        experiments = experiments.loc[experiments_to_show, :]
1✔
721

722
        for k, v in outcomes.items():
1✔
723
            outcomes[k] = v[experiments_to_show]
1✔
724

725
    time, outcomes = determine_time_dimension(outcomes)
1✔
726

727
    # remove outcomes that are not to be shown
728
    if outcomes_to_show:
1✔
729
        temp_outcomes = {}
1✔
730
        if isinstance(outcomes_to_show, str):
1✔
731
            outcomes_to_show = [outcomes_to_show]
1✔
732

733
        for entry in outcomes_to_show:
1✔
734
            temp_outcomes[entry] = outcomes[entry]
1✔
735
        outcomes = temp_outcomes
1✔
736

737
    # filter the outcomes to exclude scalar values
738
    if filter_scalar:
1✔
739
        outcomes = filter_scalar_outcomes(outcomes)
1✔
740
    if not outcomes_to_show:
1✔
741
        outcomes_to_show = list(outcomes.keys())
1✔
742

743
    # group the data if desired
744
    if group_by:
1✔
745
        if not grouping_specifiers:
1✔
746
            # no grouping specifier, so infer from the data
747
            if group_by == "index":
1✔
748
                raise EMAError(
×
749
                    "No grouping specifiers provided while trying to group on index"
750
                )
751
            else:
752
                column_to_group_by = experiments[group_by]
1✔
753
                if column_to_group_by.dtype in (object, "category"):
1✔
754
                    grouping_specifiers = set(column_to_group_by)
1✔
755
                else:
756
                    grouping_specifiers = make_continuous_grouping_specifiers(
×
757
                        column_to_group_by, grouping_specifiers
758
                    )
759
            grouping_labels = grouping_specifiers = sorted(grouping_specifiers)
1✔
760
        else:
761
            if isinstance(grouping_specifiers, str):
1✔
762
                grouping_specifiers = [grouping_specifiers]
1✔
763
                grouping_labels = grouping_specifiers
1✔
764
            elif isinstance(grouping_specifiers, dict):
1✔
765
                grouping_labels = sorted(grouping_specifiers.keys())
1✔
766
                grouping_specifiers = [
1✔
767
                    grouping_specifiers[key] for key in grouping_labels
768
                ]
769
            else:
770
                grouping_labels = grouping_specifiers
1✔
771

772
        outcomes = group_results(
1✔
773
            experiments, outcomes, group_by, grouping_specifiers, grouping_labels
774
        )
775

776
        new_outcomes = {}
1✔
777
        for key, value in outcomes.items():
1✔
778
            new_outcomes[key] = value[1]
1✔
779
        outcomes = new_outcomes
1✔
780
    else:
781
        grouping_labels = []
1✔
782

783
    return experiments, outcomes, outcomes_to_show, time, grouping_labels
1✔
784

785

786
def do_titles(ax, titles, outcome):
1✔
787
    """Helper function for setting the title on an ax.
788

789
    Parameters
790
    ----------
791
    ax : axes instance
792
    titles : dict
793
             a dict which maps outcome names to titles
794
    outcome : str
795
              the outcome plotted in the ax.
796

797
    """
798
    if isinstance(titles, dict):
1✔
799
        if not titles:
1✔
800
            ax.set_title(outcome)
1✔
801
        else:
802
            try:
1✔
803
                ax.set_title(titles[outcome])
1✔
804
            except KeyError:
1✔
805
                _logger.warning(
1✔
806
                    f"KeyError in do_titles, no title provided for outcome `{outcome}`"
807
                )
808
                ax.set_title(outcome)
1✔
809

810

811
def do_ylabels(ax, ylabels, outcome):
1✔
812
    """Helper function for setting the y labels on an ax.
813

814
    Parameters
815
    ----------
816
    ax : axes instance
817
    titles : dict
818
             a dict which maps outcome names to y labels
819
    outcome : str
820
              the outcome plotted in the ax.
821

822
    """
823
    if isinstance(ylabels, dict):
1✔
824
        if not ylabels:
1✔
825
            ax.set_ylabel(outcome)
1✔
826
        else:
827
            try:
1✔
828
                ax.set_ylabel(ylabels[outcome])
1✔
829
            except KeyError:
1✔
830
                _logger.warning(
1✔
831
                    f"KeyError in do_ylabels, no ylabel provided for outcome `{outcome}`"
832
                )
833
                ax.set_ylabel(outcome)
1✔
834

835

836
def make_grid(outcomes_to_show, density=False):
1✔
837
    """Helper function for making the grid that specifies the size and location of all axes.
838

839
    Parameters
840
    ----------
841
    outcomes_to_show : list of str
842
                       the list of outcomes to show
843
    density: boolean : bool, optional
844

845
    """
846
    # make the plotting grid
847
    if density:
1✔
848
        grid = gridspec.GridSpec(len(outcomes_to_show), 2, width_ratios=[4, 1])
1✔
849
    else:
850
        grid = gridspec.GridSpec(len(outcomes_to_show), 1)
1✔
851
    grid.update(wspace=0.1, hspace=0.4)
1✔
852

853
    figure = plt.figure()
1✔
854
    return figure, grid
1✔
855

856

857
def get_color(index):
1✔
858
    """Helper function for cycling over color list.
859

860
    Useful if the number of items is higher than the length of the color list.
861
    """
862
    corrected_index = index % len(COLOR_LIST)
1✔
863
    return COLOR_LIST[corrected_index]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc