• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 492

pending completion
492

push

travis-ci-com

web-flow
use diabetes dataset instead of boston (#200)

1119 of 1139 relevant lines covered (98.24%)

2.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.41
/upsetplot/plotting.py
1
from __future__ import print_function, division, absolute_import
3✔
2

3
try:
3✔
4
    import typing
3✔
5
except ImportError:
×
6
    import collections as typing
×
7

8
import numpy as np
3✔
9
import pandas as pd
3✔
10
import matplotlib
3✔
11
from matplotlib import pyplot as plt
3✔
12
from matplotlib import colors
3✔
13
from matplotlib import patches
3✔
14

15
from .reformat import query, _get_subset_mask
3✔
16

17
# prevents ImportError on matplotlib versions >3.5.2
18
try:
3✔
19
    from matplotlib.tight_layout import get_renderer
3✔
20

21
    RENDERER_IMPORTED = True
3✔
22
except ImportError:
×
23
    RENDERER_IMPORTED = False
×
24

25

26
def _process_data(df, sort_by, sort_categories_by, subset_size,
3✔
27
                  sum_over, min_subset_size=None, max_subset_size=None,
28
                  min_degree=None, max_degree=None, reverse=False):
29
    results = query(df, sort_by=sort_by, sort_categories_by=sort_categories_by,
3✔
30
                    subset_size=subset_size, sum_over=sum_over,
31
                    min_subset_size=min_subset_size,
32
                    max_subset_size=max_subset_size,
33
                    min_degree=min_degree, max_degree=max_degree)
34

35
    df = results.data
3✔
36
    agg = results.subset_sizes
3✔
37
    totals = results.category_totals
3✔
38
    total = agg.sum()
3✔
39

40
    # add '_bin' to df indicating index in agg
41
    # XXX: ugly!
42
    def _pack_binary(X):
3✔
43
        X = pd.DataFrame(X)
3✔
44
        out = 0
3✔
45
        for i, (_, col) in enumerate(X.items()):
3✔
46
            out *= 2
3✔
47
            out += col
3✔
48
        return out
3✔
49

50
    df_packed = _pack_binary(df.index.to_frame())
3✔
51
    data_packed = _pack_binary(agg.index.to_frame())
3✔
52
    df['_bin'] = pd.Series(df_packed).map(
3✔
53
        pd.Series(np.arange(len(data_packed))[::-1 if reverse else 1],
54
                  index=data_packed))
55
    if reverse:
3✔
56
        agg = agg[::-1]
3✔
57

58
    return total, df, agg, totals
3✔
59

60

61
def _multiply_alpha(c, mult):
3✔
62
    r, g, b, a = colors.to_rgba(c)
3✔
63
    a *= mult
3✔
64
    return colors.to_hex((r, g, b, a), keep_alpha=True)
3✔
65

66

67
class _Transposed:
3✔
68
    """Wrap an object in order to transpose some plotting operations
69

70
    Attributes of obj will be mapped.
71
    Keyword arguments when calling obj will be mapped.
72

73
    The mapping is not recursive: callable attributes need to be _Transposed
74
    again.
75
    """
76

77
    def __init__(self, obj):
3✔
78
        self.__obj = obj
3✔
79

80
    def __getattr__(self, key):
3✔
81
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
3✔
82

83
    def __call__(self, *args, **kwargs):
3✔
84
        return self.__obj(*args, **{self._NAME_TRANSPOSE.get(k, k): v
3✔
85
                                    for k, v in kwargs.items()})
86

87
    _NAME_TRANSPOSE = {
3✔
88
        'width': 'height',
89
        'height': 'width',
90
        'hspace': 'wspace',
91
        'wspace': 'hspace',
92
        'hlines': 'vlines',
93
        'vlines': 'hlines',
94
        'bar': 'barh',
95
        'barh': 'bar',
96
        'xaxis': 'yaxis',
97
        'yaxis': 'xaxis',
98
        'left': 'bottom',
99
        'right': 'top',
100
        'top': 'right',
101
        'bottom': 'left',
102
        'sharex': 'sharey',
103
        'sharey': 'sharex',
104
        'get_figwidth': 'get_figheight',
105
        'get_figheight': 'get_figwidth',
106
        'set_figwidth': 'set_figheight',
107
        'set_figheight': 'set_figwidth',
108
        'set_xlabel': 'set_ylabel',
109
        'set_ylabel': 'set_xlabel',
110
        'set_xlim': 'set_ylim',
111
        'set_ylim': 'set_xlim',
112
        'get_xlim': 'get_ylim',
113
        'get_ylim': 'get_xlim',
114
        'set_autoscalex_on': 'set_autoscaley_on',
115
        'set_autoscaley_on': 'set_autoscalex_on',
116
    }
117

118

119
def _transpose(obj):
3✔
120
    if isinstance(obj, str):
3✔
121
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
3✔
122
    return _Transposed(obj)
3✔
123

124

125
def _identity(obj):
3✔
126
    return obj
3✔
127

128

129
class UpSet:
3✔
130
    """Manage the data and drawing for a basic UpSet plot
131

132
    Primary public method is :meth:`plot`.
133

134
    Parameters
135
    ----------
136
    data : pandas.Series or pandas.DataFrame
137
        Elements associated with categories (a DataFrame), or the size of each
138
        subset of categories (a Series).
139
        Should have MultiIndex where each level is binary,
140
        corresponding to category membership.
141
        If a DataFrame, `sum_over` must be a string or False.
142
    orientation : {'horizontal' (default), 'vertical'}
143
        If horizontal, intersections are listed from left to right.
144
    sort_by : {'cardinality', 'degree', None}
145
        If 'cardinality', subset are listed from largest to smallest.
146
        If 'degree', they are listed in order of the number of categories
147
        intersected. If None, the order they appear in the data input is
148
        used.
149

150
        .. versionchanged:: 0.5
151
            Setting None was added.
152
    sort_categories_by : {'cardinality', None}
153
        Whether to sort the categories by total cardinality, or leave them
154
        in the provided order.
155

156
        .. versionadded:: 0.3
157
    subset_size : {'auto', 'count', 'sum'}
158
        Configures how to calculate the size of a subset. Choices are:
159

160
        'auto' (default)
161
            If `data` is a DataFrame, count the number of rows in each group,
162
            unless `sum_over` is specified.
163
            If `data` is a Series with at most one row for each group, use
164
            the value of the Series. If `data` is a Series with more than one
165
            row per group, raise a ValueError.
166
        'count'
167
            Count the number of rows in each group.
168
        'sum'
169
            Sum the value of the `data` Series, or the DataFrame field
170
            specified by `sum_over`.
171
    sum_over : str or None
172
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
173
        sum of the specified field in the `data` DataFrame. If a Series, only
174
        None is supported and its value is summed.
175
    min_subset_size : int, optional
176
        Minimum size of a subset to be shown in the plot. All subsets with
177
        a size smaller than this threshold will be omitted from plotting.
178
        Size may be a sum of values, see `subset_size`.
179

180
        .. versionadded:: 0.5
181
    max_subset_size : int, optional
182
        Maximum size of a subset to be shown in the plot. All subsets with
183
        a size greater than this threshold will be omitted from plotting.
184

185
        .. versionadded:: 0.5
186
    min_degree : int, optional
187
        Minimum degree of a subset to be shown in the plot.
188

189
        .. versionadded:: 0.5
190
    max_degree : int, optional
191
        Maximum degree of a subset to be shown in the plot.
192

193
        .. versionadded:: 0.5
194
    facecolor : 'auto' or matplotlib color or float
195
        Color for bar charts and active dots. Defaults to black if
196
        axes.facecolor is a light color, otherwise white.
197

198
        .. versionchanged:: 0.6
199
            Before 0.6, the default was 'black'
200
    other_dots_color : matplotlib color or float
201
        Color for shading of inactive dots, or opacity (between 0 and 1)
202
        applied to facecolor.
203

204
        .. versionadded:: 0.6
205
    shading_color : matplotlib color or float
206
        Color for shading of odd rows in matrix and totals, or opacity (between
207
        0 and 1) applied to facecolor.
208

209
        .. versionadded:: 0.6
210
    with_lines : bool
211
        Whether to show lines joining dots in the matrix, to mark multiple
212
        categories being intersected.
213
    element_size : float or None
214
        Side length in pt. If None, size is estimated to fit figure
215
    intersection_plot_elements : int
216
        The intersections plot should be large enough to fit this many matrix
217
        elements. Set to 0 to disable intersection size bars.
218

219
        .. versionchanged:: 0.4
220
            Setting to 0 is handled.
221
    totals_plot_elements : int
222
        The totals plot should be large enough to fit this many matrix
223
        elements.
224
    show_counts : bool or str, default=False
225
        Whether to label the intersection size bars with the cardinality
226
        of the intersection. When a string, this formats the number.
227
        For example, '%d' is equivalent to True.
228
    show_percentages : bool, default=False
229
        Whether to label the intersection size bars with the percentage
230
        of the intersection relative to the total dataset.
231
        This may be applied with or without show_counts.
232

233
        .. versionadded:: 0.4
234
    """
235
    _default_figsize = (10, 6)
3✔
236

237
    def __init__(self, data, orientation='horizontal', sort_by='degree',
3✔
238
                 sort_categories_by='cardinality',
239
                 subset_size='auto', sum_over=None,
240
                 min_subset_size=None, max_subset_size=None,
241
                 min_degree=None, max_degree=None,
242
                 facecolor='auto', other_dots_color=.18, shading_color=.05,
243
                 with_lines=True, element_size=32,
244
                 intersection_plot_elements=6, totals_plot_elements=2,
245
                 show_counts='', show_percentages=False):
246

247
        self._horizontal = orientation == 'horizontal'
3✔
248
        self._reorient = _identity if self._horizontal else _transpose
3✔
249
        if facecolor == 'auto':
3✔
250
            bgcolor = matplotlib.rcParams.get('axes.facecolor', 'white')
3✔
251
            r, g, b, a = colors.to_rgba(bgcolor)
3✔
252
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
3✔
253
            facecolor = 'black' if lightness >= .5 else 'white'
3✔
254
        self._facecolor = facecolor
3✔
255
        self._shading_color = (_multiply_alpha(facecolor, shading_color)
3✔
256
                               if isinstance(shading_color, float)
257
                               else shading_color)
258
        self._other_dots_color = (_multiply_alpha(facecolor, other_dots_color)
3✔
259
                                  if isinstance(other_dots_color, float)
260
                                  else other_dots_color)
261
        self._with_lines = with_lines
3✔
262
        self._element_size = element_size
3✔
263
        self._totals_plot_elements = totals_plot_elements
3✔
264
        self._subset_plots = [{'type': 'default',
3✔
265
                               'id': 'intersections',
266
                               'elements': intersection_plot_elements}]
267
        if not intersection_plot_elements:
3✔
268
            self._subset_plots.pop()
3✔
269
        self._show_counts = show_counts
3✔
270
        self._show_percentages = show_percentages
3✔
271

272
        (self.total, self._df, self.intersections,
3✔
273
         self.totals) = _process_data(data,
274
                                      sort_by=sort_by,
275
                                      sort_categories_by=sort_categories_by,
276
                                      subset_size=subset_size,
277
                                      sum_over=sum_over,
278
                                      min_subset_size=min_subset_size,
279
                                      max_subset_size=max_subset_size,
280
                                      min_degree=min_degree,
281
                                      max_degree=max_degree,
282
                                      reverse=not self._horizontal)
283
        self.subset_styles = [{"facecolor": facecolor}
3✔
284
                              for i in range(len(self.intersections))]
285
        self.subset_legend = []  # pairs of (style, label)
3✔
286

287
    def _swapaxes(self, x, y):
3✔
288
        if self._horizontal:
3✔
289
            return x, y
3✔
290
        return y, x
3✔
291

292
    def style_subsets(self, present=None, absent=None,
3✔
293
                      min_subset_size=None, max_subset_size=None,
294
                      min_degree=None, max_degree=None,
295
                      facecolor=None, edgecolor=None, hatch=None,
296
                      linewidth=None, linestyle=None, label=None):
297
        """Updates the style of selected subsets' bars and matrix dots
298

299
        Parameters are either used to select subsets, or to style them with
300
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
301
        which adds a legend entry.
302

303
        Parameters
304
        ----------
305
        present : str or list of str, optional
306
            Category or categories that must be present in subsets for styling.
307
        absent : str or list of str, optional
308
            Category or categories that must not be present in subsets for
309
            styling.
310
        min_subset_size : int, optional
311
            Minimum size of a subset to be styled.
312
        max_subset_size : int, optional
313
            Maximum size of a subset to be styled.
314
        min_degree : int, optional
315
            Minimum degree of a subset to be styled.
316
        max_degree : int, optional
317
            Maximum degree of a subset to be styled.
318

319
        facecolor : str or matplotlib color, optional
320
            Override the default UpSet facecolor for selected subsets.
321
        edgecolor : str or matplotlib color, optional
322
            Set the edgecolor for bars, dots, and the line between dots.
323
        hatch : str, optional
324
            Set the hatch. This will apply to intersection size bars, but not
325
            to matrix dots.
326
        linewidth : int, optional
327
            Line width in points for edges.
328
        linestyle : str, optional
329
            Line style for edges.
330

331
        label : str, optional
332
            If provided, a legend will be added
333
        """
334
        style = {"facecolor": facecolor, "edgecolor": edgecolor,
3✔
335
                 "hatch": hatch,
336
                 "linewidth": linewidth, "linestyle": linestyle}
337
        style = {k: v for k, v in style.items() if v is not None}
3✔
338
        mask = _get_subset_mask(self.intersections,
3✔
339
                                present=present, absent=absent,
340
                                min_subset_size=min_subset_size,
341
                                max_subset_size=max_subset_size,
342
                                min_degree=min_degree, max_degree=max_degree)
343
        for idx in np.flatnonzero(mask):
3✔
344
            self.subset_styles[idx].update(style)
3✔
345

346
        if label is not None:
3✔
347
            if "facecolor" not in style:
3✔
348
                style["facecolor"] = self._facecolor
1✔
349
            for i, (other_style, other_label) in enumerate(self.subset_legend):
3✔
350
                if other_style == style:
3✔
351
                    if other_label != label:
3✔
352
                        self.subset_legend[i] = (style,
3✔
353
                                                 other_label + '; ' + label)
354
                    break
3✔
355
            else:
356
                self.subset_legend.append((style, label))
3✔
357

358
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
3✔
359
        ax = self._reorient(ax)
3✔
360
        ax.set_autoscalex_on(False)
3✔
361
        data_df = pd.DataFrame(data)
3✔
362
        if self._horizontal:
3✔
363
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
3✔
364

365
        # TODO: colors should be broadcastable to data_df shape
366
        if callable(colors):
3✔
367
            colors = colors(range(data_df.shape[1]))
3✔
368
        elif isinstance(colors, (str, type(None))):
3✔
369
            colors = [colors] * len(data_df)
3✔
370

371
        if self._horizontal:
3✔
372
            colors = reversed(colors)
3✔
373

374
        x = np.arange(len(data_df))
3✔
375
        cum_y = None
3✔
376
        all_rects = []
3✔
377
        for (name, y), color in zip(data_df.items(), colors):
3✔
378
            rects = ax.bar(x, y, .5, cum_y,
3✔
379
                           color=color, zorder=10,
380
                           label=name if use_labels else None,
381
                           align='center')
382
            cum_y = y if cum_y is None else cum_y + y
3✔
383
            all_rects.extend(rects)
3✔
384

385
        self._label_sizes(ax, rects, 'top' if self._horizontal else 'right')
3✔
386

387
        ax.xaxis.set_visible(False)
3✔
388
        for x in ['top', 'bottom', 'right']:
3✔
389
            ax.spines[self._reorient(x)].set_visible(False)
3✔
390

391
        tick_axis = ax.yaxis
3✔
392
        tick_axis.grid(True)
3✔
393
        ax.set_ylabel(title)
3✔
394
        return all_rects
3✔
395

396
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
3✔
397
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
3✔
398
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
3✔
399
        if sum_over is None and "_value" in df.columns:
3✔
400
            data = gb["_value"].sum()
×
401
        elif sum_over is None:
3✔
402
            data = gb.size()
3✔
403
        else:
404
            data = gb[sum_over].sum()
3✔
405
        data = data.unstack(by).fillna(0)
3✔
406
        if isinstance(colors, str):
3✔
407
            colors = matplotlib.cm.get_cmap(colors)
3✔
408
        elif isinstance(colors, typing.Mapping):
3✔
409
            colors = data.columns.map(colors).values
3✔
410
            if pd.isna(colors).any():
3✔
411
                raise KeyError("Some labels mapped by colors: %r" %
×
412
                               data.columns[pd.isna(colors)].tolist())
413

414
        self._plot_bars(ax, data=data, colors=colors, title=title,
3✔
415
                        use_labels=True)
416

417
        handles, labels = ax.get_legend_handles_labels()
3✔
418
        if self._horizontal:
3✔
419
            # Make legend order match visual stack order
420
            ax.legend(reversed(handles), reversed(labels))
3✔
421
        else:
422
            ax.legend()
3✔
423

424
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3,
3✔
425
                         title=None):
426
        """Add a stacked bar chart over subsets when :func:`plot` is called.
427

428
        Used to plot categorical variable distributions within each subset.
429

430
        .. versionadded:: 0.6
431

432
        Parameters
433
        ----------
434
        by : str
435
            Column name within the dataframe for color coding the stacked bars,
436
            containing discrete or categorical values.
437
        sum_over : str, optional
438
            Ordinarily the bars will chart the size of each group. sum_over
439
            may specify a column which will be summed to determine the size
440
            of each bar.
441
        colors : Mapping, list-like, str or callable, optional
442
            The facecolors to use for bars corresponding to each discrete
443
            label, specified as one of:
444

445
            Mapping
446
                Maps from label to matplotlib-compatible color specification.
447
            list-like
448
                A list of matplotlib colors to apply to labels in order.
449
            str
450
                The name of a matplotlib colormap name.
451
            callable
452
                When called with the number of labels, this should return a
453
                list-like of that many colors.  Matplotlib colormaps satisfy
454
                this callable API.
455
            None
456
                Uses the matplotlib default colormap.
457
        elements : int, default=3
458
            Size of the axes counted in number of matrix elements.
459
        title : str, optional
460
            The axis title labelling bar length.
461

462
        Returns
463
        -------
464
        None
465
        """
466
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
467
        #                        list of labels}
468
        self._subset_plots.append({'type': 'stacked_bars',
3✔
469
                                   'by': by,
470
                                   'sum_over': sum_over,
471
                                   'colors': colors,
472
                                   'title': title,
473
                                   'id': 'extra%d' % len(self._subset_plots),
474
                                   'elements': elements})
475

476
    def add_catplot(self, kind, value=None, elements=3, **kw):
3✔
477
        """Add a seaborn catplot over subsets when :func:`plot` is called.
478

479
        Parameters
480
        ----------
481
        kind : str
482
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
483
        value : str, optional
484
            Column name for the value to plot (i.e. y if
485
            orientation='horizontal'), required if `data` is a DataFrame.
486
        elements : int, default=3
487
            Size of the axes counted in number of matrix elements.
488
        **kw : dict
489
            Additional keywords to pass to :func:`seaborn.catplot`.
490

491
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
492
            and 'orient', so these are prohibited keys in `kw`.
493

494
        Returns
495
        -------
496
        None
497
        """
498
        assert not set(kw.keys()) & {'ax', 'data', 'x', 'y', 'orient'}
2✔
499
        if value is None:
2✔
500
            if '_value' not in self._df.columns:
2✔
501
                raise ValueError('value cannot be set if data is a Series. '
2✔
502
                                 'Got %r' % value)
503
        else:
504
            if value not in self._df.columns:
2✔
505
                raise ValueError('value %r is not a column in data' % value)
2✔
506
        self._subset_plots.append({'type': 'catplot',
2✔
507
                                   'value': value,
508
                                   'kind': kind,
509
                                   'id': 'extra%d' % len(self._subset_plots),
510
                                   'elements': elements,
511
                                   'kw': kw})
512

513
    def _check_value(self, value):
3✔
514
        if value is None and '_value' in self._df.columns:
2✔
515
            value = '_value'
2✔
516
        elif value is None:
2✔
517
            raise ValueError('value can only be None when data is a Series')
×
518
        return value
2✔
519

520
    def _plot_catplot(self, ax, value, kind, kw):
3✔
521
        df = self._df
2✔
522
        value = self._check_value(value)
2✔
523
        kw = kw.copy()
2✔
524
        if self._horizontal:
2✔
525
            kw['orient'] = 'v'
2✔
526
            kw['x'] = '_bin'
2✔
527
            kw['y'] = value
2✔
528
        else:
529
            kw['orient'] = 'h'
1✔
530
            kw['x'] = value
1✔
531
            kw['y'] = '_bin'
1✔
532
        import seaborn
2✔
533
        kw['ax'] = ax
2✔
534
        getattr(seaborn, kind + 'plot')(data=df, **kw)
2✔
535

536
        ax = self._reorient(ax)
2✔
537
        if value == '_value':
2✔
538
            ax.set_ylabel('')
2✔
539

540
        ax.xaxis.set_visible(False)
2✔
541
        for x in ['top', 'bottom', 'right']:
2✔
542
            ax.spines[self._reorient(x)].set_visible(False)
2✔
543

544
        tick_axis = ax.yaxis
2✔
545
        tick_axis.grid(True)
2✔
546

547
    def make_grid(self, fig=None):
3✔
548
        """Get a SubplotSpec for each Axes, accounting for label text width
549
        """
550
        n_cats = len(self.totals)
3✔
551
        n_inters = len(self.intersections)
3✔
552

553
        if fig is None:
3✔
554
            fig = plt.gcf()
×
555

556
        # Determine text size to determine figure size / spacing
557
        text_kw = {"size": matplotlib.rcParams['xtick.labelsize']}
3✔
558
        # adding "x" ensures a margin
559
        t = fig.text(0, 0, '\n'.join(str(label) + "x"
3✔
560
                                     for label in self.totals.index.values),
561
                     **text_kw)
562
        window_extent_args = {}
3✔
563
        if RENDERER_IMPORTED:
3✔
564
            window_extent_args["renderer"] = get_renderer(fig)
3✔
565
        textw = t.get_window_extent(**window_extent_args).width
3✔
566
        t.remove()
3✔
567

568
        window_extent_args = {}
3✔
569
        if RENDERER_IMPORTED:
3✔
570
            window_extent_args["renderer"] = get_renderer(fig)
3✔
571
        figw = self._reorient(
3✔
572
            fig.get_window_extent(**window_extent_args)).width
573

574
        sizes = np.asarray([p['elements'] for p in self._subset_plots])
3✔
575
        fig = self._reorient(fig)
3✔
576

577
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
3✔
578
        if self._element_size is None:
3✔
579
            colw = (figw - textw) / non_text_nelems
3✔
580
        else:
581
            render_ratio = figw / fig.get_figwidth()
3✔
582
            colw = self._element_size / 72 * render_ratio
3✔
583
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
3✔
584
            fig.set_figwidth(figw / render_ratio)
3✔
585
            fig.set_figheight((colw * (n_cats + sizes.sum())) /
3✔
586
                              render_ratio)
587

588
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
3✔
589
        # print('textw', textw, 'figw', figw, 'colw', colw,
590
        #       'ncols', figw/colw, 'text_nelems', text_nelems)
591

592
        GS = self._reorient(matplotlib.gridspec.GridSpec)
3✔
593
        gridspec = GS(*self._swapaxes(n_cats + (sizes.sum() or 0),
3✔
594
                                      n_inters + text_nelems +
595
                                      self._totals_plot_elements),
596
                      hspace=1)
597
        if self._horizontal:
3✔
598
            out = {'matrix': gridspec[-n_cats:, -n_inters:],
3✔
599
                   'shading': gridspec[-n_cats:, :],
600
                   'totals': gridspec[-n_cats:, :self._totals_plot_elements],
601
                   'gs': gridspec}
602
            cumsizes = np.cumsum(sizes[::-1])
3✔
603
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
3✔
604
                                         self._subset_plots[::-1]):
605
                out[plot['id']] = gridspec[start:stop, -n_inters:]
3✔
606
        else:
607
            out = {'matrix': gridspec[-n_inters:, :n_cats],
3✔
608
                   'shading': gridspec[:, :n_cats],
609
                   'totals': gridspec[:self._totals_plot_elements, :n_cats],
610
                   'gs': gridspec}
611
            cumsizes = np.cumsum(sizes)
3✔
612
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
3✔
613
                                         self._subset_plots):
614
                out[plot['id']] = \
3✔
615
                    gridspec[-n_inters:, start + n_cats:stop + n_cats]
616
        return out
3✔
617

618
    def plot_matrix(self, ax):
3✔
619
        """Plot the matrix of intersection indicators onto ax
620
        """
621
        ax = self._reorient(ax)
3✔
622
        data = self.intersections
3✔
623
        n_cats = data.index.nlevels
3✔
624

625
        inclusion = data.index.to_frame().values
3✔
626

627
        # Prepare styling
628
        styles = [
3✔
629
            [
630
                self.subset_styles[i]
631
                if inclusion[i, j]
632
                else {"facecolor": self._other_dots_color, "linewidth": 0}
633
                for j in range(n_cats)
634
            ]
635
            for i in range(len(data))
636
        ]
637
        styles = sum(styles, [])  # flatten nested list
3✔
638
        style_columns = {"facecolor": "facecolors",
3✔
639
                         "edgecolor": "edgecolors",
640
                         "linewidth": "linewidths",
641
                         "linestyle": "linestyles",
642
                         "hatch": "hatch"}
643
        styles = pd.DataFrame(styles).reindex(columns=style_columns.keys())
3✔
644
        styles["linewidth"].fillna(1, inplace=True)
3✔
645
        styles["facecolor"].fillna(self._facecolor, inplace=True)
3✔
646
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
3✔
647
        styles["linestyle"].fillna("solid", inplace=True)
3✔
648
        del styles["hatch"]  # not supported in matrix (currently)
3✔
649

650
        x = np.repeat(np.arange(len(data)), n_cats)
3✔
651
        y = np.tile(np.arange(n_cats), len(data))
3✔
652

653
        # Plot dots
654
        if self._element_size is not None:
3✔
655
            s = (self._element_size * .35) ** 2
3✔
656
        else:
657
            # TODO: make s relative to colw
658
            s = 200
3✔
659
        ax.scatter(*self._swapaxes(x, y), s=s, zorder=10,
3✔
660
                   **styles.rename(columns=style_columns))
661

662
        # Plot lines
663
        if self._with_lines:
3✔
664
            idx = np.flatnonzero(inclusion)
3✔
665
            line_data = (pd.Series(y[idx], index=x[idx])
3✔
666
                         .groupby(level=0)
667
                         .aggregate(['min', 'max']))
668
            colors = pd.Series([
3✔
669
                style.get("edgecolor", style.get("facecolor", self._facecolor))
670
                for style in self.subset_styles],
671
                name="color")
672
            line_data = line_data.join(colors)
3✔
673
            ax.vlines(line_data.index.values,
3✔
674
                      line_data['min'], line_data['max'],
675
                      lw=2, colors=line_data["color"],
676
                      zorder=5)
677

678
        # Ticks and axes
679
        tick_axis = ax.yaxis
3✔
680
        tick_axis.set_ticks(np.arange(n_cats))
3✔
681
        tick_axis.set_ticklabels(data.index.names,
3✔
682
                                 rotation=0 if self._horizontal else -90)
683
        ax.xaxis.set_visible(False)
3✔
684
        ax.tick_params(axis='both', which='both', length=0)
3✔
685
        if not self._horizontal:
3✔
686
            ax.yaxis.set_ticks_position('top')
3✔
687
        ax.set_frame_on(False)
3✔
688
        ax.set_xlim(-.5, x[-1] + .5, auto=False)
3✔
689
        ax.grid(False)
3✔
690

691
    def plot_intersections(self, ax):
3✔
692
        """Plot bars indicating intersection size
693
        """
694
        rects = self._plot_bars(ax, self.intersections,
3✔
695
                                title='Intersection size',
696
                                colors=self._facecolor)
697
        for style, rect in zip(self.subset_styles, rects):
3✔
698
            style = style.copy()
3✔
699
            style.setdefault("edgecolor",
3✔
700
                             style.get("facecolor", self._facecolor))
701
            for attr, val in style.items():
3✔
702
                getattr(rect, "set_" + attr)(val)
3✔
703

704
        if self.subset_legend:
3✔
705
            styles, labels = zip(*self.subset_legend)
1✔
706
            styles = [patches.Patch(**patch_style) for patch_style in styles]
1✔
707
            ax.legend(styles, labels)
1✔
708

709
    def _label_sizes(self, ax, rects, where):
3✔
710
        if not self._show_counts and not self._show_percentages:
3✔
711
            return
3✔
712
        if self._show_counts is True:
3✔
713
            count_fmt = "%d"
3✔
714
        else:
715
            count_fmt = self._show_counts
3✔
716
        if self._show_percentages is True:
3✔
717
            pct_fmt = "%.1f%%"
3✔
718
        else:
719
            pct_fmt = self._show_percentages
3✔
720

721
        if count_fmt and pct_fmt:
3✔
722
            if where == 'top':
3✔
723
                fmt = '%s\n(%s)' % (count_fmt, pct_fmt)
3✔
724
            else:
725
                fmt = '%s (%s)' % (count_fmt, pct_fmt)
3✔
726

727
            def make_args(val):
3✔
728
                return val, 100 * val / self.total
3✔
729
        elif count_fmt:
3✔
730
            fmt = count_fmt
3✔
731

732
            def make_args(val):
3✔
733
                return val,
3✔
734
        else:
735
            fmt = pct_fmt
3✔
736

737
            def make_args(val):
3✔
738
                return 100 * val / self.total,
3✔
739

740
        if where == 'right':
3✔
741
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
3✔
742
            for rect in rects:
3✔
743
                width = rect.get_width() + rect.get_x()
3✔
744
                ax.text(width + margin,
3✔
745
                        rect.get_y() + rect.get_height() * .5,
746
                        fmt % make_args(width),
747
                        ha='left', va='center')
748
        elif where == 'left':
3✔
749
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
3✔
750
            for rect in rects:
3✔
751
                width = rect.get_width() + rect.get_x()
3✔
752
                ax.text(width + margin,
3✔
753
                        rect.get_y() + rect.get_height() * .5,
754
                        fmt % make_args(width),
755
                        ha='right', va='center')
756
        elif where == 'top':
3✔
757
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
3✔
758
            for rect in rects:
3✔
759
                height = rect.get_height() + rect.get_y()
3✔
760
                ax.text(rect.get_x() + rect.get_width() * .5,
3✔
761
                        height + margin,
762
                        fmt % make_args(height),
763
                        ha='center', va='bottom')
764
        else:
765
            raise NotImplementedError('unhandled where: %r' % where)
×
766

767
    def plot_totals(self, ax):
3✔
768
        """Plot bars indicating total set size
769
        """
770
        orig_ax = ax
3✔
771
        ax = self._reorient(ax)
3✔
772
        rects = ax.barh(np.arange(len(self.totals.index.values)), self.totals,
3✔
773
                        .5, color=self._facecolor, align='center')
774
        self._label_sizes(ax, rects, 'left' if self._horizontal else 'top')
3✔
775

776
        max_total = self.totals.max()
3✔
777
        if self._horizontal:
3✔
778
            orig_ax.set_xlim(max_total, 0)
3✔
779
        for x in ['top', 'left', 'right']:
3✔
780
            ax.spines[self._reorient(x)].set_visible(False)
3✔
781
        ax.yaxis.set_visible(False)
3✔
782
        ax.xaxis.grid(True)
3✔
783
        ax.yaxis.grid(False)
3✔
784
        ax.patch.set_visible(False)
3✔
785

786
    def plot_shading(self, ax):
3✔
787
        # alternating row shading (XXX: use add_patch(Rectangle)?)
788
        for i in range(0, len(self.totals), 2):
3✔
789
            rect = plt.Rectangle(self._swapaxes(0, i - .4),
3✔
790
                                 *self._swapaxes(*(1, .8)),
791
                                 facecolor=self._shading_color, lw=0, zorder=0)
792
            ax.add_patch(rect)
3✔
793
        ax.set_frame_on(False)
3✔
794
        ax.tick_params(
3✔
795
            axis='both',
796
            which='both',
797
            left=False,
798
            right=False,
799
            bottom=False,
800
            top=False,
801
            labelbottom=False,
802
            labelleft=False)
803
        ax.grid(False)
3✔
804
        ax.set_xticks([])
3✔
805
        ax.set_yticks([])
3✔
806
        ax.set_xticklabels([])
3✔
807
        ax.set_yticklabels([])
3✔
808

809
    def plot(self, fig=None):
3✔
810
        """Draw all parts of the plot onto fig or a new figure
811

812
        Parameters
813
        ----------
814
        fig : matplotlib.figure.Figure, optional
815
            Defaults to a new figure.
816

817
        Returns
818
        -------
819
        subplots : dict of matplotlib.axes.Axes
820
            Keys are 'matrix', 'intersections', 'totals', 'shading'
821
        """
822
        if fig is None:
3✔
823
            fig = plt.figure(figsize=self._default_figsize)
3✔
824
        specs = self.make_grid(fig)
3✔
825
        shading_ax = fig.add_subplot(specs['shading'])
3✔
826
        self.plot_shading(shading_ax)
3✔
827
        matrix_ax = self._reorient(fig.add_subplot)(specs['matrix'],
3✔
828
                                                    sharey=shading_ax)
829
        self.plot_matrix(matrix_ax)
3✔
830
        totals_ax = self._reorient(fig.add_subplot)(specs['totals'],
3✔
831
                                                    sharey=matrix_ax)
832
        self.plot_totals(totals_ax)
3✔
833
        out = {'matrix': matrix_ax,
3✔
834
               'shading': shading_ax,
835
               'totals': totals_ax}
836

837
        for plot in self._subset_plots:
3✔
838
            ax = self._reorient(fig.add_subplot)(specs[plot['id']],
3✔
839
                                                 sharex=matrix_ax)
840
            if plot['type'] == 'default':
3✔
841
                self.plot_intersections(ax)
3✔
842
            elif plot['type'] in self.PLOT_TYPES:
3✔
843
                kw = plot.copy()
3✔
844
                del kw['type']
3✔
845
                del kw['elements']
3✔
846
                del kw['id']
3✔
847
                self.PLOT_TYPES[plot['type']](self, ax, **kw)
3✔
848
            else:
849
                raise ValueError('Unknown subset plot type: %r' % plot['type'])
×
850
            out[plot['id']] = ax
3✔
851
        return out
3✔
852

853
    PLOT_TYPES = {
3✔
854
        'catplot': _plot_catplot,
855
        'stacked_bars': _plot_stacked_bars,
856
    }
857

858
    def _repr_html_(self):
3✔
859
        fig = plt.figure(figsize=self._default_figsize)
×
860
        self.plot(fig=fig)
×
861
        return fig._repr_html_()
×
862

863

864
def plot(data, fig=None, **kwargs):
3✔
865
    """Make an UpSet plot of data on fig
866

867
    Parameters
868
    ----------
869
    data : pandas.Series or pandas.DataFrame
870
        Values for each set to plot.
871
        Should have multi-index where each level is binary,
872
        corresponding to set membership.
873
        If a DataFrame, `sum_over` must be a string or False.
874
    fig : matplotlib.figure.Figure, optional
875
        Defaults to a new figure.
876
    kwargs
877
        Other arguments for :class:`UpSet`
878

879
    Returns
880
    -------
881
    subplots : dict of matplotlib.axes.Axes
882
        Keys are 'matrix', 'intersections', 'totals', 'shading'
883
    """
884
    return UpSet(data, **kwargs).plot(fig)
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc