• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 498

pending completion
498

push

travis-ci-com

web-flow
fix tests for subsets (#201)

1119 of 1159 relevant lines covered (96.55%)

2.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.44
/upsetplot/plotting.py
1
from __future__ import print_function, division, absolute_import
3✔
2

3
try:
3✔
4
    import typing
3✔
5
except ImportError:
×
6
    import collections as typing
×
7

8
import numpy as np
3✔
9
import pandas as pd
3✔
10
import matplotlib
3✔
11
from matplotlib import pyplot as plt
3✔
12
from matplotlib import colors
3✔
13
from matplotlib import patches
3✔
14

15
from .reformat import query, _get_subset_mask
3✔
16
from . import util
3✔
17

18
# prevents ImportError on matplotlib versions >3.5.2
19
try:
3✔
20
    from matplotlib.tight_layout import get_renderer
3✔
21

22
    RENDERER_IMPORTED = True
3✔
23
except ImportError:
×
24
    RENDERER_IMPORTED = False
×
25

26

27
def _process_data(df, sort_by, sort_categories_by, subset_size,
3✔
28
                  sum_over, min_subset_size=None, max_subset_size=None,
29
                  min_degree=None, max_degree=None, reverse=False):
30
    results = query(df, sort_by=sort_by, sort_categories_by=sort_categories_by,
3✔
31
                    subset_size=subset_size, sum_over=sum_over,
32
                    min_subset_size=min_subset_size,
33
                    max_subset_size=max_subset_size,
34
                    min_degree=min_degree, max_degree=max_degree)
35

36
    df = results.data
3✔
37
    agg = results.subset_sizes
3✔
38
    totals = results.category_totals
3✔
39
    total = agg.sum()
3✔
40

41
    # add '_bin' to df indicating index in agg
42
    # XXX: ugly!
43
    def _pack_binary(X):
3✔
44
        X = pd.DataFrame(X)
3✔
45
        out = 0
3✔
46
        for i, (_, col) in enumerate(X.items()):
3✔
47
            out *= 2
3✔
48
            out += col
3✔
49
        return out
3✔
50

51
    df_packed = _pack_binary(df.index.to_frame())
3✔
52
    data_packed = _pack_binary(agg.index.to_frame())
3✔
53
    df['_bin'] = pd.Series(df_packed).map(
3✔
54
        pd.Series(np.arange(len(data_packed))[::-1 if reverse else 1],
55
                  index=data_packed))
56
    if reverse:
3✔
57
        agg = agg[::-1]
3✔
58

59
    return total, df, agg, totals
3✔
60

61

62
def _multiply_alpha(c, mult):
3✔
63
    r, g, b, a = colors.to_rgba(c)
3✔
64
    a *= mult
3✔
65
    return colors.to_hex((r, g, b, a), keep_alpha=True)
3✔
66

67

68
class _Transposed:
3✔
69
    """Wrap an object in order to transpose some plotting operations
70

71
    Attributes of obj will be mapped.
72
    Keyword arguments when calling obj will be mapped.
73

74
    The mapping is not recursive: callable attributes need to be _Transposed
75
    again.
76
    """
77

78
    def __init__(self, obj):
3✔
79
        self.__obj = obj
3✔
80

81
    def __getattr__(self, key):
3✔
82
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
3✔
83

84
    def __call__(self, *args, **kwargs):
3✔
85
        return self.__obj(*args, **{self._NAME_TRANSPOSE.get(k, k): v
3✔
86
                                    for k, v in kwargs.items()})
87

88
    _NAME_TRANSPOSE = {
3✔
89
        'width': 'height',
90
        'height': 'width',
91
        'hspace': 'wspace',
92
        'wspace': 'hspace',
93
        'hlines': 'vlines',
94
        'vlines': 'hlines',
95
        'bar': 'barh',
96
        'barh': 'bar',
97
        'xaxis': 'yaxis',
98
        'yaxis': 'xaxis',
99
        'left': 'bottom',
100
        'right': 'top',
101
        'top': 'right',
102
        'bottom': 'left',
103
        'sharex': 'sharey',
104
        'sharey': 'sharex',
105
        'get_figwidth': 'get_figheight',
106
        'get_figheight': 'get_figwidth',
107
        'set_figwidth': 'set_figheight',
108
        'set_figheight': 'set_figwidth',
109
        'set_xlabel': 'set_ylabel',
110
        'set_ylabel': 'set_xlabel',
111
        'set_xlim': 'set_ylim',
112
        'set_ylim': 'set_xlim',
113
        'get_xlim': 'get_ylim',
114
        'get_ylim': 'get_xlim',
115
        'set_autoscalex_on': 'set_autoscaley_on',
116
        'set_autoscaley_on': 'set_autoscalex_on',
117
    }
118

119

120
def _transpose(obj):
3✔
121
    if isinstance(obj, str):
3✔
122
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
3✔
123
    return _Transposed(obj)
3✔
124

125

126
def _identity(obj):
3✔
127
    return obj
3✔
128

129

130
class UpSet:
3✔
131
    """Manage the data and drawing for a basic UpSet plot
132

133
    Primary public method is :meth:`plot`.
134

135
    Parameters
136
    ----------
137
    data : pandas.Series or pandas.DataFrame
138
        Elements associated with categories (a DataFrame), or the size of each
139
        subset of categories (a Series).
140
        Should have MultiIndex where each level is binary,
141
        corresponding to category membership.
142
        If a DataFrame, `sum_over` must be a string or False.
143
    orientation : {'horizontal' (default), 'vertical'}
144
        If horizontal, intersections are listed from left to right.
145
    sort_by : {'cardinality', 'degree', None}
146
        If 'cardinality', subset are listed from largest to smallest.
147
        If 'degree', they are listed in order of the number of categories
148
        intersected. If None, the order they appear in the data input is
149
        used.
150

151
        .. versionchanged:: 0.5
152
            Setting None was added.
153
    sort_categories_by : {'cardinality', None}
154
        Whether to sort the categories by total cardinality, or leave them
155
        in the provided order.
156

157
        .. versionadded:: 0.3
158
    subset_size : {'auto', 'count', 'sum'}
159
        Configures how to calculate the size of a subset. Choices are:
160

161
        'auto' (default)
162
            If `data` is a DataFrame, count the number of rows in each group,
163
            unless `sum_over` is specified.
164
            If `data` is a Series with at most one row for each group, use
165
            the value of the Series. If `data` is a Series with more than one
166
            row per group, raise a ValueError.
167
        'count'
168
            Count the number of rows in each group.
169
        'sum'
170
            Sum the value of the `data` Series, or the DataFrame field
171
            specified by `sum_over`.
172
    sum_over : str or None
173
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
174
        sum of the specified field in the `data` DataFrame. If a Series, only
175
        None is supported and its value is summed.
176
    min_subset_size : int, optional
177
        Minimum size of a subset to be shown in the plot. All subsets with
178
        a size smaller than this threshold will be omitted from plotting.
179
        Size may be a sum of values, see `subset_size`.
180

181
        .. versionadded:: 0.5
182
    max_subset_size : int, optional
183
        Maximum size of a subset to be shown in the plot. All subsets with
184
        a size greater than this threshold will be omitted from plotting.
185

186
        .. versionadded:: 0.5
187
    min_degree : int, optional
188
        Minimum degree of a subset to be shown in the plot.
189

190
        .. versionadded:: 0.5
191
    max_degree : int, optional
192
        Maximum degree of a subset to be shown in the plot.
193

194
        .. versionadded:: 0.5
195
    facecolor : 'auto' or matplotlib color or float
196
        Color for bar charts and active dots. Defaults to black if
197
        axes.facecolor is a light color, otherwise white.
198

199
        .. versionchanged:: 0.6
200
            Before 0.6, the default was 'black'
201
    other_dots_color : matplotlib color or float
202
        Color for shading of inactive dots, or opacity (between 0 and 1)
203
        applied to facecolor.
204

205
        .. versionadded:: 0.6
206
    shading_color : matplotlib color or float
207
        Color for shading of odd rows in matrix and totals, or opacity (between
208
        0 and 1) applied to facecolor.
209

210
        .. versionadded:: 0.6
211
    with_lines : bool
212
        Whether to show lines joining dots in the matrix, to mark multiple
213
        categories being intersected.
214
    element_size : float or None
215
        Side length in pt. If None, size is estimated to fit figure
216
    intersection_plot_elements : int
217
        The intersections plot should be large enough to fit this many matrix
218
        elements. Set to 0 to disable intersection size bars.
219

220
        .. versionchanged:: 0.4
221
            Setting to 0 is handled.
222
    totals_plot_elements : int
223
        The totals plot should be large enough to fit this many matrix
224
        elements.
225
    show_counts : bool or str, default=False
226
        Whether to label the intersection size bars with the cardinality
227
        of the intersection. When a string, this formats the number.
228
        For example, '{:d}' is equivalent to True.
229
        Note that, for legacy reasons, if the string does not contain '{',
230
        it will be interpreted as a C-style format string, such as '%d'.
231
    show_percentages : bool or str, default=False
232
        Whether to label the intersection size bars with the percentage
233
        of the intersection relative to the total dataset.
234
        When a string, this formats the number representing a fraction of
235
        samples.
236
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
237
        This may be applied with or without show_counts.
238

239
        .. versionadded:: 0.4
240
    """
241
    _default_figsize = (10, 6)
3✔
242

243
    def __init__(self, data, orientation='horizontal', sort_by='degree',
3✔
244
                 sort_categories_by='cardinality',
245
                 subset_size='auto', sum_over=None,
246
                 min_subset_size=None, max_subset_size=None,
247
                 min_degree=None, max_degree=None,
248
                 facecolor='auto', other_dots_color=.18, shading_color=.05,
249
                 with_lines=True, element_size=32,
250
                 intersection_plot_elements=6, totals_plot_elements=2,
251
                 show_counts='', show_percentages=False):
252

253
        self._horizontal = orientation == 'horizontal'
3✔
254
        self._reorient = _identity if self._horizontal else _transpose
3✔
255
        if facecolor == 'auto':
3✔
256
            bgcolor = matplotlib.rcParams.get('axes.facecolor', 'white')
3✔
257
            r, g, b, a = colors.to_rgba(bgcolor)
3✔
258
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
3✔
259
            facecolor = 'black' if lightness >= .5 else 'white'
3✔
260
        self._facecolor = facecolor
3✔
261
        self._shading_color = (_multiply_alpha(facecolor, shading_color)
3✔
262
                               if isinstance(shading_color, float)
263
                               else shading_color)
264
        self._other_dots_color = (_multiply_alpha(facecolor, other_dots_color)
3✔
265
                                  if isinstance(other_dots_color, float)
266
                                  else other_dots_color)
267
        self._with_lines = with_lines
3✔
268
        self._element_size = element_size
3✔
269
        self._totals_plot_elements = totals_plot_elements
3✔
270
        self._subset_plots = [{'type': 'default',
3✔
271
                               'id': 'intersections',
272
                               'elements': intersection_plot_elements}]
273
        if not intersection_plot_elements:
3✔
274
            self._subset_plots.pop()
3✔
275
        self._show_counts = show_counts
3✔
276
        self._show_percentages = show_percentages
3✔
277

278
        (self.total, self._df, self.intersections,
3✔
279
         self.totals) = _process_data(data,
280
                                      sort_by=sort_by,
281
                                      sort_categories_by=sort_categories_by,
282
                                      subset_size=subset_size,
283
                                      sum_over=sum_over,
284
                                      min_subset_size=min_subset_size,
285
                                      max_subset_size=max_subset_size,
286
                                      min_degree=min_degree,
287
                                      max_degree=max_degree,
288
                                      reverse=not self._horizontal)
289
        self.subset_styles = [{"facecolor": facecolor}
3✔
290
                              for i in range(len(self.intersections))]
291
        self.subset_legend = []  # pairs of (style, label)
3✔
292

293
    def _swapaxes(self, x, y):
3✔
294
        if self._horizontal:
3✔
295
            return x, y
3✔
296
        return y, x
3✔
297

298
    def style_subsets(self, present=None, absent=None,
3✔
299
                      min_subset_size=None, max_subset_size=None,
300
                      min_degree=None, max_degree=None,
301
                      facecolor=None, edgecolor=None, hatch=None,
302
                      linewidth=None, linestyle=None, label=None):
303
        """Updates the style of selected subsets' bars and matrix dots
304

305
        Parameters are either used to select subsets, or to style them with
306
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
307
        which adds a legend entry.
308

309
        Parameters
310
        ----------
311
        present : str or list of str, optional
312
            Category or categories that must be present in subsets for styling.
313
        absent : str or list of str, optional
314
            Category or categories that must not be present in subsets for
315
            styling.
316
        min_subset_size : int, optional
317
            Minimum size of a subset to be styled.
318
        max_subset_size : int, optional
319
            Maximum size of a subset to be styled.
320
        min_degree : int, optional
321
            Minimum degree of a subset to be styled.
322
        max_degree : int, optional
323
            Maximum degree of a subset to be styled.
324

325
        facecolor : str or matplotlib color, optional
326
            Override the default UpSet facecolor for selected subsets.
327
        edgecolor : str or matplotlib color, optional
328
            Set the edgecolor for bars, dots, and the line between dots.
329
        hatch : str, optional
330
            Set the hatch. This will apply to intersection size bars, but not
331
            to matrix dots.
332
        linewidth : int, optional
333
            Line width in points for edges.
334
        linestyle : str, optional
335
            Line style for edges.
336

337
        label : str, optional
338
            If provided, a legend will be added
339
        """
340
        style = {"facecolor": facecolor, "edgecolor": edgecolor,
3✔
341
                 "hatch": hatch,
342
                 "linewidth": linewidth, "linestyle": linestyle}
343
        style = {k: v for k, v in style.items() if v is not None}
3✔
344
        mask = _get_subset_mask(self.intersections,
3✔
345
                                present=present, absent=absent,
346
                                min_subset_size=min_subset_size,
347
                                max_subset_size=max_subset_size,
348
                                min_degree=min_degree, max_degree=max_degree)
349
        for idx in np.flatnonzero(mask):
3✔
350
            self.subset_styles[idx].update(style)
3✔
351

352
        if label is not None:
3✔
353
            if "facecolor" not in style:
3✔
354
                style["facecolor"] = self._facecolor
1✔
355
            for i, (other_style, other_label) in enumerate(self.subset_legend):
3✔
356
                if other_style == style:
3✔
357
                    if other_label != label:
3✔
358
                        self.subset_legend[i] = (style,
3✔
359
                                                 other_label + '; ' + label)
360
                    break
3✔
361
            else:
362
                self.subset_legend.append((style, label))
3✔
363

364
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
3✔
365
        ax = self._reorient(ax)
3✔
366
        ax.set_autoscalex_on(False)
3✔
367
        data_df = pd.DataFrame(data)
3✔
368
        if self._horizontal:
3✔
369
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
3✔
370

371
        # TODO: colors should be broadcastable to data_df shape
372
        if callable(colors):
3✔
373
            colors = colors(range(data_df.shape[1]))
3✔
374
        elif isinstance(colors, (str, type(None))):
3✔
375
            colors = [colors] * len(data_df)
3✔
376

377
        if self._horizontal:
3✔
378
            colors = reversed(colors)
3✔
379

380
        x = np.arange(len(data_df))
3✔
381
        cum_y = None
3✔
382
        all_rects = []
3✔
383
        for (name, y), color in zip(data_df.items(), colors):
3✔
384
            rects = ax.bar(x, y, .5, cum_y,
3✔
385
                           color=color, zorder=10,
386
                           label=name if use_labels else None,
387
                           align='center')
388
            cum_y = y if cum_y is None else cum_y + y
3✔
389
            all_rects.extend(rects)
3✔
390

391
        self._label_sizes(ax, rects, 'top' if self._horizontal else 'right')
3✔
392

393
        ax.xaxis.set_visible(False)
3✔
394
        for x in ['top', 'bottom', 'right']:
3✔
395
            ax.spines[self._reorient(x)].set_visible(False)
3✔
396

397
        tick_axis = ax.yaxis
3✔
398
        tick_axis.grid(True)
3✔
399
        ax.set_ylabel(title)
3✔
400
        return all_rects
3✔
401

402
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
3✔
403
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
3✔
404
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
3✔
405
        if sum_over is None and "_value" in df.columns:
3✔
406
            data = gb["_value"].sum()
×
407
        elif sum_over is None:
3✔
408
            data = gb.size()
3✔
409
        else:
410
            data = gb[sum_over].sum()
3✔
411
        data = data.unstack(by).fillna(0)
3✔
412
        if isinstance(colors, str):
3✔
413
            colors = matplotlib.cm.get_cmap(colors)
3✔
414
        elif isinstance(colors, typing.Mapping):
3✔
415
            colors = data.columns.map(colors).values
3✔
416
            if pd.isna(colors).any():
3✔
417
                raise KeyError("Some labels mapped by colors: %r" %
×
418
                               data.columns[pd.isna(colors)].tolist())
419

420
        self._plot_bars(ax, data=data, colors=colors, title=title,
3✔
421
                        use_labels=True)
422

423
        handles, labels = ax.get_legend_handles_labels()
3✔
424
        if self._horizontal:
3✔
425
            # Make legend order match visual stack order
426
            ax.legend(reversed(handles), reversed(labels))
3✔
427
        else:
428
            ax.legend()
3✔
429

430
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3,
3✔
431
                         title=None):
432
        """Add a stacked bar chart over subsets when :func:`plot` is called.
433

434
        Used to plot categorical variable distributions within each subset.
435

436
        .. versionadded:: 0.6
437

438
        Parameters
439
        ----------
440
        by : str
441
            Column name within the dataframe for color coding the stacked bars,
442
            containing discrete or categorical values.
443
        sum_over : str, optional
444
            Ordinarily the bars will chart the size of each group. sum_over
445
            may specify a column which will be summed to determine the size
446
            of each bar.
447
        colors : Mapping, list-like, str or callable, optional
448
            The facecolors to use for bars corresponding to each discrete
449
            label, specified as one of:
450

451
            Mapping
452
                Maps from label to matplotlib-compatible color specification.
453
            list-like
454
                A list of matplotlib colors to apply to labels in order.
455
            str
456
                The name of a matplotlib colormap name.
457
            callable
458
                When called with the number of labels, this should return a
459
                list-like of that many colors.  Matplotlib colormaps satisfy
460
                this callable API.
461
            None
462
                Uses the matplotlib default colormap.
463
        elements : int, default=3
464
            Size of the axes counted in number of matrix elements.
465
        title : str, optional
466
            The axis title labelling bar length.
467

468
        Returns
469
        -------
470
        None
471
        """
472
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
473
        #                        list of labels}
474
        self._subset_plots.append({'type': 'stacked_bars',
3✔
475
                                   'by': by,
476
                                   'sum_over': sum_over,
477
                                   'colors': colors,
478
                                   'title': title,
479
                                   'id': 'extra%d' % len(self._subset_plots),
480
                                   'elements': elements})
481

482
    def add_catplot(self, kind, value=None, elements=3, **kw):
3✔
483
        """Add a seaborn catplot over subsets when :func:`plot` is called.
484

485
        Parameters
486
        ----------
487
        kind : str
488
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
489
        value : str, optional
490
            Column name for the value to plot (i.e. y if
491
            orientation='horizontal'), required if `data` is a DataFrame.
492
        elements : int, default=3
493
            Size of the axes counted in number of matrix elements.
494
        **kw : dict
495
            Additional keywords to pass to :func:`seaborn.catplot`.
496

497
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
498
            and 'orient', so these are prohibited keys in `kw`.
499

500
        Returns
501
        -------
502
        None
503
        """
504
        assert not set(kw.keys()) & {'ax', 'data', 'x', 'y', 'orient'}
2✔
505
        if value is None:
2✔
506
            if '_value' not in self._df.columns:
2✔
507
                raise ValueError('value cannot be set if data is a Series. '
2✔
508
                                 'Got %r' % value)
509
        else:
510
            if value not in self._df.columns:
2✔
511
                raise ValueError('value %r is not a column in data' % value)
2✔
512
        self._subset_plots.append({'type': 'catplot',
2✔
513
                                   'value': value,
514
                                   'kind': kind,
515
                                   'id': 'extra%d' % len(self._subset_plots),
516
                                   'elements': elements,
517
                                   'kw': kw})
518

519
    def _check_value(self, value):
3✔
520
        if value is None and '_value' in self._df.columns:
2✔
521
            value = '_value'
2✔
522
        elif value is None:
2✔
523
            raise ValueError('value can only be None when data is a Series')
×
524
        return value
2✔
525

526
    def _plot_catplot(self, ax, value, kind, kw):
3✔
527
        df = self._df
2✔
528
        value = self._check_value(value)
2✔
529
        kw = kw.copy()
2✔
530
        if self._horizontal:
2✔
531
            kw['orient'] = 'v'
2✔
532
            kw['x'] = '_bin'
2✔
533
            kw['y'] = value
2✔
534
        else:
535
            kw['orient'] = 'h'
1✔
536
            kw['x'] = value
1✔
537
            kw['y'] = '_bin'
1✔
538
        import seaborn
2✔
539
        kw['ax'] = ax
2✔
540
        getattr(seaborn, kind + 'plot')(data=df, **kw)
2✔
541

542
        ax = self._reorient(ax)
2✔
543
        if value == '_value':
2✔
544
            ax.set_ylabel('')
2✔
545

546
        ax.xaxis.set_visible(False)
2✔
547
        for x in ['top', 'bottom', 'right']:
2✔
548
            ax.spines[self._reorient(x)].set_visible(False)
2✔
549

550
        tick_axis = ax.yaxis
2✔
551
        tick_axis.grid(True)
2✔
552

553
    def make_grid(self, fig=None):
3✔
554
        """Get a SubplotSpec for each Axes, accounting for label text width
555
        """
556
        n_cats = len(self.totals)
3✔
557
        n_inters = len(self.intersections)
3✔
558

559
        if fig is None:
3✔
560
            fig = plt.gcf()
×
561

562
        # Determine text size to determine figure size / spacing
563
        text_kw = {"size": matplotlib.rcParams['xtick.labelsize']}
3✔
564
        # adding "x" ensures a margin
565
        t = fig.text(0, 0, '\n'.join(str(label) + "x"
3✔
566
                                     for label in self.totals.index.values),
567
                     **text_kw)
568
        window_extent_args = {}
3✔
569
        if RENDERER_IMPORTED:
3✔
570
            window_extent_args["renderer"] = get_renderer(fig)
3✔
571
        textw = t.get_window_extent(**window_extent_args).width
3✔
572
        t.remove()
3✔
573

574
        window_extent_args = {}
3✔
575
        if RENDERER_IMPORTED:
3✔
576
            window_extent_args["renderer"] = get_renderer(fig)
3✔
577
        figw = self._reorient(
3✔
578
            fig.get_window_extent(**window_extent_args)).width
579

580
        sizes = np.asarray([p['elements'] for p in self._subset_plots])
3✔
581
        fig = self._reorient(fig)
3✔
582

583
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
3✔
584
        if self._element_size is None:
3✔
585
            colw = (figw - textw) / non_text_nelems
3✔
586
        else:
587
            render_ratio = figw / fig.get_figwidth()
3✔
588
            colw = self._element_size / 72 * render_ratio
3✔
589
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
3✔
590
            fig.set_figwidth(figw / render_ratio)
3✔
591
            fig.set_figheight((colw * (n_cats + sizes.sum())) /
3✔
592
                              render_ratio)
593

594
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
3✔
595
        # print('textw', textw, 'figw', figw, 'colw', colw,
596
        #       'ncols', figw/colw, 'text_nelems', text_nelems)
597

598
        GS = self._reorient(matplotlib.gridspec.GridSpec)
3✔
599
        gridspec = GS(*self._swapaxes(n_cats + (sizes.sum() or 0),
3✔
600
                                      n_inters + text_nelems +
601
                                      self._totals_plot_elements),
602
                      hspace=1)
603
        if self._horizontal:
3✔
604
            out = {'matrix': gridspec[-n_cats:, -n_inters:],
3✔
605
                   'shading': gridspec[-n_cats:, :],
606
                   'totals': gridspec[-n_cats:, :self._totals_plot_elements],
607
                   'gs': gridspec}
608
            cumsizes = np.cumsum(sizes[::-1])
3✔
609
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
3✔
610
                                         self._subset_plots[::-1]):
611
                out[plot['id']] = gridspec[start:stop, -n_inters:]
3✔
612
        else:
613
            out = {'matrix': gridspec[-n_inters:, :n_cats],
3✔
614
                   'shading': gridspec[:, :n_cats],
615
                   'totals': gridspec[:self._totals_plot_elements, :n_cats],
616
                   'gs': gridspec}
617
            cumsizes = np.cumsum(sizes)
3✔
618
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
3✔
619
                                         self._subset_plots):
620
                out[plot['id']] = \
3✔
621
                    gridspec[-n_inters:, start + n_cats:stop + n_cats]
622
        return out
3✔
623

624
    def plot_matrix(self, ax):
3✔
625
        """Plot the matrix of intersection indicators onto ax
626
        """
627
        ax = self._reorient(ax)
3✔
628
        data = self.intersections
3✔
629
        n_cats = data.index.nlevels
3✔
630

631
        inclusion = data.index.to_frame().values
3✔
632

633
        # Prepare styling
634
        styles = [
3✔
635
            [
636
                self.subset_styles[i]
637
                if inclusion[i, j]
638
                else {"facecolor": self._other_dots_color, "linewidth": 0}
639
                for j in range(n_cats)
640
            ]
641
            for i in range(len(data))
642
        ]
643
        styles = sum(styles, [])  # flatten nested list
3✔
644
        style_columns = {"facecolor": "facecolors",
3✔
645
                         "edgecolor": "edgecolors",
646
                         "linewidth": "linewidths",
647
                         "linestyle": "linestyles",
648
                         "hatch": "hatch"}
649
        styles = pd.DataFrame(styles).reindex(columns=style_columns.keys())
3✔
650
        styles["linewidth"].fillna(1, inplace=True)
3✔
651
        styles["facecolor"].fillna(self._facecolor, inplace=True)
3✔
652
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
3✔
653
        styles["linestyle"].fillna("solid", inplace=True)
3✔
654
        del styles["hatch"]  # not supported in matrix (currently)
3✔
655

656
        x = np.repeat(np.arange(len(data)), n_cats)
3✔
657
        y = np.tile(np.arange(n_cats), len(data))
3✔
658

659
        # Plot dots
660
        if self._element_size is not None:
3✔
661
            s = (self._element_size * .35) ** 2
3✔
662
        else:
663
            # TODO: make s relative to colw
664
            s = 200
3✔
665
        ax.scatter(*self._swapaxes(x, y), s=s, zorder=10,
3✔
666
                   **styles.rename(columns=style_columns))
667

668
        # Plot lines
669
        if self._with_lines:
3✔
670
            idx = np.flatnonzero(inclusion)
3✔
671
            line_data = (pd.Series(y[idx], index=x[idx])
3✔
672
                         .groupby(level=0)
673
                         .aggregate(['min', 'max']))
674
            colors = pd.Series([
3✔
675
                style.get("edgecolor", style.get("facecolor", self._facecolor))
676
                for style in self.subset_styles],
677
                name="color")
678
            line_data = line_data.join(colors)
3✔
679
            ax.vlines(line_data.index.values,
3✔
680
                      line_data['min'], line_data['max'],
681
                      lw=2, colors=line_data["color"],
682
                      zorder=5)
683

684
        # Ticks and axes
685
        tick_axis = ax.yaxis
3✔
686
        tick_axis.set_ticks(np.arange(n_cats))
3✔
687
        tick_axis.set_ticklabels(data.index.names,
3✔
688
                                 rotation=0 if self._horizontal else -90)
689
        ax.xaxis.set_visible(False)
3✔
690
        ax.tick_params(axis='both', which='both', length=0)
3✔
691
        if not self._horizontal:
3✔
692
            ax.yaxis.set_ticks_position('top')
3✔
693
        ax.set_frame_on(False)
3✔
694
        ax.set_xlim(-.5, x[-1] + .5, auto=False)
3✔
695
        ax.grid(False)
3✔
696

697
    def plot_intersections(self, ax):
3✔
698
        """Plot bars indicating intersection size
699
        """
700
        rects = self._plot_bars(ax, self.intersections,
3✔
701
                                title='Intersection size',
702
                                colors=self._facecolor)
703
        for style, rect in zip(self.subset_styles, rects):
3✔
704
            style = style.copy()
3✔
705
            style.setdefault("edgecolor",
3✔
706
                             style.get("facecolor", self._facecolor))
707
            for attr, val in style.items():
3✔
708
                getattr(rect, "set_" + attr)(val)
3✔
709

710
        if self.subset_legend:
3✔
711
            styles, labels = zip(*self.subset_legend)
1✔
712
            styles = [patches.Patch(**patch_style) for patch_style in styles]
1✔
713
            ax.legend(styles, labels)
1✔
714

715
    def _label_sizes(self, ax, rects, where):
3✔
716
        if not self._show_counts and not self._show_percentages:
3✔
717
            return
3✔
718
        if self._show_counts is True:
3✔
719
            count_fmt = "{:.0f}"
3✔
720
        else:
721
            count_fmt = self._show_counts
3✔
722
            if '{' not in count_fmt:
3✔
723
                count_fmt = util.to_new_pos_format(count_fmt)
3✔
724

725
        if self._show_percentages is True:
3✔
726
            pct_fmt = "{:.1%}"
3✔
727
        else:
728
            pct_fmt = self._show_percentages
3✔
729

730
        if count_fmt and pct_fmt:
3✔
731
            if where == 'top':
3✔
732
                fmt = '%s\n(%s)' % (count_fmt, pct_fmt)
3✔
733
            else:
734
                fmt = '%s (%s)' % (count_fmt, pct_fmt)
3✔
735

736
            def make_args(val):
3✔
737
                return val, val / self.total
3✔
738
        elif count_fmt:
3✔
739
            fmt = count_fmt
3✔
740

741
            def make_args(val):
3✔
742
                return val,
3✔
743
        else:
744
            fmt = pct_fmt
3✔
745

746
            def make_args(val):
3✔
747
                return val / self.total,
3✔
748

749
        if where == 'right':
3✔
750
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
3✔
751
            for rect in rects:
3✔
752
                width = rect.get_width() + rect.get_x()
3✔
753
                ax.text(width + margin,
3✔
754
                        rect.get_y() + rect.get_height() * .5,
755
                        fmt.format(*make_args(width)),
756
                        ha='left', va='center')
757
        elif where == 'left':
3✔
758
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
3✔
759
            for rect in rects:
3✔
760
                width = rect.get_width() + rect.get_x()
3✔
761
                ax.text(width + margin,
3✔
762
                        rect.get_y() + rect.get_height() * .5,
763
                        fmt.format(*make_args(width)),
764
                        ha='right', va='center')
765
        elif where == 'top':
3✔
766
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
3✔
767
            for rect in rects:
3✔
768
                height = rect.get_height() + rect.get_y()
3✔
769
                ax.text(rect.get_x() + rect.get_width() * .5,
3✔
770
                        height + margin,
771
                        fmt.format(*make_args(height)),
772
                        ha='center', va='bottom')
773
        else:
774
            raise NotImplementedError('unhandled where: %r' % where)
×
775

776
    def plot_totals(self, ax):
3✔
777
        """Plot bars indicating total set size
778
        """
779
        orig_ax = ax
3✔
780
        ax = self._reorient(ax)
3✔
781
        rects = ax.barh(np.arange(len(self.totals.index.values)), self.totals,
3✔
782
                        .5, color=self._facecolor, align='center')
783
        self._label_sizes(ax, rects, 'left' if self._horizontal else 'top')
3✔
784

785
        max_total = self.totals.max()
3✔
786
        if self._horizontal:
3✔
787
            orig_ax.set_xlim(max_total, 0)
3✔
788
        for x in ['top', 'left', 'right']:
3✔
789
            ax.spines[self._reorient(x)].set_visible(False)
3✔
790
        ax.yaxis.set_visible(False)
3✔
791
        ax.xaxis.grid(True)
3✔
792
        ax.yaxis.grid(False)
3✔
793
        ax.patch.set_visible(False)
3✔
794

795
    def plot_shading(self, ax):
3✔
796
        # alternating row shading (XXX: use add_patch(Rectangle)?)
797
        for i in range(0, len(self.totals), 2):
3✔
798
            rect = plt.Rectangle(self._swapaxes(0, i - .4),
3✔
799
                                 *self._swapaxes(*(1, .8)),
800
                                 facecolor=self._shading_color, lw=0, zorder=0)
801
            ax.add_patch(rect)
3✔
802
        ax.set_frame_on(False)
3✔
803
        ax.tick_params(
3✔
804
            axis='both',
805
            which='both',
806
            left=False,
807
            right=False,
808
            bottom=False,
809
            top=False,
810
            labelbottom=False,
811
            labelleft=False)
812
        ax.grid(False)
3✔
813
        ax.set_xticks([])
3✔
814
        ax.set_yticks([])
3✔
815
        ax.set_xticklabels([])
3✔
816
        ax.set_yticklabels([])
3✔
817

818
    def plot(self, fig=None):
3✔
819
        """Draw all parts of the plot onto fig or a new figure
820

821
        Parameters
822
        ----------
823
        fig : matplotlib.figure.Figure, optional
824
            Defaults to a new figure.
825

826
        Returns
827
        -------
828
        subplots : dict of matplotlib.axes.Axes
829
            Keys are 'matrix', 'intersections', 'totals', 'shading'
830
        """
831
        if fig is None:
3✔
832
            fig = plt.figure(figsize=self._default_figsize)
3✔
833
        specs = self.make_grid(fig)
3✔
834
        shading_ax = fig.add_subplot(specs['shading'])
3✔
835
        self.plot_shading(shading_ax)
3✔
836
        matrix_ax = self._reorient(fig.add_subplot)(specs['matrix'],
3✔
837
                                                    sharey=shading_ax)
838
        self.plot_matrix(matrix_ax)
3✔
839
        totals_ax = self._reorient(fig.add_subplot)(specs['totals'],
3✔
840
                                                    sharey=matrix_ax)
841
        self.plot_totals(totals_ax)
3✔
842
        out = {'matrix': matrix_ax,
3✔
843
               'shading': shading_ax,
844
               'totals': totals_ax}
845

846
        for plot in self._subset_plots:
3✔
847
            ax = self._reorient(fig.add_subplot)(specs[plot['id']],
3✔
848
                                                 sharex=matrix_ax)
849
            if plot['type'] == 'default':
3✔
850
                self.plot_intersections(ax)
3✔
851
            elif plot['type'] in self.PLOT_TYPES:
3✔
852
                kw = plot.copy()
3✔
853
                del kw['type']
3✔
854
                del kw['elements']
3✔
855
                del kw['id']
3✔
856
                self.PLOT_TYPES[plot['type']](self, ax, **kw)
3✔
857
            else:
858
                raise ValueError('Unknown subset plot type: %r' % plot['type'])
×
859
            out[plot['id']] = ax
3✔
860
        return out
3✔
861

862
    PLOT_TYPES = {
3✔
863
        'catplot': _plot_catplot,
864
        'stacked_bars': _plot_stacked_bars,
865
    }
866

867
    def _repr_html_(self):
3✔
868
        fig = plt.figure(figsize=self._default_figsize)
×
869
        self.plot(fig=fig)
×
870
        return fig._repr_html_()
×
871

872

873
def plot(data, fig=None, **kwargs):
3✔
874
    """Make an UpSet plot of data on fig
875

876
    Parameters
877
    ----------
878
    data : pandas.Series or pandas.DataFrame
879
        Values for each set to plot.
880
        Should have multi-index where each level is binary,
881
        corresponding to set membership.
882
        If a DataFrame, `sum_over` must be a string or False.
883
    fig : matplotlib.figure.Figure, optional
884
        Defaults to a new figure.
885
    kwargs
886
        Other arguments for :class:`UpSet`
887

888
    Returns
889
    -------
890
    subplots : dict of matplotlib.axes.Axes
891
        Keys are 'matrix', 'intersections', 'totals', 'shading'
892
    """
893
    return UpSet(data, **kwargs).plot(fig)
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc