• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7353712529

29 Dec 2023 04:13AM UTC coverage: 99.275%. Remained the same
7353712529

push

github

jnothman
Merge remote-tracking branch 'origin/master' into style_categories

1 of 1 new or added line in 1 file covered. (100.0%)

1 existing line in 1 file now uncovered.

1779 of 1792 relevant lines covered (99.27%)

1.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/upsetplot/plotting.py
1
import typing
2✔
2
import warnings
2✔
3

4
import matplotlib
2✔
5
import numpy as np
2✔
6
import pandas as pd
2✔
7
from matplotlib import colors, patches
2✔
8
from matplotlib import pyplot as plt
2✔
9

10
from . import util
2✔
11
from .reformat import _get_subset_mask, query
2✔
12

13
# prevents ImportError on matplotlib versions >3.5.2
14
try:
2✔
15
    from matplotlib.tight_layout import get_renderer
2✔
16

UNCOV
17
    RENDERER_IMPORTED = True
1✔
18
except ImportError:
1✔
19
    RENDERER_IMPORTED = False
1✔
20

21

22
def _process_data(
2✔
23
    df,
24
    *,
25
    sort_by,
26
    sort_categories_by,
27
    subset_size,
28
    sum_over,
29
    min_subset_size=None,
2✔
30
    max_subset_size=None,
2✔
31
    max_subset_rank=None,
2✔
32
    min_degree=None,
2✔
33
    max_degree=None,
2✔
34
    reverse=False,
2✔
35
    include_empty_subsets=False,
2✔
36
):
37
    results = query(
2✔
38
        df,
2✔
39
        sort_by=sort_by,
2✔
40
        sort_categories_by=sort_categories_by,
2✔
41
        subset_size=subset_size,
2✔
42
        sum_over=sum_over,
2✔
43
        min_subset_size=min_subset_size,
2✔
44
        max_subset_size=max_subset_size,
2✔
45
        max_subset_rank=max_subset_rank,
2✔
46
        min_degree=min_degree,
2✔
47
        max_degree=max_degree,
2✔
48
        include_empty_subsets=include_empty_subsets,
2✔
49
    )
50

51
    df = results.data
2✔
52
    agg = results.subset_sizes
2✔
53

54
    # add '_bin' to df indicating index in agg
55
    # XXX: ugly!
56
    def _pack_binary(X):
2✔
57
        X = pd.DataFrame(X)
2✔
58
        # use objects if arbitrary precision integers are needed
59
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
2✔
60
        out = pd.Series(0, index=X.index, dtype=dtype)
2✔
61
        for _, col in X.items():
2✔
62
            out *= 2
2✔
63
            out += col
2✔
64
        return out
2✔
65

66
    df_packed = _pack_binary(df.index.to_frame())
2✔
67
    data_packed = _pack_binary(agg.index.to_frame())
2✔
68
    df["_bin"] = pd.Series(df_packed).map(
2✔
69
        pd.Series(
2✔
70
            np.arange(len(data_packed))[:: -1 if reverse else 1], index=data_packed
2✔
71
        )
72
    )
73
    if reverse:
2✔
74
        agg = agg[::-1]
2✔
75

76
    return results.total, df, agg, results.category_totals
2✔
77

78

79
def _multiply_alpha(c, mult):
2✔
80
    r, g, b, a = colors.to_rgba(c)
2✔
81
    a *= mult
2✔
82
    return colors.to_hex((r, g, b, a), keep_alpha=True)
2✔
83

84

85
class _Transposed:
2✔
86
    """Wrap an object in order to transpose some plotting operations
87

88
    Attributes of obj will be mapped.
89
    Keyword arguments when calling obj will be mapped.
90

91
    The mapping is not recursive: callable attributes need to be _Transposed
92
    again.
93
    """
94

95
    def __init__(self, obj):
2✔
96
        self.__obj = obj
2✔
97

98
    def __getattr__(self, key):
2✔
99
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
2✔
100

101
    def __call__(self, *args, **kwargs):
2✔
102
        return self.__obj(
2✔
103
            *args, **{self._NAME_TRANSPOSE.get(k, k): v for k, v in kwargs.items()}
2✔
104
        )
105

106
    _NAME_TRANSPOSE = {
2✔
107
        "width": "height",
2✔
108
        "height": "width",
2✔
109
        "hspace": "wspace",
2✔
110
        "wspace": "hspace",
2✔
111
        "hlines": "vlines",
2✔
112
        "vlines": "hlines",
2✔
113
        "bar": "barh",
2✔
114
        "barh": "bar",
2✔
115
        "xaxis": "yaxis",
2✔
116
        "yaxis": "xaxis",
2✔
117
        "left": "bottom",
2✔
118
        "right": "top",
2✔
119
        "top": "right",
2✔
120
        "bottom": "left",
2✔
121
        "sharex": "sharey",
2✔
122
        "sharey": "sharex",
2✔
123
        "get_figwidth": "get_figheight",
2✔
124
        "get_figheight": "get_figwidth",
2✔
125
        "set_figwidth": "set_figheight",
2✔
126
        "set_figheight": "set_figwidth",
2✔
127
        "set_xlabel": "set_ylabel",
2✔
128
        "set_ylabel": "set_xlabel",
2✔
129
        "set_xlim": "set_ylim",
2✔
130
        "set_ylim": "set_xlim",
2✔
131
        "get_xlim": "get_ylim",
2✔
132
        "get_ylim": "get_xlim",
2✔
133
        "set_autoscalex_on": "set_autoscaley_on",
2✔
134
        "set_autoscaley_on": "set_autoscalex_on",
2✔
135
    }
136

137

138
def _transpose(obj):
2✔
139
    if isinstance(obj, str):
2✔
140
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
2✔
141
    return _Transposed(obj)
2✔
142

143

144
def _identity(obj):
2✔
145
    return obj
2✔
146

147

148
class UpSet:
2✔
149
    """Manage the data and drawing for a basic UpSet plot
150

151
    Primary public method is :meth:`plot`.
152

153
    Parameters
154
    ----------
155
    data : pandas.Series or pandas.DataFrame
156
        Elements associated with categories (a DataFrame), or the size of each
157
        subset of categories (a Series).
158
        Should have MultiIndex where each level is binary,
159
        corresponding to category membership.
160
        If a DataFrame, `sum_over` must be a string or False.
161
    orientation : {'horizontal' (default), 'vertical'}
162
        If horizontal, intersections are listed from left to right.
163
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
164
               'input', '-input'}
165
        If 'cardinality', subset are listed from largest to smallest.
166
        If 'degree', they are listed in order of the number of categories
167
        intersected. If 'input', the order they appear in the data input is
168
        used.
169
        Prefix with '-' to reverse the ordering.
170

171
        Note this affects ``subset_sizes`` but not ``data``.
172
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
173
        Whether to sort the categories by total cardinality, or leave them
174
        in the input data's provided order (order of index levels).
175
        Prefix with '-' to reverse the ordering.
176
    subset_size : {'auto', 'count', 'sum'}
177
        Configures how to calculate the size of a subset. Choices are:
178

179
        'auto' (default)
180
            If `data` is a DataFrame, count the number of rows in each group,
181
            unless `sum_over` is specified.
182
            If `data` is a Series with at most one row for each group, use
183
            the value of the Series. If `data` is a Series with more than one
184
            row per group, raise a ValueError.
185
        'count'
186
            Count the number of rows in each group.
187
        'sum'
188
            Sum the value of the `data` Series, or the DataFrame field
189
            specified by `sum_over`.
190
    sum_over : str or None
191
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
192
        sum of the specified field in the `data` DataFrame. If a Series, only
193
        None is supported and its value is summed.
194
    min_subset_size : int, optional
195
        Minimum size of a subset to be shown in the plot. All subsets with
196
        a size smaller than this threshold will be omitted from plotting.
197
        Size may be a sum of values, see `subset_size`.
198

199
        .. versionadded:: 0.5
200
    max_subset_size : int, optional
201
        Maximum size of a subset to be shown in the plot. All subsets with
202
        a size greater than this threshold will be omitted from plotting.
203

204
        .. versionadded:: 0.5
205
    max_subset_rank : int, optional
206
        Limit to the top N ranked subsets in descending order of size.
207
        All tied subsets are included.
208

209
        .. versionadded:: 0.9
210
    min_degree : int, optional
211
        Minimum degree of a subset to be shown in the plot.
212

213
        .. versionadded:: 0.5
214
    max_degree : int, optional
215
        Maximum degree of a subset to be shown in the plot.
216

217
        .. versionadded:: 0.5
218
    facecolor : 'auto' or matplotlib color or float
219
        Color for bar charts and active dots. Defaults to black if
220
        axes.facecolor is a light color, otherwise white.
221

222
        .. versionchanged:: 0.6
223
            Before 0.6, the default was 'black'
224
    other_dots_color : matplotlib color or float
225
        Color for shading of inactive dots, or opacity (between 0 and 1)
226
        applied to facecolor.
227

228
        .. versionadded:: 0.6
229
    shading_color : matplotlib color or float
230
        Color for shading of odd rows in matrix and totals, or opacity (between
231
        0 and 1) applied to facecolor.
232

233
        .. versionadded:: 0.6
234
    with_lines : bool
235
        Whether to show lines joining dots in the matrix, to mark multiple
236
        categories being intersected.
237
    element_size : float or None
238
        Side length in pt. If None, size is estimated to fit figure
239
    intersection_plot_elements : int
240
        The intersections plot should be large enough to fit this many matrix
241
        elements. Set to 0 to disable intersection size bars.
242

243
        .. versionchanged:: 0.4
244
            Setting to 0 is handled.
245
    totals_plot_elements : int
246
        The totals plot should be large enough to fit this many matrix
247
        elements. Use totals_plot_elements=0 to disable the totals plot.
248
    show_counts : bool or str, default=False
249
        Whether to label the intersection size bars with the cardinality
250
        of the intersection. When a string, this formats the number.
251
        For example, '{:d}' is equivalent to True.
252
        Note that, for legacy reasons, if the string does not contain '{',
253
        it will be interpreted as a C-style format string, such as '%d'.
254
    show_percentages : bool or str, default=False
255
        Whether to label the intersection size bars with the percentage
256
        of the intersection relative to the total dataset.
257
        When a string, this formats the number representing a fraction of
258
        samples.
259
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
260
        This may be applied with or without show_counts.
261

262
        .. versionadded:: 0.4
263
    include_empty_subsets : bool (default=False)
264
        If True, all possible category combinations will be shown as subsets,
265
        even when some are not present in data.
266
    """
267

268
    _default_figsize = (10, 6)
2✔
269
    DPI = 100  # standard matplotlib value
2✔
270

271
    def __init__(
2✔
272
        self,
273
        data,
274
        orientation="horizontal",
1✔
275
        sort_by="degree",
1✔
276
        sort_categories_by="cardinality",
1✔
277
        subset_size="auto",
1✔
278
        sum_over=None,
1✔
279
        min_subset_size=None,
1✔
280
        max_subset_size=None,
1✔
281
        max_subset_rank=None,
1✔
282
        min_degree=None,
1✔
283
        max_degree=None,
1✔
284
        facecolor="auto",
1✔
285
        other_dots_color=0.18,
1✔
286
        shading_color=0.05,
1✔
287
        with_lines=True,
1✔
288
        element_size=32,
1✔
289
        intersection_plot_elements=6,
1✔
290
        totals_plot_elements=2,
1✔
291
        show_counts="",
1✔
292
        show_percentages=False,
1✔
293
        include_empty_subsets=False,
1✔
294
    ):
295
        self._horizontal = orientation == "horizontal"
2✔
296
        self._reorient = _identity if self._horizontal else _transpose
2✔
297
        if facecolor == "auto":
2✔
298
            bgcolor = matplotlib.rcParams.get("axes.facecolor", "white")
2✔
299
            r, g, b, a = colors.to_rgba(bgcolor)
2✔
300
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
2✔
301
            facecolor = "black" if lightness >= 0.5 else "white"
2✔
302
        self._facecolor = facecolor
2✔
303
        self._shading_color = (
2✔
304
            _multiply_alpha(facecolor, shading_color)
2✔
305
            if isinstance(shading_color, float)
2✔
306
            else shading_color
2✔
307
        )
308
        self._other_dots_color = (
2✔
309
            _multiply_alpha(facecolor, other_dots_color)
2✔
310
            if isinstance(other_dots_color, float)
2✔
311
            else other_dots_color
2✔
312
        )
313
        self._with_lines = with_lines
2✔
314
        self._element_size = element_size
2✔
315
        self._totals_plot_elements = totals_plot_elements
2✔
316
        self._subset_plots = [
2✔
317
            {
2✔
318
                "type": "default",
2✔
319
                "id": "intersections",
2✔
320
                "elements": intersection_plot_elements,
2✔
321
            }
322
        ]
323
        if not intersection_plot_elements:
2✔
324
            self._subset_plots.pop()
2✔
325
        self._show_counts = show_counts
2✔
326
        self._show_percentages = show_percentages
2✔
327

328
        (self.total, self._df, self.intersections, self.totals) = _process_data(
2✔
329
            data,
2✔
330
            sort_by=sort_by,
2✔
331
            sort_categories_by=sort_categories_by,
2✔
332
            subset_size=subset_size,
2✔
333
            sum_over=sum_over,
2✔
334
            min_subset_size=min_subset_size,
2✔
335
            max_subset_size=max_subset_size,
2✔
336
            max_subset_rank=max_subset_rank,
2✔
337
            min_degree=min_degree,
2✔
338
            max_degree=max_degree,
2✔
339
            reverse=not self._horizontal,
2✔
340
            include_empty_subsets=include_empty_subsets,
2✔
341
        )
342
        self.category_styles = {}
2✔
343
        self.subset_styles = [
2✔
344
            {"facecolor": facecolor} for i in range(len(self.intersections))
2✔
345
        ]
346
        self.subset_legend = []  # pairs of (style, label)
2✔
347

348
    def _swapaxes(self, x, y):
2✔
349
        if self._horizontal:
2✔
350
            return x, y
2✔
351
        return y, x
2✔
352

353
    def style_subsets(
2✔
354
        self,
355
        present=None,
1✔
356
        absent=None,
1✔
357
        min_subset_size=None,
1✔
358
        max_subset_size=None,
1✔
359
        max_subset_rank=None,
1✔
360
        min_degree=None,
1✔
361
        max_degree=None,
1✔
362
        facecolor=None,
1✔
363
        edgecolor=None,
1✔
364
        hatch=None,
1✔
365
        linewidth=None,
1✔
366
        linestyle=None,
1✔
367
        label=None,
1✔
368
    ):
369
        """Updates the style of selected subsets' bars and matrix dots
370

371
        Parameters are either used to select subsets, or to style them with
372
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
373
        which adds a legend entry.
374

375
        Parameters
376
        ----------
377
        present : str or list of str, optional
378
            Category or categories that must be present in subsets for styling.
379
        absent : str or list of str, optional
380
            Category or categories that must not be present in subsets for
381
            styling.
382
        min_subset_size : int, optional
383
            Minimum size of a subset to be styled.
384
        max_subset_size : int, optional
385
            Maximum size of a subset to be styled.
386
        max_subset_rank : int, optional
387
            Limit to the top N ranked subsets in descending order of size.
388
            All tied subsets are included.
389

390
            .. versionadded:: 0.9
391
        min_degree : int, optional
392
            Minimum degree of a subset to be styled.
393
        max_degree : int, optional
394
            Maximum degree of a subset to be styled.
395

396
        facecolor : str or matplotlib color, optional
397
            Override the default UpSet facecolor for selected subsets.
398
        edgecolor : str or matplotlib color, optional
399
            Set the edgecolor for bars, dots, and the line between dots.
400
        hatch : str, optional
401
            Set the hatch. This will apply to intersection size bars, but not
402
            to matrix dots.
403
        linewidth : int, optional
404
            Line width in points for edges.
405
        linestyle : str, optional
406
            Line style for edges.
407

408
        label : str, optional
409
            If provided, a legend will be added
410
        """
411
        style = {
2✔
412
            "facecolor": facecolor,
2✔
413
            "edgecolor": edgecolor,
2✔
414
            "hatch": hatch,
2✔
415
            "linewidth": linewidth,
2✔
416
            "linestyle": linestyle,
2✔
417
        }
418
        style = {k: v for k, v in style.items() if v is not None}
2✔
419
        mask = _get_subset_mask(
2✔
420
            self.intersections,
2✔
421
            present=present,
2✔
422
            absent=absent,
2✔
423
            min_subset_size=min_subset_size,
2✔
424
            max_subset_size=max_subset_size,
2✔
425
            max_subset_rank=max_subset_rank,
2✔
426
            min_degree=min_degree,
2✔
427
            max_degree=max_degree,
2✔
428
        )
429
        for idx in np.flatnonzero(mask):
2✔
430
            self.subset_styles[idx].update(style)
2✔
431

432
        if label is not None:
2✔
433
            if "facecolor" not in style:
2✔
434
                style["facecolor"] = self._facecolor
435
            for i, (other_style, other_label) in enumerate(self.subset_legend):
2✔
436
                if other_style == style:
2✔
437
                    if other_label != label:
2✔
438
                        self.subset_legend[i] = (style, other_label + "; " + label)
2✔
439
                    break
2✔
440
            else:
441
                self.subset_legend.append((style, label))
2✔
442

443
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
2✔
444
        ax = self._reorient(ax)
2✔
445
        ax.set_autoscalex_on(False)
2✔
446
        data_df = pd.DataFrame(data)
2✔
447
        if self._horizontal:
2✔
448
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
2✔
449

450
        # TODO: colors should be broadcastable to data_df shape
451
        if callable(colors):
2✔
452
            colors = colors(range(data_df.shape[1]))
2✔
453
        elif isinstance(colors, (str, type(None))):
2✔
454
            colors = [colors] * len(data_df)
2✔
455

456
        if self._horizontal:
2✔
457
            colors = reversed(colors)
2✔
458

459
        x = np.arange(len(data_df))
2✔
460
        cum_y = None
2✔
461
        all_rects = []
2✔
462
        for (name, y), color in zip(data_df.items(), colors):
2✔
463
            rects = ax.bar(
2✔
464
                x,
2✔
465
                y,
2✔
466
                0.5,
2✔
467
                cum_y,
2✔
468
                color=color,
2✔
469
                zorder=10,
2✔
470
                label=name if use_labels else None,
2✔
471
                align="center",
2✔
472
            )
473
            cum_y = y if cum_y is None else cum_y + y
2✔
474
            all_rects.extend(rects)
2✔
475

476
        self._label_sizes(ax, rects, "top" if self._horizontal else "right")
2✔
477

478
        ax.xaxis.set_visible(False)
2✔
479
        for x in ["top", "bottom", "right"]:
2✔
480
            ax.spines[self._reorient(x)].set_visible(False)
2✔
481

482
        tick_axis = ax.yaxis
2✔
483
        tick_axis.grid(True)
2✔
484
        ax.set_ylabel(title)
2✔
485
        return all_rects
2✔
486

487
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
2✔
488
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
2✔
489
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
2✔
490
        if sum_over is None and "_value" in df.columns:
2✔
491
            data = gb["_value"].sum()
492
        elif sum_over is None:
2✔
493
            data = gb.size()
2✔
494
        else:
495
            data = gb[sum_over].sum()
2✔
496
        data = data.unstack(by).fillna(0)
2✔
497
        if isinstance(colors, str):
2✔
498
            colors = matplotlib.cm.get_cmap(colors)
2✔
499
        elif isinstance(colors, typing.Mapping):
2✔
500
            colors = data.columns.map(colors).values
2✔
501
            if pd.isna(colors).any():
2✔
502
                raise KeyError(
503
                    "Some labels mapped by colors: %r"
504
                    % data.columns[pd.isna(colors)].tolist()
505
                )
506

507
        self._plot_bars(ax, data=data, colors=colors, title=title, use_labels=True)
2✔
508

509
        handles, labels = ax.get_legend_handles_labels()
2✔
510
        if self._horizontal:
2✔
511
            # Make legend order match visual stack order
512
            ax.legend(reversed(handles), reversed(labels))
2✔
513
        else:
514
            ax.legend()
2✔
515

516
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3, title=None):
2✔
517
        """Add a stacked bar chart over subsets when :func:`plot` is called.
518

519
        Used to plot categorical variable distributions within each subset.
520

521
        .. versionadded:: 0.6
522

523
        Parameters
524
        ----------
525
        by : str
526
            Column name within the dataframe for color coding the stacked bars,
527
            containing discrete or categorical values.
528
        sum_over : str, optional
529
            Ordinarily the bars will chart the size of each group. sum_over
530
            may specify a column which will be summed to determine the size
531
            of each bar.
532
        colors : Mapping, list-like, str or callable, optional
533
            The facecolors to use for bars corresponding to each discrete
534
            label, specified as one of:
535

536
            Mapping
537
                Maps from label to matplotlib-compatible color specification.
538
            list-like
539
                A list of matplotlib colors to apply to labels in order.
540
            str
541
                The name of a matplotlib colormap name.
542
            callable
543
                When called with the number of labels, this should return a
544
                list-like of that many colors.  Matplotlib colormaps satisfy
545
                this callable API.
546
            None
547
                Uses the matplotlib default colormap.
548
        elements : int, default=3
549
            Size of the axes counted in number of matrix elements.
550
        title : str, optional
551
            The axis title labelling bar length.
552

553
        Returns
554
        -------
555
        None
556
        """
557
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
558
        #                        list of labels}
559
        self._subset_plots.append(
2✔
560
            {
2✔
561
                "type": "stacked_bars",
2✔
562
                "by": by,
2✔
563
                "sum_over": sum_over,
2✔
564
                "colors": colors,
2✔
565
                "title": title,
2✔
566
                "id": "extra%d" % len(self._subset_plots),
2✔
567
                "elements": elements,
2✔
568
            }
569
        )
570

571
    def add_catplot(self, kind, value=None, elements=3, **kw):
2✔
572
        """Add a seaborn catplot over subsets when :func:`plot` is called.
573

574
        Parameters
575
        ----------
576
        kind : str
577
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
578
        value : str, optional
579
            Column name for the value to plot (i.e. y if
580
            orientation='horizontal'), required if `data` is a DataFrame.
581
        elements : int, default=3
582
            Size of the axes counted in number of matrix elements.
583
        **kw : dict
584
            Additional keywords to pass to :func:`seaborn.catplot`.
585

586
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
587
            and 'orient', so these are prohibited keys in `kw`.
588

589
        Returns
590
        -------
591
        None
592
        """
593
        assert not set(kw.keys()) & {"ax", "data", "x", "y", "orient"}
594
        if value is None:
595
            if "_value" not in self._df.columns:
596
                raise ValueError(
597
                    "value cannot be set if data is a Series. " "Got %r" % value
598
                )
599
        else:
600
            if value not in self._df.columns:
601
                raise ValueError("value %r is not a column in data" % value)
602
        self._subset_plots.append(
603
            {
604
                "type": "catplot",
605
                "value": value,
606
                "kind": kind,
607
                "id": "extra%d" % len(self._subset_plots),
608
                "elements": elements,
609
                "kw": kw,
610
            }
611
        )
612

613
    def _check_value(self, value):
2✔
614
        if value is None and "_value" in self._df.columns:
615
            value = "_value"
616
        elif value is None:
617
            raise ValueError("value can only be None when data is a Series")
618
        return value
619

620
    def _plot_catplot(self, ax, value, kind, kw):
2✔
621
        df = self._df
622
        value = self._check_value(value)
623
        kw = kw.copy()
624
        if self._horizontal:
625
            kw["orient"] = "v"
626
            kw["x"] = "_bin"
627
            kw["y"] = value
628
        else:
629
            kw["orient"] = "h"
630
            kw["x"] = value
631
            kw["y"] = "_bin"
632
        import seaborn
633

634
        kw["ax"] = ax
635
        getattr(seaborn, kind + "plot")(data=df, **kw)
636

637
        ax = self._reorient(ax)
638
        if value == "_value":
639
            ax.set_ylabel("")
640

641
        ax.xaxis.set_visible(False)
642
        for x in ["top", "bottom", "right"]:
643
            ax.spines[self._reorient(x)].set_visible(False)
644

645
        tick_axis = ax.yaxis
646
        tick_axis.grid(True)
647

648
    def make_grid(self, fig=None):
2✔
649
        """Get a SubplotSpec for each Axes, accounting for label text width"""
650
        n_cats = len(self.totals)
651
        n_inters = len(self.intersections)
652

653
        if fig is None:
654
            fig = plt.gcf()
655

656
        # Determine text size to determine figure size / spacing
657
        text_kw = {"size": matplotlib.rcParams["xtick.labelsize"]}
658
        # adding "x" ensures a margin
659
        t = fig.text(
660
            0,
661
            0,
662
            "\n".join(str(label) + "x" for label in self.totals.index.values),
663
            **text_kw,
664
        )
665
        window_extent_args = {}
666
        if RENDERER_IMPORTED:
667
            with warnings.catch_warnings():
668
                warnings.simplefilter("ignore", DeprecationWarning)
669
                window_extent_args["renderer"] = get_renderer(fig)
670
        textw = t.get_window_extent(**window_extent_args).width
671
        t.remove()
672

673
        window_extent_args = {}
674
        if RENDERER_IMPORTED:
675
            with warnings.catch_warnings():
676
                warnings.simplefilter("ignore", DeprecationWarning)
677
                window_extent_args["renderer"] = get_renderer(fig)
678
        figw = self._reorient(fig.get_window_extent(**window_extent_args)).width
679

680
        sizes = np.asarray([p["elements"] for p in self._subset_plots])
681
        fig = self._reorient(fig)
682

683
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
684
        if self._element_size is None:
685
            colw = (figw - textw) / non_text_nelems
686
        else:
687
            render_ratio = figw / fig.get_figwidth()
688
            colw = self._element_size / 72 * render_ratio
689
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
690
            fig.set_figwidth(figw / render_ratio)
691
            fig.set_figheight((colw * (n_cats + sizes.sum())) / render_ratio)
692

693
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
694

695
        GS = self._reorient(matplotlib.gridspec.GridSpec)
696
        gridspec = GS(
697
            *self._swapaxes(
698
                n_cats + (sizes.sum() or 0),
699
                n_inters + text_nelems + self._totals_plot_elements,
700
            ),
701
            hspace=1,
702
        )
703
        if self._horizontal:
704
            out = {
705
                "matrix": gridspec[-n_cats:, -n_inters:],
706
                "shading": gridspec[-n_cats:, :],
707
                "totals": None
708
                if self._totals_plot_elements == 0
709
                else gridspec[-n_cats:, : self._totals_plot_elements],
710
                "gs": gridspec,
711
            }
712
            cumsizes = np.cumsum(sizes[::-1])
713
            for start, stop, plot in zip(
714
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots[::-1]
715
            ):
716
                out[plot["id"]] = gridspec[start:stop, -n_inters:]
717
        else:
718
            out = {
719
                "matrix": gridspec[-n_inters:, :n_cats],
720
                "shading": gridspec[:, :n_cats],
721
                "totals": None
722
                if self._totals_plot_elements == 0
723
                else gridspec[: self._totals_plot_elements, :n_cats],
724
                "gs": gridspec,
725
            }
726
            cumsizes = np.cumsum(sizes)
727
            for start, stop, plot in zip(
728
                np.hstack([[0], cumsizes]), cumsizes, self._subset_plots
729
            ):
730
                out[plot["id"]] = gridspec[-n_inters:, start + n_cats : stop + n_cats]
731
        return out
732

733
    def plot_matrix(self, ax):
734
        """Plot the matrix of intersection indicators onto ax"""
735
        ax = self._reorient(ax)
2✔
736
        data = self.intersections
2✔
737
        n_cats = data.index.nlevels
2✔
738

739
        inclusion = data.index.to_frame().values
2✔
740

741
        # Prepare styling
742
        styles = [
2✔
743
            [
2✔
744
                self.subset_styles[i]
2✔
745
                if inclusion[i, j]
2✔
746
                else {"facecolor": self._other_dots_color, "linewidth": 0}
2✔
747
                for j in range(n_cats)
2✔
748
            ]
749
            for i in range(len(data))
2✔
750
        ]
751
        styles = sum(styles, [])  # flatten nested list
2✔
752
        style_columns = {
2✔
753
            "facecolor": "facecolors",
2✔
754
            "edgecolor": "edgecolors",
2✔
755
            "linewidth": "linewidths",
2✔
756
            "linestyle": "linestyles",
2✔
757
            "hatch": "hatch",
2✔
758
        }
759
        styles = (
2✔
760
            pd.DataFrame(styles)
2✔
761
            .reindex(columns=style_columns.keys())
2✔
762
            .astype(
1✔
763
                {
2✔
764
                    "facecolor": "O",
2✔
765
                    "edgecolor": "O",
2✔
766
                    "linewidth": float,
2✔
767
                    "linestyle": "O",
2✔
768
                    "hatch": "O",
2✔
769
                }
770
            )
771
        )
772
        styles["linewidth"].fillna(1, inplace=True)
2✔
773
        styles["facecolor"].fillna(self._facecolor, inplace=True)
2✔
774
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
2✔
775
        styles["linestyle"].fillna("solid", inplace=True)
2✔
776
        del styles["hatch"]  # not supported in matrix (currently)
2✔
777

778
        x = np.repeat(np.arange(len(data)), n_cats)
2✔
779
        y = np.tile(np.arange(n_cats), len(data))
2✔
780

781
        # Plot dots
782
        if self._element_size is not None:  # noqa
2✔
783
            s = (self._element_size * 0.35) ** 2
2✔
784
        else:
785
            # TODO: make s relative to colw
786
            s = 200
2✔
787
        ax.scatter(
2✔
788
            *self._swapaxes(x, y),
2✔
789
            s=s,
2✔
790
            zorder=10,
2✔
791
            **styles.rename(columns=style_columns),
2✔
792
        )
793

794
        # Plot lines
795
        if self._with_lines:
2✔
796
            idx = np.flatnonzero(inclusion)
2✔
797
            line_data = (
2✔
798
                pd.Series(y[idx], index=x[idx])
2✔
799
                .groupby(level=0)
2✔
800
                .aggregate(["min", "max"])
2✔
801
            )
802
            colors = pd.Series(
2✔
803
                [
2✔
804
                    style.get("edgecolor", style.get("facecolor", self._facecolor))
2✔
805
                    for style in self.subset_styles
2✔
806
                ],
807
                name="color",
2✔
808
            )
809
            line_data = line_data.join(colors)
2✔
810
            ax.vlines(
2✔
811
                line_data.index.values,
2✔
812
                line_data["min"],
2✔
813
                line_data["max"],
2✔
814
                lw=2,
2✔
815
                colors=line_data["color"],
2✔
816
                zorder=5,
2✔
817
            )
818

819
        # Ticks and axes
820
        tick_axis = ax.yaxis
2✔
821
        tick_axis.set_ticks(np.arange(n_cats))
2✔
822
        tick_axis.set_ticklabels(
2✔
823
            data.index.names, rotation=0 if self._horizontal else -90
2✔
824
        )
825
        ax.xaxis.set_visible(False)
2✔
826
        ax.tick_params(axis="both", which="both", length=0)
2✔
827
        if not self._horizontal:
2✔
828
            ax.yaxis.set_ticks_position("top")
2✔
829
        ax.set_frame_on(False)
2✔
830
        ax.set_xlim(-0.5, x[-1] + 0.5, auto=False)
2✔
831
        ax.grid(False)
2✔
832

833
    def plot_intersections(self, ax):
2✔
834
        """Plot bars indicating intersection size"""
835
        rects = self._plot_bars(
836
            ax, self.intersections, title="Intersection size", colors=self._facecolor
837
        )
838
        for style, rect in zip(self.subset_styles, rects):
839
            style = style.copy()
840
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
841
            for attr, val in style.items():
842
                getattr(rect, "set_" + attr)(val)
843

844
        if self.subset_legend:
845
            styles, labels = zip(*self.subset_legend)
846
            styles = [patches.Patch(**patch_style) for patch_style in styles]
847
            ax.legend(styles, labels)
848

849
    def _label_sizes(self, ax, rects, where):
850
        if not self._show_counts and not self._show_percentages:
851
            return
852
        if self._show_counts is True:
853
            count_fmt = "{:.0f}"
854
        else:
855
            count_fmt = self._show_counts
856
            if "{" not in count_fmt:
857
                count_fmt = util.to_new_pos_format(count_fmt)
858

859
        pct_fmt = "{:.1%}" if self._show_percentages is True else self._show_percentages
860

861
        if count_fmt and pct_fmt:
862
            if where == "top":
863
                fmt = f"{count_fmt}\n({pct_fmt})"
864
            else:
865
                fmt = f"{count_fmt} ({pct_fmt})"
866

867
            def make_args(val):
868
                return val, val / self.total
869
        elif count_fmt:
870
            fmt = count_fmt
871

872
            def make_args(val):
873
                return (val,)
874
        else:
875
            fmt = pct_fmt
876

877
            def make_args(val):
878
                return (val / self.total,)
879

880
        if where == "right":
881
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
882
            for rect in rects:
883
                width = rect.get_width() + rect.get_x()
884
                ax.text(
885
                    width + margin,
886
                    rect.get_y() + rect.get_height() * 0.5,
887
                    fmt.format(*make_args(width)),
888
                    ha="left",
889
                    va="center",
890
                )
891
        elif where == "left":
892
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
893
            for rect in rects:
894
                width = rect.get_width() + rect.get_x()
895
                ax.text(
896
                    width + margin,
897
                    rect.get_y() + rect.get_height() * 0.5,
898
                    fmt.format(*make_args(width)),
899
                    ha="right",
900
                    va="center",
901
                )
902
        elif where == "top":
903
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
904
            for rect in rects:
905
                height = rect.get_height() + rect.get_y()
906
                ax.text(
907
                    rect.get_x() + rect.get_width() * 0.5,
908
                    height + margin,
909
                    fmt.format(*make_args(height)),
910
                    ha="center",
911
                    va="bottom",
912
                )
913
        else:
914
            raise NotImplementedError("unhandled where: %r" % where)
915

916
    def plot_totals(self, ax):
917
        """Plot bars indicating total set size"""
918
        orig_ax = ax
2✔
919
        ax = self._reorient(ax)
2✔
920
        rects = ax.barh(
2✔
921
            np.arange(len(self.totals.index.values)),
2✔
922
            self.totals,
2✔
923
            0.5,
2✔
924
            color=self._facecolor,
2✔
925
            align="center",
2✔
926
        )
927
        self._label_sizes(ax, rects, "left" if self._horizontal else "top")
2✔
928

929
        for category, rect in zip(self.totals.index.values, rects):
2✔
930
            style = {
2✔
931
                k[len("bar_") :]: v
932
                for k, v in self.category_styles.get(category, {}).items()
2✔
933
                if k.startswith("bar_")
934
            }
935
            style.setdefault("edgecolor", style.get("facecolor", self._facecolor))
2✔
936
            for attr, val in style.items():
2✔
937
                getattr(rect, "set_" + attr)(val)
2✔
938

939
        max_total = self.totals.max()
2✔
940
        if self._horizontal:
2✔
941
            orig_ax.set_xlim(max_total, 0)
2✔
942
        for x in ["top", "left", "right"]:
2✔
943
            ax.spines[self._reorient(x)].set_visible(False)
2✔
944
        ax.yaxis.set_visible(False)
2✔
945
        ax.xaxis.grid(True)
2✔
946
        ax.yaxis.grid(False)
2✔
947
        ax.patch.set_visible(False)
2✔
948

949
    def plot_shading(self, ax):
2✔
950
        # shade all rows, set every second row to zero visibility
951
        for i, category in enumerate(self.totals.index):
2✔
952
            default_shading = (
2✔
953
                self._shading_color if i % 2 == 0 else (0.0, 0.0, 0.0, 0.0)
2✔
954
            )
955
            shading_style = {
2✔
956
                k[len("shading_") :]: v
957
                for k, v in self.category_styles.get(category, {}).items()
2✔
958
                if k.startswith("shading_")
959
            }
960

961
            lw = shading_style.get(
2✔
962
                "linewidth", 1 if shading_style.get("edgecolor") else 0
2✔
963
            )
964
            lw_padding = lw / (self._default_figsize[0] * self.DPI)
2✔
965
            start_x = lw_padding
2✔
966
            end_x = 1 - lw_padding * 3
2✔
967

968
            rect = plt.Rectangle(
2✔
969
                self._swapaxes(start_x, i - 0.4),
2✔
970
                *self._swapaxes(end_x, 0.8),
2✔
971
                facecolor=shading_style.get("facecolor", default_shading),
2✔
972
                edgecolor=shading_style.get("edgecolor", None),
2✔
973
                ls=shading_style.get("linestyle", "-"),
2✔
974
                lw=lw,
2✔
975
                zorder=0,
2✔
976
            )
977

978
            ax.add_patch(rect)
2✔
979
        ax.set_frame_on(False)
2✔
980
        ax.tick_params(
2✔
981
            axis="both",
2✔
982
            which="both",
2✔
983
            left=False,
2✔
984
            right=False,
2✔
985
            bottom=False,
2✔
986
            top=False,
2✔
987
            labelbottom=False,
2✔
988
            labelleft=False,
2✔
989
        )
990
        ax.grid(False)
2✔
991
        ax.set_xticks([])
2✔
992
        ax.set_yticks([])
2✔
993
        ax.set_xticklabels([])
2✔
994
        ax.set_yticklabels([])
2✔
995

996
    def style_categories(
2✔
997
        self,
998
        category_names,
999
        *,
1000
        bar_facecolor=None,
2✔
1001
        bar_hatch=None,
2✔
1002
        bar_edgecolor=None,
2✔
1003
        bar_linewidth=None,
2✔
1004
        bar_linestyle=None,
2✔
1005
        shading_facecolor=None,
2✔
1006
        shading_edgecolor=None,
2✔
1007
        shading_linewidth=None,
2✔
1008
        shading_linestyle=None,
2✔
1009
    ):
1010
        """Updates the style of the categories.
1011

1012
        Select a category by name, and style either its total bar or its shading.
1013

1014
        Parameters
1015
        ----------
1016
        category_names : str or list[str] category names.
1017
            Axes names where the changed style is applied.
1018
        bar_facecolor : str or RGBA matplotlib color tuple, optional.
1019
            Override the default facecolor in the totals plot.
1020
        bar_hatch : str, optional
1021
            Set a hatch for the totals plot.
1022
        bar_edgecolor : str or matplotlib color, optional
1023
            Set the edgecolor for total bars.
1024
        bar_linewidth : int, optional
1025
            Line width in points for total bar edges.
1026
        bar_linestyle : str, optional
1027
            Line style for edges.
1028
        shading_facecolor : str or RGBA matplotlib color tuple, optional.
1029
            Override the default alternating shading for specified categories.
1030
        shading_edgecolor : str or matplotlib color, optional
1031
            Set the edgecolor for bars, dots, and the line between dots.
1032
        shading_linewidth : int, optional
1033
            Line width in points for edges.
1034
        shading_linestyle : str, optional
1035
            Line style for edges.
1036
        """
1037
        if isinstance(category_names, str):
1038
            category_names = [category_names]
1039
        style = {
1040
            "bar_facecolor": bar_facecolor,
1041
            "bar_hatch": bar_hatch,
1042
            "bar_edgecolor": bar_edgecolor,
1043
            "bar_linewidth": bar_linewidth,
1044
            "bar_linestyle": bar_linestyle,
1045
            "shading_facecolor": shading_facecolor,
1046
            "shading_edgecolor": shading_edgecolor,
1047
            "shading_linewidth": shading_linewidth,
1048
            "shading_linestyle": shading_linestyle,
1049
        }
1050
        style = {k: v for k, v in style.items() if v is not None}
1051
        for category_name in category_names:
1052
            self.category_styles.setdefault(category_name, {}).update(style)
1053

1054
    def plot(self, fig=None):
2✔
1055
        """Draw all parts of the plot onto fig or a new figure
1056

1057
        Parameters
1058
        ----------
1059
        fig : matplotlib.figure.Figure, optional
1060
            Defaults to a new figure.
1061

1062
        Returns
1063
        -------
1064
        subplots : dict of matplotlib.axes.Axes
1065
            Keys are 'matrix', 'intersections', 'totals', 'shading'
1066
        """
1067
        if fig is None:
2✔
1068
            fig = plt.figure(figsize=self._default_figsize)
2✔
1069
        specs = self.make_grid(fig)
2✔
1070
        shading_ax = fig.add_subplot(specs["shading"])
2✔
1071
        self.plot_shading(shading_ax)
2✔
1072
        matrix_ax = self._reorient(fig.add_subplot)(specs["matrix"], sharey=shading_ax)
2✔
1073
        self.plot_matrix(matrix_ax)
2✔
1074
        if specs["totals"] is None:
2✔
1075
            totals_ax = None
2✔
1076
        else:
1077
            totals_ax = self._reorient(fig.add_subplot)(
2✔
1078
                specs["totals"], sharey=matrix_ax
2✔
1079
            )
1080
            self.plot_totals(totals_ax)
2✔
1081
        out = {"matrix": matrix_ax, "shading": shading_ax, "totals": totals_ax}
2✔
1082

1083
        for plot in self._subset_plots:
2✔
1084
            ax = self._reorient(fig.add_subplot)(specs[plot["id"]], sharex=matrix_ax)
2✔
1085
            if plot["type"] == "default":
2✔
1086
                self.plot_intersections(ax)
2✔
1087
            elif plot["type"] in self.PLOT_TYPES:
2✔
1088
                kw = plot.copy()
2✔
1089
                del kw["type"]
2✔
1090
                del kw["elements"]
2✔
1091
                del kw["id"]
2✔
1092
                self.PLOT_TYPES[plot["type"]](self, ax, **kw)
2✔
1093
            else:
1094
                raise ValueError("Unknown subset plot type: %r" % plot["type"])
1095
            out[plot["id"]] = ax
2✔
1096
        return out
2✔
1097

1098
    PLOT_TYPES = {
2✔
1099
        "catplot": _plot_catplot,
2✔
1100
        "stacked_bars": _plot_stacked_bars,
2✔
1101
    }
1102

1103
    def _repr_html_(self):
2✔
1104
        fig = plt.figure(figsize=self._default_figsize)
1105
        self.plot(fig=fig)
1106
        return fig._repr_html_()
1107

1108

1109
def plot(data, fig=None, **kwargs):
2✔
1110
    """Make an UpSet plot of data on fig
1111

1112
    Parameters
1113
    ----------
1114
    data : pandas.Series or pandas.DataFrame
1115
        Values for each set to plot.
1116
        Should have multi-index where each level is binary,
1117
        corresponding to set membership.
1118
        If a DataFrame, `sum_over` must be a string or False.
1119
    fig : matplotlib.figure.Figure, optional
1120
        Defaults to a new figure.
1121
    kwargs
1122
        Other arguments for :class:`UpSet`
1123

1124
    Returns
1125
    -------
1126
    subplots : dict of matplotlib.axes.Axes
1127
        Keys are 'matrix', 'intersections', 'totals', 'shading'
1128
    """
1129
    return UpSet(data, **kwargs).plot(fig)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc