• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 504

pending completion
504

push

travis-ci-com

web-flow
Add include_empty_subsets (#203)

17 of 17 new or added lines in 3 files covered. (100.0%)

1171 of 1201 relevant lines covered (97.5%)

2.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.21
/upsetplot/reformat.py
1
from __future__ import print_function, division, absolute_import
3✔
2

3
try:
3✔
4
    import typing
3✔
5
except ImportError:
×
6
    import collections as typing
×
7

8
import numpy as np
3✔
9
import pandas as pd
3✔
10

11

12
def _aggregate_data(df, subset_size, sum_over):
3✔
13
    """
14
    Returns
15
    -------
16
    df : DataFrame
17
        full data frame
18
    aggregated : Series
19
        aggregates
20
    """
21
    _SUBSET_SIZE_VALUES = ['auto', 'count', 'sum']
3✔
22
    if subset_size not in _SUBSET_SIZE_VALUES:
3✔
23
        raise ValueError('subset_size should be one of %s. Got %r'
×
24
                         % (_SUBSET_SIZE_VALUES, subset_size))
25
    if df.ndim == 1:
3✔
26
        # Series
27
        input_name = df.name
3✔
28
        df = pd.DataFrame({'_value': df})
3✔
29

30
        if subset_size == 'auto' and not df.index.is_unique:
3✔
31
            raise ValueError('subset_size="auto" cannot be used for a '
3✔
32
                             'Series with non-unique groups.')
33
        if sum_over is not None:
3✔
34
            raise ValueError('sum_over is not applicable when the input is a '
3✔
35
                             'Series')
36
        if subset_size == 'count':
3✔
37
            sum_over = False
3✔
38
        else:
39
            sum_over = '_value'
3✔
40
    else:
41
        # DataFrame
42
        if sum_over is False:
3✔
43
            raise ValueError('Unsupported value for sum_over: False')
3✔
44
        elif subset_size == 'auto' and sum_over is None:
3✔
45
            sum_over = False
3✔
46
        elif subset_size == 'count':
3✔
47
            if sum_over is not None:
3✔
48
                raise ValueError('sum_over cannot be set if subset_size=%r' %
3✔
49
                                 subset_size)
50
            sum_over = False
3✔
51
        elif subset_size == 'sum':
3✔
52
            if sum_over is None:
3✔
53
                raise ValueError('sum_over should be a field name if '
3✔
54
                                 'subset_size="sum" and a DataFrame is '
55
                                 'provided.')
56

57
    gb = df.groupby(level=list(range(df.index.nlevels)), sort=False)
3✔
58
    if sum_over is False:
3✔
59
        aggregated = gb.size()
3✔
60
        aggregated.name = 'size'
3✔
61
    elif hasattr(sum_over, 'lower'):
3✔
62
        aggregated = gb[sum_over].sum()
3✔
63
    else:
64
        raise ValueError('Unsupported value for sum_over: %r' % sum_over)
×
65

66
    if aggregated.name == '_value':
3✔
67
        aggregated.name = input_name
3✔
68

69
    return df, aggregated
3✔
70

71

72
def _check_index(df):
3✔
73
    # check all indices are boolean
74
    if not all(set([True, False]) >= set(level)
3✔
75
               for level in df.index.levels):
76
        raise ValueError('The DataFrame has values in its index that are not '
3✔
77
                         'boolean')
78
    df = df.copy(deep=False)
3✔
79
    # XXX: this may break if input is not MultiIndex
80
    kw = {'levels': [x.astype(bool) for x in df.index.levels],
3✔
81
          'names': df.index.names,
82
          }
83
    if hasattr(df.index, 'codes'):
3✔
84
        # compat for pandas <= 0.20
85
        kw['codes'] = df.index.codes
3✔
86
    else:
87
        kw['labels'] = df.index.labels
×
88
    df.index = pd.MultiIndex(**kw)
3✔
89
    return df
3✔
90

91

92
def _scalar_to_list(val):
3✔
93
    if not isinstance(val, (typing.Sequence, set)) or isinstance(val, str):
3✔
94
        val = [val]
3✔
95
    return val
3✔
96

97

98
def _get_subset_mask(agg, min_subset_size, max_subset_size,
3✔
99
                     min_degree, max_degree,
100
                     present, absent):
101
    """Get a mask over subsets based on size, degree or category presence"""
102
    subset_mask = True
3✔
103
    if min_subset_size is not None:
3✔
104
        subset_mask = np.logical_and(subset_mask, agg >= min_subset_size)
3✔
105
    if max_subset_size is not None:
3✔
106
        subset_mask = np.logical_and(subset_mask, agg <= max_subset_size)
3✔
107
    if (min_degree is not None and min_degree >= 0) or max_degree is not None:
3✔
108
        degree = agg.index.to_frame().sum(axis=1)
3✔
109
        if min_degree is not None:
3✔
110
            subset_mask = np.logical_and(subset_mask, degree >= min_degree)
3✔
111
        if max_degree is not None:
3✔
112
            subset_mask = np.logical_and(subset_mask, degree <= max_degree)
3✔
113
    if present is not None:
3✔
114
        for col in _scalar_to_list(present):
3✔
115
            subset_mask = np.logical_and(
3✔
116
                subset_mask,
117
                agg.index.get_level_values(col).values)
118
    if absent is not None:
3✔
119
        for col in _scalar_to_list(absent):
3✔
120
            exclude_mask = np.logical_not(
3✔
121
                agg.index.get_level_values(col).values)
122
            subset_mask = np.logical_and(subset_mask, exclude_mask)
3✔
123
    return subset_mask
3✔
124

125

126
def _filter_subsets(df, agg,
3✔
127
                    min_subset_size, max_subset_size,
128
                    min_degree, max_degree,
129
                    present, absent):
130
    subset_mask = _get_subset_mask(agg,
3✔
131
                                   min_subset_size=min_subset_size,
132
                                   max_subset_size=max_subset_size,
133
                                   min_degree=min_degree,
134
                                   max_degree=max_degree,
135
                                   present=present, absent=absent)
136

137
    if subset_mask is True:
3✔
138
        return df, agg
3✔
139

140
    agg = agg[subset_mask]
3✔
141
    df = df[df.index.isin(agg.index)]
3✔
142
    return df, agg
3✔
143

144

145
class QueryResult:
3✔
146
    """Container for reformatted data and aggregates
147

148
    Attributes
149
    ----------
150
    data : DataFrame
151
        Selected samples. The index is a MultiIndex with one boolean level for
152
        each category.
153
    subsets : dict[frozenset, DataFrame]
154
        Dataframes for each intersection of categories.
155
    subset_sizes : Series
156
        Total size of each selected subset as a series. The index is as
157
        for `data`.
158
    category_totals : Series
159
        Total size of each category, regardless of selection.
160
    """
161
    def __init__(self, data, subset_sizes, category_totals):
3✔
162
        self.data = data
3✔
163
        self.subset_sizes = subset_sizes
3✔
164
        self.category_totals = category_totals
3✔
165

166
    def __repr__(self):
3✔
167
        return ("QueryResult(data={data}, subset_sizes={subset_sizes}, "
×
168
                "category_totals={category_totals}".format(**vars(self)))
169

170
    @property
3✔
171
    def subsets(self):
2✔
172
        categories = np.asarray(self.data.index.names)
3✔
173
        return {
3✔
174
            frozenset(categories.take(mask)): subset_data
175
            for mask, subset_data
176
            in self.data.groupby(level=list(range(len(categories))),
177
                                 sort=False)
178
        }
179

180

181
def query(data, present=None, absent=None,
3✔
182
          min_subset_size=None, max_subset_size=None,
183
          min_degree=None, max_degree=None,
184
          sort_by='degree', sort_categories_by='cardinality',
185
          subset_size='auto', sum_over=None, include_empty_subsets=False):
186
    """Transform and filter a categorised dataset
187

188
    Retrieve the set of items and totals corresponding to subsets of interest.
189

190
    Parameters
191
    ----------
192
    data : pandas.Series or pandas.DataFrame
193
        Elements associated with categories (a DataFrame), or the size of each
194
        subset of categories (a Series).
195
        Should have MultiIndex where each level is binary,
196
        corresponding to category membership.
197
        If a DataFrame, `sum_over` must be a string or False.
198
    present : str or list of str, optional
199
        Category or categories that must be present in subsets for styling.
200
    absent : str or list of str, optional
201
        Category or categories that must not be present in subsets for
202
        styling.
203
    min_subset_size : int, optional
204
        Minimum size of a subset to be reported. All subsets with
205
        a size smaller than this threshold will be omitted from
206
        category_totals and data.
207
        Size may be a sum of values, see `subset_size`.
208
    max_subset_size : int, optional
209
        Maximum size of a subset to be reported.
210
    min_degree : int, optional
211
        Minimum degree of a subset to be reported.
212
    max_degree : int, optional
213
        Maximum degree of a subset to be reported.
214
    sort_by : {'cardinality', 'degree', None}
215
        If 'cardinality', subset are listed from largest to smallest.
216
        If 'degree', they are listed in order of the number of categories
217
        intersected. If None, the order they appear in the data input is
218
        used.
219

220
        Note this affects ``subset_sizes`` but not ``data``.
221
    sort_categories_by : {'cardinality', None}
222
        Whether to sort the categories by total cardinality, or leave them
223
        in the provided order.
224
    subset_size : {'auto', 'count', 'sum'}
225
        Configures how to calculate the size of a subset. Choices are:
226

227
        'auto' (default)
228
            If `data` is a DataFrame, count the number of rows in each group,
229
            unless `sum_over` is specified.
230
            If `data` is a Series with at most one row for each group, use
231
            the value of the Series. If `data` is a Series with more than one
232
            row per group, raise a ValueError.
233
        'count'
234
            Count the number of rows in each group.
235
        'sum'
236
            Sum the value of the `data` Series, or the DataFrame field
237
            specified by `sum_over`.
238
    sum_over : str or None
239
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
240
        sum of the specified field in the `data` DataFrame. If a Series, only
241
        None is supported and its value is summed.
242
    include_empty_subsets : bool (default=False)
243
        If True, all possible category combinations will be returned in
244
        subset_sizes, even when some are not present in data.
245

246
    Returns
247
    -------
248
    QueryResult
249
        Including filtered ``data``, filtered and sorted ``subset_sizes`` and
250
        overall ``category_totals``.
251

252
    Examples
253
    --------
254
    >>> from upsetplot import query, generate_samples
255
    >>> data = generate_samples(n_samples=20)
256
    >>> result = query(data, present="cat1", max_subset_size=4)
257
    >>> result.category_totals
258
    cat1    14
259
    cat2     4
260
    cat0     0
261
    dtype: int64
262
    >>> result.subset_sizes
263
    cat1  cat2  cat0
264
    True  True  False    3
265
    Name: size, dtype: int64
266
    >>> result.data
267
                     index     value
268
    cat1 cat2 cat0
269
    True True False      0  2.04...
270
              False      2  2.05...
271
              False     10  2.55...
272
    >>>
273
    >>> # Sorting:
274
    >>> query(data, min_degree=1, sort_by="degree").subset_sizes
275
    cat1   cat2   cat0
276
    True   False  False    11
277
    False  True   False     1
278
    True   True   False     3
279
    Name: size, dtype: int64
280
    >>> query(data, min_degree=1, sort_by="cardinality").subset_sizes
281
    cat1   cat2   cat0
282
    True   False  False    11
283
           True   False     3
284
    False  True   False     1
285
    Name: size, dtype: int64
286
    >>>
287
    >>> # Getting each subset's data
288
    >>> result = query(data)
289
    >>> result.subsets[frozenset({"cat1", "cat2"})]
290
                index     value
291
    cat1  cat2 cat0
292
    False True False      3  1.333795
293
    >>> result.subsets[frozenset({"cat1"})]
294
                        index     value
295
    cat1  cat2  cat0
296
    False False False      5  0.918174
297
                False      8  1.948521
298
                False      9  1.086599
299
                False     13  1.105696
300
                False     19  1.339895
301
    """
302

303
    data, agg = _aggregate_data(data, subset_size, sum_over)
3✔
304
    data = _check_index(data)
3✔
305
    totals = [agg[agg.index.get_level_values(name).values.astype(bool)].sum()
3✔
306
              for name in agg.index.names]
307
    totals = pd.Series(totals, index=agg.index.names)
3✔
308

309
    if include_empty_subsets:
3✔
310
        nlevels = len(agg.index.levels)
3✔
311
        if nlevels > 10:
3✔
312
            raise ValueError(
×
313
                "include_empty_subsets is supported for at most 10 categories")
314
        new_agg = pd.Series(0,
3✔
315
                            index=pd.MultiIndex.from_product(
316
                                [[False, True]] * nlevels,
317
                                names=agg.index.names),
318
                            dtype=agg.dtype,
319
                            name=agg.name)
320
        new_agg.update(agg)
3✔
321
        agg = new_agg
3✔
322

323
    data, agg = _filter_subsets(data, agg,
3✔
324
                                min_subset_size=min_subset_size,
325
                                max_subset_size=max_subset_size,
326
                                min_degree=min_degree,
327
                                max_degree=max_degree,
328
                                present=present, absent=absent)
329

330
    # sort:
331
    if sort_categories_by == 'cardinality':
3✔
332
        totals.sort_values(ascending=False, inplace=True)
3✔
333
    elif sort_categories_by is not None:
3✔
334
        raise ValueError('Unknown sort_categories_by: %r' % sort_categories_by)
3✔
335
    data = data.reorder_levels(totals.index.values)
3✔
336
    agg = agg.reorder_levels(totals.index.values)
3✔
337

338
    if sort_by == 'cardinality':
3✔
339
        agg = agg.sort_values(ascending=False)
3✔
340
    elif sort_by == 'degree':
3✔
341
        index_tuples = sorted(agg.index,
3✔
342
                              key=lambda x: (sum(x),) + tuple(reversed(x)))
343
        agg = agg.reindex(pd.MultiIndex.from_tuples(index_tuples,
3✔
344
                                                    names=agg.index.names))
345
    elif sort_by is None:
3✔
346
        pass
3✔
347
    else:
348
        raise ValueError('Unknown sort_by: %r' % sort_by)
3✔
349

350
    return QueryResult(data=data, subset_sizes=agg, category_totals=totals)
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc