• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7347497868

28 Dec 2023 12:15PM UTC coverage: 99.018% (-0.2%) from 99.199%
7347497868

Pull #250

github

jnothman
Catch deprecation warnings
Pull Request #250: Remove use of LooseVersion

3 of 3 new or added lines in 2 files covered. (100.0%)

9 existing lines in 2 files now uncovered.

1715 of 1732 relevant lines covered (99.02%)

0.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.76
/upsetplot/plotting.py
1
from __future__ import print_function, division, absolute_import
1✔
2

3
import typing
1✔
4

5
import numpy as np
1✔
6
import pandas as pd
1✔
7
import matplotlib
1✔
8
from matplotlib import pyplot as plt
1✔
9
from matplotlib import colors
1✔
10
from matplotlib import patches
1✔
11
import warnings
1✔
12

13
from .reformat import query, _get_subset_mask
1✔
14
from . import util
1✔
15

16
# prevents ImportError on matplotlib versions >3.5.2
17
try:
1✔
18
    from matplotlib.tight_layout import get_renderer
1✔
19

UNCOV
20
    RENDERER_IMPORTED = True
×
21
except ImportError:
1✔
22
    RENDERER_IMPORTED = False
1✔
23

24

25
PlotReturnType = typing.Dict[
1✔
26
    typing.Literal["matrix", "intersections", "totals", "shading"], matplotlib.axes.Axes
1✔
27
]
28

29

30
def _process_data(
1✔
31
    df,
32
    *,
33
    sort_by,
34
    sort_categories_by,
35
    subset_size,
36
    sum_over,
37
    min_subset_size=None,
1✔
38
    max_subset_size=None,
1✔
39
    min_degree=None,
1✔
40
    max_degree=None,
1✔
41
    reverse=False,
1✔
42
    include_empty_subsets=False,
1✔
43
):
44
    results = query(
1✔
45
        df,
1✔
46
        sort_by=sort_by,
1✔
47
        sort_categories_by=sort_categories_by,
1✔
48
        subset_size=subset_size,
1✔
49
        sum_over=sum_over,
1✔
50
        min_subset_size=min_subset_size,
1✔
51
        max_subset_size=max_subset_size,
1✔
52
        min_degree=min_degree,
1✔
53
        max_degree=max_degree,
1✔
54
        include_empty_subsets=include_empty_subsets,
1✔
55
    )
56

57
    df = results.data
1✔
58
    agg = results.subset_sizes
1✔
59
    totals = results.category_totals
1✔
60
    total = agg.sum()
1✔
61

62
    # add '_bin' to df indicating index in agg
63
    # XXX: ugly!
64
    def _pack_binary(X):
1✔
65
        X = pd.DataFrame(X)
1✔
66
        # use objects if arbitrary precision integers are needed
67
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
1✔
68
        out = pd.Series(0, index=X.index, dtype=dtype)
1✔
69
        for i, (_, col) in enumerate(X.items()):
1✔
70
            out *= 2
1✔
71
            out += col
1✔
72
        return out
1✔
73

74
    df_packed = _pack_binary(df.index.to_frame())
1✔
75
    data_packed = _pack_binary(agg.index.to_frame())
1✔
76
    df["_bin"] = pd.Series(df_packed).map(
1✔
77
        pd.Series(
1✔
78
            np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed
1✔
79
        )
80
    )
81
    if reverse:
1✔
82
        agg = agg[::-1]
1✔
83

84
    return total, df, agg, totals
1✔
85

86

87
def _multiply_alpha(c, mult):
1✔
88
    r, g, b, a = colors.to_rgba(c)
1✔
89
    a *= mult
1✔
90
    return colors.to_hex((r, g, b, a), keep_alpha=True)
1✔
91

92

93
class _Transposed:
1✔
94
    """Wrap an object in order to transpose some plotting operations
95

96
    Attributes of obj will be mapped.
97
    Keyword arguments when calling obj will be mapped.
98

99
    The mapping is not recursive: callable attributes need to be _Transposed
100
    again.
101
    """
102

103
    def __init__(self, obj):
1✔
104
        self.__obj = obj
1✔
105

106
    def __getattr__(self, key):
1✔
107
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
1✔
108

109
    def __call__(self, *args, **kwargs):
1✔
110
        return self.__obj(
1✔
111
            *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()}
1✔
112
        )
113

114
    _NAME_TRANSPOSE = {
1✔
115
        "width": "height",
1✔
116
        "height": "width",
1✔
117
        "hspace": "wspace",
1✔
118
        "wspace": "hspace",
1✔
119
        "hlines": "vlines",
1✔
120
        "vlines": "hlines",
1✔
121
        "bar": "barh",
1✔
122
        "barh": "bar",
1✔
123
        "xaxis": "yaxis",
1✔
124
        "yaxis": "xaxis",
1✔
125
        "left": "bottom",
1✔
126
        "right": "top",
1✔
127
        "top": "right",
1✔
128
        "bottom": "left",
1✔
129
        "sharex": "sharey",
1✔
130
        "sharey": "sharex",
1✔
131
        "get_figwidth": "get_figheight",
1✔
132
        "get_figheight": "get_figwidth",
1✔
133
        "set_figwidth": "set_figheight",
1✔
134
        "set_figheight": "set_figwidth",
1✔
135
        "set_xlabel": "set_ylabel",
1✔
136
        "set_ylabel": "set_xlabel",
1✔
137
        "set_xlim": "set_ylim",
1✔
138
        "set_ylim": "set_xlim",
1✔
139
        "get_xlim": "get_ylim",
1✔
140
        "get_ylim": "get_xlim",
1✔
141
        "set_autoscalex_on": "set_autoscaley_on",
1✔
142
        "set_autoscaley_on": "set_autoscalex_on",
1✔
143
    }
144

145

146
def _transpose(obj):
1✔
147
    if isinstance(obj, str):
1✔
148
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
1✔
149
    return _Transposed(obj)
1✔
150

151

152
def _identity(obj):
1✔
153
    return obj
1✔
154

155

156
class UpSet:
1✔
157
    """Manage the data and drawing for a basic UpSet plot
158

159
    Primary public method is :meth:`plot`.
160

161
    Parameters
162
    ----------
163
    data : pandas.Series or pandas.DataFrame
164
        Elements associated with categories (a DataFrame), or the size of each
165
        subset of categories (a Series).
166
        Should have MultiIndex where each level is binary,
167
        corresponding to category membership.
168
        If a DataFrame, `sum_over` must be a string or False.
169
    orientation : {'horizontal' (default), 'vertical'}
170
        If horizontal, intersections are listed from left to right.
171
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
172
               'input', '-input'}
173
        If 'cardinality', subset are listed from largest to smallest.
174
        If 'degree', they are listed in order of the number of categories
175
        intersected. If 'input', the order they appear in the data input is
176
        used.
177
        Prefix with '-' to reverse the ordering.
178

179
        Note this affects ``subset_sizes`` but not ``data``.
180
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
181
        Whether to sort the categories by total cardinality, or leave them
182
        in the input data's provided order (order of index levels).
183
        Prefix with '-' to reverse the ordering.
184
    subset_size : {'auto', 'count', 'sum'}
185
        Configures how to calculate the size of a subset. Choices are:
186

187
        'auto' (default)
188
            If `data` is a DataFrame, count the number of rows in each group,
189
            unless `sum_over` is specified.
190
            If `data` is a Series with at most one row for each group, use
191
            the value of the Series. If `data` is a Series with more than one
192
            row per group, raise a ValueError.
193
        'count'
194
            Count the number of rows in each group.
195
        'sum'
196
            Sum the value of the `data` Series, or the DataFrame field
197
            specified by `sum_over`.
198
    sum_over : str or None
199
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
200
        sum of the specified field in the `data` DataFrame. If a Series, only
201
        None is supported and its value is summed.
202
    min_subset_size : int, optional
203
        Minimum size of a subset to be shown in the plot. All subsets with
204
        a size smaller than this threshold will be omitted from plotting.
205
        Size may be a sum of values, see `subset_size`.
206

207
        .. versionadded:: 0.5
208
    max_subset_size : int, optional
209
        Maximum size of a subset to be shown in the plot. All subsets with
210
        a size greater than this threshold will be omitted from plotting.
211

212
        .. versionadded:: 0.5
213
    min_degree : int, optional
214
        Minimum degree of a subset to be shown in the plot.
215

216
        .. versionadded:: 0.5
217
    max_degree : int, optional
218
        Maximum degree of a subset to be shown in the plot.
219

220
        .. versionadded:: 0.5
221
    facecolor : 'auto' or matplotlib color or float
222
        Color for bar charts and active dots. Defaults to black if
223
        axes.facecolor is a light color, otherwise white.
224

225
        .. versionchanged:: 0.6
226
            Before 0.6, the default was 'black'
227
    other_dots_color : matplotlib color or float
228
        Color for shading of inactive dots, or opacity (between 0 and 1)
229
        applied to facecolor.
230

231
        .. versionadded:: 0.6
232
    shading_color : matplotlib color or float
233
        Color for shading of odd rows in matrix and totals, or opacity (between
234
        0 and 1) applied to facecolor.
235

236
        .. versionadded:: 0.6
237
    with_lines : bool
238
        Whether to show lines joining dots in the matrix, to mark multiple
239
        categories being intersected.
240
    element_size : float or None
241
        Side length in pt. If None, size is estimated to fit figure
242
    intersection_plot_elements : int
243
        The intersections plot should be large enough to fit this many matrix
244
        elements. Set to 0 to disable intersection size bars.
245

246
        .. versionchanged:: 0.4
247
            Setting to 0 is handled.
248
    totals_plot_elements : int
249
        The totals plot should be large enough to fit this many matrix
250
        elements.
251
    show_counts : bool or str, default=False
252
        Whether to label the intersection size bars with the cardinality
253
        of the intersection. When a string, this formats the number.
254
        For example, '{:d}' is equivalent to True.
255
        Note that, for legacy reasons, if the string does not contain '{',
256
        it will be interpreted as a C-style format string, such as '%d'.
257
    show_percentages : bool or str, default=False
258
        Whether to label the intersection size bars with the percentage
259
        of the intersection relative to the total dataset.
260
        When a string, this formats the number representing a fraction of
261
        samples.
262
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
263
        This may be applied with or without show_counts.
264

265
        .. versionadded:: 0.4
266
    include_empty_subsets : bool (default=False)
267
        If True, all possible category combinations will be shown as subsets,
268
        even when some are not present in data.
269
    """
270

271
    _default_figsize = (10, 6)
1✔
272

273
    def __init__(
1✔
274
        self,
275
        data,
276
        orientation="horizontal",
1✔
277
        sort_by="degree",
1✔
278
        sort_categories_by="cardinality",
1✔
279
        subset_size="auto",
1✔
280
        sum_over=None,
1✔
281
        min_subset_size=None,
1✔
282
        max_subset_size=None,
1✔
283
        min_degree=None,
1✔
284
        max_degree=None,
1✔
285
        facecolor="auto",
1✔
286
        other_dots_color=0.18,
1✔
287
        shading_color=0.05,
1✔
288
        with_lines=True,
1✔
289
        element_size=32,
1✔
290
        intersection_plot_elements=6,
1✔
291
        totals_plot_elements=2,
1✔
292
        show_counts="",
1✔
293
        show_percentages=False,
1✔
294
        include_empty_subsets=False,
1✔
295
    ):
296
        self._horizontal = orientation == "horizontal"
1✔
297
        self._reorient = _identity if self._horizontal else _transpose
1✔
298
        if facecolor == "auto":
1✔
299
            bgcolor = matplotlib.rcParams.get("axes.facecolor", "white")
1✔
300
            r, g, b, a = colors.to_rgba(bgcolor)
1✔
301
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
1✔
302
            facecolor = "black" if lightness >= 0.5 else "white"
1✔
303
        self._facecolor = facecolor
1✔
304
        self._shading_color = (
1✔
305
            _multiply_alpha(facecolor, shading_color)
1✔
306
            if isinstance(shading_color, float)
1✔
307
            else shading_color
1✔
308
        )
309
        self._other_dots_color = (
1✔
310
            _multiply_alpha(facecolor, other_dots_color)
1✔
311
            if isinstance(other_dots_color, float)
1✔
312
            else other_dots_color
1✔
313
        )
314
        self._with_lines = with_lines
1✔
315
        self._element_size = element_size
1✔
316
        self._totals_plot_elements = totals_plot_elements
1✔
317
        self._subset_plots = [
1✔
318
            {
1✔
319
                "type": "default",
1✔
320
                "id": "intersections",
1✔
321
                "elements": intersection_plot_elements,
1✔
322
            }
323
        ]
324
        if not intersection_plot_elements:
1✔
325
            self._subset_plots.pop()
1✔
326
        self._show_counts = show_counts
1✔
327
        self._show_percentages = show_percentages
1✔
328

329
        (self.total, self._df, self.intersections, self.totals) = _process_data(
1✔
330
            data,
1✔
331
            sort_by=sort_by,
1✔
332
            sort_categories_by=sort_categories_by,
1✔
333
            subset_size=subset_size,
1✔
334
            sum_over=sum_over,
1✔
335
            min_subset_size=min_subset_size,
1✔
336
            max_subset_size=max_subset_size,
1✔
337
            min_degree=min_degree,
1✔
338
            max_degree=max_degree,
1✔
339
            reverse=not self._horizontal,
1✔
340
            include_empty_subsets=include_empty_subsets,
1✔
341
        )
342
        self.subset_styles = [
1✔
343
            {"facecolor": facecolor} for i in range(len(self.intersections))
1✔
344
        ]
345
        self.subset_legend = []  # pairs of (style, label)
1✔
346

347
    def _swapaxes(self, x, y):
1✔
348
        if self._horizontal:
1✔
349
            return x, y
1✔
350
        return y, x
1✔
351

352
    def style_subsets(
1✔
353
        self,
354
        present=None,
1✔
355
        absent=None,
1✔
356
        min_subset_size=None,
1✔
357
        max_subset_size=None,
1✔
358
        min_degree=None,
1✔
359
        max_degree=None,
1✔
360
        facecolor=None,
1✔
361
        edgecolor=None,
1✔
362
        hatch=None,
1✔
363
        linewidth=None,
1✔
364
        linestyle=None,
1✔
365
        label=None,
1✔
366
    ):
367
        """Updates the style of selected subsets' bars and matrix dots
368

369
        Parameters are either used to select subsets, or to style them with
370
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
371
        which adds a legend entry.
372

373
        Parameters
374
        ----------
375
        present : str or list of str, optional
376
            Category or categories that must be present in subsets for styling.
377
        absent : str or list of str, optional
378
            Category or categories that must not be present in subsets for
379
            styling.
380
        min_subset_size : int, optional
381
            Minimum size of a subset to be styled.
382
        max_subset_size : int, optional
383
            Maximum size of a subset to be styled.
384
        min_degree : int, optional
385
            Minimum degree of a subset to be styled.
386
        max_degree : int, optional
387
            Maximum degree of a subset to be styled.
388

389
        facecolor : str or matplotlib color, optional
390
            Override the default UpSet facecolor for selected subsets.
391
        edgecolor : str or matplotlib color, optional
392
            Set the edgecolor for bars, dots, and the line between dots.
393
        hatch : str, optional
394
            Set the hatch. This will apply to intersection size bars, but not
395
            to matrix dots.
396
        linewidth : int, optional
397
            Line width in points for edges.
398
        linestyle : str, optional
399
            Line style for edges.
400

401
        label : str, optional
402
            If provided, a legend will be added
403
        """
404
        style = {
1✔
405
            "facecolor": facecolor,
1✔
406
            "edgecolor": edgecolor,
1✔
407
            "hatch": hatch,
1✔
408
            "linewidth": linewidth,
1✔
409
            "linestyle": linestyle,
1✔
410
        }
411
        style = {k: v for k, v in style.items() if v is not None}
1✔
412
        mask = _get_subset_mask(
1✔
413
            self.intersections,
1✔
414
            present=present,
1✔
415
            absent=absent,
1✔
416
            min_subset_size=min_subset_size,
1✔
417
            max_subset_size=max_subset_size,
1✔
418
            min_degree=min_degree,
1✔
419
            max_degree=max_degree,
1✔
420
        )
421
        for idx in np.flatnonzero(mask):
1✔
422
            self.subset_styles[idx].update(style)
1✔
423

424
        if label is not None:
1✔
425
            if "facecolor" not in style:
1✔
426
                style["facecolor"] = self._facecolor
427
            for i, (other_style, other_label) in enumerate(self.subset_legend):
1✔
428
                if other_style == style:
1✔
429
                    if other_label != label:
1✔
430
                        self.subset_legend[i] = (style, other_label + "; " + label)
1✔
431
                    break
1✔
432
            else:
433
                self.subset_legend.append((style, label))
1✔
434

435
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
1✔
436
        ax = self._reorient(ax)
1✔
437
        ax.set_autoscalex_on(False)
1✔
438
        data_df = pd.DataFrame(data)
1✔
439
        if self._horizontal:
1✔
440
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
1✔
441

442
        # TODO: colors should be broadcastable to data_df shape
443
        if callable(colors):
1✔
444
            colors = colors(range(data_df.shape[1]))
1✔
445
        elif isinstance(colors, (str, type(None))):
1✔
446
            colors = [colors] * len(data_df)
1✔
447

448
        if self._horizontal:
1✔
449
            colors = reversed(colors)
1✔
450

451
        x = np.arange(len(data_df))
1✔
452
        cum_y = None
1✔
453
        all_rects = []
1✔
454
        for (name, y), color in zip(data_df.items(), colors):
1✔
455
            rects = ax.bar(
1✔
456
                x,
1✔
457
                y,
1✔
458
                0.5,
1✔
459
                cum_y,
1✔
460
                color=color,
1✔
461
                zorder=10,
1✔
462
                label=name if use_labels else None,
1✔
463
                align="center",
1✔
464
            )
465
            cum_y = y if cum_y is None else cum_y + y
1✔
466
            all_rects.extend(rects)
1✔
467

468
        self._label_sizes(ax, rects, "top" if self._horizontal else "right")
1✔
469

470
        ax.xaxis.set_visible(False)
1✔
471
        for x in ["top", "bottom", "right"]:
1✔
472
            ax.spines[self._reorient(x)].set_visible(False)
1✔
473

474
        tick_axis = ax.yaxis
1✔
475
        tick_axis.grid(True)
1✔
476
        ax.set_ylabel(title)
1✔
477
        return all_rects
1✔
478

479
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
1✔
480
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
1✔
481
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
1✔
482
        if sum_over is None and "_value" in df.columns:
1✔
483
            data = gb["_value"].sum()
484
        elif sum_over is None:
1✔
485
            data = gb.size()
1✔
486
        else:
487
            data = gb[sum_over].sum()
1✔
488
        data = data.unstack(by).fillna(0)
1✔
489
        if isinstance(colors, str):
1✔
490
            colors = matplotlib.cm.get_cmap(colors)
1✔
491
        elif isinstance(colors, typing.Mapping):
1✔
492
            colors = data.columns.map(colors).values
1✔
493
            if pd.isna(colors).any():
1✔
494
                raise KeyError(
495
                    "Some labels mapped by colors: %r"
496
                    % data.columns[pd.isna(colors)].tolist()
497
                )
498

499
        self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True)
1✔
500

501
        handles, labels = ax.get_legend_handles_labels()
1✔
502
        if self._horizontal:
1✔
503
            # Make legend order match visual stack order
504
            ax.legend(reversed(handles), reversed(labels))
1✔
505
        else:
506
            ax.legend()
1✔
507

508
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None):
1✔
509
        """Add a stacked bar chart over subsets when :func:`plot` is called.
510

511
        Used to plot categorical variable distributions within each subset.
512

513
        .. versionadded:: 0.6
514

515
        Parameters
516
        ----------
517
        by : str
518
            Column name within the dataframe for color coding the stacked bars,
519
            containing discrete or categorical values.
520
        sum_over : str, optional
521
            Ordinarily the bars will chart the size of each group. sum_over
522
            may specify a column which will be summed to determine the size
523
            of each bar.
524
        colors : Mapping, list-like, str or callable, optional
525
            The facecolors to use for bars corresponding to each discrete
526
            label, specified as one of:
527

528
            Mapping
529
                Maps from label to matplotlib-compatible color specification.
530
            list-like
531
                A list of matplotlib colors to apply to labels in order.
532
            str
533
                The name of a matplotlib colormap name.
534
            callable
535
                When called with the number of labels, this should return a
536
                list-like of that many colors.  Matplotlib colormaps satisfy
537
                this callable API.
538
            None
539
                Uses the matplotlib default colormap.
540
        elements : int, default=3
541
            Size of the axes counted in number of matrix elements.
542
        title : str, optional
543
            The axis title labelling bar length.
544

545
        Returns
546
        -------
547
        None
548
        """
549
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
550
        #                        list of labels}
551
        self._subset_plots.append(
1✔
552
            {
1✔
553
                "type": "stacked_bars",
1✔
554
                "by": by,
1✔
555
                "sum_over": sum_over,
1✔
556
                "colors": colors,
1✔
557
                "title": title,
1✔
558
                "id": "extra%d" % len(self._subset_plots),
1✔
559
                "elements": elements,
1✔
560
            }
561
        )
562

563
    def add_catplot(self, kind, value=None, elements=3, **kw):
1✔
564
        """Add a seaborn catplot over subsets when :func:`plot` is called.
565

566
        Parameters
567
        ----------
568
        kind : str
569
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
570
        value : str, optional
571
            Column name for the value to plot (i.e. y if
572
            orientation='horizontal'), required if `data` is a DataFrame.
573
        elements : int, default=3
574
            Size of the axes counted in number of matrix elements.
575
        **kw : dict
576
            Additional keywords to pass to :func:`seaborn.catplot`.
577

578
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
579
            and 'orient', so these are prohibited keys in `kw`.
580

581
        Returns
582
        -------
583
        None
584
        """
585
        assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"}
586
        if value is None:
587
            if "_value" not in self._df.columns:
588
                raise ValueError(
589
                    "value cannot be set if data is a Series. " "Got %r" % value
590
                )
591
        else:
592
            if value not in self._df.columns:
593
                raise ValueError("value %r is not a column in data" % value)
594
        self._subset_plots.append(
595
            {
596
                "type": "catplot",
597
                "value": value,
598
                "kind": kind,
599
                "id": "extra%d" % len(self._subset_plots),
600
                "elements": elements,
601
                "kw": kw,
602
            }
603
        )
604

605
    def _check_value(self, value):
1✔
606
        if value is None and "_value" in self._df.columns:
607
            value = "_value"
608
        elif value is None:
609
            raise ValueError("value can only be None when data is a Series")
610
        return value
611

612
    def _plot_catplot(self, ax, value, kind, kw):
1✔
613
        df = self._df
614
        value = self._check_value(value)
615
        kw = kw.copy()
616
        if self._horizontal:
617
            kw["orient"] = "v"
618
            kw["x"] = "_bin"
619
            kw["y"] = value
620
        else:
621
            kw["orient"] = "h"
622
            kw["x"] = value
623
            kw["y"] = "_bin"
624
        import seaborn
625

626
        kw["ax"] = ax
627
        getattr(seaborn, kind + "plot")(data=df, **kw)
628

629
        ax = self._reorient(ax)
630
        if value == "_value":
631
            ax.set_ylabel("")
632

633
        ax.xaxis.set_visible(False)
634
        for x in ["top", "bottom", "right"]:
635
            ax.spines[self._reorient(x)].set_visible(False)
636

637
        tick_axis = ax.yaxis
638
        tick_axis.grid(True)
639

640
    def make_grid(self, fig=None):
1✔
641
        """Get a SubplotSpec for each Axes, accounting for label text width"""
642
        n_cats = len(self.totals)
643
        n_inters = len(self.intersections)
644

645
        if fig is None:
646
            fig = plt.gcf()
647

648
        # Determine text size to determine figure size / spacing
649
        text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]}
650
        # adding "x" ensures a margin
651
        t = fig.text(
652
            0,
653
            0,
654
            "\n".join(str(label) + "x" for label in self.totals.index.values),
655
            **text_kw,
656
        )
657
        window_extent_args = {}
658
        if RENDERER_IMPORTED:
659
            with warnings.catch_warnings(DeprecationWarning):
660
                window_extent_args["renderer"] = get_renderer(fig)
661
        textw = t.get_window_extent(**window_extent_args).width
662
        t.remove()
663

664
        window_extent_args = {}
665
        if RENDERER_IMPORTED:
666
            with warnings.catch_warnings(DeprecationWarning):
667
                window_extent_args["renderer"] = get_renderer(fig)
668
        figw = self._reorient(fig.get_window_extent(**window_extent_args)).width
669

670
        sizes = np.asarray([p["elements"] for p in self._subset_plots])
671
        fig = self._reorient(fig)
672

673
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
674
        if self._element_size is None:
675
            colw = (figw - textw) / non_text_nelems
676
        else:
677
            render_ratio = figw / fig.get_figwidth()
678
            colw = self._element_size / 72 * render_ratio
679
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
680
            fig.set_figwidth(figw / render_ratio)
681
            fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)
682

683
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
684
        # print('textw', textw, 'figw', figw, 'colw', colw,
685
        #       'ncols', figw/colw, 'text_nelems', text_nelems)
686

687
        GS = self._reorient(matplotlib.gridspec.GridSpec)
688
        gridspec = GS(
689
            *self._swapaxes(
690
                n_cats + (sizes.sum() or 0),
691
                n_inters + text_nelems + self._totals_plot_elements,
692
            ),
693
            hspace=1,
694
        )
695
        if self._horizontal:
696
            out = {
697
                "matrix": gridspec[-n_cats:, -n_inters:],
698
                "shading": gridspec[-n_cats:, :],
699
                "totals": gridspec[-n_cats:, : self._totals_plot_elements],
700
                "gs": gridspec,
701
            }
702
            cumsizes = np.cumsum(sizes[::-1])
703
            for start, stop, plot in zip(
704
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]
705
            ):
706
                out[plot["id"]] = gridspec[start:stop, -n_inters:]
707
        else:
708
            out = {
709
                "matrix": gridspec[-n_inters:, :n_cats],
710
                "shading": gridspec[:, :n_cats],
711
                "totals": gridspec[: self._totals_plot_elements, :n_cats],
712
                "gs": gridspec,
713
            }
714
            cumsizes = np.cumsum(sizes)
715
            for start, stop, plot in zip(
716
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots
717
            ):
718
                out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats]
719
        return out
720

721
    def plot_matrix(self, ax):
722
        """Plot the matrix of intersection indicators onto ax"""
723
        ax = self._reorient(ax)
1✔
724
        data = self.intersections
1✔
725
        n_cats = data.index.nlevels
1✔
726

727
        inclusion = data.index.to_frame().values
1✔
728

729
        # Prepare styling
730
        styles = [
1✔
731
            [
1✔
732
                self.subset_styles[i]
1✔
733
                if inclusion[i, j]
1✔
734
                else {"facecolor": self._other_dots_color, "linewidth": 0}
1✔
735
                for j in range(n_cats)
1✔
736
            ]
737
            for i in range(len(data))
1✔
738
        ]
739
        styles = sum(styles, [])  # flatten nested list
1✔
740
        style_columns = {
1✔
741
            "facecolor": "facecolors",
1✔
742
            "edgecolor": "edgecolors",
1✔
743
            "linewidth": "linewidths",
1✔
744
            "linestyle": "linestyles",
1✔
745
            "hatch": "hatch",
1✔
746
        }
747
        styles = (
1✔
748
            pd.DataFrame(styles)
1✔
749
            .reindex(columns=style_columns.keys())
1✔
750
            .astype(
1✔
751
                {
1✔
752
                    "facecolor": "O",
1✔
753
                    "edgecolor": "O",
1✔
754
                    "linewidth": float,
1✔
755
                    "linestyle": "O",
1✔
756
                    "hatch": "O",
1✔
757
                }
758
            )
759
        )
760
        styles["linewidth"].fillna(1, inplace=True)
1✔
761
        styles["facecolor"].fillna(self._facecolor, inplace=True)
1✔
762
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
1✔
763
        styles["linestyle"].fillna("solid", inplace=True)
1✔
764
        del styles["hatch"]  # not supported in matrix (currently)
1✔
765

766
        x = np.repeat(np.arange(len(data)), n_cats)
1✔
767
        y = np.tile(np.arange(n_cats), len(data))
1✔
768

769
        # Plot dots
770
        if self._element_size is not None:
1✔
771
            s = (self._element_size * 0.35) ** 2
1✔
772
        else:
773
            # TODO: make s relative to colw
774
            s = 200
1✔
775
        ax.scatter(
1✔
776
            *self._swapaxes(x, y),
1✔
777
            s=s,
1✔
778
            zorder=10,
1✔
779
            **styles.rename(columns=style_columns),
1✔
780
        )
781

782
        # Plot lines
783
        if self._with_lines:
1✔
784
            idx = np.flatnonzero(inclusion)
1✔
785
            line_data = (
1✔
786
                pd.Series(y[idx], index=x[idx])
1✔
787
                .groupby(level=0)
1✔
788
                .aggregate(["min", "max"])
1✔
789
            )
790
            colors = pd.Series(
1✔
791
                [
1✔
792
                    style.get("edgecolor", style.get("facecolor", self._facecolor))
1✔
793
                    for style in self.subset_styles
1✔
794
                ],
795
                name="color",
1✔
796
            )
797
            line_data = line_data.join(colors)
1✔
798
            ax.vlines(
1✔
799
                line_data.index.values,
1✔
800
                line_data["min"],
1✔
801
                line_data["max"],
1✔
802
                lw=2,
1✔
803
                colors=line_data["color"],
1✔
804
                zorder=5,
1✔
805
            )
806

807
        # Ticks and axes
808
        tick_axis = ax.yaxis
1✔
809
        tick_axis.set_ticks(np.arange(n_cats))
1✔
810
        tick_axis.set_ticklabels(
1✔
811
            data.index.names, rotation=0 if self._horizontal else -90
1✔
812
        )
813
        ax.xaxis.set_ticks([])
1✔
814
        ax.tick_params(axis="both", which="both", length=0)
1✔
815
        if not self._horizontal:
1✔
816
            ax.yaxis.set_ticks_position("top")
1✔
817
        ax.set_frame_on(False)
1✔
818
        ax.set_xlim(-0.5, x[-1] + 0.5, auto=False)
1✔
819
        ax.grid(False)
1✔
820

821
    def plot_intersections(self, ax):
1✔
822
        """Plot bars indicating intersection size"""
823
        rects = self._plot_bars(
824
            ax, self.intersections, title="Intersection size", colors=self._facecolor
825
        )
826
        for style, rect in zip(self.subset_styles, rects):
827
            style = style.copy()
828
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
829
            for attr, val in style.items():
830
                getattr(rect, "set_" + attr)(val)
831

832
        if self.subset_legend:
833
            styles, labels = zip(*self.subset_legend)
834
            styles = [patches.Patch(**patch_style) for patch_style in styles]
835
            ax.legend(styles, labels)
836

837
    def _label_sizes(self, ax, rects, where):
838
        if not self._show_counts and not self._show_percentages:
839
            return
840
        if self._show_counts is True:
841
            count_fmt = "{:.0f}"
842
        else:
843
            count_fmt = self._show_counts
844
            if "{" not in count_fmt:
845
                count_fmt = util.to_new_pos_format(count_fmt)
846

847
        if self._show_percentages is True:
848
            pct_fmt = "{:.1%}"
849
        else:
850
            pct_fmt = self._show_percentages
851

852
        if count_fmt and pct_fmt:
853
            if where == "top":
854
                fmt = "%s\n(%s)" % (count_fmt, pct_fmt)
855
            else:
856
                fmt = "%s (%s)" % (count_fmt, pct_fmt)
857

858
            def make_args(val):
859
                return val, val / self.total
860
        elif count_fmt:
861
            fmt = count_fmt
862

863
            def make_args(val):
864
                return (val,)
865
        else:
866
            fmt = pct_fmt
867

868
            def make_args(val):
869
                return (val / self.total,)
870

871
        if where == "right":
872
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
873
            for rect in rects:
874
                width = rect.get_width() + rect.get_x()
875
                ax.text(
876
                    width + margin,
877
                    rect.get_y() + rect.get_height() * 0.5,
878
                    fmt.format(*make_args(width)),
879
                    ha="left",
880
                    va="center",
881
                )
882
        elif where == "left":
883
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
884
            for rect in rects:
885
                width = rect.get_width() + rect.get_x()
886
                ax.text(
887
                    width + margin,
888
                    rect.get_y() + rect.get_height() * 0.5,
889
                    fmt.format(*make_args(width)),
890
                    ha="right",
891
                    va="center",
892
                )
893
        elif where == "top":
894
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
895
            for rect in rects:
896
                height = rect.get_height() + rect.get_y()
897
                ax.text(
898
                    rect.get_x() + rect.get_width() * 0.5,
899
                    height + margin,
900
                    fmt.format(*make_args(height)),
901
                    ha="center",
902
                    va="bottom",
903
                )
904
        else:
905
            raise NotImplementedError("unhandled where: %r" % where)
906

907
    def plot_totals(self, ax):
908
        """Plot bars indicating total set size"""
909
        orig_ax = ax
1✔
910
        ax = self._reorient(ax)
1✔
911
        rects = ax.barh(
1✔
912
            np.arange(len(self.totals.index.values)),
1✔
913
            self.totals,
1✔
914
            0.5,
1✔
915
            color=self._facecolor,
1✔
916
            align="center",
1✔
917
        )
918
        self._label_sizes(ax, rects, "left" if self._horizontal else "top")
1✔
919

920
        max_total = self.totals.max()
1✔
921
        if self._horizontal:
1✔
922
            orig_ax.set_xlim(max_total, 0)
1✔
923
        for x in ["top", "left", "right"]:
1✔
924
            ax.spines[self._reorient(x)].set_visible(False)
1✔
925
        ax.yaxis.set_visible(True)
1✔
926
        ax.yaxis.set_ticklabels([])
1✔
927
        ax.yaxis.set_ticks([])
1✔
928
        ax.xaxis.grid(True)
1✔
929
        ax.yaxis.grid(False)
1✔
930
        ax.patch.set_visible(False)
1✔
931

932
    def plot_shading(self, ax):
1✔
933
        # alternating row shading (XXX: use add_patch(Rectangle)?)
934
        for i in range(0, len(self.totals), 2):
1✔
935
            rect = plt.Rectangle(
1✔
936
                self._swapaxes(0, i - 0.4),
1✔
937
                *self._swapaxes(*(1, 0.8)),
1✔
938
                facecolor=self._shading_color,
1✔
939
                lw=0,
1✔
940
                zorder=0,
1✔
941
            )
942
            ax.add_patch(rect)
1✔
943
        ax.set_frame_on(False)
1✔
944
        ax.tick_params(
1✔
945
            axis="both",
1✔
946
            which="both",
1✔
947
            left=False,
1✔
948
            right=False,
1✔
949
            bottom=False,
1✔
950
            top=False,
1✔
951
            labelbottom=False,
1✔
952
            labelleft=False,
1✔
953
        )
954
        ax.grid(False)
1✔
955
        ax.set_xticks([])
1✔
956
        ax.set_yticks([])
1✔
957
        ax.set_xticklabels([])
1✔
958
        ax.set_yticklabels([])
1✔
959

960
    def plot(self, fig=None) -> PlotReturnType:
1✔
961
        """Draw all parts of the plot onto fig or a new figure
962

963
        Parameters
964
        ----------
965
        fig : matplotlib.figure.Figure, optional
966
            Defaults to a new figure.
967

968
        Returns
969
        -------
970
        subplots : dict of matplotlib.axes.Axes
971
            Keys are 'matrix', 'intersections', 'totals', 'shading'
972
        """
973
        if fig is None:
1✔
974
            fig = plt.figure(figsize=self._default_figsize)
1✔
975
        specs = self.make_grid(fig)
1✔
976
        shading_ax = fig.add_subplot(specs["shading"])
1✔
977
        self.plot_shading(shading_ax)
1✔
978
        matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax)
1✔
979
        self.plot_matrix(matrix_ax)
1✔
980
        totals_ax = self._reorient(fig.add_subplot)(specs["totals"], sharey=matrix_ax)
1✔
981
        self.plot_totals(totals_ax)
1✔
982
        out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax}
1✔
983

984
        for plot in self._subset_plots:
1✔
985
            ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax)
1✔
986
            if plot["type"] == "default":
1✔
987
                self.plot_intersections(ax)
1✔
988
            elif plot["type"] in self.PLOT_TYPES:
1✔
989
                kw = plot.copy()
1✔
990
                del kw["type"]
1✔
991
                del kw["elements"]
1✔
992
                del kw["id"]
1✔
993
                self.PLOT_TYPES[plot["type"]](self, ax, **kw)
1✔
994
            else:
995
                raise ValueError("Unknown subset plot type: %r" % plot["type"])
996
            out[plot["id"]] = ax
1✔
997
        return out
1✔
998

999
    PLOT_TYPES = {
1✔
1000
        "catplot": _plot_catplot,
1✔
1001
        "stacked_bars": _plot_stacked_bars,
1✔
1002
    }
1003

1004
    def _repr_html_(self):
1✔
1005
        fig = plt.figure(figsize=self._default_figsize)
1006
        self.plot(fig=fig)
1007
        return fig._repr_html_()
1008

1009

1010
def plot(data, fig=None, **kwargs) -> PlotReturnType:
1✔
1011
    """Make an UpSet plot of data on fig
1012

1013
    Parameters
1014
    ----------
1015
    data : pandas.Series or pandas.DataFrame
1016
        Values for each set to plot.
1017
        Should have multi-index where each level is binary,
1018
        corresponding to set membership.
1019
        If a DataFrame, `sum_over` must be a string or False.
1020
    fig : matplotlib.figure.Figure, optional
1021
        Defaults to a new figure.
1022
    kwargs
1023
        Other arguments for :class:`UpSet`
1024

1025
    Returns
1026
    -------
1027
    subplots : dict of matplotlib.axes.Axes
1028
        Keys are 'matrix', 'intersections', 'totals', 'shading'
1029
    """
1030
    return UpSet(data, **kwargs).plot(fig)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc