• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7353735838

29 Dec 2023 04:27AM UTC coverage: 99.263%. Remained the same
7353735838

push

github

web-flow
Update CHANGELOG.rst to mention PR number (#262)

1750 of 1763 relevant lines covered (99.26%)

1.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/upsetplot/plotting.py
1
import typing
2✔
2
import warnings
2✔
3

4
import matplotlib
2✔
5
import numpy as np
2✔
6
import pandas as pd
2✔
7
from matplotlib import colors, patches
2✔
8
from matplotlib import pyplot as plt
2✔
9

10
from . import util
2✔
11
from .reformat import _get_subset_mask, query
2✔
12

13
# prevents ImportError on matplotlib versions >3.5.2
14
try:
2✔
15
    from matplotlib.tight_layout import get_renderer
2✔
16

17
    RENDERER_IMPORTED = True
1✔
18
except ImportError:
1✔
19
    RENDERER_IMPORTED = False
1✔
20

21

22
def _process_data(
2✔
23
    df,
24
    *,
25
    sort_by,
26
    sort_categories_by,
27
    subset_size,
28
    sum_over,
29
    min_subset_size=None,
2✔
30
    max_subset_size=None,
2✔
31
    max_subset_rank=None,
2✔
32
    min_degree=None,
2✔
33
    max_degree=None,
2✔
34
    reverse=False,
2✔
35
    include_empty_subsets=False,
2✔
36
):
37
    results = query(
2✔
38
        df,
2✔
39
        sort_by=sort_by,
2✔
40
        sort_categories_by=sort_categories_by,
2✔
41
        subset_size=subset_size,
2✔
42
        sum_over=sum_over,
2✔
43
        min_subset_size=min_subset_size,
2✔
44
        max_subset_size=max_subset_size,
2✔
45
        max_subset_rank=max_subset_rank,
2✔
46
        min_degree=min_degree,
2✔
47
        max_degree=max_degree,
2✔
48
        include_empty_subsets=include_empty_subsets,
2✔
49
    )
50

51
    df = results.data
2✔
52
    agg = results.subset_sizes
2✔
53

54
    # add '_bin' to df indicating index in agg
55
    # XXX: ugly!
56
    def _pack_binary(X):
2✔
57
        X = pd.DataFrame(X)
2✔
58
        # use objects if arbitrary precision integers are needed
59
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
2✔
60
        out = pd.Series(0, index=X.index, dtype=dtype)
2✔
61
        for _, col in X.items():
2✔
62
            out *= 2
2✔
63
            out += col
2✔
64
        return out
2✔
65

66
    df_packed = _pack_binary(df.index.to_frame())
2✔
67
    data_packed = _pack_binary(agg.index.to_frame())
2✔
68
    df["_bin"] = pd.Series(df_packed).map(
2✔
69
        pd.Series(
2✔
70
            np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed
2✔
71
        )
72
    )
73
    if reverse:
2✔
74
        agg = agg[::-1]
2✔
75

76
    return results.total, df, agg, results.category_totals
2✔
77

78

79
def _multiply_alpha(c, mult):
2✔
80
    r, g, b, a = colors.to_rgba(c)
2✔
81
    a *= mult
2✔
82
    return colors.to_hex((r, g, b, a), keep_alpha=True)
2✔
83

84

85
class _Transposed:
2✔
86
    """Wrap an object in order to transpose some plotting operations
87

88
    Attributes of obj will be mapped.
89
    Keyword arguments when calling obj will be mapped.
90

91
    The mapping is not recursive: callable attributes need to be _Transposed
92
    again.
93
    """
94

95
    def __init__(self, obj):
2✔
96
        self.__obj = obj
2✔
97

98
    def __getattr__(self, key):
2✔
99
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
2✔
100

101
    def __call__(self, *args, **kwargs):
2✔
102
        return self.__obj(
2✔
103
            *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()}
2✔
104
        )
105

106
    _NAME_TRANSPOSE = {
2✔
107
        "width": "height",
2✔
108
        "height": "width",
2✔
109
        "hspace": "wspace",
2✔
110
        "wspace": "hspace",
2✔
111
        "hlines": "vlines",
2✔
112
        "vlines": "hlines",
2✔
113
        "bar": "barh",
2✔
114
        "barh": "bar",
2✔
115
        "xaxis": "yaxis",
2✔
116
        "yaxis": "xaxis",
2✔
117
        "left": "bottom",
2✔
118
        "right": "top",
2✔
119
        "top": "right",
2✔
120
        "bottom": "left",
2✔
121
        "sharex": "sharey",
2✔
122
        "sharey": "sharex",
2✔
123
        "get_figwidth": "get_figheight",
2✔
124
        "get_figheight": "get_figwidth",
2✔
125
        "set_figwidth": "set_figheight",
2✔
126
        "set_figheight": "set_figwidth",
2✔
127
        "set_xlabel": "set_ylabel",
2✔
128
        "set_ylabel": "set_xlabel",
2✔
129
        "set_xlim": "set_ylim",
2✔
130
        "set_ylim": "set_xlim",
2✔
131
        "get_xlim": "get_ylim",
2✔
132
        "get_ylim": "get_xlim",
2✔
133
        "set_autoscalex_on": "set_autoscaley_on",
2✔
134
        "set_autoscaley_on": "set_autoscalex_on",
2✔
135
    }
136

137

138
def _transpose(obj):
2✔
139
    if isinstance(obj, str):
2✔
140
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
2✔
141
    return _Transposed(obj)
2✔
142

143

144
def _identity(obj):
2✔
145
    return obj
2✔
146

147

148
class UpSet:
2✔
149
    """Manage the data and drawing for a basic UpSet plot
150

151
    Primary public method is :meth:`plot`.
152

153
    Parameters
154
    ----------
155
    data : pandas.Series or pandas.DataFrame
156
        Elements associated with categories (a DataFrame), or the size of each
157
        subset of categories (a Series).
158
        Should have MultiIndex where each level is binary,
159
        corresponding to category membership.
160
        If a DataFrame, `sum_over` must be a string or False.
161
    orientation : {'horizontal' (default), 'vertical'}
162
        If horizontal, intersections are listed from left to right.
163
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
164
               'input', '-input'}
165
        If 'cardinality', subset are listed from largest to smallest.
166
        If 'degree', they are listed in order of the number of categories
167
        intersected. If 'input', the order they appear in the data input is
168
        used.
169
        Prefix with '-' to reverse the ordering.
170

171
        Note this affects ``subset_sizes`` but not ``data``.
172
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
173
        Whether to sort the categories by total cardinality, or leave them
174
        in the input data's provided order (order of index levels).
175
        Prefix with '-' to reverse the ordering.
176
    subset_size : {'auto', 'count', 'sum'}
177
        Configures how to calculate the size of a subset. Choices are:
178

179
        'auto' (default)
180
            If `data` is a DataFrame, count the number of rows in each group,
181
            unless `sum_over` is specified.
182
            If `data` is a Series with at most one row for each group, use
183
            the value of the Series. If `data` is a Series with more than one
184
            row per group, raise a ValueError.
185
        'count'
186
            Count the number of rows in each group.
187
        'sum'
188
            Sum the value of the `data` Series, or the DataFrame field
189
            specified by `sum_over`.
190
    sum_over : str or None
191
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
192
        sum of the specified field in the `data` DataFrame. If a Series, only
193
        None is supported and its value is summed.
194
    min_subset_size : int, optional
195
        Minimum size of a subset to be shown in the plot. All subsets with
196
        a size smaller than this threshold will be omitted from plotting.
197
        Size may be a sum of values, see `subset_size`.
198

199
        .. versionadded:: 0.5
200
    max_subset_size : int, optional
201
        Maximum size of a subset to be shown in the plot. All subsets with
202
        a size greater than this threshold will be omitted from plotting.
203

204
        .. versionadded:: 0.5
205
    max_subset_rank : int, optional
206
        Limit to the top N ranked subsets in descending order of size.
207
        All tied subsets are included.
208

209
        .. versionadded:: 0.9
210
    min_degree : int, optional
211
        Minimum degree of a subset to be shown in the plot.
212

213
        .. versionadded:: 0.5
214
    max_degree : int, optional
215
        Maximum degree of a subset to be shown in the plot.
216

217
        .. versionadded:: 0.5
218
    facecolor : 'auto' or matplotlib color or float
219
        Color for bar charts and active dots. Defaults to black if
220
        axes.facecolor is a light color, otherwise white.
221

222
        .. versionchanged:: 0.6
223
            Before 0.6, the default was 'black'
224
    other_dots_color : matplotlib color or float
225
        Color for shading of inactive dots, or opacity (between 0 and 1)
226
        applied to facecolor.
227

228
        .. versionadded:: 0.6
229
    shading_color : matplotlib color or float
230
        Color for shading of odd rows in matrix and totals, or opacity (between
231
        0 and 1) applied to facecolor.
232

233
        .. versionadded:: 0.6
234
    with_lines : bool
235
        Whether to show lines joining dots in the matrix, to mark multiple
236
        categories being intersected.
237
    element_size : float or None
238
        Side length in pt. If None, size is estimated to fit figure
239
    intersection_plot_elements : int
240
        The intersections plot should be large enough to fit this many matrix
241
        elements. Set to 0 to disable intersection size bars.
242

243
        .. versionchanged:: 0.4
244
            Setting to 0 is handled.
245
    totals_plot_elements : int
246
        The totals plot should be large enough to fit this many matrix
247
        elements. Use totals_plot_elements=0 to disable the totals plot.
248
    show_counts : bool or str, default=False
249
        Whether to label the intersection size bars with the cardinality
250
        of the intersection. When a string, this formats the number.
251
        For example, '{:d}' is equivalent to True.
252
        Note that, for legacy reasons, if the string does not contain '{',
253
        it will be interpreted as a C-style format string, such as '%d'.
254
    show_percentages : bool or str, default=False
255
        Whether to label the intersection size bars with the percentage
256
        of the intersection relative to the total dataset.
257
        When a string, this formats the number representing a fraction of
258
        samples.
259
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
260
        This may be applied with or without show_counts.
261

262
        .. versionadded:: 0.4
263
    include_empty_subsets : bool (default=False)
264
        If True, all possible category combinations will be shown as subsets,
265
        even when some are not present in data.
266
    """
267

268
    _default_figsize = (10, 6)
2✔
269

270
    def __init__(
2✔
271
        self,
272
        data,
273
        orientation="horizontal",
1✔
274
        sort_by="degree",
1✔
275
        sort_categories_by="cardinality",
1✔
276
        subset_size="auto",
1✔
277
        sum_over=None,
1✔
278
        min_subset_size=None,
1✔
279
        max_subset_size=None,
1✔
280
        max_subset_rank=None,
1✔
281
        min_degree=None,
1✔
282
        max_degree=None,
1✔
283
        facecolor="auto",
1✔
284
        other_dots_color=0.18,
1✔
285
        shading_color=0.05,
1✔
286
        with_lines=True,
1✔
287
        element_size=32,
1✔
288
        intersection_plot_elements=6,
1✔
289
        totals_plot_elements=2,
1✔
290
        show_counts="",
1✔
291
        show_percentages=False,
1✔
292
        include_empty_subsets=False,
1✔
293
    ):
294
        self._horizontal = orientation == "horizontal"
2✔
295
        self._reorient = _identity if self._horizontal else _transpose
2✔
296
        if facecolor == "auto":
2✔
297
            bgcolor = matplotlib.rcParams.get("axes.facecolor", "white")
2✔
298
            r, g, b, a = colors.to_rgba(bgcolor)
2✔
299
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
2✔
300
            facecolor = "black" if lightness >= 0.5 else "white"
2✔
301
        self._facecolor = facecolor
2✔
302
        self._shading_color = (
2✔
303
            _multiply_alpha(facecolor, shading_color)
2✔
304
            if isinstance(shading_color, float)
2✔
305
            else shading_color
2✔
306
        )
307
        self._other_dots_color = (
2✔
308
            _multiply_alpha(facecolor, other_dots_color)
2✔
309
            if isinstance(other_dots_color, float)
2✔
310
            else other_dots_color
2✔
311
        )
312
        self._with_lines = with_lines
2✔
313
        self._element_size = element_size
2✔
314
        self._totals_plot_elements = totals_plot_elements
2✔
315
        self._subset_plots = [
2✔
316
            {
2✔
317
                "type": "default",
2✔
318
                "id": "intersections",
2✔
319
                "elements": intersection_plot_elements,
2✔
320
            }
321
        ]
322
        if not intersection_plot_elements:
2✔
323
            self._subset_plots.pop()
2✔
324
        self._show_counts = show_counts
2✔
325
        self._show_percentages = show_percentages
2✔
326

327
        (self.total, self._df, self.intersections, self.totals) = _process_data(
2✔
328
            data,
2✔
329
            sort_by=sort_by,
2✔
330
            sort_categories_by=sort_categories_by,
2✔
331
            subset_size=subset_size,
2✔
332
            sum_over=sum_over,
2✔
333
            min_subset_size=min_subset_size,
2✔
334
            max_subset_size=max_subset_size,
2✔
335
            max_subset_rank=max_subset_rank,
2✔
336
            min_degree=min_degree,
2✔
337
            max_degree=max_degree,
2✔
338
            reverse=not self._horizontal,
2✔
339
            include_empty_subsets=include_empty_subsets,
2✔
340
        )
341
        self.subset_styles = [
2✔
342
            {"facecolor": facecolor} for i in range(len(self.intersections))
2✔
343
        ]
344
        self.subset_legend = []  # pairs of (style, label)
2✔
345

346
    def _swapaxes(self, x, y):
2✔
347
        if self._horizontal:
2✔
348
            return x, y
2✔
349
        return y, x
2✔
350

351
    def style_subsets(
2✔
352
        self,
353
        present=None,
1✔
354
        absent=None,
1✔
355
        min_subset_size=None,
1✔
356
        max_subset_size=None,
1✔
357
        max_subset_rank=None,
1✔
358
        min_degree=None,
1✔
359
        max_degree=None,
1✔
360
        facecolor=None,
1✔
361
        edgecolor=None,
1✔
362
        hatch=None,
1✔
363
        linewidth=None,
1✔
364
        linestyle=None,
1✔
365
        label=None,
1✔
366
    ):
367
        """Updates the style of selected subsets' bars and matrix dots
368

369
        Parameters are either used to select subsets, or to style them with
370
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
371
        which adds a legend entry.
372

373
        Parameters
374
        ----------
375
        present : str or list of str, optional
376
            Category or categories that must be present in subsets for styling.
377
        absent : str or list of str, optional
378
            Category or categories that must not be present in subsets for
379
            styling.
380
        min_subset_size : int, optional
381
            Minimum size of a subset to be styled.
382
        max_subset_size : int, optional
383
            Maximum size of a subset to be styled.
384
        max_subset_rank : int, optional
385
            Limit to the top N ranked subsets in descending order of size.
386
            All tied subsets are included.
387

388
            .. versionadded:: 0.9
389
        min_degree : int, optional
390
            Minimum degree of a subset to be styled.
391
        max_degree : int, optional
392
            Maximum degree of a subset to be styled.
393

394
        facecolor : str or matplotlib color, optional
395
            Override the default UpSet facecolor for selected subsets.
396
        edgecolor : str or matplotlib color, optional
397
            Set the edgecolor for bars, dots, and the line between dots.
398
        hatch : str, optional
399
            Set the hatch. This will apply to intersection size bars, but not
400
            to matrix dots.
401
        linewidth : int, optional
402
            Line width in points for edges.
403
        linestyle : str, optional
404
            Line style for edges.
405

406
        label : str, optional
407
            If provided, a legend will be added
408
        """
409
        style = {
2✔
410
            "facecolor": facecolor,
2✔
411
            "edgecolor": edgecolor,
2✔
412
            "hatch": hatch,
2✔
413
            "linewidth": linewidth,
2✔
414
            "linestyle": linestyle,
2✔
415
        }
416
        style = {k: v for k, v in style.items() if v is not None}
2✔
417
        mask = _get_subset_mask(
2✔
418
            self.intersections,
2✔
419
            present=present,
2✔
420
            absent=absent,
2✔
421
            min_subset_size=min_subset_size,
2✔
422
            max_subset_size=max_subset_size,
2✔
423
            max_subset_rank=max_subset_rank,
2✔
424
            min_degree=min_degree,
2✔
425
            max_degree=max_degree,
2✔
426
        )
427
        for idx in np.flatnonzero(mask):
2✔
428
            self.subset_styles[idx].update(style)
2✔
429

430
        if label is not None:
2✔
431
            if "facecolor" not in style:
2✔
432
                style["facecolor"] = self._facecolor
433
            for i, (other_style, other_label) in enumerate(self.subset_legend):
2✔
434
                if other_style == style:
2✔
435
                    if other_label != label:
2✔
436
                        self.subset_legend[i] = (style, other_label + "; " + label)
2✔
437
                    break
2✔
438
            else:
439
                self.subset_legend.append((style, label))
2✔
440

441
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
2✔
442
        ax = self._reorient(ax)
2✔
443
        ax.set_autoscalex_on(False)
2✔
444
        data_df = pd.DataFrame(data)
2✔
445
        if self._horizontal:
2✔
446
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
2✔
447

448
        # TODO: colors should be broadcastable to data_df shape
449
        if callable(colors):
2✔
450
            colors = colors(range(data_df.shape[1]))
2✔
451
        elif isinstance(colors, (str, type(None))):
2✔
452
            colors = [colors] * len(data_df)
2✔
453

454
        if self._horizontal:
2✔
455
            colors = reversed(colors)
2✔
456

457
        x = np.arange(len(data_df))
2✔
458
        cum_y = None
2✔
459
        all_rects = []
2✔
460
        for (name, y), color in zip(data_df.items(), colors):
2✔
461
            rects = ax.bar(
2✔
462
                x,
2✔
463
                y,
2✔
464
                0.5,
2✔
465
                cum_y,
2✔
466
                color=color,
2✔
467
                zorder=10,
2✔
468
                label=name if use_labels else None,
2✔
469
                align="center",
2✔
470
            )
471
            cum_y = y if cum_y is None else cum_y + y
2✔
472
            all_rects.extend(rects)
2✔
473

474
        self._label_sizes(ax, rects, "top" if self._horizontal else "right")
2✔
475

476
        ax.xaxis.set_visible(False)
2✔
477
        for x in ["top", "bottom", "right"]:
2✔
478
            ax.spines[self._reorient(x)].set_visible(False)
2✔
479

480
        tick_axis = ax.yaxis
2✔
481
        tick_axis.grid(True)
2✔
482
        ax.set_ylabel(title)
2✔
483
        return all_rects
2✔
484

485
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
2✔
486
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
2✔
487
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
2✔
488
        if sum_over is None and "_value" in df.columns:
2✔
489
            data = gb["_value"].sum()
490
        elif sum_over is None:
2✔
491
            data = gb.size()
2✔
492
        else:
493
            data = gb[sum_over].sum()
2✔
494
        data = data.unstack(by).fillna(0)
2✔
495
        if isinstance(colors, str):
2✔
496
            colors = matplotlib.cm.get_cmap(colors)
2✔
497
        elif isinstance(colors, typing.Mapping):
2✔
498
            colors = data.columns.map(colors).values
2✔
499
            if pd.isna(colors).any():
2✔
500
                raise KeyError(
501
                    "Some labels mapped by colors: %r"
502
                    % data.columns[pd.isna(colors)].tolist()
503
                )
504

505
        self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True)
2✔
506

507
        handles, labels = ax.get_legend_handles_labels()
2✔
508
        if self._horizontal:
2✔
509
            # Make legend order match visual stack order
510
            ax.legend(reversed(handles), reversed(labels))
2✔
511
        else:
512
            ax.legend()
2✔
513

514
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None):
2✔
515
        """Add a stacked bar chart over subsets when :func:`plot` is called.
516

517
        Used to plot categorical variable distributions within each subset.
518

519
        .. versionadded:: 0.6
520

521
        Parameters
522
        ----------
523
        by : str
524
            Column name within the dataframe for color coding the stacked bars,
525
            containing discrete or categorical values.
526
        sum_over : str, optional
527
            Ordinarily the bars will chart the size of each group. sum_over
528
            may specify a column which will be summed to determine the size
529
            of each bar.
530
        colors : Mapping, list-like, str or callable, optional
531
            The facecolors to use for bars corresponding to each discrete
532
            label, specified as one of:
533

534
            Mapping
535
                Maps from label to matplotlib-compatible color specification.
536
            list-like
537
                A list of matplotlib colors to apply to labels in order.
538
            str
539
                The name of a matplotlib colormap name.
540
            callable
541
                When called with the number of labels, this should return a
542
                list-like of that many colors.  Matplotlib colormaps satisfy
543
                this callable API.
544
            None
545
                Uses the matplotlib default colormap.
546
        elements : int, default=3
547
            Size of the axes counted in number of matrix elements.
548
        title : str, optional
549
            The axis title labelling bar length.
550

551
        Returns
552
        -------
553
        None
554
        """
555
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
556
        #                        list of labels}
557
        self._subset_plots.append(
2✔
558
            {
2✔
559
                "type": "stacked_bars",
2✔
560
                "by": by,
2✔
561
                "sum_over": sum_over,
2✔
562
                "colors": colors,
2✔
563
                "title": title,
2✔
564
                "id": "extra%d" % len(self._subset_plots),
2✔
565
                "elements": elements,
2✔
566
            }
567
        )
568

569
    def add_catplot(self, kind, value=None, elements=3, **kw):
2✔
570
        """Add a seaborn catplot over subsets when :func:`plot` is called.
571

572
        Parameters
573
        ----------
574
        kind : str
575
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
576
        value : str, optional
577
            Column name for the value to plot (i.e. y if
578
            orientation='horizontal'), required if `data` is a DataFrame.
579
        elements : int, default=3
580
            Size of the axes counted in number of matrix elements.
581
        **kw : dict
582
            Additional keywords to pass to :func:`seaborn.catplot`.
583

584
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
585
            and 'orient', so these are prohibited keys in `kw`.
586

587
        Returns
588
        -------
589
        None
590
        """
591
        assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"}
592
        if value is None:
593
            if "_value" not in self._df.columns:
594
                raise ValueError(
595
                    "value cannot be set if data is a Series. " "Got %r" % value
596
                )
597
        else:
598
            if value not in self._df.columns:
599
                raise ValueError("value %r is not a column in data" % value)
600
        self._subset_plots.append(
601
            {
602
                "type": "catplot",
603
                "value": value,
604
                "kind": kind,
605
                "id": "extra%d" % len(self._subset_plots),
606
                "elements": elements,
607
                "kw": kw,
608
            }
609
        )
610

611
    def _check_value(self, value):
2✔
612
        if value is None and "_value" in self._df.columns:
613
            value = "_value"
614
        elif value is None:
615
            raise ValueError("value can only be None when data is a Series")
616
        return value
617

618
    def _plot_catplot(self, ax, value, kind, kw):
2✔
619
        df = self._df
620
        value = self._check_value(value)
621
        kw = kw.copy()
622
        if self._horizontal:
623
            kw["orient"] = "v"
624
            kw["x"] = "_bin"
625
            kw["y"] = value
626
        else:
627
            kw["orient"] = "h"
628
            kw["x"] = value
629
            kw["y"] = "_bin"
630
        import seaborn
631

632
        kw["ax"] = ax
633
        getattr(seaborn, kind + "plot")(data=df, **kw)
634

635
        ax = self._reorient(ax)
636
        if value == "_value":
637
            ax.set_ylabel("")
638

639
        ax.xaxis.set_visible(False)
640
        for x in ["top", "bottom", "right"]:
641
            ax.spines[self._reorient(x)].set_visible(False)
642

643
        tick_axis = ax.yaxis
644
        tick_axis.grid(True)
645

646
    def make_grid(self, fig=None):
2✔
647
        """Get a SubplotSpec for each Axes, accounting for label text width"""
648
        n_cats = len(self.totals)
649
        n_inters = len(self.intersections)
650

651
        if fig is None:
652
            fig = plt.gcf()
653

654
        # Determine text size to determine figure size / spacing
655
        text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]}
656
        # adding "x" ensures a margin
657
        t = fig.text(
658
            0,
659
            0,
660
            "\n".join(str(label) + "x" for label in self.totals.index.values),
661
            **text_kw,
662
        )
663
        window_extent_args = {}
664
        if RENDERER_IMPORTED:
665
            with warnings.catch_warnings():
666
                warnings.simplefilter("ignore", DeprecationWarning)
667
                window_extent_args["renderer"] = get_renderer(fig)
668
        textw = t.get_window_extent(**window_extent_args).width
669
        t.remove()
670

671
        window_extent_args = {}
672
        if RENDERER_IMPORTED:
673
            with warnings.catch_warnings():
674
                warnings.simplefilter("ignore", DeprecationWarning)
675
                window_extent_args["renderer"] = get_renderer(fig)
676
        figw = self._reorient(fig.get_window_extent(**window_extent_args)).width
677

678
        sizes = np.asarray([p["elements"] for p in self._subset_plots])
679
        fig = self._reorient(fig)
680

681
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
682
        if self._element_size is None:
683
            colw = (figw - textw) / non_text_nelems
684
        else:
685
            render_ratio = figw / fig.get_figwidth()
686
            colw = self._element_size / 72 * render_ratio
687
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
688
            fig.set_figwidth(figw / render_ratio)
689
            fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)
690

691
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
692

693
        GS = self._reorient(matplotlib.gridspec.GridSpec)
694
        gridspec = GS(
695
            *self._swapaxes(
696
                n_cats + (sizes.sum() or 0),
697
                n_inters + text_nelems + self._totals_plot_elements,
698
            ),
699
            hspace=1,
700
        )
701
        if self._horizontal:
702
            out = {
703
                "matrix": gridspec[-n_cats:, -n_inters:],
704
                "shading": gridspec[-n_cats:, :],
705
                "totals": None
706
                if self._totals_plot_elements == 0
707
                else gridspec[-n_cats:, : self._totals_plot_elements],
708
                "gs": gridspec,
709
            }
710
            cumsizes = np.cumsum(sizes[::-1])
711
            for start, stop, plot in zip(
712
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]
713
            ):
714
                out[plot["id"]] = gridspec[start:stop, -n_inters:]
715
        else:
716
            out = {
717
                "matrix": gridspec[-n_inters:, :n_cats],
718
                "shading": gridspec[:, :n_cats],
719
                "totals": None
720
                if self._totals_plot_elements == 0
721
                else gridspec[: self._totals_plot_elements, :n_cats],
722
                "gs": gridspec,
723
            }
724
            cumsizes = np.cumsum(sizes)
725
            for start, stop, plot in zip(
726
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots
727
            ):
728
                out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats]
729
        return out
730

731
    def plot_matrix(self, ax):
732
        """Plot the matrix of intersection indicators onto ax"""
733
        ax = self._reorient(ax)
2✔
734
        data = self.intersections
2✔
735
        n_cats = data.index.nlevels
2✔
736

737
        inclusion = data.index.to_frame().values
2✔
738

739
        # Prepare styling
740
        styles = [
2✔
741
            [
2✔
742
                self.subset_styles[i]
2✔
743
                if inclusion[i, j]
2✔
744
                else {"facecolor": self._other_dots_color, "linewidth": 0}
2✔
745
                for j in range(n_cats)
2✔
746
            ]
747
            for i in range(len(data))
2✔
748
        ]
749
        styles = sum(styles, [])  # flatten nested list
2✔
750
        style_columns = {
2✔
751
            "facecolor": "facecolors",
2✔
752
            "edgecolor": "edgecolors",
2✔
753
            "linewidth": "linewidths",
2✔
754
            "linestyle": "linestyles",
2✔
755
            "hatch": "hatch",
2✔
756
        }
757
        styles = (
2✔
758
            pd.DataFrame(styles)
2✔
759
            .reindex(columns=style_columns.keys())
2✔
760
            .astype(
1✔
761
                {
2✔
762
                    "facecolor": "O",
2✔
763
                    "edgecolor": "O",
2✔
764
                    "linewidth": float,
2✔
765
                    "linestyle": "O",
2✔
766
                    "hatch": "O",
2✔
767
                }
768
            )
769
        )
770
        styles["linewidth"].fillna(1, inplace=True)
2✔
771
        styles["facecolor"].fillna(self._facecolor, inplace=True)
2✔
772
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
2✔
773
        styles["linestyle"].fillna("solid", inplace=True)
2✔
774
        del styles["hatch"]  # not supported in matrix (currently)
2✔
775

776
        x = np.repeat(np.arange(len(data)), n_cats)
2✔
777
        y = np.tile(np.arange(n_cats), len(data))
2✔
778

779
        # Plot dots
780
        if self._element_size is not None:  # noqa
2✔
781
            s = (self._element_size * 0.35) ** 2
2✔
782
        else:
783
            # TODO: make s relative to colw
784
            s = 200
2✔
785
        ax.scatter(
2✔
786
            *self._swapaxes(x, y),
2✔
787
            s=s,
2✔
788
            zorder=10,
2✔
789
            **styles.rename(columns=style_columns),
2✔
790
        )
791

792
        # Plot lines
793
        if self._with_lines:
2✔
794
            idx = np.flatnonzero(inclusion)
2✔
795
            line_data = (
2✔
796
                pd.Series(y[idx], index=x[idx])
2✔
797
                .groupby(level=0)
2✔
798
                .aggregate(["min", "max"])
2✔
799
            )
800
            colors = pd.Series(
2✔
801
                [
2✔
802
                    style.get("edgecolor", style.get("facecolor", self._facecolor))
2✔
803
                    for style in self.subset_styles
2✔
804
                ],
805
                name="color",
2✔
806
            )
807
            line_data = line_data.join(colors)
2✔
808
            ax.vlines(
2✔
809
                line_data.index.values,
2✔
810
                line_data["min"],
2✔
811
                line_data["max"],
2✔
812
                lw=2,
2✔
813
                colors=line_data["color"],
2✔
814
                zorder=5,
2✔
815
            )
816

817
        # Ticks and axes
818
        tick_axis = ax.yaxis
2✔
819
        tick_axis.set_ticks(np.arange(n_cats))
2✔
820
        tick_axis.set_ticklabels(
2✔
821
            data.index.names, rotation=0 if self._horizontal else -90
2✔
822
        )
823
        ax.xaxis.set_visible(False)
2✔
824
        ax.tick_params(axis="both", which="both", length=0)
2✔
825
        if not self._horizontal:
2✔
826
            ax.yaxis.set_ticks_position("top")
2✔
827
        ax.set_frame_on(False)
2✔
828
        ax.set_xlim(-0.5, x[-1] + 0.5, auto=False)
2✔
829
        ax.grid(False)
2✔
830

831
    def plot_intersections(self, ax):
2✔
832
        """Plot bars indicating intersection size"""
833
        rects = self._plot_bars(
834
            ax, self.intersections, title="Intersection size", colors=self._facecolor
835
        )
836
        for style, rect in zip(self.subset_styles, rects):
837
            style = style.copy()
838
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
839
            for attr, val in style.items():
840
                getattr(rect, "set_" + attr)(val)
841

842
        if self.subset_legend:
843
            styles, labels = zip(*self.subset_legend)
844
            styles = [patches.Patch(**patch_style) for patch_style in styles]
845
            ax.legend(styles, labels)
846

847
    def _label_sizes(self, ax, rects, where):
848
        if not self._show_counts and not self._show_percentages:
849
            return
850
        if self._show_counts is True:
851
            count_fmt = "{:.0f}"
852
        else:
853
            count_fmt = self._show_counts
854
            if "{" not in count_fmt:
855
                count_fmt = util.to_new_pos_format(count_fmt)
856

857
        pct_fmt = "{:.1%}" if self._show_percentages is True else self._show_percentages
858

859
        if count_fmt and pct_fmt:
860
            if where == "top":
861
                fmt = f"{count_fmt}\n({pct_fmt})"
862
            else:
863
                fmt = f"{count_fmt} ({pct_fmt})"
864

865
            def make_args(val):
866
                return val, val / self.total
867
        elif count_fmt:
868
            fmt = count_fmt
869

870
            def make_args(val):
871
                return (val,)
872
        else:
873
            fmt = pct_fmt
874

875
            def make_args(val):
876
                return (val / self.total,)
877

878
        if where == "right":
879
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
880
            for rect in rects:
881
                width = rect.get_width() + rect.get_x()
882
                ax.text(
883
                    width + margin,
884
                    rect.get_y() + rect.get_height() * 0.5,
885
                    fmt.format(*make_args(width)),
886
                    ha="left",
887
                    va="center",
888
                )
889
        elif where == "left":
890
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
891
            for rect in rects:
892
                width = rect.get_width() + rect.get_x()
893
                ax.text(
894
                    width + margin,
895
                    rect.get_y() + rect.get_height() * 0.5,
896
                    fmt.format(*make_args(width)),
897
                    ha="right",
898
                    va="center",
899
                )
900
        elif where == "top":
901
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
902
            for rect in rects:
903
                height = rect.get_height() + rect.get_y()
904
                ax.text(
905
                    rect.get_x() + rect.get_width() * 0.5,
906
                    height + margin,
907
                    fmt.format(*make_args(height)),
908
                    ha="center",
909
                    va="bottom",
910
                )
911
        else:
912
            raise NotImplementedError("unhandled where: %r" % where)
913

914
    def plot_totals(self, ax):
915
        """Plot bars indicating total set size"""
916
        orig_ax = ax
2✔
917
        ax = self._reorient(ax)
2✔
918
        rects = ax.barh(
2✔
919
            np.arange(len(self.totals.index.values)),
2✔
920
            self.totals,
2✔
921
            0.5,
2✔
922
            color=self._facecolor,
2✔
923
            align="center",
2✔
924
        )
925
        self._label_sizes(ax, rects, "left" if self._horizontal else "top")
2✔
926

927
        max_total = self.totals.max()
2✔
928
        if self._horizontal:
2✔
929
            orig_ax.set_xlim(max_total, 0)
2✔
930
        for x in ["top", "left", "right"]:
2✔
931
            ax.spines[self._reorient(x)].set_visible(False)
2✔
932
        ax.yaxis.set_visible(False)
2✔
933
        ax.xaxis.grid(True)
2✔
934
        ax.yaxis.grid(False)
2✔
935
        ax.patch.set_visible(False)
2✔
936

937
    def plot_shading(self, ax):
2✔
938
        # alternating row shading (XXX: use add_patch(Rectangle)?)
939
        for i in range(0, len(self.totals), 2):
2✔
940
            rect = plt.Rectangle(
2✔
941
                self._swapaxes(0, i - 0.4),
2✔
942
                *self._swapaxes(*(1, 0.8)),
2✔
943
                facecolor=self._shading_color,
2✔
944
                lw=0,
2✔
945
                zorder=0,
2✔
946
            )
947
            ax.add_patch(rect)
2✔
948
        ax.set_frame_on(False)
2✔
949
        ax.tick_params(
2✔
950
            axis="both",
2✔
951
            which="both",
2✔
952
            left=False,
2✔
953
            right=False,
2✔
954
            bottom=False,
2✔
955
            top=False,
2✔
956
            labelbottom=False,
2✔
957
            labelleft=False,
2✔
958
        )
959
        ax.grid(False)
2✔
960
        ax.set_xticks([])
2✔
961
        ax.set_yticks([])
2✔
962
        ax.set_xticklabels([])
2✔
963
        ax.set_yticklabels([])
2✔
964

965
    def plot(self, fig=None):
2✔
966
        """Draw all parts of the plot onto fig or a new figure
967

968
        Parameters
969
        ----------
970
        fig : matplotlib.figure.Figure, optional
971
            Defaults to a new figure.
972

973
        Returns
974
        -------
975
        subplots : dict of matplotlib.axes.Axes
976
            Keys are 'matrix', 'intersections', 'totals', 'shading'
977
        """
978
        if fig is None:
2✔
979
            fig = plt.figure(figsize=self._default_figsize)
2✔
980
        specs = self.make_grid(fig)
2✔
981
        shading_ax = fig.add_subplot(specs["shading"])
2✔
982
        self.plot_shading(shading_ax)
2✔
983
        matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax)
2✔
984
        self.plot_matrix(matrix_ax)
2✔
985
        if specs["totals"] is None:
2✔
986
            totals_ax = None
2✔
987
        else:
988
            totals_ax = self._reorient(fig.add_subplot)(
2✔
989
                specs["totals"], sharey=matrix_ax
2✔
990
            )
991
            self.plot_totals(totals_ax)
2✔
992
        out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax}
2✔
993

994
        for plot in self._subset_plots:
2✔
995
            ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax)
2✔
996
            if plot["type"] == "default":
2✔
997
                self.plot_intersections(ax)
2✔
998
            elif plot["type"] in self.PLOT_TYPES:
2✔
999
                kw = plot.copy()
2✔
1000
                del kw["type"]
2✔
1001
                del kw["elements"]
2✔
1002
                del kw["id"]
2✔
1003
                self.PLOT_TYPES[plot["type"]](self, ax, **kw)
2✔
1004
            else:
1005
                raise ValueError("Unknown subset plot type: %r" % plot["type"])
1006
            out[plot["id"]] = ax
2✔
1007
        return out
2✔
1008

1009
    PLOT_TYPES = {
2✔
1010
        "catplot": _plot_catplot,
2✔
1011
        "stacked_bars": _plot_stacked_bars,
2✔
1012
    }
1013

1014
    def _repr_html_(self):
2✔
1015
        fig = plt.figure(figsize=self._default_figsize)
1016
        self.plot(fig=fig)
1017
        return fig._repr_html_()
1018

1019

1020
def plot(data, fig=None, **kwargs):
2✔
1021
    """Make an UpSet plot of data on fig
1022

1023
    Parameters
1024
    ----------
1025
    data : pandas.Series or pandas.DataFrame
1026
        Values for each set to plot.
1027
        Should have multi-index where each level is binary,
1028
        corresponding to set membership.
1029
        If a DataFrame, `sum_over` must be a string or False.
1030
    fig : matplotlib.figure.Figure, optional
1031
        Defaults to a new figure.
1032
    kwargs
1033
        Other arguments for :class:`UpSet`
1034

1035
    Returns
1036
    -------
1037
    subplots : dict of matplotlib.axes.Axes
1038
        Keys are 'matrix', 'intersections', 'totals', 'shading'
1039
    """
1040
    return UpSet(data, **kwargs).plot(fig)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc