• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7364335499

30 Dec 2023 01:16PM UTC coverage: 99.263%. Remained the same
7364335499

Pull #263

github

jnothman
eol
Pull Request #263: Improve docstring for totals_plot_elements=0

1750 of 1763 relevant lines covered (99.26%)

1.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/upsetplot/plotting.py
1
import typing
2✔
2
import warnings
2✔
3

4
import matplotlib
2✔
5
import numpy as np
2✔
6
import pandas as pd
2✔
7
from matplotlib import colors, patches
2✔
8
from matplotlib import pyplot as plt
2✔
9

10
from . import util
2✔
11
from .reformat import _get_subset_mask, query
2✔
12

13
# prevents ImportError on matplotlib versions >3.5.2
14
try:
2✔
15
    from matplotlib.tight_layout import get_renderer
2✔
16

17
    RENDERER_IMPORTED = True
1✔
18
except ImportError:
1✔
19
    RENDERER_IMPORTED = False
1✔
20

21

22
def _process_data(
2✔
23
    df,
24
    *,
25
    sort_by,
26
    sort_categories_by,
27
    subset_size,
28
    sum_over,
29
    min_subset_size=None,
2✔
30
    max_subset_size=None,
2✔
31
    max_subset_rank=None,
2✔
32
    min_degree=None,
2✔
33
    max_degree=None,
2✔
34
    reverse=False,
2✔
35
    include_empty_subsets=False,
2✔
36
):
37
    results = query(
2✔
38
        df,
2✔
39
        sort_by=sort_by,
2✔
40
        sort_categories_by=sort_categories_by,
2✔
41
        subset_size=subset_size,
2✔
42
        sum_over=sum_over,
2✔
43
        min_subset_size=min_subset_size,
2✔
44
        max_subset_size=max_subset_size,
2✔
45
        max_subset_rank=max_subset_rank,
2✔
46
        min_degree=min_degree,
2✔
47
        max_degree=max_degree,
2✔
48
        include_empty_subsets=include_empty_subsets,
2✔
49
    )
50

51
    df = results.data
2✔
52
    agg = results.subset_sizes
2✔
53

54
    # add '_bin' to df indicating index in agg
55
    # XXX: ugly!
56
    def _pack_binary(X):
2✔
57
        X = pd.DataFrame(X)
2✔
58
        # use objects if arbitrary precision integers are needed
59
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
2✔
60
        out = pd.Series(0, index=X.index, dtype=dtype)
2✔
61
        for _, col in X.items():
2✔
62
            out *= 2
2✔
63
            out += col
2✔
64
        return out
2✔
65

66
    df_packed = _pack_binary(df.index.to_frame())
2✔
67
    data_packed = _pack_binary(agg.index.to_frame())
2✔
68
    df["_bin"] = pd.Series(df_packed).map(
2✔
69
        pd.Series(
2✔
70
            np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed
2✔
71
        )
72
    )
73
    if reverse:
2✔
74
        agg = agg[::-1]
2✔
75

76
    return results.total, df, agg, results.category_totals
2✔
77

78

79
def _multiply_alpha(c, mult):
2✔
80
    r, g, b, a = colors.to_rgba(c)
2✔
81
    a *= mult
2✔
82
    return colors.to_hex((r, g, b, a), keep_alpha=True)
2✔
83

84

85
class _Transposed:
2✔
86
    """Wrap an object in order to transpose some plotting operations
87

88
    Attributes of obj will be mapped.
89
    Keyword arguments when calling obj will be mapped.
90

91
    The mapping is not recursive: callable attributes need to be _Transposed
92
    again.
93
    """
94

95
    def __init__(self, obj):
2✔
96
        self.__obj = obj
2✔
97

98
    def __getattr__(self, key):
2✔
99
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
2✔
100

101
    def __call__(self, *args, **kwargs):
2✔
102
        return self.__obj(
2✔
103
            *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()}
2✔
104
        )
105

106
    _NAME_TRANSPOSE = {
2✔
107
        "width": "height",
2✔
108
        "height": "width",
2✔
109
        "hspace": "wspace",
2✔
110
        "wspace": "hspace",
2✔
111
        "hlines": "vlines",
2✔
112
        "vlines": "hlines",
2✔
113
        "bar": "barh",
2✔
114
        "barh": "bar",
2✔
115
        "xaxis": "yaxis",
2✔
116
        "yaxis": "xaxis",
2✔
117
        "left": "bottom",
2✔
118
        "right": "top",
2✔
119
        "top": "right",
2✔
120
        "bottom": "left",
2✔
121
        "sharex": "sharey",
2✔
122
        "sharey": "sharex",
2✔
123
        "get_figwidth": "get_figheight",
2✔
124
        "get_figheight": "get_figwidth",
2✔
125
        "set_figwidth": "set_figheight",
2✔
126
        "set_figheight": "set_figwidth",
2✔
127
        "set_xlabel": "set_ylabel",
2✔
128
        "set_ylabel": "set_xlabel",
2✔
129
        "set_xlim": "set_ylim",
2✔
130
        "set_ylim": "set_xlim",
2✔
131
        "get_xlim": "get_ylim",
2✔
132
        "get_ylim": "get_xlim",
2✔
133
        "set_autoscalex_on": "set_autoscaley_on",
2✔
134
        "set_autoscaley_on": "set_autoscalex_on",
2✔
135
    }
136

137

138
def _transpose(obj):
2✔
139
    if isinstance(obj, str):
2✔
140
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
2✔
141
    return _Transposed(obj)
2✔
142

143

144
def _identity(obj):
2✔
145
    return obj
2✔
146

147

148
class UpSet:
2✔
149
    """Manage the data and drawing for a basic UpSet plot
150

151
    Primary public method is :meth:`plot`.
152

153
    Parameters
154
    ----------
155
    data : pandas.Series or pandas.DataFrame
156
        Elements associated with categories (a DataFrame), or the size of each
157
        subset of categories (a Series).
158
        Should have MultiIndex where each level is binary,
159
        corresponding to category membership.
160
        If a DataFrame, `sum_over` must be a string or False.
161
    orientation : {'horizontal' (default), 'vertical'}
162
        If horizontal, intersections are listed from left to right.
163
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
164
               'input', '-input'}
165
        If 'cardinality', subset are listed from largest to smallest.
166
        If 'degree', they are listed in order of the number of categories
167
        intersected. If 'input', the order they appear in the data input is
168
        used.
169
        Prefix with '-' to reverse the ordering.
170

171
        Note this affects ``subset_sizes`` but not ``data``.
172
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
173
        Whether to sort the categories by total cardinality, or leave them
174
        in the input data's provided order (order of index levels).
175
        Prefix with '-' to reverse the ordering.
176
    subset_size : {'auto', 'count', 'sum'}
177
        Configures how to calculate the size of a subset. Choices are:
178

179
        'auto' (default)
180
            If `data` is a DataFrame, count the number of rows in each group,
181
            unless `sum_over` is specified.
182
            If `data` is a Series with at most one row for each group, use
183
            the value of the Series. If `data` is a Series with more than one
184
            row per group, raise a ValueError.
185
        'count'
186
            Count the number of rows in each group.
187
        'sum'
188
            Sum the value of the `data` Series, or the DataFrame field
189
            specified by `sum_over`.
190
    sum_over : str or None
191
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
192
        sum of the specified field in the `data` DataFrame. If a Series, only
193
        None is supported and its value is summed.
194
    min_subset_size : int, optional
195
        Minimum size of a subset to be shown in the plot. All subsets with
196
        a size smaller than this threshold will be omitted from plotting.
197
        Size may be a sum of values, see `subset_size`.
198

199
        .. versionadded:: 0.5
200
    max_subset_size : int, optional
201
        Maximum size of a subset to be shown in the plot. All subsets with
202
        a size greater than this threshold will be omitted from plotting.
203

204
        .. versionadded:: 0.5
205
    max_subset_rank : int, optional
206
        Limit to the top N ranked subsets in descending order of size.
207
        All tied subsets are included.
208

209
        .. versionadded:: 0.9
210
    min_degree : int, optional
211
        Minimum degree of a subset to be shown in the plot.
212

213
        .. versionadded:: 0.5
214
    max_degree : int, optional
215
        Maximum degree of a subset to be shown in the plot.
216

217
        .. versionadded:: 0.5
218
    facecolor : 'auto' or matplotlib color or float
219
        Color for bar charts and active dots. Defaults to black if
220
        axes.facecolor is a light color, otherwise white.
221

222
        .. versionchanged:: 0.6
223
            Before 0.6, the default was 'black'
224
    other_dots_color : matplotlib color or float
225
        Color for shading of inactive dots, or opacity (between 0 and 1)
226
        applied to facecolor.
227

228
        .. versionadded:: 0.6
229
    shading_color : matplotlib color or float
230
        Color for shading of odd rows in matrix and totals, or opacity (between
231
        0 and 1) applied to facecolor.
232

233
        .. versionadded:: 0.6
234
    with_lines : bool
235
        Whether to show lines joining dots in the matrix, to mark multiple
236
        categories being intersected.
237
    element_size : float or None
238
        Side length in pt. If None, size is estimated to fit figure
239
    intersection_plot_elements : int
240
        The intersections plot should be large enough to fit this many matrix
241
        elements. Set to 0 to disable intersection size bars.
242

243
        .. versionchanged:: 0.4
244
            Setting to 0 is handled.
245
    totals_plot_elements : int
246
        The totals plot should be large enough to fit this many matrix
247
        elements. Set to 0 to disable the totals plot.
248

249
        .. versionchanged:: 0.9
250
            Setting to 0 is handled.
251
    show_counts : bool or str, default=False
252
        Whether to label the intersection size bars with the cardinality
253
        of the intersection. When a string, this formats the number.
254
        For example, '{:d}' is equivalent to True.
255
        Note that, for legacy reasons, if the string does not contain '{',
256
        it will be interpreted as a C-style format string, such as '%d'.
257
    show_percentages : bool or str, default=False
258
        Whether to label the intersection size bars with the percentage
259
        of the intersection relative to the total dataset.
260
        When a string, this formats the number representing a fraction of
261
        samples.
262
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
263
        This may be applied with or without show_counts.
264

265
        .. versionadded:: 0.4
266
    include_empty_subsets : bool (default=False)
267
        If True, all possible category combinations will be shown as subsets,
268
        even when some are not present in data.
269
    """
270

271
    _default_figsize = (10, 6)
2✔
272

273
    def __init__(
2✔
274
        self,
275
        data,
276
        orientation="horizontal",
1✔
277
        sort_by="degree",
1✔
278
        sort_categories_by="cardinality",
1✔
279
        subset_size="auto",
1✔
280
        sum_over=None,
1✔
281
        min_subset_size=None,
1✔
282
        max_subset_size=None,
1✔
283
        max_subset_rank=None,
1✔
284
        min_degree=None,
1✔
285
        max_degree=None,
1✔
286
        facecolor="auto",
1✔
287
        other_dots_color=0.18,
1✔
288
        shading_color=0.05,
1✔
289
        with_lines=True,
1✔
290
        element_size=32,
1✔
291
        intersection_plot_elements=6,
1✔
292
        totals_plot_elements=2,
1✔
293
        show_counts="",
1✔
294
        show_percentages=False,
1✔
295
        include_empty_subsets=False,
1✔
296
    ):
297
        self._horizontal = orientation == "horizontal"
2✔
298
        self._reorient = _identity if self._horizontal else _transpose
2✔
299
        if facecolor == "auto":
2✔
300
            bgcolor = matplotlib.rcParams.get("axes.facecolor", "white")
2✔
301
            r, g, b, a = colors.to_rgba(bgcolor)
2✔
302
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
2✔
303
            facecolor = "black" if lightness >= 0.5 else "white"
2✔
304
        self._facecolor = facecolor
2✔
305
        self._shading_color = (
2✔
306
            _multiply_alpha(facecolor, shading_color)
2✔
307
            if isinstance(shading_color, float)
2✔
308
            else shading_color
2✔
309
        )
310
        self._other_dots_color = (
2✔
311
            _multiply_alpha(facecolor, other_dots_color)
2✔
312
            if isinstance(other_dots_color, float)
2✔
313
            else other_dots_color
2✔
314
        )
315
        self._with_lines = with_lines
2✔
316
        self._element_size = element_size
2✔
317
        self._totals_plot_elements = totals_plot_elements
2✔
318
        self._subset_plots = [
2✔
319
            {
2✔
320
                "type": "default",
2✔
321
                "id": "intersections",
2✔
322
                "elements": intersection_plot_elements,
2✔
323
            }
324
        ]
325
        if not intersection_plot_elements:
2✔
326
            self._subset_plots.pop()
2✔
327
        self._show_counts = show_counts
2✔
328
        self._show_percentages = show_percentages
2✔
329

330
        (self.total, self._df, self.intersections, self.totals) = _process_data(
2✔
331
            data,
2✔
332
            sort_by=sort_by,
2✔
333
            sort_categories_by=sort_categories_by,
2✔
334
            subset_size=subset_size,
2✔
335
            sum_over=sum_over,
2✔
336
            min_subset_size=min_subset_size,
2✔
337
            max_subset_size=max_subset_size,
2✔
338
            max_subset_rank=max_subset_rank,
2✔
339
            min_degree=min_degree,
2✔
340
            max_degree=max_degree,
2✔
341
            reverse=not self._horizontal,
2✔
342
            include_empty_subsets=include_empty_subsets,
2✔
343
        )
344
        self.subset_styles = [
2✔
345
            {"facecolor": facecolor} for i in range(len(self.intersections))
2✔
346
        ]
347
        self.subset_legend = []  # pairs of (style, label)
2✔
348

349
    def _swapaxes(self, x, y):
2✔
350
        if self._horizontal:
2✔
351
            return x, y
2✔
352
        return y, x
2✔
353

354
    def style_subsets(
2✔
355
        self,
356
        present=None,
1✔
357
        absent=None,
1✔
358
        min_subset_size=None,
1✔
359
        max_subset_size=None,
1✔
360
        max_subset_rank=None,
1✔
361
        min_degree=None,
1✔
362
        max_degree=None,
1✔
363
        facecolor=None,
1✔
364
        edgecolor=None,
1✔
365
        hatch=None,
1✔
366
        linewidth=None,
1✔
367
        linestyle=None,
1✔
368
        label=None,
1✔
369
    ):
370
        """Updates the style of selected subsets' bars and matrix dots
371

372
        Parameters are either used to select subsets, or to style them with
373
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
374
        which adds a legend entry.
375

376
        Parameters
377
        ----------
378
        present : str or list of str, optional
379
            Category or categories that must be present in subsets for styling.
380
        absent : str or list of str, optional
381
            Category or categories that must not be present in subsets for
382
            styling.
383
        min_subset_size : int, optional
384
            Minimum size of a subset to be styled.
385
        max_subset_size : int, optional
386
            Maximum size of a subset to be styled.
387
        max_subset_rank : int, optional
388
            Limit to the top N ranked subsets in descending order of size.
389
            All tied subsets are included.
390

391
            .. versionadded:: 0.9
392
        min_degree : int, optional
393
            Minimum degree of a subset to be styled.
394
        max_degree : int, optional
395
            Maximum degree of a subset to be styled.
396

397
        facecolor : str or matplotlib color, optional
398
            Override the default UpSet facecolor for selected subsets.
399
        edgecolor : str or matplotlib color, optional
400
            Set the edgecolor for bars, dots, and the line between dots.
401
        hatch : str, optional
402
            Set the hatch. This will apply to intersection size bars, but not
403
            to matrix dots.
404
        linewidth : int, optional
405
            Line width in points for edges.
406
        linestyle : str, optional
407
            Line style for edges.
408

409
        label : str, optional
410
            If provided, a legend will be added
411
        """
412
        style = {
2✔
413
            "facecolor": facecolor,
2✔
414
            "edgecolor": edgecolor,
2✔
415
            "hatch": hatch,
2✔
416
            "linewidth": linewidth,
2✔
417
            "linestyle": linestyle,
2✔
418
        }
419
        style = {k: v for k, v in style.items() if v is not None}
2✔
420
        mask = _get_subset_mask(
2✔
421
            self.intersections,
2✔
422
            present=present,
2✔
423
            absent=absent,
2✔
424
            min_subset_size=min_subset_size,
2✔
425
            max_subset_size=max_subset_size,
2✔
426
            max_subset_rank=max_subset_rank,
2✔
427
            min_degree=min_degree,
2✔
428
            max_degree=max_degree,
2✔
429
        )
430
        for idx in np.flatnonzero(mask):
2✔
431
            self.subset_styles[idx].update(style)
2✔
432

433
        if label is not None:
2✔
434
            if "facecolor" not in style:
2✔
435
                style["facecolor"] = self._facecolor
436
            for i, (other_style, other_label) in enumerate(self.subset_legend):
2✔
437
                if other_style == style:
2✔
438
                    if other_label != label:
2✔
439
                        self.subset_legend[i] = (style, other_label + "; " + label)
2✔
440
                    break
2✔
441
            else:
442
                self.subset_legend.append((style, label))
2✔
443

444
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
2✔
445
        ax = self._reorient(ax)
2✔
446
        ax.set_autoscalex_on(False)
2✔
447
        data_df = pd.DataFrame(data)
2✔
448
        if self._horizontal:
2✔
449
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
2✔
450

451
        # TODO: colors should be broadcastable to data_df shape
452
        if callable(colors):
2✔
453
            colors = colors(range(data_df.shape[1]))
2✔
454
        elif isinstance(colors, (str, type(None))):
2✔
455
            colors = [colors] * len(data_df)
2✔
456

457
        if self._horizontal:
2✔
458
            colors = reversed(colors)
2✔
459

460
        x = np.arange(len(data_df))
2✔
461
        cum_y = None
2✔
462
        all_rects = []
2✔
463
        for (name, y), color in zip(data_df.items(), colors):
2✔
464
            rects = ax.bar(
2✔
465
                x,
2✔
466
                y,
2✔
467
                0.5,
2✔
468
                cum_y,
2✔
469
                color=color,
2✔
470
                zorder=10,
2✔
471
                label=name if use_labels else None,
2✔
472
                align="center",
2✔
473
            )
474
            cum_y = y if cum_y is None else cum_y + y
2✔
475
            all_rects.extend(rects)
2✔
476

477
        self._label_sizes(ax, rects, "top" if self._horizontal else "right")
2✔
478

479
        ax.xaxis.set_visible(False)
2✔
480
        for x in ["top", "bottom", "right"]:
2✔
481
            ax.spines[self._reorient(x)].set_visible(False)
2✔
482

483
        tick_axis = ax.yaxis
2✔
484
        tick_axis.grid(True)
2✔
485
        ax.set_ylabel(title)
2✔
486
        return all_rects
2✔
487

488
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
2✔
489
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
2✔
490
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
2✔
491
        if sum_over is None and "_value" in df.columns:
2✔
492
            data = gb["_value"].sum()
493
        elif sum_over is None:
2✔
494
            data = gb.size()
2✔
495
        else:
496
            data = gb[sum_over].sum()
2✔
497
        data = data.unstack(by).fillna(0)
2✔
498
        if isinstance(colors, str):
2✔
499
            colors = matplotlib.cm.get_cmap(colors)
2✔
500
        elif isinstance(colors, typing.Mapping):
2✔
501
            colors = data.columns.map(colors).values
2✔
502
            if pd.isna(colors).any():
2✔
503
                raise KeyError(
504
                    "Some labels mapped by colors: %r"
505
                    % data.columns[pd.isna(colors)].tolist()
506
                )
507

508
        self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True)
2✔
509

510
        handles, labels = ax.get_legend_handles_labels()
2✔
511
        if self._horizontal:
2✔
512
            # Make legend order match visual stack order
513
            ax.legend(reversed(handles), reversed(labels))
2✔
514
        else:
515
            ax.legend()
2✔
516

517
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None):
2✔
518
        """Add a stacked bar chart over subsets when :func:`plot` is called.
519

520
        Used to plot categorical variable distributions within each subset.
521

522
        .. versionadded:: 0.6
523

524
        Parameters
525
        ----------
526
        by : str
527
            Column name within the dataframe for color coding the stacked bars,
528
            containing discrete or categorical values.
529
        sum_over : str, optional
530
            Ordinarily the bars will chart the size of each group. sum_over
531
            may specify a column which will be summed to determine the size
532
            of each bar.
533
        colors : Mapping, list-like, str or callable, optional
534
            The facecolors to use for bars corresponding to each discrete
535
            label, specified as one of:
536

537
            Mapping
538
                Maps from label to matplotlib-compatible color specification.
539
            list-like
540
                A list of matplotlib colors to apply to labels in order.
541
            str
542
                The name of a matplotlib colormap name.
543
            callable
544
                When called with the number of labels, this should return a
545
                list-like of that many colors.  Matplotlib colormaps satisfy
546
                this callable API.
547
            None
548
                Uses the matplotlib default colormap.
549
        elements : int, default=3
550
            Size of the axes counted in number of matrix elements.
551
        title : str, optional
552
            The axis title labelling bar length.
553

554
        Returns
555
        -------
556
        None
557
        """
558
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
559
        #                        list of labels}
560
        self._subset_plots.append(
2✔
561
            {
2✔
562
                "type": "stacked_bars",
2✔
563
                "by": by,
2✔
564
                "sum_over": sum_over,
2✔
565
                "colors": colors,
2✔
566
                "title": title,
2✔
567
                "id": "extra%d" % len(self._subset_plots),
2✔
568
                "elements": elements,
2✔
569
            }
570
        )
571

572
    def add_catplot(self, kind, value=None, elements=3, **kw):
2✔
573
        """Add a seaborn catplot over subsets when :func:`plot` is called.
574

575
        Parameters
576
        ----------
577
        kind : str
578
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
579
        value : str, optional
580
            Column name for the value to plot (i.e. y if
581
            orientation='horizontal'), required if `data` is a DataFrame.
582
        elements : int, default=3
583
            Size of the axes counted in number of matrix elements.
584
        **kw : dict
585
            Additional keywords to pass to :func:`seaborn.catplot`.
586

587
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
588
            and 'orient', so these are prohibited keys in `kw`.
589

590
        Returns
591
        -------
592
        None
593
        """
594
        assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"}
595
        if value is None:
596
            if "_value" not in self._df.columns:
597
                raise ValueError(
598
                    "value cannot be set if data is a Series. " "Got %r" % value
599
                )
600
        else:
601
            if value not in self._df.columns:
602
                raise ValueError("value %r is not a column in data" % value)
603
        self._subset_plots.append(
604
            {
605
                "type": "catplot",
606
                "value": value,
607
                "kind": kind,
608
                "id": "extra%d" % len(self._subset_plots),
609
                "elements": elements,
610
                "kw": kw,
611
            }
612
        )
613

614
    def _check_value(self, value):
2✔
615
        if value is None and "_value" in self._df.columns:
616
            value = "_value"
617
        elif value is None:
618
            raise ValueError("value can only be None when data is a Series")
619
        return value
620

621
    def _plot_catplot(self, ax, value, kind, kw):
2✔
622
        df = self._df
623
        value = self._check_value(value)
624
        kw = kw.copy()
625
        if self._horizontal:
626
            kw["orient"] = "v"
627
            kw["x"] = "_bin"
628
            kw["y"] = value
629
        else:
630
            kw["orient"] = "h"
631
            kw["x"] = value
632
            kw["y"] = "_bin"
633
        import seaborn
634

635
        kw["ax"] = ax
636
        getattr(seaborn, kind + "plot")(data=df, **kw)
637

638
        ax = self._reorient(ax)
639
        if value == "_value":
640
            ax.set_ylabel("")
641

642
        ax.xaxis.set_visible(False)
643
        for x in ["top", "bottom", "right"]:
644
            ax.spines[self._reorient(x)].set_visible(False)
645

646
        tick_axis = ax.yaxis
647
        tick_axis.grid(True)
648

649
    def make_grid(self, fig=None):
2✔
650
        """Get a SubplotSpec for each Axes, accounting for label text width"""
651
        n_cats = len(self.totals)
652
        n_inters = len(self.intersections)
653

654
        if fig is None:
655
            fig = plt.gcf()
656

657
        # Determine text size to determine figure size / spacing
658
        text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]}
659
        # adding "x" ensures a margin
660
        t = fig.text(
661
            0,
662
            0,
663
            "\n".join(str(label) + "x" for label in self.totals.index.values),
664
            **text_kw,
665
        )
666
        window_extent_args = {}
667
        if RENDERER_IMPORTED:
668
            with warnings.catch_warnings():
669
                warnings.simplefilter("ignore", DeprecationWarning)
670
                window_extent_args["renderer"] = get_renderer(fig)
671
        textw = t.get_window_extent(**window_extent_args).width
672
        t.remove()
673

674
        window_extent_args = {}
675
        if RENDERER_IMPORTED:
676
            with warnings.catch_warnings():
677
                warnings.simplefilter("ignore", DeprecationWarning)
678
                window_extent_args["renderer"] = get_renderer(fig)
679
        figw = self._reorient(fig.get_window_extent(**window_extent_args)).width
680

681
        sizes = np.asarray([p["elements"] for p in self._subset_plots])
682
        fig = self._reorient(fig)
683

684
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
685
        if self._element_size is None:
686
            colw = (figw - textw) / non_text_nelems
687
        else:
688
            render_ratio = figw / fig.get_figwidth()
689
            colw = self._element_size / 72 * render_ratio
690
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
691
            fig.set_figwidth(figw / render_ratio)
692
            fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)
693

694
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
695

696
        GS = self._reorient(matplotlib.gridspec.GridSpec)
697
        gridspec = GS(
698
            *self._swapaxes(
699
                n_cats + (sizes.sum() or 0),
700
                n_inters + text_nelems + self._totals_plot_elements,
701
            ),
702
            hspace=1,
703
        )
704
        if self._horizontal:
705
            out = {
706
                "matrix": gridspec[-n_cats:, -n_inters:],
707
                "shading": gridspec[-n_cats:, :],
708
                "totals": None
709
                if self._totals_plot_elements == 0
710
                else gridspec[-n_cats:, : self._totals_plot_elements],
711
                "gs": gridspec,
712
            }
713
            cumsizes = np.cumsum(sizes[::-1])
714
            for start, stop, plot in zip(
715
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]
716
            ):
717
                out[plot["id"]] = gridspec[start:stop, -n_inters:]
718
        else:
719
            out = {
720
                "matrix": gridspec[-n_inters:, :n_cats],
721
                "shading": gridspec[:, :n_cats],
722
                "totals": None
723
                if self._totals_plot_elements == 0
724
                else gridspec[: self._totals_plot_elements, :n_cats],
725
                "gs": gridspec,
726
            }
727
            cumsizes = np.cumsum(sizes)
728
            for start, stop, plot in zip(
729
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots
730
            ):
731
                out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats]
732
        return out
733

734
    def plot_matrix(self, ax):
735
        """Plot the matrix of intersection indicators onto ax"""
736
        ax = self._reorient(ax)
2✔
737
        data = self.intersections
2✔
738
        n_cats = data.index.nlevels
2✔
739

740
        inclusion = data.index.to_frame().values
2✔
741

742
        # Prepare styling
743
        styles = [
2✔
744
            [
2✔
745
                self.subset_styles[i]
2✔
746
                if inclusion[i, j]
2✔
747
                else {"facecolor": self._other_dots_color, "linewidth": 0}
2✔
748
                for j in range(n_cats)
2✔
749
            ]
750
            for i in range(len(data))
2✔
751
        ]
752
        styles = sum(styles, [])  # flatten nested list
2✔
753
        style_columns = {
2✔
754
            "facecolor": "facecolors",
2✔
755
            "edgecolor": "edgecolors",
2✔
756
            "linewidth": "linewidths",
2✔
757
            "linestyle": "linestyles",
2✔
758
            "hatch": "hatch",
2✔
759
        }
760
        styles = (
2✔
761
            pd.DataFrame(styles)
2✔
762
            .reindex(columns=style_columns.keys())
2✔
763
            .astype(
1✔
764
                {
2✔
765
                    "facecolor": "O",
2✔
766
                    "edgecolor": "O",
2✔
767
                    "linewidth": float,
2✔
768
                    "linestyle": "O",
2✔
769
                    "hatch": "O",
2✔
770
                }
771
            )
772
        )
773
        styles["linewidth"].fillna(1, inplace=True)
2✔
774
        styles["facecolor"].fillna(self._facecolor, inplace=True)
2✔
775
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
2✔
776
        styles["linestyle"].fillna("solid", inplace=True)
2✔
777
        del styles["hatch"]  # not supported in matrix (currently)
2✔
778

779
        x = np.repeat(np.arange(len(data)), n_cats)
2✔
780
        y = np.tile(np.arange(n_cats), len(data))
2✔
781

782
        # Plot dots
783
        if self._element_size is not None:  # noqa
2✔
784
            s = (self._element_size * 0.35) ** 2
2✔
785
        else:
786
            # TODO: make s relative to colw
787
            s = 200
2✔
788
        ax.scatter(
2✔
789
            *self._swapaxes(x, y),
2✔
790
            s=s,
2✔
791
            zorder=10,
2✔
792
            **styles.rename(columns=style_columns),
2✔
793
        )
794

795
        # Plot lines
796
        if self._with_lines:
2✔
797
            idx = np.flatnonzero(inclusion)
2✔
798
            line_data = (
2✔
799
                pd.Series(y[idx], index=x[idx])
2✔
800
                .groupby(level=0)
2✔
801
                .aggregate(["min", "max"])
2✔
802
            )
803
            colors = pd.Series(
2✔
804
                [
2✔
805
                    style.get("edgecolor", style.get("facecolor", self._facecolor))
2✔
806
                    for style in self.subset_styles
2✔
807
                ],
808
                name="color",
2✔
809
            )
810
            line_data = line_data.join(colors)
2✔
811
            ax.vlines(
2✔
812
                line_data.index.values,
2✔
813
                line_data["min"],
2✔
814
                line_data["max"],
2✔
815
                lw=2,
2✔
816
                colors=line_data["color"],
2✔
817
                zorder=5,
2✔
818
            )
819

820
        # Ticks and axes
821
        tick_axis = ax.yaxis
2✔
822
        tick_axis.set_ticks(np.arange(n_cats))
2✔
823
        tick_axis.set_ticklabels(
2✔
824
            data.index.names, rotation=0 if self._horizontal else -90
2✔
825
        )
826
        ax.xaxis.set_visible(False)
2✔
827
        ax.tick_params(axis="both", which="both", length=0)
2✔
828
        if not self._horizontal:
2✔
829
            ax.yaxis.set_ticks_position("top")
2✔
830
        ax.set_frame_on(False)
2✔
831
        ax.set_xlim(-0.5, x[-1] + 0.5, auto=False)
2✔
832
        ax.grid(False)
2✔
833

834
    def plot_intersections(self, ax):
2✔
835
        """Plot bars indicating intersection size"""
836
        rects = self._plot_bars(
837
            ax, self.intersections, title="Intersection size", colors=self._facecolor
838
        )
839
        for style, rect in zip(self.subset_styles, rects):
840
            style = style.copy()
841
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
842
            for attr, val in style.items():
843
                getattr(rect, "set_" + attr)(val)
844

845
        if self.subset_legend:
846
            styles, labels = zip(*self.subset_legend)
847
            styles = [patches.Patch(**patch_style) for patch_style in styles]
848
            ax.legend(styles, labels)
849

850
    def _label_sizes(self, ax, rects, where):
851
        if not self._show_counts and not self._show_percentages:
852
            return
853
        if self._show_counts is True:
854
            count_fmt = "{:.0f}"
855
        else:
856
            count_fmt = self._show_counts
857
            if "{" not in count_fmt:
858
                count_fmt = util.to_new_pos_format(count_fmt)
859

860
        pct_fmt = "{:.1%}" if self._show_percentages is True else self._show_percentages
861

862
        if count_fmt and pct_fmt:
863
            if where == "top":
864
                fmt = f"{count_fmt}\n({pct_fmt})"
865
            else:
866
                fmt = f"{count_fmt} ({pct_fmt})"
867

868
            def make_args(val):
869
                return val, val / self.total
870
        elif count_fmt:
871
            fmt = count_fmt
872

873
            def make_args(val):
874
                return (val,)
875
        else:
876
            fmt = pct_fmt
877

878
            def make_args(val):
879
                return (val / self.total,)
880

881
        if where == "right":
882
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
883
            for rect in rects:
884
                width = rect.get_width() + rect.get_x()
885
                ax.text(
886
                    width + margin,
887
                    rect.get_y() + rect.get_height() * 0.5,
888
                    fmt.format(*make_args(width)),
889
                    ha="left",
890
                    va="center",
891
                )
892
        elif where == "left":
893
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
894
            for rect in rects:
895
                width = rect.get_width() + rect.get_x()
896
                ax.text(
897
                    width + margin,
898
                    rect.get_y() + rect.get_height() * 0.5,
899
                    fmt.format(*make_args(width)),
900
                    ha="right",
901
                    va="center",
902
                )
903
        elif where == "top":
904
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
905
            for rect in rects:
906
                height = rect.get_height() + rect.get_y()
907
                ax.text(
908
                    rect.get_x() + rect.get_width() * 0.5,
909
                    height + margin,
910
                    fmt.format(*make_args(height)),
911
                    ha="center",
912
                    va="bottom",
913
                )
914
        else:
915
            raise NotImplementedError("unhandled where: %r" % where)
916

917
    def plot_totals(self, ax):
918
        """Plot bars indicating total set size"""
919
        orig_ax = ax
2✔
920
        ax = self._reorient(ax)
2✔
921
        rects = ax.barh(
2✔
922
            np.arange(len(self.totals.index.values)),
2✔
923
            self.totals,
2✔
924
            0.5,
2✔
925
            color=self._facecolor,
2✔
926
            align="center",
2✔
927
        )
928
        self._label_sizes(ax, rects, "left" if self._horizontal else "top")
2✔
929

930
        max_total = self.totals.max()
2✔
931
        if self._horizontal:
2✔
932
            orig_ax.set_xlim(max_total, 0)
2✔
933
        for x in ["top", "left", "right"]:
2✔
934
            ax.spines[self._reorient(x)].set_visible(False)
2✔
935
        ax.yaxis.set_visible(False)
2✔
936
        ax.xaxis.grid(True)
2✔
937
        ax.yaxis.grid(False)
2✔
938
        ax.patch.set_visible(False)
2✔
939

940
    def plot_shading(self, ax):
2✔
941
        # alternating row shading (XXX: use add_patch(Rectangle)?)
942
        for i in range(0, len(self.totals), 2):
2✔
943
            rect = plt.Rectangle(
2✔
944
                self._swapaxes(0, i - 0.4),
2✔
945
                *self._swapaxes(*(1, 0.8)),
2✔
946
                facecolor=self._shading_color,
2✔
947
                lw=0,
2✔
948
                zorder=0,
2✔
949
            )
950
            ax.add_patch(rect)
2✔
951
        ax.set_frame_on(False)
2✔
952
        ax.tick_params(
2✔
953
            axis="both",
2✔
954
            which="both",
2✔
955
            left=False,
2✔
956
            right=False,
2✔
957
            bottom=False,
2✔
958
            top=False,
2✔
959
            labelbottom=False,
2✔
960
            labelleft=False,
2✔
961
        )
962
        ax.grid(False)
2✔
963
        ax.set_xticks([])
2✔
964
        ax.set_yticks([])
2✔
965
        ax.set_xticklabels([])
2✔
966
        ax.set_yticklabels([])
2✔
967

968
    def plot(self, fig=None):
2✔
969
        """Draw all parts of the plot onto fig or a new figure
970

971
        Parameters
972
        ----------
973
        fig : matplotlib.figure.Figure, optional
974
            Defaults to a new figure.
975

976
        Returns
977
        -------
978
        subplots : dict of matplotlib.axes.Axes
979
            Keys are 'matrix', 'intersections', 'totals', 'shading'
980
        """
981
        if fig is None:
2✔
982
            fig = plt.figure(figsize=self._default_figsize)
2✔
983
        specs = self.make_grid(fig)
2✔
984
        shading_ax = fig.add_subplot(specs["shading"])
2✔
985
        self.plot_shading(shading_ax)
2✔
986
        matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax)
2✔
987
        self.plot_matrix(matrix_ax)
2✔
988
        if specs["totals"] is None:
2✔
989
            totals_ax = None
2✔
990
        else:
991
            totals_ax = self._reorient(fig.add_subplot)(
2✔
992
                specs["totals"], sharey=matrix_ax
2✔
993
            )
994
            self.plot_totals(totals_ax)
2✔
995
        out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax}
2✔
996

997
        for plot in self._subset_plots:
2✔
998
            ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax)
2✔
999
            if plot["type"] == "default":
2✔
1000
                self.plot_intersections(ax)
2✔
1001
            elif plot["type"] in self.PLOT_TYPES:
2✔
1002
                kw = plot.copy()
2✔
1003
                del kw["type"]
2✔
1004
                del kw["elements"]
2✔
1005
                del kw["id"]
2✔
1006
                self.PLOT_TYPES[plot["type"]](self, ax, **kw)
2✔
1007
            else:
1008
                raise ValueError("Unknown subset plot type: %r" % plot["type"])
1009
            out[plot["id"]] = ax
2✔
1010
        return out
2✔
1011

1012
    PLOT_TYPES = {
2✔
1013
        "catplot": _plot_catplot,
2✔
1014
        "stacked_bars": _plot_stacked_bars,
2✔
1015
    }
1016

1017
    def _repr_html_(self):
2✔
1018
        fig = plt.figure(figsize=self._default_figsize)
1019
        self.plot(fig=fig)
1020
        return fig._repr_html_()
1021

1022

1023
def plot(data, fig=None, **kwargs):
2✔
1024
    """Make an UpSet plot of data on fig
1025

1026
    Parameters
1027
    ----------
1028
    data : pandas.Series or pandas.DataFrame
1029
        Values for each set to plot.
1030
        Should have multi-index where each level is binary,
1031
        corresponding to set membership.
1032
        If a DataFrame, `sum_over` must be a string or False.
1033
    fig : matplotlib.figure.Figure, optional
1034
        Defaults to a new figure.
1035
    kwargs
1036
        Other arguments for :class:`UpSet`
1037

1038
    Returns
1039
    -------
1040
    subplots : dict of matplotlib.axes.Axes
1041
        Keys are 'matrix', 'intersections', 'totals', 'shading'
1042
    """
1043
    return UpSet(data, **kwargs).plot(fig)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc