• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7348001748

28 Dec 2023 01:15PM UTC coverage: 99.104% (-0.09%) from 99.196%
7348001748

push

github

web-flow
Apply more ruff rules (#251)

38 of 38 new or added lines in 7 files covered. (100.0%)

3 existing lines in 2 files now uncovered.

1660 of 1675 relevant lines covered (99.1%)

0.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.49
/upsetplot/plotting.py
1
import typing
2✔
2
import warnings
2✔
3

4
import matplotlib
2✔
5
import numpy as np
2✔
6
import pandas as pd
2✔
7
from matplotlib import colors, patches
2✔
8
from matplotlib import pyplot as plt
2✔
9

10
from . import util
2✔
11
from .reformat import _get_subset_mask, query
2✔
12

13
# prevents ImportError on matplotlib versions >3.5.2
14
try:
2✔
15
    from matplotlib.tight_layout import get_renderer
2✔
16

17
    RENDERER_IMPORTED = True
1✔
UNCOV
18
except ImportError:
1✔
UNCOV
19
    RENDERER_IMPORTED = False
1✔
20

21

22
def _process_data(
2✔
23
    df,
24
    *,
25
    sort_by,
26
    sort_categories_by,
27
    subset_size,
28
    sum_over,
29
    min_subset_size=None,
2✔
30
    max_subset_size=None,
2✔
31
    min_degree=None,
2✔
32
    max_degree=None,
2✔
33
    reverse=False,
2✔
34
    include_empty_subsets=False,
2✔
35
):
36
    results = query(
2✔
37
        df,
2✔
38
        sort_by=sort_by,
2✔
39
        sort_categories_by=sort_categories_by,
2✔
40
        subset_size=subset_size,
2✔
41
        sum_over=sum_over,
2✔
42
        min_subset_size=min_subset_size,
2✔
43
        max_subset_size=max_subset_size,
2✔
44
        min_degree=min_degree,
2✔
45
        max_degree=max_degree,
2✔
46
        include_empty_subsets=include_empty_subsets,
2✔
47
    )
48

49
    df = results.data
2✔
50
    agg = results.subset_sizes
2✔
51

52
    # add '_bin' to df indicating index in agg
53
    # XXX: ugly!
54
    def _pack_binary(X):
2✔
55
        X = pd.DataFrame(X)
2✔
56
        # use objects if arbitrary precision integers are needed
57
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
2✔
58
        out = pd.Series(0, index=X.index, dtype=dtype)
2✔
59
        for _, col in X.items():
2✔
60
            out *= 2
2✔
61
            out += col
2✔
62
        return out
2✔
63

64
    df_packed = _pack_binary(df.index.to_frame())
2✔
65
    data_packed = _pack_binary(agg.index.to_frame())
2✔
66
    df["_bin"] = pd.Series(df_packed).map(
2✔
67
        pd.Series(
2✔
68
            np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed
2✔
69
        )
70
    )
71
    if reverse:
2✔
72
        agg = agg[::-1]
2✔
73

74
    return results.total, df, agg, results.category_totals
2✔
75

76

77
def _multiply_alpha(c, mult):
2✔
78
    r, g, b, a = colors.to_rgba(c)
2✔
79
    a *= mult
2✔
80
    return colors.to_hex((r, g, b, a), keep_alpha=True)
2✔
81

82

83
class _Transposed:
2✔
84
    """Wrap an object in order to transpose some plotting operations
85

86
    Attributes of obj will be mapped.
87
    Keyword arguments when calling obj will be mapped.
88

89
    The mapping is not recursive: callable attributes need to be _Transposed
90
    again.
91
    """
92

93
    def __init__(self, obj):
2✔
94
        self.__obj = obj
2✔
95

96
    def __getattr__(self, key):
2✔
97
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
2✔
98

99
    def __call__(self, *args, **kwargs):
2✔
100
        return self.__obj(
2✔
101
            *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()}
2✔
102
        )
103

104
    _NAME_TRANSPOSE = {
2✔
105
        "width": "height",
2✔
106
        "height": "width",
2✔
107
        "hspace": "wspace",
2✔
108
        "wspace": "hspace",
2✔
109
        "hlines": "vlines",
2✔
110
        "vlines": "hlines",
2✔
111
        "bar": "barh",
2✔
112
        "barh": "bar",
2✔
113
        "xaxis": "yaxis",
2✔
114
        "yaxis": "xaxis",
2✔
115
        "left": "bottom",
2✔
116
        "right": "top",
2✔
117
        "top": "right",
2✔
118
        "bottom": "left",
2✔
119
        "sharex": "sharey",
2✔
120
        "sharey": "sharex",
2✔
121
        "get_figwidth": "get_figheight",
2✔
122
        "get_figheight": "get_figwidth",
2✔
123
        "set_figwidth": "set_figheight",
2✔
124
        "set_figheight": "set_figwidth",
2✔
125
        "set_xlabel": "set_ylabel",
2✔
126
        "set_ylabel": "set_xlabel",
2✔
127
        "set_xlim": "set_ylim",
2✔
128
        "set_ylim": "set_xlim",
2✔
129
        "get_xlim": "get_ylim",
2✔
130
        "get_ylim": "get_xlim",
2✔
131
        "set_autoscalex_on": "set_autoscaley_on",
2✔
132
        "set_autoscaley_on": "set_autoscalex_on",
2✔
133
    }
134

135

136
def _transpose(obj):
2✔
137
    if isinstance(obj, str):
2✔
138
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
2✔
139
    return _Transposed(obj)
2✔
140

141

142
def _identity(obj):
2✔
143
    return obj
2✔
144

145

146
class UpSet:
2✔
147
    """Manage the data and drawing for a basic UpSet plot
148

149
    Primary public method is :meth:`plot`.
150

151
    Parameters
152
    ----------
153
    data : pandas.Series or pandas.DataFrame
154
        Elements associated with categories (a DataFrame), or the size of each
155
        subset of categories (a Series).
156
        Should have MultiIndex where each level is binary,
157
        corresponding to category membership.
158
        If a DataFrame, `sum_over` must be a string or False.
159
    orientation : {'horizontal' (default), 'vertical'}
160
        If horizontal, intersections are listed from left to right.
161
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
162
               'input', '-input'}
163
        If 'cardinality', subset are listed from largest to smallest.
164
        If 'degree', they are listed in order of the number of categories
165
        intersected. If 'input', the order they appear in the data input is
166
        used.
167
        Prefix with '-' to reverse the ordering.
168

169
        Note this affects ``subset_sizes`` but not ``data``.
170
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
171
        Whether to sort the categories by total cardinality, or leave them
172
        in the input data's provided order (order of index levels).
173
        Prefix with '-' to reverse the ordering.
174
    subset_size : {'auto', 'count', 'sum'}
175
        Configures how to calculate the size of a subset. Choices are:
176

177
        'auto' (default)
178
            If `data` is a DataFrame, count the number of rows in each group,
179
            unless `sum_over` is specified.
180
            If `data` is a Series with at most one row for each group, use
181
            the value of the Series. If `data` is a Series with more than one
182
            row per group, raise a ValueError.
183
        'count'
184
            Count the number of rows in each group.
185
        'sum'
186
            Sum the value of the `data` Series, or the DataFrame field
187
            specified by `sum_over`.
188
    sum_over : str or None
189
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
190
        sum of the specified field in the `data` DataFrame. If a Series, only
191
        None is supported and its value is summed.
192
    min_subset_size : int, optional
193
        Minimum size of a subset to be shown in the plot. All subsets with
194
        a size smaller than this threshold will be omitted from plotting.
195
        Size may be a sum of values, see `subset_size`.
196

197
        .. versionadded:: 0.5
198
    max_subset_size : int, optional
199
        Maximum size of a subset to be shown in the plot. All subsets with
200
        a size greater than this threshold will be omitted from plotting.
201

202
        .. versionadded:: 0.5
203
    min_degree : int, optional
204
        Minimum degree of a subset to be shown in the plot.
205

206
        .. versionadded:: 0.5
207
    max_degree : int, optional
208
        Maximum degree of a subset to be shown in the plot.
209

210
        .. versionadded:: 0.5
211
    facecolor : 'auto' or matplotlib color or float
212
        Color for bar charts and active dots. Defaults to black if
213
        axes.facecolor is a light color, otherwise white.
214

215
        .. versionchanged:: 0.6
216
            Before 0.6, the default was 'black'
217
    other_dots_color : matplotlib color or float
218
        Color for shading of inactive dots, or opacity (between 0 and 1)
219
        applied to facecolor.
220

221
        .. versionadded:: 0.6
222
    shading_color : matplotlib color or float
223
        Color for shading of odd rows in matrix and totals, or opacity (between
224
        0 and 1) applied to facecolor.
225

226
        .. versionadded:: 0.6
227
    with_lines : bool
228
        Whether to show lines joining dots in the matrix, to mark multiple
229
        categories being intersected.
230
    element_size : float or None
231
        Side length in pt. If None, size is estimated to fit figure
232
    intersection_plot_elements : int
233
        The intersections plot should be large enough to fit this many matrix
234
        elements. Set to 0 to disable intersection size bars.
235

236
        .. versionchanged:: 0.4
237
            Setting to 0 is handled.
238
    totals_plot_elements : int
239
        The totals plot should be large enough to fit this many matrix
240
        elements. Use totals_plot_elements=0 to disable the totals plot.
241
    show_counts : bool or str, default=False
242
        Whether to label the intersection size bars with the cardinality
243
        of the intersection. When a string, this formats the number.
244
        For example, '{:d}' is equivalent to True.
245
        Note that, for legacy reasons, if the string does not contain '{',
246
        it will be interpreted as a C-style format string, such as '%d'.
247
    show_percentages : bool or str, default=False
248
        Whether to label the intersection size bars with the percentage
249
        of the intersection relative to the total dataset.
250
        When a string, this formats the number representing a fraction of
251
        samples.
252
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
253
        This may be applied with or without show_counts.
254

255
        .. versionadded:: 0.4
256
    include_empty_subsets : bool (default=False)
257
        If True, all possible category combinations will be shown as subsets,
258
        even when some are not present in data.
259
    """
260

261
    _default_figsize = (10, 6)
2✔
262

263
    def __init__(
2✔
264
        self,
265
        data,
266
        orientation="horizontal",
1✔
267
        sort_by="degree",
1✔
268
        sort_categories_by="cardinality",
1✔
269
        subset_size="auto",
1✔
270
        sum_over=None,
1✔
271
        min_subset_size=None,
1✔
272
        max_subset_size=None,
1✔
273
        min_degree=None,
1✔
274
        max_degree=None,
1✔
275
        facecolor="auto",
1✔
276
        other_dots_color=0.18,
1✔
277
        shading_color=0.05,
1✔
278
        with_lines=True,
1✔
279
        element_size=32,
1✔
280
        intersection_plot_elements=6,
1✔
281
        totals_plot_elements=2,
1✔
282
        show_counts="",
1✔
283
        show_percentages=False,
1✔
284
        include_empty_subsets=False,
1✔
285
    ):
286
        self._horizontal = orientation == "horizontal"
2✔
287
        self._reorient = _identity if self._horizontal else _transpose
2✔
288
        if facecolor == "auto":
2✔
289
            bgcolor = matplotlib.rcParams.get("axes.facecolor", "white")
2✔
290
            r, g, b, a = colors.to_rgba(bgcolor)
2✔
291
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
2✔
292
            facecolor = "black" if lightness >= 0.5 else "white"
2✔
293
        self._facecolor = facecolor
2✔
294
        self._shading_color = (
2✔
295
            _multiply_alpha(facecolor, shading_color)
2✔
296
            if isinstance(shading_color, float)
2✔
297
            else shading_color
2✔
298
        )
299
        self._other_dots_color = (
2✔
300
            _multiply_alpha(facecolor, other_dots_color)
2✔
301
            if isinstance(other_dots_color, float)
2✔
302
            else other_dots_color
2✔
303
        )
304
        self._with_lines = with_lines
2✔
305
        self._element_size = element_size
2✔
306
        self._totals_plot_elements = totals_plot_elements
2✔
307
        self._subset_plots = [
2✔
308
            {
2✔
309
                "type": "default",
2✔
310
                "id": "intersections",
2✔
311
                "elements": intersection_plot_elements,
2✔
312
            }
313
        ]
314
        if not intersection_plot_elements:
2✔
315
            self._subset_plots.pop()
2✔
316
        self._show_counts = show_counts
2✔
317
        self._show_percentages = show_percentages
2✔
318

319
        (self.total, self._df, self.intersections, self.totals) = _process_data(
2✔
320
            data,
2✔
321
            sort_by=sort_by,
2✔
322
            sort_categories_by=sort_categories_by,
2✔
323
            subset_size=subset_size,
2✔
324
            sum_over=sum_over,
2✔
325
            min_subset_size=min_subset_size,
2✔
326
            max_subset_size=max_subset_size,
2✔
327
            min_degree=min_degree,
2✔
328
            max_degree=max_degree,
2✔
329
            reverse=not self._horizontal,
2✔
330
            include_empty_subsets=include_empty_subsets,
2✔
331
        )
332
        self.subset_styles = [
2✔
333
            {"facecolor": facecolor} for i in range(len(self.intersections))
2✔
334
        ]
335
        self.subset_legend = []  # pairs of (style, label)
2✔
336

337
    def _swapaxes(self, x, y):
2✔
338
        if self._horizontal:
2✔
339
            return x, y
2✔
340
        return y, x
2✔
341

342
    def style_subsets(
2✔
343
        self,
344
        present=None,
1✔
345
        absent=None,
1✔
346
        min_subset_size=None,
1✔
347
        max_subset_size=None,
1✔
348
        min_degree=None,
1✔
349
        max_degree=None,
1✔
350
        facecolor=None,
1✔
351
        edgecolor=None,
1✔
352
        hatch=None,
1✔
353
        linewidth=None,
1✔
354
        linestyle=None,
1✔
355
        label=None,
1✔
356
    ):
357
        """Updates the style of selected subsets' bars and matrix dots
358

359
        Parameters are either used to select subsets, or to style them with
360
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
361
        which adds a legend entry.
362

363
        Parameters
364
        ----------
365
        present : str or list of str, optional
366
            Category or categories that must be present in subsets for styling.
367
        absent : str or list of str, optional
368
            Category or categories that must not be present in subsets for
369
            styling.
370
        min_subset_size : int, optional
371
            Minimum size of a subset to be styled.
372
        max_subset_size : int, optional
373
            Maximum size of a subset to be styled.
374
        min_degree : int, optional
375
            Minimum degree of a subset to be styled.
376
        max_degree : int, optional
377
            Maximum degree of a subset to be styled.
378

379
        facecolor : str or matplotlib color, optional
380
            Override the default UpSet facecolor for selected subsets.
381
        edgecolor : str or matplotlib color, optional
382
            Set the edgecolor for bars, dots, and the line between dots.
383
        hatch : str, optional
384
            Set the hatch. This will apply to intersection size bars, but not
385
            to matrix dots.
386
        linewidth : int, optional
387
            Line width in points for edges.
388
        linestyle : str, optional
389
            Line style for edges.
390

391
        label : str, optional
392
            If provided, a legend will be added
393
        """
394
        style = {
2✔
395
            "facecolor": facecolor,
2✔
396
            "edgecolor": edgecolor,
2✔
397
            "hatch": hatch,
2✔
398
            "linewidth": linewidth,
2✔
399
            "linestyle": linestyle,
2✔
400
        }
401
        style = {k: v for k, v in style.items() if v is not None}
2✔
402
        mask = _get_subset_mask(
2✔
403
            self.intersections,
2✔
404
            present=present,
2✔
405
            absent=absent,
2✔
406
            min_subset_size=min_subset_size,
2✔
407
            max_subset_size=max_subset_size,
2✔
408
            min_degree=min_degree,
2✔
409
            max_degree=max_degree,
2✔
410
        )
411
        for idx in np.flatnonzero(mask):
2✔
412
            self.subset_styles[idx].update(style)
2✔
413

414
        if label is not None:
2✔
415
            if "facecolor" not in style:
2✔
416
                style["facecolor"] = self._facecolor
417
            for i, (other_style, other_label) in enumerate(self.subset_legend):
2✔
418
                if other_style == style:
2✔
419
                    if other_label != label:
2✔
420
                        self.subset_legend[i] = (style, other_label + "; " + label)
2✔
421
                    break
2✔
422
            else:
423
                self.subset_legend.append((style, label))
2✔
424

425
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
2✔
426
        ax = self._reorient(ax)
2✔
427
        ax.set_autoscalex_on(False)
2✔
428
        data_df = pd.DataFrame(data)
2✔
429
        if self._horizontal:
2✔
430
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
2✔
431

432
        # TODO: colors should be broadcastable to data_df shape
433
        if callable(colors):
2✔
434
            colors = colors(range(data_df.shape[1]))
2✔
435
        elif isinstance(colors, (str, type(None))):
2✔
436
            colors = [colors] * len(data_df)
2✔
437

438
        if self._horizontal:
2✔
439
            colors = reversed(colors)
2✔
440

441
        x = np.arange(len(data_df))
2✔
442
        cum_y = None
2✔
443
        all_rects = []
2✔
444
        for (name, y), color in zip(data_df.items(), colors):
2✔
445
            rects = ax.bar(
2✔
446
                x,
2✔
447
                y,
2✔
448
                0.5,
2✔
449
                cum_y,
2✔
450
                color=color,
2✔
451
                zorder=10,
2✔
452
                label=name if use_labels else None,
2✔
453
                align="center",
2✔
454
            )
455
            cum_y = y if cum_y is None else cum_y + y
2✔
456
            all_rects.extend(rects)
2✔
457

458
        self._label_sizes(ax, rects, "top" if self._horizontal else "right")
2✔
459

460
        ax.xaxis.set_visible(False)
2✔
461
        for x in ["top", "bottom", "right"]:
2✔
462
            ax.spines[self._reorient(x)].set_visible(False)
2✔
463

464
        tick_axis = ax.yaxis
2✔
465
        tick_axis.grid(True)
2✔
466
        ax.set_ylabel(title)
2✔
467
        return all_rects
2✔
468

469
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
2✔
470
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
2✔
471
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
2✔
472
        if sum_over is None and "_value" in df.columns:
2✔
473
            data = gb["_value"].sum()
474
        elif sum_over is None:
2✔
475
            data = gb.size()
2✔
476
        else:
477
            data = gb[sum_over].sum()
2✔
478
        data = data.unstack(by).fillna(0)
2✔
479
        if isinstance(colors, str):
2✔
480
            colors = matplotlib.cm.get_cmap(colors)
2✔
481
        elif isinstance(colors, typing.Mapping):
2✔
482
            colors = data.columns.map(colors).values
2✔
483
            if pd.isna(colors).any():
2✔
484
                raise KeyError(
485
                    "Some labels mapped by colors: %r"
486
                    % data.columns[pd.isna(colors)].tolist()
487
                )
488

489
        self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True)
2✔
490

491
        handles, labels = ax.get_legend_handles_labels()
2✔
492
        if self._horizontal:
2✔
493
            # Make legend order match visual stack order
494
            ax.legend(reversed(handles), reversed(labels))
2✔
495
        else:
496
            ax.legend()
2✔
497

498
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None):
2✔
499
        """Add a stacked bar chart over subsets when :func:`plot` is called.
500

501
        Used to plot categorical variable distributions within each subset.
502

503
        .. versionadded:: 0.6
504

505
        Parameters
506
        ----------
507
        by : str
508
            Column name within the dataframe for color coding the stacked bars,
509
            containing discrete or categorical values.
510
        sum_over : str, optional
511
            Ordinarily the bars will chart the size of each group. sum_over
512
            may specify a column which will be summed to determine the size
513
            of each bar.
514
        colors : Mapping, list-like, str or callable, optional
515
            The facecolors to use for bars corresponding to each discrete
516
            label, specified as one of:
517

518
            Mapping
519
                Maps from label to matplotlib-compatible color specification.
520
            list-like
521
                A list of matplotlib colors to apply to labels in order.
522
            str
523
                The name of a matplotlib colormap name.
524
            callable
525
                When called with the number of labels, this should return a
526
                list-like of that many colors.  Matplotlib colormaps satisfy
527
                this callable API.
528
            None
529
                Uses the matplotlib default colormap.
530
        elements : int, default=3
531
            Size of the axes counted in number of matrix elements.
532
        title : str, optional
533
            The axis title labelling bar length.
534

535
        Returns
536
        -------
537
        None
538
        """
539
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
540
        #                        list of labels}
541
        self._subset_plots.append(
2✔
542
            {
2✔
543
                "type": "stacked_bars",
2✔
544
                "by": by,
2✔
545
                "sum_over": sum_over,
2✔
546
                "colors": colors,
2✔
547
                "title": title,
2✔
548
                "id": "extra%d" % len(self._subset_plots),
2✔
549
                "elements": elements,
2✔
550
            }
551
        )
552

553
    def add_catplot(self, kind, value=None, elements=3, **kw):
2✔
554
        """Add a seaborn catplot over subsets when :func:`plot` is called.
555

556
        Parameters
557
        ----------
558
        kind : str
559
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
560
        value : str, optional
561
            Column name for the value to plot (i.e. y if
562
            orientation='horizontal'), required if `data` is a DataFrame.
563
        elements : int, default=3
564
            Size of the axes counted in number of matrix elements.
565
        **kw : dict
566
            Additional keywords to pass to :func:`seaborn.catplot`.
567

568
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
569
            and 'orient', so these are prohibited keys in `kw`.
570

571
        Returns
572
        -------
573
        None
574
        """
575
        assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"}
576
        if value is None:
577
            if "_value" not in self._df.columns:
578
                raise ValueError(
579
                    "value cannot be set if data is a Series. " "Got %r" % value
580
                )
581
        else:
582
            if value not in self._df.columns:
583
                raise ValueError("value %r is not a column in data" % value)
584
        self._subset_plots.append(
585
            {
586
                "type": "catplot",
587
                "value": value,
588
                "kind": kind,
589
                "id": "extra%d" % len(self._subset_plots),
590
                "elements": elements,
591
                "kw": kw,
592
            }
593
        )
594

595
    def _check_value(self, value):
2✔
596
        if value is None and "_value" in self._df.columns:
597
            value = "_value"
598
        elif value is None:
599
            raise ValueError("value can only be None when data is a Series")
600
        return value
601

602
    def _plot_catplot(self, ax, value, kind, kw):
2✔
603
        df = self._df
604
        value = self._check_value(value)
605
        kw = kw.copy()
606
        if self._horizontal:
607
            kw["orient"] = "v"
608
            kw["x"] = "_bin"
609
            kw["y"] = value
610
        else:
611
            kw["orient"] = "h"
612
            kw["x"] = value
613
            kw["y"] = "_bin"
614
        import seaborn
615

616
        kw["ax"] = ax
617
        getattr(seaborn, kind + "plot")(data=df, **kw)
618

619
        ax = self._reorient(ax)
620
        if value == "_value":
621
            ax.set_ylabel("")
622

623
        ax.xaxis.set_visible(False)
624
        for x in ["top", "bottom", "right"]:
625
            ax.spines[self._reorient(x)].set_visible(False)
626

627
        tick_axis = ax.yaxis
628
        tick_axis.grid(True)
629

630
    def make_grid(self, fig=None):
2✔
631
        """Get a SubplotSpec for each Axes, accounting for label text width"""
632
        n_cats = len(self.totals)
633
        n_inters = len(self.intersections)
634

635
        if fig is None:
636
            fig = plt.gcf()
637

638
        # Determine text size to determine figure size / spacing
639
        text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]}
640
        # adding "x" ensures a margin
641
        t = fig.text(
642
            0,
643
            0,
644
            "\n".join(str(label) + "x" for label in self.totals.index.values),
645
            **text_kw,
646
        )
647
        window_extent_args = {}
648
        if RENDERER_IMPORTED:
649
            with warnings.catch_warnings():
650
                warnings.simplefilter("ignore", DeprecationWarning)
651
                window_extent_args["renderer"] = get_renderer(fig)
652
        textw = t.get_window_extent(**window_extent_args).width
653
        t.remove()
654

655
        window_extent_args = {}
656
        if RENDERER_IMPORTED:
657
            with warnings.catch_warnings():
658
                warnings.simplefilter("ignore", DeprecationWarning)
659
                window_extent_args["renderer"] = get_renderer(fig)
660
        figw = self._reorient(fig.get_window_extent(**window_extent_args)).width
661

662
        sizes = np.asarray([p["elements"] for p in self._subset_plots])
663
        fig = self._reorient(fig)
664

665
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
666
        if self._element_size is None:
667
            colw = (figw - textw) / non_text_nelems
668
        else:
669
            render_ratio = figw / fig.get_figwidth()
670
            colw = self._element_size / 72 * render_ratio
671
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
672
            fig.set_figwidth(figw / render_ratio)
673
            fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)
674

675
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
676

677
        GS = self._reorient(matplotlib.gridspec.GridSpec)
678
        gridspec = GS(
679
            *self._swapaxes(
680
                n_cats + (sizes.sum() or 0),
681
                n_inters + text_nelems + self._totals_plot_elements,
682
            ),
683
            hspace=1,
684
        )
685
        if self._horizontal:
686
            out = {
687
                "matrix": gridspec[-n_cats:, -n_inters:],
688
                "shading": gridspec[-n_cats:, :],
689
                "totals": None
690
                if self._totals_plot_elements == 0
691
                else gridspec[-n_cats:, : self._totals_plot_elements],
692
                "gs": gridspec,
693
            }
694
            cumsizes = np.cumsum(sizes[::-1])
695
            for start, stop, plot in zip(
696
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]
697
            ):
698
                out[plot["id"]] = gridspec[start:stop, -n_inters:]
699
        else:
700
            out = {
701
                "matrix": gridspec[-n_inters:, :n_cats],
702
                "shading": gridspec[:, :n_cats],
703
                "totals": None
704
                if self._totals_plot_elements == 0
705
                else gridspec[: self._totals_plot_elements, :n_cats],
706
                "gs": gridspec,
707
            }
708
            cumsizes = np.cumsum(sizes)
709
            for start, stop, plot in zip(
710
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots
711
            ):
712
                out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats]
713
        return out
714

715
    def plot_matrix(self, ax):
716
        """Plot the matrix of intersection indicators onto ax"""
717
        ax = self._reorient(ax)
2✔
718
        data = self.intersections
2✔
719
        n_cats = data.index.nlevels
2✔
720

721
        inclusion = data.index.to_frame().values
2✔
722

723
        # Prepare styling
724
        styles = [
2✔
725
            [
2✔
726
                self.subset_styles[i]
2✔
727
                if inclusion[i, j]
2✔
728
                else {"facecolor": self._other_dots_color, "linewidth": 0}
2✔
729
                for j in range(n_cats)
2✔
730
            ]
731
            for i in range(len(data))
2✔
732
        ]
733
        styles = sum(styles, [])  # flatten nested list
2✔
734
        style_columns = {
2✔
735
            "facecolor": "facecolors",
2✔
736
            "edgecolor": "edgecolors",
2✔
737
            "linewidth": "linewidths",
2✔
738
            "linestyle": "linestyles",
2✔
739
            "hatch": "hatch",
2✔
740
        }
741
        styles = (
2✔
742
            pd.DataFrame(styles)
2✔
743
            .reindex(columns=style_columns.keys())
2✔
744
            .astype(
1✔
745
                {
2✔
746
                    "facecolor": "O",
2✔
747
                    "edgecolor": "O",
2✔
748
                    "linewidth": float,
2✔
749
                    "linestyle": "O",
2✔
750
                    "hatch": "O",
2✔
751
                }
752
            )
753
        )
754
        styles["linewidth"].fillna(1, inplace=True)
2✔
755
        styles["facecolor"].fillna(self._facecolor, inplace=True)
2✔
756
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
2✔
757
        styles["linestyle"].fillna("solid", inplace=True)
2✔
758
        del styles["hatch"]  # not supported in matrix (currently)
2✔
759

760
        x = np.repeat(np.arange(len(data)), n_cats)
2✔
761
        y = np.tile(np.arange(n_cats), len(data))
2✔
762

763
        # Plot dots
764
        if self._element_size is not None:  # noqa
2✔
765
            s = (self._element_size * 0.35) ** 2
2✔
766
        else:
767
            # TODO: make s relative to colw
768
            s = 200
2✔
769
        ax.scatter(
2✔
770
            *self._swapaxes(x, y),
2✔
771
            s=s,
2✔
772
            zorder=10,
2✔
773
            **styles.rename(columns=style_columns),
2✔
774
        )
775

776
        # Plot lines
777
        if self._with_lines:
2✔
778
            idx = np.flatnonzero(inclusion)
2✔
779
            line_data = (
2✔
780
                pd.Series(y[idx], index=x[idx])
2✔
781
                .groupby(level=0)
2✔
782
                .aggregate(["min", "max"])
2✔
783
            )
784
            colors = pd.Series(
2✔
785
                [
2✔
786
                    style.get("edgecolor", style.get("facecolor", self._facecolor))
2✔
787
                    for style in self.subset_styles
2✔
788
                ],
789
                name="color",
2✔
790
            )
791
            line_data = line_data.join(colors)
2✔
792
            ax.vlines(
2✔
793
                line_data.index.values,
2✔
794
                line_data["min"],
2✔
795
                line_data["max"],
2✔
796
                lw=2,
2✔
797
                colors=line_data["color"],
2✔
798
                zorder=5,
2✔
799
            )
800

801
        # Ticks and axes
802
        tick_axis = ax.yaxis
2✔
803
        tick_axis.set_ticks(np.arange(n_cats))
2✔
804
        tick_axis.set_ticklabels(
2✔
805
            data.index.names, rotation=0 if self._horizontal else -90
2✔
806
        )
807
        ax.xaxis.set_ticks([])
2✔
808
        ax.tick_params(axis="both", which="both", length=0)
2✔
809
        if not self._horizontal:
2✔
810
            ax.yaxis.set_ticks_position("top")
2✔
811
        ax.set_frame_on(False)
2✔
812
        ax.set_xlim(-0.5, x[-1] + 0.5, auto=False)
2✔
813
        ax.grid(False)
2✔
814

815
    def plot_intersections(self, ax):
2✔
816
        """Plot bars indicating intersection size"""
817
        rects = self._plot_bars(
818
            ax, self.intersections, title="Intersection size", colors=self._facecolor
819
        )
820
        for style, rect in zip(self.subset_styles, rects):
821
            style = style.copy()
822
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
823
            for attr, val in style.items():
824
                getattr(rect, "set_" + attr)(val)
825

826
        if self.subset_legend:
827
            styles, labels = zip(*self.subset_legend)
828
            styles = [patches.Patch(**patch_style) for patch_style in styles]
829
            ax.legend(styles, labels)
830

831
    def _label_sizes(self, ax, rects, where):
832
        if not self._show_counts and not self._show_percentages:
833
            return
834
        if self._show_counts is True:
835
            count_fmt = "{:.0f}"
836
        else:
837
            count_fmt = self._show_counts
838
            if "{" not in count_fmt:
839
                count_fmt = util.to_new_pos_format(count_fmt)
840

841
        pct_fmt = "{:.1%}" if self._show_percentages is True else self._show_percentages
842

843
        if count_fmt and pct_fmt:
844
            if where == "top":
845
                fmt = f"{count_fmt}\n({pct_fmt})"
846
            else:
847
                fmt = f"{count_fmt} ({pct_fmt})"
848

849
            def make_args(val):
850
                return val, val / self.total
851
        elif count_fmt:
852
            fmt = count_fmt
853

854
            def make_args(val):
855
                return (val,)
856
        else:
857
            fmt = pct_fmt
858

859
            def make_args(val):
860
                return (val / self.total,)
861

862
        if where == "right":
863
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
864
            for rect in rects:
865
                width = rect.get_width() + rect.get_x()
866
                ax.text(
867
                    width + margin,
868
                    rect.get_y() + rect.get_height() * 0.5,
869
                    fmt.format(*make_args(width)),
870
                    ha="left",
871
                    va="center",
872
                )
873
        elif where == "left":
874
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
875
            for rect in rects:
876
                width = rect.get_width() + rect.get_x()
877
                ax.text(
878
                    width + margin,
879
                    rect.get_y() + rect.get_height() * 0.5,
880
                    fmt.format(*make_args(width)),
881
                    ha="right",
882
                    va="center",
883
                )
884
        elif where == "top":
885
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
886
            for rect in rects:
887
                height = rect.get_height() + rect.get_y()
888
                ax.text(
889
                    rect.get_x() + rect.get_width() * 0.5,
890
                    height + margin,
891
                    fmt.format(*make_args(height)),
892
                    ha="center",
893
                    va="bottom",
894
                )
895
        else:
896
            raise NotImplementedError("unhandled where: %r" % where)
897

898
    def plot_totals(self, ax):
899
        """Plot bars indicating total set size"""
900
        orig_ax = ax
2✔
901
        ax = self._reorient(ax)
2✔
902
        rects = ax.barh(
2✔
903
            np.arange(len(self.totals.index.values)),
2✔
904
            self.totals,
2✔
905
            0.5,
2✔
906
            color=self._facecolor,
2✔
907
            align="center",
2✔
908
        )
909
        self._label_sizes(ax, rects, "left" if self._horizontal else "top")
2✔
910

911
        max_total = self.totals.max()
2✔
912
        if self._horizontal:
2✔
913
            orig_ax.set_xlim(max_total, 0)
2✔
914
        for x in ["top", "left", "right"]:
2✔
915
            ax.spines[self._reorient(x)].set_visible(False)
2✔
916
        ax.yaxis.set_visible(True)
2✔
917
        ax.yaxis.set_ticklabels([])
2✔
918
        ax.yaxis.set_ticks([])
2✔
919
        ax.xaxis.grid(True)
2✔
920
        ax.yaxis.grid(False)
2✔
921
        ax.patch.set_visible(False)
2✔
922

923
    def plot_shading(self, ax):
2✔
924
        # alternating row shading (XXX: use add_patch(Rectangle)?)
925
        for i in range(0, len(self.totals), 2):
2✔
926
            rect = plt.Rectangle(
2✔
927
                self._swapaxes(0, i - 0.4),
2✔
928
                *self._swapaxes(*(1, 0.8)),
2✔
929
                facecolor=self._shading_color,
2✔
930
                lw=0,
2✔
931
                zorder=0,
2✔
932
            )
933
            ax.add_patch(rect)
2✔
934
        ax.set_frame_on(False)
2✔
935
        ax.tick_params(
2✔
936
            axis="both",
2✔
937
            which="both",
2✔
938
            left=False,
2✔
939
            right=False,
2✔
940
            bottom=False,
2✔
941
            top=False,
2✔
942
            labelbottom=False,
2✔
943
            labelleft=False,
2✔
944
        )
945
        ax.grid(False)
2✔
946
        ax.set_xticks([])
2✔
947
        ax.set_yticks([])
2✔
948
        ax.set_xticklabels([])
2✔
949
        ax.set_yticklabels([])
2✔
950

951
    def plot(self, fig=None):
2✔
952
        """Draw all parts of the plot onto fig or a new figure
953

954
        Parameters
955
        ----------
956
        fig : matplotlib.figure.Figure, optional
957
            Defaults to a new figure.
958

959
        Returns
960
        -------
961
        subplots : dict of matplotlib.axes.Axes
962
            Keys are 'matrix', 'intersections', 'totals', 'shading'
963
        """
964
        if fig is None:
2✔
965
            fig = plt.figure(figsize=self._default_figsize)
2✔
966
        specs = self.make_grid(fig)
2✔
967
        shading_ax = fig.add_subplot(specs["shading"])
2✔
968
        self.plot_shading(shading_ax)
2✔
969
        matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax)
2✔
970
        self.plot_matrix(matrix_ax)
2✔
971
        if specs["totals"] is None:
2✔
972
            totals_ax = None
2✔
973
        else:
974
            totals_ax = self._reorient(fig.add_subplot)(
2✔
975
                specs["totals"], sharey=matrix_ax
2✔
976
            )
977
            self.plot_totals(totals_ax)
2✔
978
        out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax}
2✔
979

980
        for plot in self._subset_plots:
2✔
981
            ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax)
2✔
982
            if plot["type"] == "default":
2✔
983
                self.plot_intersections(ax)
2✔
984
            elif plot["type"] in self.PLOT_TYPES:
2✔
985
                kw = plot.copy()
2✔
986
                del kw["type"]
2✔
987
                del kw["elements"]
2✔
988
                del kw["id"]
2✔
989
                self.PLOT_TYPES[plot["type"]](self, ax, **kw)
2✔
990
            else:
991
                raise ValueError("Unknown subset plot type: %r" % plot["type"])
992
            out[plot["id"]] = ax
2✔
993
        return out
2✔
994

995
    PLOT_TYPES = {
2✔
996
        "catplot": _plot_catplot,
2✔
997
        "stacked_bars": _plot_stacked_bars,
2✔
998
    }
999

1000
    def _repr_html_(self):
2✔
1001
        fig = plt.figure(figsize=self._default_figsize)
1002
        self.plot(fig=fig)
1003
        return fig._repr_html_()
1004

1005

1006
def plot(data, fig=None, **kwargs):
2✔
1007
    """Make an UpSet plot of data on fig
1008

1009
    Parameters
1010
    ----------
1011
    data : pandas.Series or pandas.DataFrame
1012
        Values for each set to plot.
1013
        Should have multi-index where each level is binary,
1014
        corresponding to set membership.
1015
        If a DataFrame, `sum_over` must be a string or False.
1016
    fig : matplotlib.figure.Figure, optional
1017
        Defaults to a new figure.
1018
    kwargs
1019
        Other arguments for :class:`UpSet`
1020

1021
    Returns
1022
    -------
1023
    subplots : dict of matplotlib.axes.Axes
1024
        Keys are 'matrix', 'intersections', 'totals', 'shading'
1025
    """
1026
    return UpSet(data, **kwargs).plot(fig)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc