• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7342943552

28 Dec 2023 12:13AM UTC coverage: 83.549% (-14.0%) from 97.551%
7342943552

push

github

web-flow
Fix warning due to styling dtyles, and fix column dtype test failure (#238)


Fixes #225

6 of 6 new or added lines in 2 files covered. (100.0%)

312 existing lines in 7 files now uncovered.

1681 of 2012 relevant lines covered (83.55%)

1.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.81
/upsetplot/plotting.py
1
from __future__ import print_function, division, absolute_import
2✔
2

3
try:
2✔
4
    import typing
2✔
5
except ImportError:
×
6
    import collections as typing
×
7

8
import numpy as np
2✔
9
import pandas as pd
2✔
10
import matplotlib
2✔
11
from matplotlib import pyplot as plt
2✔
12
from matplotlib import colors
2✔
13
from matplotlib import patches
2✔
14

15
from .reformat import query, _get_subset_mask
2✔
16
from . import util
2✔
17

18
# prevents ImportError on matplotlib versions >3.5.2
19
try:
2✔
20
    from matplotlib.tight_layout import get_renderer
2✔
21

22
    RENDERER_IMPORTED = True
1✔
23
except ImportError:
1✔
24
    RENDERER_IMPORTED = False
1✔
25

26

27
def _process_data(df, *, sort_by, sort_categories_by, subset_size,
1✔
28
                  sum_over, min_subset_size=None, max_subset_size=None,
2✔
29
                  min_degree=None, max_degree=None, reverse=False,
2✔
30
                  include_empty_subsets=False):
2✔
31
    results = query(df, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
32
                    subset_size=subset_size, sum_over=sum_over,
2✔
33
                    min_subset_size=min_subset_size,
2✔
34
                    max_subset_size=max_subset_size,
2✔
35
                    min_degree=min_degree, max_degree=max_degree,
2✔
36
                    include_empty_subsets=include_empty_subsets)
2✔
37

38
    df = results.data
2✔
39
    agg = results.subset_sizes
2✔
40
    totals = results.category_totals
2✔
41
    total = agg.sum()
2✔
42

43
    # add '_bin' to df indicating index in agg
44
    # XXX: ugly!
45
    def _pack_binary(X):
2✔
46
        X = pd.DataFrame(X)
2✔
47
        # use objects if arbitrary precision integers are needed
48
        dtype = np.object_ if X.shape[1] > 62 else np.uint64
2✔
49
        out = pd.Series(0, index=X.index, dtype=dtype)
2✔
50
        for i, (_, col) in enumerate(X.items()):
2✔
51
            out *= 2
2✔
52
            out += col
2✔
53
        return out
2✔
54

55
    df_packed = _pack_binary(df.index.to_frame())
2✔
56
    data_packed = _pack_binary(agg.index.to_frame())
2✔
57
    df['_bin'] = pd.Series(df_packed).map(
2✔
58
        pd.Series(np.arange(len(data_packed))[::-1 if reverse else 1],
2✔
59
                  index=data_packed))
2✔
60
    if reverse:
2✔
61
        agg = agg[::-1]
2✔
62

63
    return total, df, agg, totals
2✔
64

65

66
def _multiply_alpha(c, mult):
2✔
67
    r, g, b, a = colors.to_rgba(c)
2✔
68
    a *= mult
2✔
69
    return colors.to_hex((r, g, b, a), keep_alpha=True)
2✔
70

71

72
class _Transposed:
2✔
73
    """Wrap an object in order to transpose some plotting operations
74

75
    Attributes of obj will be mapped.
76
    Keyword arguments when calling obj will be mapped.
77

78
    The mapping is not recursive: callable attributes need to be _Transposed
79
    again.
80
    """
81

82
    def __init__(self, obj):
2✔
83
        self.__obj = obj
2✔
84

85
    def __getattr__(self, key):
2✔
86
        return getattr(self.__obj, self._NAME_TRANSPOSE.get(key, key))
2✔
87

88
    def __call__(self, *args, **kwargs):
2✔
89
        return self.__obj(*args, **{self._NAME_TRANSPOSE.get(k, k): v
2✔
90
                                    for k, v in kwargs.items()})
2✔
91

92
    _NAME_TRANSPOSE = {
1✔
93
        'width': 'height',
2✔
94
        'height': 'width',
2✔
95
        'hspace': 'wspace',
2✔
96
        'wspace': 'hspace',
2✔
97
        'hlines': 'vlines',
2✔
98
        'vlines': 'hlines',
2✔
99
        'bar': 'barh',
2✔
100
        'barh': 'bar',
2✔
101
        'xaxis': 'yaxis',
2✔
102
        'yaxis': 'xaxis',
2✔
103
        'left': 'bottom',
2✔
104
        'right': 'top',
2✔
105
        'top': 'right',
2✔
106
        'bottom': 'left',
2✔
107
        'sharex': 'sharey',
2✔
108
        'sharey': 'sharex',
2✔
109
        'get_figwidth': 'get_figheight',
2✔
110
        'get_figheight': 'get_figwidth',
2✔
111
        'set_figwidth': 'set_figheight',
2✔
112
        'set_figheight': 'set_figwidth',
2✔
113
        'set_xlabel': 'set_ylabel',
2✔
114
        'set_ylabel': 'set_xlabel',
2✔
115
        'set_xlim': 'set_ylim',
2✔
116
        'set_ylim': 'set_xlim',
2✔
117
        'get_xlim': 'get_ylim',
2✔
118
        'get_ylim': 'get_xlim',
2✔
119
        'set_autoscalex_on': 'set_autoscaley_on',
2✔
120
        'set_autoscaley_on': 'set_autoscalex_on',
2✔
UNCOV
121
    }
×
122

123

124
def _transpose(obj):
2✔
125
    if isinstance(obj, str):
2✔
126
        return _Transposed._NAME_TRANSPOSE.get(obj, obj)
2✔
127
    return _Transposed(obj)
2✔
128

129

130
def _identity(obj):
2✔
131
    return obj
2✔
132

133

134
class UpSet:
2✔
135
    """Manage the data and drawing for a basic UpSet plot
136

137
    Primary public method is :meth:`plot`.
138

139
    Parameters
140
    ----------
141
    data : pandas.Series or pandas.DataFrame
142
        Elements associated with categories (a DataFrame), or the size of each
143
        subset of categories (a Series).
144
        Should have MultiIndex where each level is binary,
145
        corresponding to category membership.
146
        If a DataFrame, `sum_over` must be a string or False.
147
    orientation : {'horizontal' (default), 'vertical'}
148
        If horizontal, intersections are listed from left to right.
149
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
150
               'input', '-input'}
151
        If 'cardinality', subset are listed from largest to smallest.
152
        If 'degree', they are listed in order of the number of categories
153
        intersected. If 'input', the order they appear in the data input is
154
        used.
155
        Prefix with '-' to reverse the ordering.
156

157
        Note this affects ``subset_sizes`` but not ``data``.
158
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
159
        Whether to sort the categories by total cardinality, or leave them
160
        in the input data's provided order (order of index levels).
161
        Prefix with '-' to reverse the ordering.
162
    subset_size : {'auto', 'count', 'sum'}
163
        Configures how to calculate the size of a subset. Choices are:
164

165
        'auto' (default)
166
            If `data` is a DataFrame, count the number of rows in each group,
167
            unless `sum_over` is specified.
168
            If `data` is a Series with at most one row for each group, use
169
            the value of the Series. If `data` is a Series with more than one
170
            row per group, raise a ValueError.
171
        'count'
172
            Count the number of rows in each group.
173
        'sum'
174
            Sum the value of the `data` Series, or the DataFrame field
175
            specified by `sum_over`.
176
    sum_over : str or None
177
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
178
        sum of the specified field in the `data` DataFrame. If a Series, only
179
        None is supported and its value is summed.
180
    min_subset_size : int, optional
181
        Minimum size of a subset to be shown in the plot. All subsets with
182
        a size smaller than this threshold will be omitted from plotting.
183
        Size may be a sum of values, see `subset_size`.
184

185
        .. versionadded:: 0.5
186
    max_subset_size : int, optional
187
        Maximum size of a subset to be shown in the plot. All subsets with
188
        a size greater than this threshold will be omitted from plotting.
189

190
        .. versionadded:: 0.5
191
    min_degree : int, optional
192
        Minimum degree of a subset to be shown in the plot.
193

194
        .. versionadded:: 0.5
195
    max_degree : int, optional
196
        Maximum degree of a subset to be shown in the plot.
197

198
        .. versionadded:: 0.5
199
    facecolor : 'auto' or matplotlib color or float
200
        Color for bar charts and active dots. Defaults to black if
201
        axes.facecolor is a light color, otherwise white.
202

203
        .. versionchanged:: 0.6
204
            Before 0.6, the default was 'black'
205
    other_dots_color : matplotlib color or float
206
        Color for shading of inactive dots, or opacity (between 0 and 1)
207
        applied to facecolor.
208

209
        .. versionadded:: 0.6
210
    shading_color : matplotlib color or float
211
        Color for shading of odd rows in matrix and totals, or opacity (between
212
        0 and 1) applied to facecolor.
213

214
        .. versionadded:: 0.6
215
    with_lines : bool
216
        Whether to show lines joining dots in the matrix, to mark multiple
217
        categories being intersected.
218
    element_size : float or None
219
        Side length in pt. If None, size is estimated to fit figure
220
    intersection_plot_elements : int
221
        The intersections plot should be large enough to fit this many matrix
222
        elements. Set to 0 to disable intersection size bars.
223

224
        .. versionchanged:: 0.4
225
            Setting to 0 is handled.
226
    totals_plot_elements : int
227
        The totals plot should be large enough to fit this many matrix
228
        elements.
229
    show_counts : bool or str, default=False
230
        Whether to label the intersection size bars with the cardinality
231
        of the intersection. When a string, this formats the number.
232
        For example, '{:d}' is equivalent to True.
233
        Note that, for legacy reasons, if the string does not contain '{',
234
        it will be interpreted as a C-style format string, such as '%d'.
235
    show_percentages : bool or str, default=False
236
        Whether to label the intersection size bars with the percentage
237
        of the intersection relative to the total dataset.
238
        When a string, this formats the number representing a fraction of
239
        samples.
240
        For example, '{:.1%}' is the default, formatting .123 as 12.3%.
241
        This may be applied with or without show_counts.
242

243
        .. versionadded:: 0.4
244
    include_empty_subsets : bool (default=False)
245
        If True, all possible category combinations will be shown as subsets,
246
        even when some are not present in data.
247
    """
248
    _default_figsize = (10, 6)
2✔
249

250
    def __init__(self, data, orientation='horizontal', sort_by='degree',
1✔
251
                 sort_categories_by='cardinality',
1✔
252
                 subset_size='auto', sum_over=None,
1✔
253
                 min_subset_size=None, max_subset_size=None,
1✔
254
                 min_degree=None, max_degree=None,
1✔
255
                 facecolor='auto', other_dots_color=.18, shading_color=.05,
1✔
256
                 with_lines=True, element_size=32,
1✔
257
                 intersection_plot_elements=6, totals_plot_elements=2,
1✔
258
                 show_counts='', show_percentages=False,
1✔
259
                 include_empty_subsets=False):
2✔
260

261
        self._horizontal = orientation == 'horizontal'
2✔
262
        self._reorient = _identity if self._horizontal else _transpose
2✔
263
        if facecolor == 'auto':
2✔
264
            bgcolor = matplotlib.rcParams.get('axes.facecolor', 'white')
2✔
265
            r, g, b, a = colors.to_rgba(bgcolor)
2✔
266
            lightness = colors.rgb_to_hsv((r, g, b))[-1] * a
2✔
267
            facecolor = 'black' if lightness >= .5 else 'white'
2✔
268
        self._facecolor = facecolor
2✔
269
        self._shading_color = (_multiply_alpha(facecolor, shading_color)
1✔
270
                               if isinstance(shading_color, float)
2✔
271
                               else shading_color)
2✔
272
        self._other_dots_color = (_multiply_alpha(facecolor, other_dots_color)
1✔
273
                                  if isinstance(other_dots_color, float)
2✔
274
                                  else other_dots_color)
2✔
275
        self._with_lines = with_lines
2✔
276
        self._element_size = element_size
2✔
277
        self._totals_plot_elements = totals_plot_elements
2✔
278
        self._subset_plots = [{'type': 'default',
2✔
279
                               'id': 'intersections',
2✔
280
                               'elements': intersection_plot_elements}]
2✔
281
        if not intersection_plot_elements:
2✔
282
            self._subset_plots.pop()
2✔
283
        self._show_counts = show_counts
2✔
284
        self._show_percentages = show_percentages
2✔
285

286
        (self.total, self._df, self.intersections,
1✔
287
         self.totals) = _process_data(
2✔
288
            data,
2✔
289
            sort_by=sort_by,
2✔
290
            sort_categories_by=sort_categories_by,
2✔
291
            subset_size=subset_size,
2✔
292
            sum_over=sum_over,
2✔
293
            min_subset_size=min_subset_size,
2✔
294
            max_subset_size=max_subset_size,
2✔
295
            min_degree=min_degree,
2✔
296
            max_degree=max_degree,
2✔
297
            reverse=not self._horizontal,
2✔
298
            include_empty_subsets=include_empty_subsets)
2✔
299
        self.subset_styles = [{"facecolor": facecolor}
2✔
300
                              for i in range(len(self.intersections))]
2✔
301
        self.subset_legend = []  # pairs of (style, label)
2✔
302

303
    def _swapaxes(self, x, y):
2✔
304
        if self._horizontal:
2✔
305
            return x, y
2✔
306
        return y, x
2✔
307

308
    def style_subsets(self, present=None, absent=None,
1✔
309
                      min_subset_size=None, max_subset_size=None,
1✔
310
                      min_degree=None, max_degree=None,
1✔
311
                      facecolor=None, edgecolor=None, hatch=None,
1✔
312
                      linewidth=None, linestyle=None, label=None):
2✔
UNCOV
313
        """Updates the style of selected subsets' bars and matrix dots
×
314

UNCOV
315
        Parameters are either used to select subsets, or to style them with
×
UNCOV
316
        attributes of :class:`matplotlib.patches.Patch`, apart from label,
×
UNCOV
317
        which adds a legend entry.
×
318

UNCOV
319
        Parameters
×
UNCOV
320
        ----------
×
UNCOV
321
        present : str or list of str, optional
×
UNCOV
322
            Category or categories that must be present in subsets for styling.
×
UNCOV
323
        absent : str or list of str, optional
×
UNCOV
324
            Category or categories that must not be present in subsets for
×
UNCOV
325
            styling.
×
UNCOV
326
        min_subset_size : int, optional
×
UNCOV
327
            Minimum size of a subset to be styled.
×
UNCOV
328
        max_subset_size : int, optional
×
UNCOV
329
            Maximum size of a subset to be styled.
×
UNCOV
330
        min_degree : int, optional
×
UNCOV
331
            Minimum degree of a subset to be styled.
×
UNCOV
332
        max_degree : int, optional
×
UNCOV
333
            Maximum degree of a subset to be styled.
×
334

UNCOV
335
        facecolor : str or matplotlib color, optional
×
UNCOV
336
            Override the default UpSet facecolor for selected subsets.
×
UNCOV
337
        edgecolor : str or matplotlib color, optional
×
UNCOV
338
            Set the edgecolor for bars, dots, and the line between dots.
×
UNCOV
339
        hatch : str, optional
×
UNCOV
340
            Set the hatch. This will apply to intersection size bars, but not
×
UNCOV
341
            to matrix dots.
×
UNCOV
342
        linewidth : int, optional
×
UNCOV
343
            Line width in points for edges.
×
UNCOV
344
        linestyle : str, optional
×
UNCOV
345
            Line style for edges.
×
346

UNCOV
347
        label : str, optional
×
UNCOV
348
            If provided, a legend will be added
×
UNCOV
349
        """
×
350
        style = {"facecolor": facecolor, "edgecolor": edgecolor,
2✔
351
                 "hatch": hatch,
2✔
352
                 "linewidth": linewidth, "linestyle": linestyle}
2✔
353
        style = {k: v for k, v in style.items() if v is not None}
2✔
354
        mask = _get_subset_mask(self.intersections,
2✔
355
                                present=present, absent=absent,
2✔
356
                                min_subset_size=min_subset_size,
2✔
357
                                max_subset_size=max_subset_size,
2✔
358
                                min_degree=min_degree, max_degree=max_degree)
2✔
359
        for idx in np.flatnonzero(mask):
2✔
360
            self.subset_styles[idx].update(style)
2✔
361

362
        if label is not None:
2✔
363
            if "facecolor" not in style:
2✔
UNCOV
364
                style["facecolor"] = self._facecolor
×
365
            for i, (other_style, other_label) in enumerate(self.subset_legend):
2✔
366
                if other_style == style:
2✔
367
                    if other_label != label:
2✔
368
                        self.subset_legend[i] = (style,
2✔
369
                                                 other_label + '; ' + label)
2✔
370
                    break
2✔
UNCOV
371
            else:
×
372
                self.subset_legend.append((style, label))
2✔
373

374
    def _plot_bars(self, ax, data, title, colors=None, use_labels=False):
2✔
375
        ax = self._reorient(ax)
2✔
376
        ax.set_autoscalex_on(False)
2✔
377
        data_df = pd.DataFrame(data)
2✔
378
        if self._horizontal:
2✔
379
            data_df = data_df.loc[:, ::-1]  # reverse: top row is top of stack
2✔
380

381
        # TODO: colors should be broadcastable to data_df shape
382
        if callable(colors):
2✔
383
            colors = colors(range(data_df.shape[1]))
2✔
384
        elif isinstance(colors, (str, type(None))):
2✔
385
            colors = [colors] * len(data_df)
2✔
386

387
        if self._horizontal:
2✔
388
            colors = reversed(colors)
2✔
389

390
        x = np.arange(len(data_df))
2✔
391
        cum_y = None
2✔
392
        all_rects = []
2✔
393
        for (name, y), color in zip(data_df.items(), colors):
2✔
394
            rects = ax.bar(x, y, .5, cum_y,
2✔
395
                           color=color, zorder=10,
2✔
396
                           label=name if use_labels else None,
2✔
397
                           align='center')
2✔
398
            cum_y = y if cum_y is None else cum_y + y
2✔
399
            all_rects.extend(rects)
2✔
400

401
        self._label_sizes(ax, rects, 'top' if self._horizontal else 'right')
2✔
402

403
        ax.xaxis.set_visible(False)
2✔
404
        for x in ['top', 'bottom', 'right']:
2✔
405
            ax.spines[self._reorient(x)].set_visible(False)
2✔
406

407
        tick_axis = ax.yaxis
2✔
408
        tick_axis.grid(True)
2✔
409
        ax.set_ylabel(title)
2✔
410
        return all_rects
2✔
411

412
    def _plot_stacked_bars(self, ax, by, sum_over, colors, title):
2✔
413
        df = self._df.set_index("_bin").set_index(by, append=True, drop=False)
2✔
414
        gb = df.groupby(level=list(range(df.index.nlevels)), sort=True)
2✔
415
        if sum_over is None and "_value" in df.columns:
2✔
416
            data = gb["_value"].sum()
×
417
        elif sum_over is None:
2✔
418
            data = gb.size()
2✔
UNCOV
419
        else:
×
420
            data = gb[sum_over].sum()
2✔
421
        data = data.unstack(by).fillna(0)
2✔
422
        if isinstance(colors, str):
2✔
423
            colors = matplotlib.cm.get_cmap(colors)
2✔
424
        elif isinstance(colors, typing.Mapping):
2✔
425
            colors = data.columns.map(colors).values
2✔
426
            if pd.isna(colors).any():
2✔
427
                raise KeyError("Some labels mapped by colors: %r" %
×
UNCOV
428
                               data.columns[pd.isna(colors)].tolist())
×
429

430
        self._plot_bars(ax, data=data, colors=colors, title=title,
2✔
431
                        use_labels=True)
2✔
432

433
        handles, labels = ax.get_legend_handles_labels()
2✔
434
        if self._horizontal:
2✔
435
            # Make legend order match visual stack order
436
            ax.legend(reversed(handles), reversed(labels))
2✔
UNCOV
437
        else:
×
438
            ax.legend()
2✔
439

440
    def add_stacked_bars(self, by, sum_over=None, colors=None, elements=3,
1✔
441
                         title=None):
2✔
UNCOV
442
        """Add a stacked bar chart over subsets when :func:`plot` is called.
×
443

UNCOV
444
        Used to plot categorical variable distributions within each subset.
×
445

UNCOV
446
        .. versionadded:: 0.6
×
447

UNCOV
448
        Parameters
×
UNCOV
449
        ----------
×
UNCOV
450
        by : str
×
UNCOV
451
            Column name within the dataframe for color coding the stacked bars,
×
UNCOV
452
            containing discrete or categorical values.
×
UNCOV
453
        sum_over : str, optional
×
UNCOV
454
            Ordinarily the bars will chart the size of each group. sum_over
×
UNCOV
455
            may specify a column which will be summed to determine the size
×
UNCOV
456
            of each bar.
×
UNCOV
457
        colors : Mapping, list-like, str or callable, optional
×
UNCOV
458
            The facecolors to use for bars corresponding to each discrete
×
UNCOV
459
            label, specified as one of:
×
460

UNCOV
461
            Mapping
×
UNCOV
462
                Maps from label to matplotlib-compatible color specification.
×
UNCOV
463
            list-like
×
UNCOV
464
                A list of matplotlib colors to apply to labels in order.
×
UNCOV
465
            str
×
UNCOV
466
                The name of a matplotlib colormap name.
×
UNCOV
467
            callable
×
UNCOV
468
                When called with the number of labels, this should return a
×
UNCOV
469
                list-like of that many colors.  Matplotlib colormaps satisfy
×
UNCOV
470
                this callable API.
×
UNCOV
471
            None
×
UNCOV
472
                Uses the matplotlib default colormap.
×
UNCOV
473
        elements : int, default=3
×
UNCOV
474
            Size of the axes counted in number of matrix elements.
×
UNCOV
475
        title : str, optional
×
UNCOV
476
            The axis title labelling bar length.
×
477

UNCOV
478
        Returns
×
UNCOV
479
        -------
×
UNCOV
480
        None
×
UNCOV
481
        """
×
482
        # TODO: allow sort_by = {"lexical", "sum_squares", "rev_sum_squares",
483
        #                        list of labels}
484
        self._subset_plots.append({'type': 'stacked_bars',
2✔
485
                                   'by': by,
2✔
486
                                   'sum_over': sum_over,
2✔
487
                                   'colors': colors,
2✔
488
                                   'title': title,
2✔
489
                                   'id': 'extra%d' % len(self._subset_plots),
2✔
490
                                   'elements': elements})
2✔
491

492
    def add_catplot(self, kind, value=None, elements=3, **kw):
2✔
493
        """Add a seaborn catplot over subsets when :func:`plot` is called.
494

495
        Parameters
496
        ----------
497
        kind : str
498
            One of {"point", "bar", "strip", "swarm", "box", "violin", "boxen"}
499
        value : str, optional
500
            Column name for the value to plot (i.e. y if
501
            orientation='horizontal'), required if `data` is a DataFrame.
502
        elements : int, default=3
503
            Size of the axes counted in number of matrix elements.
504
        **kw : dict
505
            Additional keywords to pass to :func:`seaborn.catplot`.
506

507
            Our implementation automatically determines 'ax', 'data', 'x', 'y'
508
            and 'orient', so these are prohibited keys in `kw`.
509

510
        Returns
511
        -------
512
        None
513
        """
UNCOV
514
        assert not set(kw.keys()) & {'ax', 'data', 'x', 'y', 'orient'}
×
UNCOV
515
        if value is None:
×
UNCOV
516
            if '_value' not in self._df.columns:
×
UNCOV
517
                raise ValueError('value cannot be set if data is a Series. '
×
UNCOV
518
                                 'Got %r' % value)
×
UNCOV
519
        else:
×
UNCOV
520
            if value not in self._df.columns:
×
UNCOV
521
                raise ValueError('value %r is not a column in data' % value)
×
UNCOV
522
        self._subset_plots.append({'type': 'catplot',
×
UNCOV
523
                                   'value': value,
×
UNCOV
524
                                   'kind': kind,
×
UNCOV
525
                                   'id': 'extra%d' % len(self._subset_plots),
×
UNCOV
526
                                   'elements': elements,
×
UNCOV
527
                                   'kw': kw})
×
528

529
    def _check_value(self, value):
2✔
UNCOV
530
        if value is None and '_value' in self._df.columns:
×
UNCOV
531
            value = '_value'
×
UNCOV
532
        elif value is None:
×
533
            raise ValueError('value can only be None when data is a Series')
×
UNCOV
534
        return value
×
535

536
    def _plot_catplot(self, ax, value, kind, kw):
2✔
UNCOV
537
        df = self._df
×
UNCOV
538
        value = self._check_value(value)
×
UNCOV
539
        kw = kw.copy()
×
UNCOV
540
        if self._horizontal:
×
UNCOV
541
            kw['orient'] = 'v'
×
UNCOV
542
            kw['x'] = '_bin'
×
UNCOV
543
            kw['y'] = value
×
UNCOV
544
        else:
×
UNCOV
545
            kw['orient'] = 'h'
×
UNCOV
546
            kw['x'] = value
×
UNCOV
547
            kw['y'] = '_bin'
×
UNCOV
548
        import seaborn
×
UNCOV
549
        kw['ax'] = ax
×
UNCOV
550
        getattr(seaborn, kind + 'plot')(data=df, **kw)
×
551

UNCOV
552
        ax = self._reorient(ax)
×
UNCOV
553
        if value == '_value':
×
UNCOV
554
            ax.set_ylabel('')
×
555

UNCOV
556
        ax.xaxis.set_visible(False)
×
UNCOV
557
        for x in ['top', 'bottom', 'right']:
×
UNCOV
558
            ax.spines[self._reorient(x)].set_visible(False)
×
559

UNCOV
560
        tick_axis = ax.yaxis
×
UNCOV
561
        tick_axis.grid(True)
×
562

563
    def make_grid(self, fig=None):
2✔
564
        """Get a SubplotSpec for each Axes, accounting for label text width
565
        """
566
        n_cats = len(self.totals)
2✔
567
        n_inters = len(self.intersections)
2✔
568

569
        if fig is None:
2✔
570
            fig = plt.gcf()
×
571

572
        # Determine text size to determine figure size / spacing
573
        text_kw = {"size": matplotlib.rcParams['xtick.labelsize']}
2✔
574
        # adding "x" ensures a margin
575
        t = fig.text(0, 0, '\n'.join(str(label) + "x"
2✔
576
                                     for label in self.totals.index.values),
2✔
577
                     **text_kw)
2✔
578
        window_extent_args = {}
2✔
579
        if RENDERER_IMPORTED:
2✔
580
            window_extent_args["renderer"] = get_renderer(fig)
1✔
581
        textw = t.get_window_extent(**window_extent_args).width
2✔
582
        t.remove()
2✔
583

584
        window_extent_args = {}
2✔
585
        if RENDERER_IMPORTED:
2✔
586
            window_extent_args["renderer"] = get_renderer(fig)
1✔
587
        figw = self._reorient(
2✔
588
            fig.get_window_extent(**window_extent_args)).width
2✔
589

590
        sizes = np.asarray([p['elements'] for p in self._subset_plots])
2✔
591
        fig = self._reorient(fig)
2✔
592

593
        non_text_nelems = len(self.intersections) + self._totals_plot_elements
2✔
594
        if self._element_size is None:
2✔
595
            colw = (figw - textw) / non_text_nelems
2✔
596
        else:
597
            render_ratio = figw / fig.get_figwidth()
2✔
598
            colw = self._element_size / 72 * render_ratio
2✔
599
            figw = colw * (non_text_nelems + np.ceil(textw / colw) + 1)
2✔
600
            fig.set_figwidth(figw / render_ratio)
2✔
601
            fig.set_figheight((colw * (n_cats + sizes.sum())) /
2✔
602
                              render_ratio)
2✔
603

604
        text_nelems = int(np.ceil(figw / colw - non_text_nelems))
2✔
605
        # print('textw', textw, 'figw', figw, 'colw', colw,
606
        #       'ncols', figw/colw, 'text_nelems', text_nelems)
607

608
        GS = self._reorient(matplotlib.gridspec.GridSpec)
2✔
609
        gridspec = GS(*self._swapaxes(n_cats + (sizes.sum() or 0),
2✔
610
                                      n_inters + text_nelems +
2✔
611
                                      self._totals_plot_elements),
2✔
612
                      hspace=1)
2✔
613
        if self._horizontal:
2✔
614
            out = {'matrix': gridspec[-n_cats:, -n_inters:],
2✔
615
                   'shading': gridspec[-n_cats:, :],
2✔
616
                   'totals': gridspec[-n_cats:, :self._totals_plot_elements],
2✔
617
                   'gs': gridspec}
2✔
618
            cumsizes = np.cumsum(sizes[::-1])
2✔
619
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
2✔
620
                                         self._subset_plots[::-1]):
2✔
621
                out[plot['id']] = gridspec[start:stop, -n_inters:]
2✔
UNCOV
622
        else:
×
623
            out = {'matrix': gridspec[-n_inters:, :n_cats],
2✔
624
                   'shading': gridspec[:, :n_cats],
2✔
625
                   'totals': gridspec[:self._totals_plot_elements, :n_cats],
2✔
626
                   'gs': gridspec}
2✔
627
            cumsizes = np.cumsum(sizes)
2✔
628
            for start, stop, plot in zip(np.hstack([[0], cumsizes]), cumsizes,
2✔
629
                                         self._subset_plots):
2✔
630
                out[plot['id']] = \
1✔
631
                    gridspec[-n_inters:, start + n_cats:stop + n_cats]
2✔
632
        return out
2✔
633

634
    def plot_matrix(self, ax):
2✔
635
        """Plot the matrix of intersection indicators onto ax
636
        """
637
        ax = self._reorient(ax)
2✔
638
        data = self.intersections
2✔
639
        n_cats = data.index.nlevels
2✔
640

641
        inclusion = data.index.to_frame().values
2✔
642

643
        # Prepare styling
644
        styles = [
1✔
645
            [
2✔
646
                self.subset_styles[i]
2✔
647
                if inclusion[i, j]
1✔
648
                else {"facecolor": self._other_dots_color, "linewidth": 0}
1✔
649
                for j in range(n_cats)
2✔
650
            ]
651
            for i in range(len(data))
2✔
652
        ]
653
        styles = sum(styles, [])  # flatten nested list
2✔
654
        style_columns = {"facecolor": "facecolors",
2✔
655
                         "edgecolor": "edgecolors",
2✔
656
                         "linewidth": "linewidths",
2✔
657
                         "linestyle": "linestyles",
2✔
658
                         "hatch": "hatch"}
2✔
659
        styles = (pd.DataFrame(styles)
2✔
660
                  .reindex(columns=style_columns.keys())
2✔
661
                  .astype({"facecolor": 'O',
2✔
662
                           "edgecolor": 'O', "linewidth": float, "linestyle": 'O', "hatch": 'O'}))
2✔
663
        styles["linewidth"].fillna(1, inplace=True)
2✔
664
        styles["facecolor"].fillna(self._facecolor, inplace=True)
2✔
665
        styles["edgecolor"].fillna(styles["facecolor"], inplace=True)
2✔
666
        styles["linestyle"].fillna("solid", inplace=True)
2✔
667
        del styles["hatch"]  # not supported in matrix (currently)
2✔
668

669
        x = np.repeat(np.arange(len(data)), n_cats)
2✔
670
        y = np.tile(np.arange(n_cats), len(data))
2✔
671

672
        # Plot dots
673
        if self._element_size is not None:
2✔
674
            s = (self._element_size * .35) ** 2
2✔
UNCOV
675
        else:
×
676
            # TODO: make s relative to colw
677
            s = 200
2✔
678
        ax.scatter(*self._swapaxes(x, y), s=s, zorder=10,
2✔
679
                   **styles.rename(columns=style_columns))
2✔
680

681
        # Plot lines
682
        if self._with_lines:
2✔
683
            idx = np.flatnonzero(inclusion)
2✔
684
            line_data = (pd.Series(y[idx], index=x[idx])
2✔
685
                         .groupby(level=0)
2✔
686
                         .aggregate(['min', 'max']))
2✔
687
            colors = pd.Series([
2✔
688
                style.get("edgecolor", style.get("facecolor", self._facecolor))
2✔
689
                for style in self.subset_styles],
2✔
690
                name="color")
2✔
691
            line_data = line_data.join(colors)
2✔
692
            ax.vlines(line_data.index.values,
2✔
693
                      line_data['min'], line_data['max'],
2✔
694
                      lw=2, colors=line_data["color"],
2✔
695
                      zorder=5)
2✔
696

697
        # Ticks and axes
698
        tick_axis = ax.yaxis
2✔
699
        tick_axis.set_ticks(np.arange(n_cats))
2✔
700
        tick_axis.set_ticklabels(data.index.names,
2✔
701
                                 rotation=0 if self._horizontal else -90)
2✔
702
        ax.xaxis.set_visible(False)
2✔
703
        ax.tick_params(axis='both', which='both', length=0)
2✔
704
        if not self._horizontal:
2✔
705
            ax.yaxis.set_ticks_position('top')
2✔
706
        ax.set_frame_on(False)
2✔
707
        ax.set_xlim(-.5, x[-1] + .5, auto=False)
2✔
708
        ax.grid(False)
2✔
709

710
    def plot_intersections(self, ax):
2✔
711
        """Plot bars indicating intersection size
712
        """
713
        rects = self._plot_bars(ax, self.intersections,
2✔
714
                                title='Intersection size',
2✔
715
                                colors=self._facecolor)
2✔
716
        for style, rect in zip(self.subset_styles, rects):
2✔
717
            style = style.copy()
2✔
718
            style.setdefault("edgecolor",
2✔
719
                             style.get("facecolor", self._facecolor))
2✔
720
            for attr, val in style.items():
2✔
721
                getattr(rect, "set_" + attr)(val)
2✔
722

723
        if self.subset_legend:
2✔
UNCOV
724
            styles, labels = zip(*self.subset_legend)
×
UNCOV
725
            styles = [patches.Patch(**patch_style) for patch_style in styles]
×
UNCOV
726
            ax.legend(styles, labels)
×
727

728
    def _label_sizes(self, ax, rects, where):
2✔
729
        if not self._show_counts and not self._show_percentages:
2✔
730
            return
2✔
731
        if self._show_counts is True:
2✔
732
            count_fmt = "{:.0f}"
2✔
UNCOV
733
        else:
×
734
            count_fmt = self._show_counts
2✔
735
            if '{' not in count_fmt:
2✔
736
                count_fmt = util.to_new_pos_format(count_fmt)
2✔
737

738
        if self._show_percentages is True:
2✔
739
            pct_fmt = "{:.1%}"
2✔
UNCOV
740
        else:
×
741
            pct_fmt = self._show_percentages
2✔
742

743
        if count_fmt and pct_fmt:
2✔
744
            if where == 'top':
2✔
745
                fmt = '%s\n(%s)' % (count_fmt, pct_fmt)
2✔
UNCOV
746
            else:
×
747
                fmt = '%s (%s)' % (count_fmt, pct_fmt)
2✔
748

749
            def make_args(val):
2✔
750
                return val, val / self.total
2✔
751
        elif count_fmt:
2✔
752
            fmt = count_fmt
2✔
753

754
            def make_args(val):
2✔
755
                return val,
2✔
UNCOV
756
        else:
×
757
            fmt = pct_fmt
2✔
758

759
            def make_args(val):
2✔
760
                return val / self.total,
2✔
761

762
        if where == 'right':
2✔
763
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
2✔
764
            for rect in rects:
2✔
765
                width = rect.get_width() + rect.get_x()
2✔
766
                ax.text(width + margin,
2✔
767
                        rect.get_y() + rect.get_height() * .5,
2✔
768
                        fmt.format(*make_args(width)),
2✔
769
                        ha='left', va='center')
2✔
770
        elif where == 'left':
2✔
771
            margin = 0.01 * abs(np.diff(ax.get_xlim()))
2✔
772
            for rect in rects:
2✔
773
                width = rect.get_width() + rect.get_x()
2✔
774
                ax.text(width + margin,
2✔
775
                        rect.get_y() + rect.get_height() * .5,
2✔
776
                        fmt.format(*make_args(width)),
2✔
777
                        ha='right', va='center')
2✔
778
        elif where == 'top':
2✔
779
            margin = 0.01 * abs(np.diff(ax.get_ylim()))
2✔
780
            for rect in rects:
2✔
781
                height = rect.get_height() + rect.get_y()
2✔
782
                ax.text(rect.get_x() + rect.get_width() * .5,
2✔
783
                        height + margin,
2✔
784
                        fmt.format(*make_args(height)),
2✔
785
                        ha='center', va='bottom')
2✔
UNCOV
786
        else:
×
787
            raise NotImplementedError('unhandled where: %r' % where)
×
788

789
    def plot_totals(self, ax):
2✔
790
        """Plot bars indicating total set size
791
        """
792
        orig_ax = ax
2✔
793
        ax = self._reorient(ax)
2✔
794
        rects = ax.barh(np.arange(len(self.totals.index.values)), self.totals,
2✔
795
                        .5, color=self._facecolor, align='center')
2✔
796
        self._label_sizes(ax, rects, 'left' if self._horizontal else 'top')
2✔
797

798
        max_total = self.totals.max()
2✔
799
        if self._horizontal:
2✔
800
            orig_ax.set_xlim(max_total, 0)
2✔
801
        for x in ['top', 'left', 'right']:
2✔
802
            ax.spines[self._reorient(x)].set_visible(False)
2✔
803
        ax.yaxis.set_visible(False)
2✔
804
        ax.xaxis.grid(True)
2✔
805
        ax.yaxis.grid(False)
2✔
806
        ax.patch.set_visible(False)
2✔
807

808
    def plot_shading(self, ax):
2✔
809
        # alternating row shading (XXX: use add_patch(Rectangle)?)
810
        for i in range(0, len(self.totals), 2):
2✔
811
            rect = plt.Rectangle(self._swapaxes(0, i - .4),
2✔
812
                                 *self._swapaxes(*(1, .8)),
2✔
813
                                 facecolor=self._shading_color, lw=0, zorder=0)
2✔
814
            ax.add_patch(rect)
2✔
815
        ax.set_frame_on(False)
2✔
816
        ax.tick_params(
2✔
817
            axis='both',
2✔
818
            which='both',
2✔
819
            left=False,
2✔
820
            right=False,
2✔
821
            bottom=False,
2✔
822
            top=False,
2✔
823
            labelbottom=False,
2✔
824
            labelleft=False)
2✔
825
        ax.grid(False)
2✔
826
        ax.set_xticks([])
2✔
827
        ax.set_yticks([])
2✔
828
        ax.set_xticklabels([])
2✔
829
        ax.set_yticklabels([])
2✔
830

831
    def plot(self, fig=None):
2✔
832
        """Draw all parts of the plot onto fig or a new figure
833

834
        Parameters
835
        ----------
836
        fig : matplotlib.figure.Figure, optional
837
            Defaults to a new figure.
838

839
        Returns
840
        -------
841
        subplots : dict of matplotlib.axes.Axes
842
            Keys are 'matrix', 'intersections', 'totals', 'shading'
843
        """
844
        if fig is None:
2✔
845
            fig = plt.figure(figsize=self._default_figsize)
2✔
846
        specs = self.make_grid(fig)
2✔
847
        shading_ax = fig.add_subplot(specs['shading'])
2✔
848
        self.plot_shading(shading_ax)
2✔
849
        matrix_ax = self._reorient(fig.add_subplot)(specs['matrix'],
2✔
850
                                                    sharey=shading_ax)
2✔
851
        self.plot_matrix(matrix_ax)
2✔
852
        totals_ax = self._reorient(fig.add_subplot)(specs['totals'],
2✔
853
                                                    sharey=matrix_ax)
2✔
854
        self.plot_totals(totals_ax)
2✔
855
        out = {'matrix': matrix_ax,
2✔
856
               'shading': shading_ax,
2✔
857
               'totals': totals_ax}
2✔
858

859
        for plot in self._subset_plots:
2✔
860
            ax = self._reorient(fig.add_subplot)(specs[plot['id']],
2✔
861
                                                 sharex=matrix_ax)
2✔
862
            if plot['type'] == 'default':
2✔
863
                self.plot_intersections(ax)
2✔
864
            elif plot['type'] in self.PLOT_TYPES:
2✔
865
                kw = plot.copy()
2✔
866
                del kw['type']
2✔
867
                del kw['elements']
2✔
868
                del kw['id']
2✔
869
                self.PLOT_TYPES[plot['type']](self, ax, **kw)
2✔
UNCOV
870
            else:
×
871
                raise ValueError('Unknown subset plot type: %r' % plot['type'])
×
872
            out[plot['id']] = ax
2✔
873
        return out
2✔
874

875
    PLOT_TYPES = {
1✔
876
        'catplot': _plot_catplot,
2✔
877
        'stacked_bars': _plot_stacked_bars,
2✔
878
    }
879

880
    def _repr_html_(self):
2✔
881
        fig = plt.figure(figsize=self._default_figsize)
×
882
        self.plot(fig=fig)
×
883
        return fig._repr_html_()
×
884

885

886
def plot(data, fig=None, **kwargs):
2✔
887
    """Make an UpSet plot of data on fig
888

889
    Parameters
890
    ----------
891
    data : pandas.Series or pandas.DataFrame
892
        Values for each set to plot.
893
        Should have multi-index where each level is binary,
894
        corresponding to set membership.
895
        If a DataFrame, `sum_over` must be a string or False.
896
    fig : matplotlib.figure.Figure, optional
897
        Defaults to a new figure.
898
    kwargs
899
        Other arguments for :class:`UpSet`
900

901
    Returns
902
    -------
903
    subplots : dict of matplotlib.axes.Axes
904
        Keys are 'matrix', 'intersections', 'totals', 'shading'
905
    """
906
    return UpSet(data, **kwargs).plot(fig)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc