• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7344254321

28 Dec 2023 03:58AM UTC coverage: 98.586% (+15.0%) from 83.549%
7344254321

push

github

web-flow
Format with black/ruff (#240)

844 of 848 new or added lines in 8 files covered. (99.53%)

4 existing lines in 3 files now uncovered.

1534 of 1556 relevant lines covered (98.59%)

0.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.48
/upsetplot/reformat.py
1
from __future__ import print_function, division, absolute_import
2✔
2

3
try:
2✔
4
    import typing
2✔
5
except ImportError:
×
6
    import collections as typing
×
7

8
import numpy as np
2✔
9
import pandas as pd
2✔
10

11

12
def _aggregate_data(df, subset_size, sum_over):
2✔
13
    """
14
    Returns
15
    -------
16
    df : DataFrame
17
        full data frame
18
    aggregated : Series
19
        aggregates
20
    """
21
    _SUBSET_SIZE_VALUES = ["auto", "count", "sum"]
2✔
22
    if subset_size not in _SUBSET_SIZE_VALUES:
2✔
23
        raise ValueError(
24
            "subset_size should be one of %s. Got %r"
25
            % (_SUBSET_SIZE_VALUES, subset_size)
26
        )
27
    if df.ndim == 1:
2✔
28
        # Series
29
        input_name = df.name
2✔
30
        df = pd.DataFrame({"_value": df})
2✔
31

32
        if subset_size == "auto" and not df.index.is_unique:
2✔
33
            raise ValueError(
2✔
34
                'subset_size="auto" cannot be used for a '
2✔
35
                "Series with non-unique groups."
36
            )
37
        if sum_over is not None:
2✔
38
            raise ValueError("sum_over is not applicable when the input is a " "Series")
2✔
39
        if subset_size == "count":
2✔
40
            sum_over = False
2✔
41
        else:
×
42
            sum_over = "_value"
2✔
43
    else:
×
44
        # DataFrame
45
        if sum_over is False:
2✔
46
            raise ValueError("Unsupported value for sum_over: False")
2✔
47
        elif subset_size == "auto" and sum_over is None:
2✔
48
            sum_over = False
2✔
49
        elif subset_size == "count":
2✔
50
            if sum_over is not None:
2✔
51
                raise ValueError(
2✔
52
                    "sum_over cannot be set if subset_size=%r" % subset_size
2✔
53
                )
54
            sum_over = False
2✔
55
        elif subset_size == "sum":
2✔
56
            if sum_over is None:
2✔
57
                raise ValueError(
2✔
58
                    "sum_over should be a field name if "
2✔
59
                    'subset_size="sum" and a DataFrame is '
60
                    "provided."
61
                )
62

63
    gb = df.groupby(level=list(range(df.index.nlevels)), sort=False)
2✔
64
    if sum_over is False:
2✔
65
        aggregated = gb.size()
2✔
66
        aggregated.name = "size"
2✔
67
    elif hasattr(sum_over, "lower"):
2✔
68
        aggregated = gb[sum_over].sum()
2✔
69
    else:
×
NEW
70
        raise ValueError("Unsupported value for sum_over: %r" % sum_over)
×
71

72
    if aggregated.name == "_value":
2✔
73
        aggregated.name = input_name
2✔
74

75
    return df, aggregated
2✔
76

77

78
def _check_index(df):
2✔
79
    # check all indices are boolean
80
    if not all(set([True, False]) >= set(level) for level in df.index.levels):
2✔
81
        raise ValueError(
2✔
82
            "The DataFrame has values in its index that are not " "boolean"
2✔
83
        )
84
    df = df.copy(deep=False)
2✔
85
    # XXX: this may break if input is not MultiIndex
86
    kw = {
1✔
87
        "levels": [x.astype(bool) for x in df.index.levels],
2✔
88
        "names": df.index.names,
2✔
89
    }
90
    if hasattr(df.index, "codes"):
2✔
91
        # compat for pandas <= 0.20
92
        kw["codes"] = df.index.codes
2✔
93
    else:
×
NEW
94
        kw["labels"] = df.index.labels
×
95
    df.index = pd.MultiIndex(**kw)
2✔
96
    return df
2✔
97

98

99
def _scalar_to_list(val):
2✔
100
    if not isinstance(val, (typing.Sequence, set)) or isinstance(val, str):
2✔
101
        val = [val]
2✔
102
    return val
2✔
103

104

105
def _get_subset_mask(
2✔
106
    agg, min_subset_size, max_subset_size, min_degree, max_degree, present, absent
107
):
108
    """Get a mask over subsets based on size, degree or category presence"""
109
    subset_mask = True
2✔
110
    if min_subset_size is not None:
2✔
111
        subset_mask = np.logical_and(subset_mask, agg >= min_subset_size)
2✔
112
    if max_subset_size is not None:
2✔
113
        subset_mask = np.logical_and(subset_mask, agg <= max_subset_size)
2✔
114
    if (min_degree is not None and min_degree >= 0) or max_degree is not None:
2✔
115
        degree = agg.index.to_frame().sum(axis=1)
2✔
116
        if min_degree is not None:
2✔
117
            subset_mask = np.logical_and(subset_mask, degree >= min_degree)
2✔
118
        if max_degree is not None:
2✔
119
            subset_mask = np.logical_and(subset_mask, degree <= max_degree)
2✔
120
    if present is not None:
2✔
121
        for col in _scalar_to_list(present):
2✔
122
            subset_mask = np.logical_and(
2✔
123
                subset_mask, agg.index.get_level_values(col).values
2✔
124
            )
125
    if absent is not None:
2✔
126
        for col in _scalar_to_list(absent):
2✔
127
            exclude_mask = np.logical_not(agg.index.get_level_values(col).values)
2✔
128
            subset_mask = np.logical_and(subset_mask, exclude_mask)
2✔
129
    return subset_mask
2✔
130

131

132
def _filter_subsets(
2✔
133
    df, agg, min_subset_size, max_subset_size, min_degree, max_degree, present, absent
134
):
135
    subset_mask = _get_subset_mask(
2✔
136
        agg,
2✔
137
        min_subset_size=min_subset_size,
2✔
138
        max_subset_size=max_subset_size,
2✔
139
        min_degree=min_degree,
2✔
140
        max_degree=max_degree,
2✔
141
        present=present,
2✔
142
        absent=absent,
2✔
143
    )
144

145
    if subset_mask is True:
2✔
146
        return df, agg
2✔
147

148
    agg = agg[subset_mask]
2✔
149
    df = df[df.index.isin(agg.index)]
2✔
150
    return df, agg
2✔
151

152

153
class QueryResult:
2✔
154
    """Container for reformatted data and aggregates
155

156
    Attributes
157
    ----------
158
    data : DataFrame
159
        Selected samples. The index is a MultiIndex with one boolean level for
160
        each category.
161
    subsets : dict[frozenset, DataFrame]
162
        Dataframes for each intersection of categories.
163
    subset_sizes : Series
164
        Total size of each selected subset as a series. The index is as
165
        for `data`.
166
    category_totals : Series
167
        Total size of each category, regardless of selection.
168
    """
169

170
    def __init__(self, data, subset_sizes, category_totals):
2✔
171
        self.data = data
2✔
172
        self.subset_sizes = subset_sizes
2✔
173
        self.category_totals = category_totals
2✔
174

175
    def __repr__(self):
2✔
176
        return (
177
            "QueryResult(data={data}, subset_sizes={subset_sizes}, "
178
            "category_totals={category_totals}".format(**vars(self))
179
        )
180

181
    @property
2✔
182
    def subsets(self):
1✔
183
        categories = np.asarray(self.data.index.names)
2✔
184
        return {
2✔
185
            frozenset(categories.take(mask)): subset_data
1✔
186
            for mask, subset_data in self.data.groupby(
2✔
187
                level=list(range(len(categories))), sort=False
2✔
188
            )
189
        }
190

191

192
def query(
1✔
193
    data,
194
    present=None,
1✔
195
    absent=None,
1✔
196
    min_subset_size=None,
1✔
197
    max_subset_size=None,
1✔
198
    min_degree=None,
1✔
199
    max_degree=None,
1✔
200
    sort_by="degree",
1✔
201
    sort_categories_by="cardinality",
1✔
202
    subset_size="auto",
1✔
203
    sum_over=None,
1✔
204
    include_empty_subsets=False,
2✔
205
):
206
    """Transform and filter a categorised dataset
207

208
    Retrieve the set of items and totals corresponding to subsets of interest.
209

210
    Parameters
211
    ----------
212
    data : pandas.Series or pandas.DataFrame
213
        Elements associated with categories (a DataFrame), or the size of each
214
        subset of categories (a Series).
215
        Should have MultiIndex where each level is binary,
216
        corresponding to category membership.
217
        If a DataFrame, `sum_over` must be a string or False.
218
    present : str or list of str, optional
219
        Category or categories that must be present in subsets for styling.
220
    absent : str or list of str, optional
221
        Category or categories that must not be present in subsets for
222
        styling.
223
    min_subset_size : int, optional
224
        Minimum size of a subset to be reported. All subsets with
225
        a size smaller than this threshold will be omitted from
226
        category_totals and data.
227
        Size may be a sum of values, see `subset_size`.
228
    max_subset_size : int, optional
229
        Maximum size of a subset to be reported.
230
    min_degree : int, optional
231
        Minimum degree of a subset to be reported.
232
    max_degree : int, optional
233
        Maximum degree of a subset to be reported.
234
    sort_by : {'cardinality', 'degree', '-cardinality', '-degree',
235
               'input', '-input'}
236
        If 'cardinality', subset are listed from largest to smallest.
237
        If 'degree', they are listed in order of the number of categories
238
        intersected. If 'input', the order they appear in the data input is
239
        used.
240
        Prefix with '-' to reverse the ordering.
241

242
        Note this affects ``subset_sizes`` but not ``data``.
243
    sort_categories_by : {'cardinality', '-cardinality', 'input', '-input'}
244
        Whether to sort the categories by total cardinality, or leave them
245
        in the input data's provided order (order of index levels).
246
        Prefix with '-' to reverse the ordering.
247
    subset_size : {'auto', 'count', 'sum'}
248
        Configures how to calculate the size of a subset. Choices are:
249

250
        'auto' (default)
251
            If `data` is a DataFrame, count the number of rows in each group,
252
            unless `sum_over` is specified.
253
            If `data` is a Series with at most one row for each group, use
254
            the value of the Series. If `data` is a Series with more than one
255
            row per group, raise a ValueError.
256
        'count'
257
            Count the number of rows in each group.
258
        'sum'
259
            Sum the value of the `data` Series, or the DataFrame field
260
            specified by `sum_over`.
261
    sum_over : str or None
262
        If `subset_size='sum'` or `'auto'`, then the intersection size is the
263
        sum of the specified field in the `data` DataFrame. If a Series, only
264
        None is supported and its value is summed.
265
    include_empty_subsets : bool (default=False)
266
        If True, all possible category combinations will be returned in
267
        subset_sizes, even when some are not present in data.
268

269
    Returns
270
    -------
271
    QueryResult
272
        Including filtered ``data``, filtered and sorted ``subset_sizes`` and
273
        overall ``category_totals``.
274

275
    Examples
276
    --------
277
    >>> from upsetplot import query, generate_samples
278
    >>> data = generate_samples(n_samples=20)
279
    >>> result = query(data, present="cat1", max_subset_size=4)
280
    >>> result.category_totals
281
    cat1    14
282
    cat2     4
283
    cat0     0
284
    dtype: int64
285
    >>> result.subset_sizes
286
    cat1  cat2  cat0
287
    True  True  False    3
288
    Name: size, dtype: int64
289
    >>> result.data
290
                     index     value
291
    cat1 cat2 cat0
292
    True True False      0  2.04...
293
              False      2  2.05...
294
              False     10  2.55...
295
    >>>
296
    >>> # Sorting:
297
    >>> query(data, min_degree=1, sort_by="degree").subset_sizes
298
    cat1   cat2   cat0
299
    True   False  False    11
300
    False  True   False     1
301
    True   True   False     3
302
    Name: size, dtype: int64
303
    >>> query(data, min_degree=1, sort_by="cardinality").subset_sizes
304
    cat1   cat2   cat0
305
    True   False  False    11
306
           True   False     3
307
    False  True   False     1
308
    Name: size, dtype: int64
309
    >>>
310
    >>> # Getting each subset's data
311
    >>> result = query(data)
312
    >>> result.subsets[frozenset({"cat1", "cat2"})]
313
                index     value
314
    cat1  cat2 cat0
315
    False True False      3  1.333795
316
    >>> result.subsets[frozenset({"cat1"})]
317
                        index     value
318
    cat1  cat2  cat0
319
    False False False      5  0.918174
320
                False      8  1.948521
321
                False      9  1.086599
322
                False     13  1.105696
323
                False     19  1.339895
324
    """
325

326
    data, agg = _aggregate_data(data, subset_size, sum_over)
2✔
327
    data = _check_index(data)
2✔
328
    totals = [
1✔
329
        agg[agg.index.get_level_values(name).values.astype(bool)].sum()
2✔
330
        for name in agg.index.names
2✔
331
    ]
332
    totals = pd.Series(totals, index=agg.index.names)
2✔
333

334
    if include_empty_subsets:
2✔
335
        nlevels = len(agg.index.levels)
2✔
336
        if nlevels > 10:
2✔
337
            raise ValueError(
338
                "include_empty_subsets is supported for at most 10 categories"
339
            )
340
        new_agg = pd.Series(
2✔
341
            0,
2✔
342
            index=pd.MultiIndex.from_product(
2✔
343
                [[False, True]] * nlevels, names=agg.index.names
2✔
344
            ),
345
            dtype=agg.dtype,
2✔
346
            name=agg.name,
2✔
347
        )
348
        new_agg.update(agg)
2✔
349
        agg = new_agg
2✔
350

351
    data, agg = _filter_subsets(
2✔
352
        data,
2✔
353
        agg,
2✔
354
        min_subset_size=min_subset_size,
2✔
355
        max_subset_size=max_subset_size,
2✔
356
        min_degree=min_degree,
2✔
357
        max_degree=max_degree,
2✔
358
        present=present,
2✔
359
        absent=absent,
2✔
360
    )
361

362
    # sort:
363
    if sort_categories_by in ("cardinality", "-cardinality"):
2✔
364
        totals.sort_values(ascending=sort_categories_by[:1] == "-", inplace=True)
2✔
365
    elif sort_categories_by == "-input":
2✔
366
        totals = totals[::-1]
2✔
367
    elif sort_categories_by in (None, "input"):
2✔
368
        pass
2✔
369
    else:
370
        raise ValueError("Unknown sort_categories_by: %r" % sort_categories_by)
2✔
371
    data = data.reorder_levels(totals.index.values)
2✔
372
    agg = agg.reorder_levels(totals.index.values)
2✔
373

374
    if sort_by in ("cardinality", "-cardinality"):
2✔
375
        agg = agg.sort_values(ascending=sort_by[:1] == "-")
2✔
376
    elif sort_by in ("degree", "-degree"):
2✔
377
        index_tuples = sorted(
2✔
378
            agg.index,
2✔
379
            key=lambda x: (sum(x),) + tuple(reversed(x)),
2✔
380
            reverse=sort_by[:1] == "-",
2✔
381
        )
382
        agg = agg.reindex(
2✔
383
            pd.MultiIndex.from_tuples(index_tuples, names=agg.index.names)
2✔
384
        )
385
    elif sort_by == "-input":
2✔
386
        print("<", agg)
2✔
387
        agg = agg[::-1]
2✔
388
        print(">", agg)
2✔
389
    elif sort_by in (None, "input"):
2✔
390
        pass
2✔
391
    else:
392
        raise ValueError("Unknown sort_by: %r" % sort_by)
2✔
393

394
    return QueryResult(data=data, subset_sizes=agg, category_totals=totals)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc