• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7344254321

28 Dec 2023 03:58AM UTC coverage: 98.586% (+15.0%) from 83.549%
7344254321

push

github

web-flow
Format with black/ruff (#240)

844 of 848 new or added lines in 8 files covered. (99.53%)

4 existing lines in 3 files now uncovered.

1534 of 1556 relevant lines covered (98.59%)

0.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.53
/upsetplot/plotting.py
1
from __future__ import print_function, division, absolute_import
2✔
2

3
try:
2✔
4
    import typing
2✔
5
except ImportError:
×
6
    import collections as typing
×
7

8
import numpy as np
2✔
9
import pandas as pd
2✔
10
import matplotlib
2✔
11
from matplotlib import pyplot as plt
2✔
12
from matplotlib import colors
2✔
13
from matplotlib import patches
2✔
14

15
from .reformat import query, _get_subset_mask
2✔
16
from . import util
2✔
17

18
# prevents ImportError on matplotlib versions >3.5.2
19
try:
2✔
20
    from matplotlib.tight_layout import get_renderer
2✔
21

22
    RENDERER_IMPORTED = True
1✔
UNCOV
23
except ImportError:
1✔
UNCOV
24
    RENDERER_IMPORTED = False
1✔
25

26

27
def _process_data(
1✔
28
    df,
29
    *,
30
    sort_by,
31
    sort_categories_by,
32
    subset_size,
33
    sum_over,
34
    min_subset_size=None,
2✔
35
    max_subset_size=None,
2✔
36
    min_degree=None,
2✔
37
    max_degree=None,
2✔
38
    reverse=False,
2✔
39
    include_empty_subsets=False,
2✔
40
):
41
    results = query(
2✔
42
        df,
2✔
43
        sort_by=sort_by,
2✔
44
        sort_categories_by=sort_categories_by,
2✔
45
        subset_size=subset_size,
2✔
46
        sum_over=sum_over,
2✔
47
        min_subset_size=min_subset_size,
2✔
48
        max_subset_size=max_subset_size,
2✔
49
        min_degree=min_degree,
2✔
50
        max_degree=max_degree,
2✔
51
        include_empty_subsets=include_empty_subsets,
2✔
52
    )
53

54
    df = results.data
2✔
55
    agg = results.subset_sizes
2✔
56
    totals = results.category_totals
2✔
57
    total = agg.sum()
2✔
58

59
    # add '_bin' to df indicating index in agg
60
    # XXX: ugly!
61
    def _pack_binary(X):
2✔
62
        X = pd.DataFrame(X)
2✔
63
        # use objects if arbitrary precision integers are needed
64
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
2✔
65
        out = pd.Series(0, index=X.index, dtype=dtype)
2✔
66
        for i, (_, col) in enumerate(X.items()):
2✔
67
            out *= 2
2✔
68
            out += col
2✔
69
        return out
2✔
70

71
    df_packed = _pack_binary(df.index.to_frame())
2✔
72
    data_packed = _pack_binary(agg.index.to_frame())
2✔
73
    df["_bin"] = pd.Series(df_packed).map(
2✔
74
        pd.Series(
2✔
75
            np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed
2✔
76
        )
77
    )
78
    if reverse:
2✔
79
        agg = agg[::-1]
2✔
80

81
    return total, df, agg, totals
2✔
82

83

84
def _multiply_alpha(c, mult):
2✔
85
    r, g, b, a = colors.to_rgba(c)
2✔
86
    a *= mult
2✔
87
    return colors.to_hex((r, g, b, a), keep_alpha=True)
2✔
88

89

90
class _Transposed:
2✔
91
    """Wrap an object in order to transpose some plotting operations
92

93
    Attributes of obj will be mapped.
94
    Keyword arguments when calling obj will be mapped.
95

96
    The mapping is not recursive: callable attributes need to be _Transposed
97
    again.
98
    """
99

100
    def __init__(self, obj):
2✔
101
        self.__obj = obj
2✔
102

103
    def __getattr__(self, key):
2✔
104
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
2✔
105

106
    def __call__(self, *args, **kwargs):
2✔
107
        return self.__obj(
2✔
108
            *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()}
2✔
109
        )
110

111
    _NAME_TRANSPOSE = {
1✔
112
        "width": "height",
2✔
113
        "height": "width",
2✔
114
        "hspace": "wspace",
2✔
115
        "wspace": "hspace",
2✔
116
        "hlines": "vlines",
2✔
117
        "vlines": "hlines",
2✔
118
        "bar": "barh",
2✔
119
        "barh": "bar",
2✔
120
        "xaxis": "yaxis",
2✔
121
        "yaxis": "xaxis",
2✔
122
        "left": "bottom",
2✔
123
        "right": "top",
2✔
124
        "top": "right",
2✔
125
        "bottom": "left",
2✔
126
        "sharex": "sharey",
2✔
127
        "sharey": "sharex",
2✔
128
        "get_figwidth": "get_figheight",
2✔
129
        "get_figheight": "get_figwidth",
2✔
130
        "set_figwidth": "set_figheight",
2✔
131
        "set_figheight": "set_figwidth",
2✔
132
        "set_xlabel": "set_ylabel",
2✔
133
        "set_ylabel": "set_xlabel",
2✔
134
        "set_xlim": "set_ylim",
2✔
135
        "set_ylim": "set_xlim",
2✔
136
        "get_xlim": "get_ylim",
2✔
137
        "get_ylim": "get_xlim",
2✔
138
        "set_autoscalex_on": "set_autoscaley_on",
2✔
139
        "set_autoscaley_on": "set_autoscalex_on",
2✔
140
    }
141

142

143
def _transpose(obj):
2✔
144
    if isinstance(obj, str):
2✔
145
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
2✔
146
    return _Transposed(obj)
2✔
147

148

149
def _identity(obj):
2✔
150
    return obj
2✔
151

152

153
class UpSet:
2✔
154
    """Manage the data and drawing for a basic UpSet plot
155

156
    Primary public method is :meth:`plot`.
157

158
    Parameters
159
    ----------
160
    data : pandas.Series or pandas.DataFrame
161
        Elements associated with categories (a DataFrame), or the size of each
162
        subset of categories (a Series).
163
        Should have MultiIndex where each level is binary,
164
        corresponding to category membership.
165
        If a DataFrame, `sum_over` must be a string or False.
166
    orientation : {'horizontal' (default), 'vertical'}
167
        If horizontal, intersections are listed from left to right.
168
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
169
               'input', '-input'}
170
        If 'cardinality', subset are listed from largest to smallest.
171
        If 'degree', they are listed in order of the number of categories
172
        intersected. If 'input', the order they appear in the data input is
173
        used.
174
        Prefix with '-' to reverse the ordering.
175

176
        Note this affects ``subset_sizes`` but not ``data``.
177
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
178
        Whether to sort the categories by total cardinality, or leave them
179
        in the input data's provided order (order of index levels).
180
        Prefix with '-' to reverse the ordering.
181
    subset_size : {'auto', 'count', 'sum'}
182
        Configures how to calculate the size of a subset. Choices are:
183

184
        'auto' (default)
185
            If `data` is a DataFrame, count the number of rows in each group,
186
            unless `sum_over` is specified.
187
            If `data` is a Series with at most one row for each group, use
188
            the value of the Series. If `data` is a Series with more than one
189
            row per group, raise a ValueError.
190
        'count'
191
            Count the number of rows in each group.
192
        'sum'
193
            Sum the value of the `data` Series, or the DataFrame field
194
            specified by `sum_over`.
195
    sum_over : str or None
196
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
197
        sum of the specified field in the `data` DataFrame. If a Series, only
198
        None is supported and its value is summed.
199
    min_subset_size : int, optional
200
        Minimum size of a subset to be shown in the plot. All subsets with
201
        a size smaller than this threshold will be omitted from plotting.
202
        Size may be a sum of values, see `subset_size`.
203

204
        .. versionadded:: 0.5
205
    max_subset_size : int, optional
206
        Maximum size of a subset to be shown in the plot. All subsets with
207
        a size greater than this threshold will be omitted from plotting.
208

209
        .. versionadded:: 0.5
210
    min_degree : int, optional
211
        Minimum degree of a subset to be shown in the plot.
212

213
        .. versionadded:: 0.5
214
    max_degree : int, optional
215
        Maximum degree of a subset to be shown in the plot.
216

217
        .. versionadded:: 0.5
218
    facecolor : 'auto' or matplotlib color or float
219
        Color for bar charts and active dots. Defaults to black if
220
        axes.facecolor is a light color, otherwise white.
221

222
        .. versionchanged:: 0.6
223
            Before 0.6, the default was 'black'
224
    other_dots_color : matplotlib color or float
225
        Color for shading of inactive dots, or opacity (between 0 and 1)
226
        applied to facecolor.
227

228
        .. versionadded:: 0.6
229
    shading_color : matplotlib color or float
230
        Color for shading of odd rows in matrix and totals, or opacity (between
231
        0 and 1) applied to facecolor.
232

233
        .. versionadded:: 0.6
234
    with_lines : bool
235
        Whether to show lines joining dots in the matrix, to mark multiple
236
        categories being intersected.
237
    element_size : float or None
238
        Side length in pt. If None, size is estimated to fit figure
239
    intersection_plot_elements : int
240
        The intersections plot should be large enough to fit this many matrix
241
        elements. Set to 0 to disable intersection size bars.
242

243
        .. versionchanged:: 0.4
244
            Setting to 0 is handled.
245
    totals_plot_elements : int
246
        The totals plot should be large enough to fit this many matrix
247
        elements.
248
    show_counts : bool or str, default=False
249
        Whether to label the intersection size bars with the cardinality
250
        of the intersection. When a string, this formats the number.
251
        For example, '{:d}' is equivalent to True.
252
        Note that, for legacy reasons, if the string does not contain '{',
253
        it will be interpreted as a C-style format string, such as '%d'.
254
    show_percentages : bool or str, default=False
255
        Whether to label the intersection size bars with the percentage
256
        of the intersection relative to the total dataset.
257
        When a string, this formats the number representing a fraction of
258
        samples.
259
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
260
        This may be applied with or without show_counts.
261

262
        .. versionadded:: 0.4
263
    include_empty_subsets : bool (default=False)
264
        If True, all possible category combinations will be shown as subsets,
265
        even when some are not present in data.
266
    """
267

268
    _default_figsize = (10, 6)
2✔
269

270
    def __init__(
1✔
271
        self,
272
        data,
273
        orientation="horizontal",
1✔
274
        sort_by="degree",
1✔
275
        sort_categories_by="cardinality",
1✔
276
        subset_size="auto",
1✔
277
        sum_over=None,
1✔
278
        min_subset_size=None,
1✔
279
        max_subset_size=None,
1✔
280
        min_degree=None,
1✔
281
        max_degree=None,
1✔
282
        facecolor="auto",
1✔
283
        other_dots_color=0.18,
1✔
284
        shading_color=0.05,
1✔
285
        with_lines=True,
1✔
286
        element_size=32,
1✔
287
        intersection_plot_elements=6,
1✔
288
        totals_plot_elements=2,
1✔
289
        show_counts="",
1✔
290
        show_percentages=False,
1✔
291
        include_empty_subsets=False,
2✔
292
    ):
293
        self._horizontal = orientation == "horizontal"
2✔
294
        self._reorient = _identity if self._horizontal else _transpose
2✔
295
        if facecolor == "auto":
2✔
296
            bgcolor = matplotlib.rcParams.get("axes.facecolor", "white")
2✔
297
            r, g, b, a = colors.to_rgba(bgcolor)
2✔
298
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
2✔
299
            facecolor = "black" if lightness >= 0.5 else "white"
2✔
300
        self._facecolor = facecolor
2✔
301
        self._shading_color = (
1✔
302
            _multiply_alpha(facecolor, shading_color)
1✔
303
            if isinstance(shading_color, float)
2✔
304
            else shading_color
2✔
305
        )
306
        self._other_dots_color = (
1✔
307
            _multiply_alpha(facecolor, other_dots_color)
1✔
308
            if isinstance(other_dots_color, float)
2✔
309
            else other_dots_color
2✔
310
        )
311
        self._with_lines = with_lines
2✔
312
        self._element_size = element_size
2✔
313
        self._totals_plot_elements = totals_plot_elements
2✔
314
        self._subset_plots = [
1✔
315
            {
1✔
316
                "type": "default",
2✔
317
                "id": "intersections",
2✔
318
                "elements": intersection_plot_elements,
2✔
319
            }
320
        ]
321
        if not intersection_plot_elements:
2✔
322
            self._subset_plots.pop()
2✔
323
        self._show_counts = show_counts
2✔
324
        self._show_percentages = show_percentages
2✔
325

326
        (self.total, self._df, self.intersections, self.totals) = _process_data(
2✔
327
            data,
2✔
328
            sort_by=sort_by,
2✔
329
            sort_categories_by=sort_categories_by,
2✔
330
            subset_size=subset_size,
2✔
331
            sum_over=sum_over,
2✔
332
            min_subset_size=min_subset_size,
2✔
333
            max_subset_size=max_subset_size,
2✔
334
            min_degree=min_degree,
2✔
335
            max_degree=max_degree,
2✔
336
            reverse=not self._horizontal,
2✔
337
            include_empty_subsets=include_empty_subsets,
2✔
338
        )
339
        self.subset_styles = [
1✔
340
            {"facecolor": facecolor} for i in range(len(self.intersections))
2✔
341
        ]
342
        self.subset_legend = []  # pairs of (style, label)
2✔
343

344
    def _swapaxes(self, x, y):
2✔
345
        if self._horizontal:
2✔
346
            return x, y
2✔
347
        return y, x
2✔
348

349
    def style_subsets(
1✔
350
        self,
351
        present=None,
1✔
352
        absent=None,
1✔
353
        min_subset_size=None,
1✔
354
        max_subset_size=None,
1✔
355
        min_degree=None,
1✔
356
        max_degree=None,
1✔
357
        facecolor=None,
1✔
358
        edgecolor=None,
1✔
359
        hatch=None,
1✔
360
        linewidth=None,
1✔
361
        linestyle=None,
1✔
362
        label=None,
2✔
363
    ):
364
        """Updates the style of selected subsets' bars and matrix dots
365

366
        Parameters are either used to select subsets, or to style them with
367
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
368
        which adds a legend entry.
369

370
        Parameters
371
        ----------
372
        present : str or list of str, optional
373
            Category or categories that must be present in subsets for styling.
374
        absent : str or list of str, optional
375
            Category or categories that must not be present in subsets for
376
            styling.
377
        min_subset_size : int, optional
378
            Minimum size of a subset to be styled.
379
        max_subset_size : int, optional
380
            Maximum size of a subset to be styled.
381
        min_degree : int, optional
382
            Minimum degree of a subset to be styled.
383
        max_degree : int, optional
384
            Maximum degree of a subset to be styled.
385

386
        facecolor : str or matplotlib color, optional
387
            Override the default UpSet facecolor for selected subsets.
388
        edgecolor : str or matplotlib color, optional
389
            Set the edgecolor for bars, dots, and the line between dots.
390
        hatch : str, optional
391
            Set the hatch. This will apply to intersection size bars, but not
392
            to matrix dots.
393
        linewidth : int, optional
394
            Line width in points for edges.
395
        linestyle : str, optional
396
            Line style for edges.
397

398
        label : str, optional
399
            If provided, a legend will be added
400
        """
401
        style = {
1✔
402
            "facecolor": facecolor,
2✔
403
            "edgecolor": edgecolor,
2✔
404
            "hatch": hatch,
2✔
405
            "linewidth": linewidth,
2✔
406
            "linestyle": linestyle,
2✔
407
        }
408
        style = {k: v for k, v in style.items() if v is not None}
2✔
409
        mask = _get_subset_mask(
2✔
410
            self.intersections,
2✔
411
            present=present,
2✔
412
            absent=absent,
2✔
413
            min_subset_size=min_subset_size,
2✔
414
            max_subset_size=max_subset_size,
2✔
415
            min_degree=min_degree,
2✔
416
            max_degree=max_degree,
2✔
417
        )
418
        for idx in np.flatnonzero(mask):
2✔
419
            self.subset_styles[idx].update(style)
2✔
420

421
        if label is not None:
2✔
422
            if "facecolor" not in style:
2✔
423
                style["facecolor"] = self._facecolor
424
            for i, (other_style, other_label) in enumerate(self.subset_legend):
2✔
425
                if other_style == style:
2✔
426
                    if other_label != label:
2✔
427
                        self.subset_legend[i] = (style, other_label + "; " + label)
2✔
428
                    break
2✔
429
            else:
430
                self.subset_legend.append((style, label))
2✔
431

432
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
2✔
433
        ax = self._reorient(ax)
2✔
434
        ax.set_autoscalex_on(False)
2✔
435
        data_df = pd.DataFrame(data)
2✔
436
        if self._horizontal:
2✔
437
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
2✔
438

439
        # TODO: colors should be broadcastable to data_df shape
440
        if callable(colors):
2✔
441
            colors = colors(range(data_df.shape[1]))
2✔
442
        elif isinstance(colors, (str, type(None))):
2✔
443
            colors = [colors] * len(data_df)
2✔
444

445
        if self._horizontal:
2✔
446
            colors = reversed(colors)
2✔
447

448
        x = np.arange(len(data_df))
2✔
449
        cum_y = None
2✔
450
        all_rects = []
2✔
451
        for (name, y), color in zip(data_df.items(), colors):
2✔
452
            rects = ax.bar(
2✔
453
                x,
2✔
454
                y,
2✔
455
                0.5,
2✔
456
                cum_y,
2✔
457
                color=color,
2✔
458
                zorder=10,
2✔
459
                label=name if use_labels else None,
2✔
460
                align="center",
2✔
461
            )
462
            cum_y = y if cum_y is None else cum_y + y
2✔
463
            all_rects.extend(rects)
2✔
464

465
        self._label_sizes(ax, rects, "top" if self._horizontal else "right")
2✔
466

467
        ax.xaxis.set_visible(False)
2✔
468
        for x in ["top", "bottom", "right"]:
2✔
469
            ax.spines[self._reorient(x)].set_visible(False)
2✔
470

471
        tick_axis = ax.yaxis
2✔
472
        tick_axis.grid(True)
2✔
473
        ax.set_ylabel(title)
2✔
474
        return all_rects
2✔
475

476
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
2✔
477
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
2✔
478
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
2✔
479
        if sum_over is None and "_value" in df.columns:
2✔
480
            data = gb["_value"].sum()
481
        elif sum_over is None:
2✔
482
            data = gb.size()
2✔
483
        else:
484
            data = gb[sum_over].sum()
2✔
485
        data = data.unstack(by).fillna(0)
2✔
486
        if isinstance(colors, str):
2✔
487
            colors = matplotlib.cm.get_cmap(colors)
2✔
488
        elif isinstance(colors, typing.Mapping):
2✔
489
            colors = data.columns.map(colors).values
2✔
490
            if pd.isna(colors).any():
2✔
491
                raise KeyError(
492
                    "Some labels mapped by colors: %r"
493
                    % data.columns[pd.isna(colors)].tolist()
494
                )
495

496
        self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True)
2✔
497

498
        handles, labels = ax.get_legend_handles_labels()
2✔
499
        if self._horizontal:
2✔
500
            # Make legend order match visual stack order
501
            ax.legend(reversed(handles), reversed(labels))
2✔
502
        else:
503
            ax.legend()
2✔
504

505
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None):
2✔
506
        """Add a stacked bar chart over subsets when :func:`plot` is called.
507

508
        Used to plot categorical variable distributions within each subset.
509

510
        .. versionadded:: 0.6
511

512
        Parameters
513
        ----------
514
        by : str
515
            Column name within the dataframe for color coding the stacked bars,
516
            containing discrete or categorical values.
517
        sum_over : str, optional
518
            Ordinarily the bars will chart the size of each group. sum_over
519
            may specify a column which will be summed to determine the size
520
            of each bar.
521
        colors : Mapping, list-like, str or callable, optional
522
            The facecolors to use for bars corresponding to each discrete
523
            label, specified as one of:
524

525
            Mapping
526
                Maps from label to matplotlib-compatible color specification.
527
            list-like
528
                A list of matplotlib colors to apply to labels in order.
529
            str
530
                The name of a matplotlib colormap name.
531
            callable
532
                When called with the number of labels, this should return a
533
                list-like of that many colors.  Matplotlib colormaps satisfy
534
                this callable API.
535
            None
536
                Uses the matplotlib default colormap.
537
        elements : int, default=3
538
            Size of the axes counted in number of matrix elements.
539
        title : str, optional
540
            The axis title labelling bar length.
541

542
        Returns
543
        -------
544
        None
545
        """
546
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
547
        #                        list of labels}
548
        self._subset_plots.append(
2✔
549
            {
1✔
550
                "type": "stacked_bars",
2✔
551
                "by": by,
2✔
552
                "sum_over": sum_over,
2✔
553
                "colors": colors,
2✔
554
                "title": title,
2✔
555
                "id": "extra%d" % len(self._subset_plots),
2✔
556
                "elements": elements,
2✔
557
            }
558
        )
559

560
    def add_catplot(self, kind, value=None, elements=3, **kw):
2✔
561
        """Add a seaborn catplot over subsets when :func:`plot` is called.
562

563
        Parameters
564
        ----------
565
        kind : str
566
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
567
        value : str, optional
568
            Column name for the value to plot (i.e. y if
569
            orientation='horizontal'), required if `data` is a DataFrame.
570
        elements : int, default=3
571
            Size of the axes counted in number of matrix elements.
572
        **kw : dict
573
            Additional keywords to pass to :func:`seaborn.catplot`.
574

575
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
576
            and 'orient', so these are prohibited keys in `kw`.
577

578
        Returns
579
        -------
580
        None
581
        """
582
        assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"}
583
        if value is None:
584
            if "_value" not in self._df.columns:
585
                raise ValueError(
586
                    "value cannot be set if data is a Series. " "Got %r" % value
587
                )
588
        else:
589
            if value not in self._df.columns:
590
                raise ValueError("value %r is not a column in data" % value)
591
        self._subset_plots.append(
592
            {
593
                "type": "catplot",
594
                "value": value,
595
                "kind": kind,
596
                "id": "extra%d" % len(self._subset_plots),
597
                "elements": elements,
598
                "kw": kw,
599
            }
600
        )
601

602
    def _check_value(self, value):
2✔
603
        if value is None and "_value" in self._df.columns:
604
            value = "_value"
605
        elif value is None:
606
            raise ValueError("value can only be None when data is a Series")
607
        return value
608

609
    def _plot_catplot(self, ax, value, kind, kw):
2✔
610
        df = self._df
611
        value = self._check_value(value)
612
        kw = kw.copy()
613
        if self._horizontal:
614
            kw["orient"] = "v"
615
            kw["x"] = "_bin"
616
            kw["y"] = value
617
        else:
618
            kw["orient"] = "h"
619
            kw["x"] = value
620
            kw["y"] = "_bin"
621
        import seaborn
622

623
        kw["ax"] = ax
624
        getattr(seaborn, kind + "plot")(data=df, **kw)
625

626
        ax = self._reorient(ax)
627
        if value == "_value":
628
            ax.set_ylabel("")
629

630
        ax.xaxis.set_visible(False)
631
        for x in ["top", "bottom", "right"]:
632
            ax.spines[self._reorient(x)].set_visible(False)
633

634
        tick_axis = ax.yaxis
635
        tick_axis.grid(True)
636

637
    def make_grid(self, fig=None):
2✔
638
        """Get a SubplotSpec for each Axes, accounting for label text width"""
639
        n_cats = len(self.totals)
640
        n_inters = len(self.intersections)
641

642
        if fig is None:
643
            fig = plt.gcf()
644

645
        # Determine text size to determine figure size / spacing
646
        text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]}
647
        # adding "x" ensures a margin
648
        t = fig.text(
649
            0,
650
            0,
651
            "\n".join(str(label) + "x" for label in self.totals.index.values),
652
            **text_kw,
653
        )
654
        window_extent_args = {}
655
        if RENDERER_IMPORTED:
656
            window_extent_args["renderer"] = get_renderer(fig)
657
        textw = t.get_window_extent(**window_extent_args).width
658
        t.remove()
659

660
        window_extent_args = {}
661
        if RENDERER_IMPORTED:
662
            window_extent_args["renderer"] = get_renderer(fig)
663
        figw = self._reorient(fig.get_window_extent(**window_extent_args)).width
664

665
        sizes = np.asarray([p["elements"] for p in self._subset_plots])
666
        fig = self._reorient(fig)
667

668
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
669
        if self._element_size is None:
670
            colw = (figw - textw) / non_text_nelems
671
        else:
672
            render_ratio = figw / fig.get_figwidth()
673
            colw = self._element_size / 72 * render_ratio
674
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
675
            fig.set_figwidth(figw / render_ratio)
676
            fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)
677

678
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
679
        # print('textw', textw, 'figw', figw, 'colw', colw,
680
        #       'ncols', figw/colw, 'text_nelems', text_nelems)
681

682
        GS = self._reorient(matplotlib.gridspec.GridSpec)
683
        gridspec = GS(
684
            *self._swapaxes(
685
                n_cats + (sizes.sum() or 0),
686
                n_inters + text_nelems + self._totals_plot_elements,
687
            ),
688
            hspace=1,
689
        )
690
        if self._horizontal:
691
            out = {
692
                "matrix": gridspec[-n_cats:, -n_inters:],
693
                "shading": gridspec[-n_cats:, :],
694
                "totals": gridspec[-n_cats:, : self._totals_plot_elements],
695
                "gs": gridspec,
696
            }
697
            cumsizes = np.cumsum(sizes[::-1])
698
            for start, stop, plot in zip(
699
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]
700
            ):
701
                out[plot["id"]] = gridspec[start:stop, -n_inters:]
702
        else:
703
            out = {
704
                "matrix": gridspec[-n_inters:, :n_cats],
705
                "shading": gridspec[:, :n_cats],
706
                "totals": gridspec[: self._totals_plot_elements, :n_cats],
707
                "gs": gridspec,
708
            }
709
            cumsizes = np.cumsum(sizes)
710
            for start, stop, plot in zip(
711
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots
712
            ):
713
                out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats]
714
        return out
715

716
    def plot_matrix(self, ax):
717
        """Plot the matrix of intersection indicators onto ax"""
718
        ax = self._reorient(ax)
2✔
719
        data = self.intersections
2✔
720
        n_cats = data.index.nlevels
2✔
721

722
        inclusion = data.index.to_frame().values
2✔
723

724
        # Prepare styling
725
        styles = [
1✔
726
            [
2✔
727
                self.subset_styles[i]
2✔
728
                if inclusion[i, j]
1✔
729
                else {"facecolor": self._other_dots_color, "linewidth": 0}
1✔
730
                for j in range(n_cats)
2✔
731
            ]
732
            for i in range(len(data))
2✔
733
        ]
734
        styles = sum(styles, [])  # flatten nested list
2✔
735
        style_columns = {
1✔
736
            "facecolor": "facecolors",
2✔
737
            "edgecolor": "edgecolors",
2✔
738
            "linewidth": "linewidths",
2✔
739
            "linestyle": "linestyles",
2✔
740
            "hatch": "hatch",
2✔
741
        }
742
        styles = (
1✔
743
            pd.DataFrame(styles)
2✔
744
            .reindex(columns=style_columns.keys())
2✔
745
            .astype(
1✔
746
                {
1✔
747
                    "facecolor": "O",
2✔
748
                    "edgecolor": "O",
2✔
749
                    "linewidth": float,
2✔
750
                    "linestyle": "O",
2✔
751
                    "hatch": "O",
2✔
752
                }
753
            )
754
        )
755
        styles["linewidth"].fillna(1, inplace=True)
2✔
756
        styles["facecolor"].fillna(self._facecolor, inplace=True)
2✔
757
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
2✔
758
        styles["linestyle"].fillna("solid", inplace=True)
2✔
759
        del styles["hatch"]  # not supported in matrix (currently)
2✔
760

761
        x = np.repeat(np.arange(len(data)), n_cats)
2✔
762
        y = np.tile(np.arange(n_cats), len(data))
2✔
763

764
        # Plot dots
765
        if self._element_size is not None:
2✔
766
            s = (self._element_size * 0.35) ** 2
2✔
767
        else:
768
            # TODO: make s relative to colw
769
            s = 200
2✔
770
        ax.scatter(
2✔
771
            *self._swapaxes(x, y),
2✔
772
            s=s,
2✔
773
            zorder=10,
2✔
774
            **styles.rename(columns=style_columns),
2✔
775
        )
776

777
        # Plot lines
778
        if self._with_lines:
2✔
779
            idx = np.flatnonzero(inclusion)
2✔
780
            line_data = (
1✔
781
                pd.Series(y[idx], index=x[idx])
2✔
782
                .groupby(level=0)
2✔
783
                .aggregate(["min", "max"])
2✔
784
            )
785
            colors = pd.Series(
2✔
786
                [
1✔
787
                    style.get("edgecolor", style.get("facecolor", self._facecolor))
2✔
788
                    for style in self.subset_styles
2✔
789
                ],
790
                name="color",
2✔
791
            )
792
            line_data = line_data.join(colors)
2✔
793
            ax.vlines(
2✔
794
                line_data.index.values,
2✔
795
                line_data["min"],
2✔
796
                line_data["max"],
2✔
797
                lw=2,
2✔
798
                colors=line_data["color"],
2✔
799
                zorder=5,
2✔
800
            )
801

802
        # Ticks and axes
803
        tick_axis = ax.yaxis
2✔
804
        tick_axis.set_ticks(np.arange(n_cats))
2✔
805
        tick_axis.set_ticklabels(
2✔
806
            data.index.names, rotation=0 if self._horizontal else -90
2✔
807
        )
808
        ax.xaxis.set_visible(False)
2✔
809
        ax.tick_params(axis="both", which="both", length=0)
2✔
810
        if not self._horizontal:
2✔
811
            ax.yaxis.set_ticks_position("top")
2✔
812
        ax.set_frame_on(False)
2✔
813
        ax.set_xlim(-0.5, x[-1] + 0.5, auto=False)
2✔
814
        ax.grid(False)
2✔
815

816
    def plot_intersections(self, ax):
2✔
817
        """Plot bars indicating intersection size"""
818
        rects = self._plot_bars(
819
            ax, self.intersections, title="Intersection size", colors=self._facecolor
820
        )
821
        for style, rect in zip(self.subset_styles, rects):
822
            style = style.copy()
823
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
824
            for attr, val in style.items():
825
                getattr(rect, "set_" + attr)(val)
826

827
        if self.subset_legend:
828
            styles, labels = zip(*self.subset_legend)
829
            styles = [patches.Patch(**patch_style) for patch_style in styles]
830
            ax.legend(styles, labels)
831

832
    def _label_sizes(self, ax, rects, where):
833
        if not self._show_counts and not self._show_percentages:
834
            return
835
        if self._show_counts is True:
836
            count_fmt = "{:.0f}"
837
        else:
838
            count_fmt = self._show_counts
839
            if "{" not in count_fmt:
840
                count_fmt = util.to_new_pos_format(count_fmt)
841

842
        if self._show_percentages is True:
843
            pct_fmt = "{:.1%}"
844
        else:
845
            pct_fmt = self._show_percentages
846

847
        if count_fmt and pct_fmt:
848
            if where == "top":
849
                fmt = "%s\n(%s)" % (count_fmt, pct_fmt)
850
            else:
851
                fmt = "%s (%s)" % (count_fmt, pct_fmt)
852

853
            def make_args(val):
854
                return val, val / self.total
855
        elif count_fmt:
856
            fmt = count_fmt
857

858
            def make_args(val):
859
                return (val,)
860
        else:
861
            fmt = pct_fmt
862

863
            def make_args(val):
864
                return (val / self.total,)
865

866
        if where == "right":
867
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
868
            for rect in rects:
869
                width = rect.get_width() + rect.get_x()
870
                ax.text(
871
                    width + margin,
872
                    rect.get_y() + rect.get_height() * 0.5,
873
                    fmt.format(*make_args(width)),
874
                    ha="left",
875
                    va="center",
876
                )
877
        elif where == "left":
878
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
879
            for rect in rects:
880
                width = rect.get_width() + rect.get_x()
881
                ax.text(
882
                    width + margin,
883
                    rect.get_y() + rect.get_height() * 0.5,
884
                    fmt.format(*make_args(width)),
885
                    ha="right",
886
                    va="center",
887
                )
888
        elif where == "top":
889
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
890
            for rect in rects:
891
                height = rect.get_height() + rect.get_y()
892
                ax.text(
893
                    rect.get_x() + rect.get_width() * 0.5,
894
                    height + margin,
895
                    fmt.format(*make_args(height)),
896
                    ha="center",
897
                    va="bottom",
898
                )
899
        else:
900
            raise NotImplementedError("unhandled where: %r" % where)
901

902
    def plot_totals(self, ax):
903
        """Plot bars indicating total set size"""
904
        orig_ax = ax
2✔
905
        ax = self._reorient(ax)
2✔
906
        rects = ax.barh(
2✔
907
            np.arange(len(self.totals.index.values)),
2✔
908
            self.totals,
2✔
909
            0.5,
2✔
910
            color=self._facecolor,
2✔
911
            align="center",
2✔
912
        )
913
        self._label_sizes(ax, rects, "left" if self._horizontal else "top")
2✔
914

915
        max_total = self.totals.max()
2✔
916
        if self._horizontal:
2✔
917
            orig_ax.set_xlim(max_total, 0)
2✔
918
        for x in ["top", "left", "right"]:
2✔
919
            ax.spines[self._reorient(x)].set_visible(False)
2✔
920
        ax.yaxis.set_visible(False)
2✔
921
        ax.xaxis.grid(True)
2✔
922
        ax.yaxis.grid(False)
2✔
923
        ax.patch.set_visible(False)
2✔
924

925
    def plot_shading(self, ax):
2✔
926
        # alternating row shading (XXX: use add_patch(Rectangle)?)
927
        for i in range(0, len(self.totals), 2):
2✔
928
            rect = plt.Rectangle(
2✔
929
                self._swapaxes(0, i - 0.4),
2✔
930
                *self._swapaxes(*(1, 0.8)),
2✔
931
                facecolor=self._shading_color,
2✔
932
                lw=0,
2✔
933
                zorder=0,
2✔
934
            )
935
            ax.add_patch(rect)
2✔
936
        ax.set_frame_on(False)
2✔
937
        ax.tick_params(
2✔
938
            axis="both",
2✔
939
            which="both",
2✔
940
            left=False,
2✔
941
            right=False,
2✔
942
            bottom=False,
2✔
943
            top=False,
2✔
944
            labelbottom=False,
2✔
945
            labelleft=False,
2✔
946
        )
947
        ax.grid(False)
2✔
948
        ax.set_xticks([])
2✔
949
        ax.set_yticks([])
2✔
950
        ax.set_xticklabels([])
2✔
951
        ax.set_yticklabels([])
2✔
952

953
    def plot(self, fig=None):
2✔
954
        """Draw all parts of the plot onto fig or a new figure
955

956
        Parameters
957
        ----------
958
        fig : matplotlib.figure.Figure, optional
959
            Defaults to a new figure.
960

961
        Returns
962
        -------
963
        subplots : dict of matplotlib.axes.Axes
964
            Keys are 'matrix', 'intersections', 'totals', 'shading'
965
        """
966
        if fig is None:
2✔
967
            fig = plt.figure(figsize=self._default_figsize)
2✔
968
        specs = self.make_grid(fig)
2✔
969
        shading_ax = fig.add_subplot(specs["shading"])
2✔
970
        self.plot_shading(shading_ax)
2✔
971
        matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax)
2✔
972
        self.plot_matrix(matrix_ax)
2✔
973
        totals_ax = self._reorient(fig.add_subplot)(specs["totals"], sharey=matrix_ax)
2✔
974
        self.plot_totals(totals_ax)
2✔
975
        out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax}
2✔
976

977
        for plot in self._subset_plots:
2✔
978
            ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax)
2✔
979
            if plot["type"] == "default":
2✔
980
                self.plot_intersections(ax)
2✔
981
            elif plot["type"] in self.PLOT_TYPES:
2✔
982
                kw = plot.copy()
2✔
983
                del kw["type"]
2✔
984
                del kw["elements"]
2✔
985
                del kw["id"]
2✔
986
                self.PLOT_TYPES[plot["type"]](self, ax, **kw)
2✔
987
            else:
988
                raise ValueError("Unknown subset plot type: %r" % plot["type"])
989
            out[plot["id"]] = ax
2✔
990
        return out
2✔
991

992
    PLOT_TYPES = {
1✔
993
        "catplot": _plot_catplot,
2✔
994
        "stacked_bars": _plot_stacked_bars,
2✔
995
    }
996

997
    def _repr_html_(self):
2✔
998
        fig = plt.figure(figsize=self._default_figsize)
999
        self.plot(fig=fig)
1000
        return fig._repr_html_()
1001

1002

1003
def plot(data, fig=None, **kwargs):
2✔
1004
    """Make an UpSet plot of data on fig
1005

1006
    Parameters
1007
    ----------
1008
    data : pandas.Series or pandas.DataFrame
1009
        Values for each set to plot.
1010
        Should have multi-index where each level is binary,
1011
        corresponding to set membership.
1012
        If a DataFrame, `sum_over` must be a string or False.
1013
    fig : matplotlib.figure.Figure, optional
1014
        Defaults to a new figure.
1015
    kwargs
1016
        Other arguments for :class:`UpSet`
1017

1018
    Returns
1019
    -------
1020
    subplots : dict of matplotlib.axes.Axes
1021
        Keys are 'matrix', 'intersections', 'totals', 'shading'
1022
    """
1023
    return UpSet(data, **kwargs).plot(fig)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc