• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jnothman / UpSetPlot / 7342943552

28 Dec 2023 12:13AM UTC coverage: 83.549% (-14.0%) from 97.551%
7342943552

push

github

web-flow
Fix warning due to styling dtyles, and fix column dtype test failure (#238)


Fixes #225

6 of 6 new or added lines in 2 files covered. (100.0%)

312 existing lines in 7 files now uncovered.

1681 of 2012 relevant lines covered (83.55%)

1.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.49
/upsetplot/tests/test_upsetplot.py
1
import io
2✔
2
import itertools
2✔
3

4
import pytest
2✔
5
from pandas.testing import (
2✔
6
    assert_series_equal, assert_frame_equal, assert_index_equal)
7
from numpy.testing import assert_array_equal
2✔
8
import pandas as pd
2✔
9
import numpy as np
2✔
10
import matplotlib.figure
2✔
11
import matplotlib.pyplot as plt
2✔
12
from matplotlib.text import Text
2✔
13
from matplotlib.colors import to_hex
2✔
14
from matplotlib import cm
2✔
15

16
from upsetplot import plot
2✔
17
from upsetplot import UpSet
2✔
18
from upsetplot import generate_counts, generate_samples
2✔
19
from upsetplot.plotting import _process_data
2✔
20

21
# TODO: warnings should raise errors
22

23

24
def is_ascending(seq):
2✔
25
    # return np.all(np.diff(seq) >= 0)
26
    return sorted(seq) == list(seq)
2✔
27

28

29
def get_all_texts(mpl_artist):
2✔
30
    out = [text.get_text() for text in mpl_artist.findobj(Text)]
2✔
31
    return [text for text in out if text]
2✔
32

33

34
@pytest.mark.parametrize('x', [
2✔
35
    generate_counts(),
2✔
36
    generate_counts().iloc[1:-2],
2✔
37
])
38
@pytest.mark.parametrize(
2✔
39
    'sort_by',
2✔
40
    ['cardinality', 'degree', '-cardinality', '-degree', None,
2✔
41
     'input', '-input'])
1✔
42
@pytest.mark.parametrize(
2✔
43
    'sort_categories_by',
2✔
44
    [None, 'input', '-input', 'cardinality', '-cardinality'])
2✔
45
def test_process_data_series(x, sort_by, sort_categories_by):
1✔
46
    assert x.name == 'value'
2✔
47
    for subset_size in ['auto', 'sum', 'count']:
2✔
48
        for sum_over in ['abc', False]:
2✔
49
            with pytest.raises(ValueError, match='sum_over is not applicable'):
2✔
50
                _process_data(x, sort_by=sort_by,
2✔
51
                              sort_categories_by=sort_categories_by,
2✔
52
                              subset_size=subset_size, sum_over=sum_over)
2✔
53

54
    # shuffle input to test sorting
55
    x = x.sample(frac=1., replace=False, random_state=0)
2✔
56

57
    total, df, intersections, totals = _process_data(
2✔
58
        x, subset_size='auto', sort_by=sort_by,
2✔
59
        sort_categories_by=sort_categories_by, sum_over=None)
2✔
60

61
    assert total == x.sum()
2✔
62

63
    assert intersections.name == 'value'
2✔
64
    x_reordered_levels = (x
2✔
65
                          .reorder_levels(intersections.index.names))
2✔
66
    x_reordered = (x_reordered_levels
2✔
67
                   .reindex(index=intersections.index))
2✔
68
    assert len(x) == len(x_reordered)
2✔
69
    assert x_reordered.index.is_unique
2✔
70
    assert_series_equal(x_reordered, intersections,
2✔
71
                        check_dtype=False)
2✔
72

73
    if sort_by == 'cardinality':
2✔
74
        assert is_ascending(intersections.values[::-1])
2✔
75
    elif sort_by == '-cardinality':
2✔
76
        assert is_ascending(intersections.values)
2✔
77
    elif sort_by == 'degree':
2✔
78
        # check degree order
79
        assert is_ascending(intersections.index.to_frame().sum(axis=1))
2✔
80
        # TODO: within a same-degree group, the tuple of active names should
81
        #       be in sort-order
82
    elif sort_by == '-degree':
2✔
83
        # check degree order
84
        assert is_ascending(intersections.index.to_frame().sum(axis=1)[::-1])
2✔
UNCOV
85
    else:
×
86
        find_first_in_orig = x_reordered_levels.index.tolist().index
2✔
87
        orig_order = [find_first_in_orig(key)
2✔
88
                      for key in intersections.index.tolist()]
2✔
89
        assert orig_order == sorted(
2✔
90
            orig_order,
91
            reverse=sort_by is not None and sort_by.startswith('-'))
1✔
92

93
    if sort_categories_by == 'cardinality':
2✔
94
        assert is_ascending(totals.values[::-1])
2✔
95
    elif sort_categories_by == '-cardinality':
2✔
96
        assert is_ascending(totals.values)
2✔
97

98
    assert np.all(totals.index.values == intersections.index.names)
2✔
99

100
    assert np.all(df.index.names == intersections.index.names)
2✔
101
    assert set(df.columns) == {'_value', '_bin'}
2✔
102
    assert_index_equal(df['_value'].reorder_levels(x.index.names).index,
2✔
103
                       x.index)
2✔
104
    assert_array_equal(df['_value'], x)
2✔
105
    assert_index_equal(intersections.iloc[df['_bin']].index,
2✔
106
                       df.index)
2✔
107
    assert len(df) == len(x)
2✔
108

109

110
@pytest.mark.parametrize('x', [
2✔
111
    generate_samples()['value'],
2✔
112
    generate_counts(),
2✔
113
])
114
def test_subset_size_series(x):
1✔
115
    kw = {'sort_by': 'cardinality',
2✔
116
          'sort_categories_by': 'cardinality',
2✔
117
          'sum_over': None}
2✔
118
    total, df_sum, intersections_sum, totals_sum = _process_data(
2✔
119
        x, subset_size='sum', **kw)
2✔
120
    assert total == intersections_sum.sum()
2✔
121

122
    if x.index.is_unique:
2✔
123
        total, df, intersections, totals = _process_data(
2✔
124
            x, subset_size='auto', **kw)
2✔
125
        assert total == intersections.sum()
2✔
126
        assert_frame_equal(df, df_sum)
2✔
127
        assert_series_equal(intersections, intersections_sum)
2✔
128
        assert_series_equal(totals, totals_sum)
2✔
UNCOV
129
    else:
×
130
        with pytest.raises(ValueError):
2✔
131
            _process_data(x, subset_size='auto', **kw)
2✔
132

133
    total, df_count, intersections_count, totals_count = _process_data(
2✔
134
        x, subset_size='count', **kw)
2✔
135
    assert total == intersections_count.sum()
2✔
136
    total, df, intersections, totals = _process_data(
2✔
137
        x.groupby(level=list(range(len(x.index.levels)))).count(),
2✔
138
        subset_size='sum', **kw)
2✔
139
    assert total == intersections.sum()
2✔
140
    assert_series_equal(intersections, intersections_count, check_names=False)
2✔
141
    assert_series_equal(totals, totals_count)
2✔
142

143

144
@pytest.mark.parametrize('x', [
2✔
145
    generate_samples()['value'],
2✔
146
])
147
@pytest.mark.parametrize('sort_by', ['cardinality', 'degree', None])
2✔
148
@pytest.mark.parametrize('sort_categories_by', [None, 'cardinality'])
2✔
149
def test_process_data_frame(x, sort_by, sort_categories_by):
1✔
150
    # shuffle input to test sorting
151
    x = x.sample(frac=1., replace=False, random_state=0)
2✔
152

153
    X = pd.DataFrame({'a': x})
2✔
154

155
    with pytest.warns(None):
2✔
156
        total, df, intersections, totals = _process_data(
2✔
157
            X, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
158
            sum_over='a', subset_size='auto')
2✔
159
    assert df is not X
2✔
160
    assert total == pytest.approx(intersections.sum())
2✔
161

162
    # check equivalence to Series
163
    total1, df1, intersections1, totals1 = _process_data(
2✔
164
        x, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
165
        subset_size='sum', sum_over=None)
2✔
166

167
    assert intersections.name == 'a'
2✔
168
    assert_frame_equal(df, df1.rename(columns={'_value': 'a'}))
2✔
169
    assert_series_equal(intersections, intersections1, check_names=False)
2✔
170
    assert_series_equal(totals, totals1)
2✔
171

172
    # check effect of extra column
173
    X = pd.DataFrame({'a': x, 'b': np.arange(len(x))})
2✔
174
    total2, df2, intersections2, totals2 = _process_data(
2✔
175
        X, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
176
        sum_over='a', subset_size='auto')
2✔
177
    assert total2 == pytest.approx(intersections2.sum())
2✔
178
    assert_series_equal(intersections, intersections2)
2✔
179
    assert_series_equal(totals, totals2)
2✔
180
    assert_frame_equal(df, df2.drop('b', axis=1))
2✔
181
    assert_array_equal(df2['b'], X['b'])  # disregard levels, tested above
2✔
182

183
    # check effect not dependent on order/name
184
    X = pd.DataFrame({'b': np.arange(len(x)), 'c': x})
2✔
185
    total3, df3, intersections3, totals3 = _process_data(
2✔
186
        X, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
187
        sum_over='c', subset_size='auto')
2✔
188
    assert total3 == pytest.approx(intersections3.sum())
2✔
189
    assert_series_equal(intersections, intersections3, check_names=False)
2✔
190
    assert intersections.name == 'a'
2✔
191
    assert intersections3.name == 'c'
2✔
192
    assert_series_equal(totals, totals3)
2✔
193
    assert_frame_equal(df.rename(columns={'a': 'c'}), df3.drop('b', axis=1))
2✔
194
    assert_array_equal(df3['b'], X['b'])
2✔
195

196
    # check subset_size='count'
197
    X = pd.DataFrame({'b': np.ones(len(x), dtype='int64'), 'c': x})
2✔
198

199
    total4, df4, intersections4, totals4 = _process_data(
2✔
200
        X, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
201
        sum_over='b', subset_size='auto')
2✔
202
    total5, df5, intersections5, totals5 = _process_data(
2✔
203
        X, sort_by=sort_by, sort_categories_by=sort_categories_by,
2✔
204
        subset_size='count', sum_over=None)
2✔
205
    assert total5 == pytest.approx(intersections5.sum())
2✔
206
    assert_series_equal(intersections4, intersections5, check_names=False)
2✔
207
    assert intersections4.name == 'b'
2✔
208
    assert intersections5.name == 'size'
2✔
209
    assert_series_equal(totals4, totals5)
2✔
210
    assert_frame_equal(df4, df5)
2✔
211

212

213
@pytest.mark.parametrize('x', [
2✔
214
    generate_samples()['value'],
2✔
215
    generate_counts(),
2✔
216
])
217
def test_subset_size_frame(x):
1✔
218
    kw = {'sort_by': 'cardinality',
2✔
219
          'sort_categories_by': 'cardinality'}
2✔
220
    X = pd.DataFrame({'x': x})
2✔
221
    total_sum, df_sum, intersections_sum, totals_sum = _process_data(
2✔
222
        X, subset_size='sum', sum_over='x', **kw)
2✔
223
    total_count, df_count, intersections_count, totals_count = _process_data(
2✔
224
        X, subset_size='count', sum_over=None, **kw)
2✔
225

226
    # error cases: sum_over=False
227
    for subset_size in ['auto', 'sum', 'count']:
2✔
228
        with pytest.raises(ValueError, match='sum_over'):
2✔
229
            _process_data(
2✔
230
                X, subset_size=subset_size, sum_over=False, **kw)
2✔
231

232
    with pytest.raises(ValueError, match='sum_over'):
2✔
233
        _process_data(
2✔
234
            X, subset_size=subset_size, sum_over=False, **kw)
2✔
235

236
    # error cases: sum_over incompatible with subset_size
237
    with pytest.raises(ValueError, match='sum_over should be a field'):
2✔
238
        _process_data(
2✔
239
            X, subset_size='sum', sum_over=None, **kw)
2✔
240
    with pytest.raises(ValueError, match='sum_over cannot be set'):
2✔
241
        _process_data(
2✔
242
            X, subset_size='count', sum_over='x', **kw)
2✔
243

244
    # check subset_size='auto' with sum_over=str => sum
245
    total, df, intersections, totals = _process_data(
2✔
246
        X, subset_size='auto', sum_over='x', **kw)
2✔
247
    assert total == intersections.sum()
2✔
248
    assert_frame_equal(df, df_sum)
2✔
249
    assert_series_equal(intersections, intersections_sum)
2✔
250
    assert_series_equal(totals, totals_sum)
2✔
251

252
    # check subset_size='auto' with sum_over=None => count
253
    total, df, intersections, totals = _process_data(
2✔
254
        X, subset_size='auto', sum_over=None, **kw)
2✔
255
    assert total == intersections.sum()
2✔
256
    assert_frame_equal(df, df_count)
2✔
257
    assert_series_equal(intersections, intersections_count)
2✔
258
    assert_series_equal(totals, totals_count)
2✔
259

260

261
@pytest.mark.parametrize('sort_by', ['cardinality', 'degree'])
2✔
262
@pytest.mark.parametrize('sort_categories_by', [None, 'cardinality'])
2✔
263
def test_not_unique(sort_by, sort_categories_by):
1✔
264
    kw = {'sort_by': sort_by,
2✔
265
          'sort_categories_by': sort_categories_by,
2✔
266
          'subset_size': 'sum',
2✔
267
          'sum_over': None}
2✔
268
    Xagg = generate_counts()
2✔
269
    total1, df1, intersections1, totals1 = _process_data(Xagg, **kw)
2✔
270
    Xunagg = generate_samples()['value']
2✔
271
    Xunagg.loc[:] = 1
2✔
272
    total2, df2, intersections2, totals2 = _process_data(Xunagg, **kw)
2✔
273
    assert_series_equal(intersections1, intersections2,
2✔
274
                        check_dtype=False)
2✔
275
    assert total2 == intersections2.sum()
2✔
276
    assert_series_equal(totals1, totals2, check_dtype=False)
2✔
277
    assert set(df1.columns) == {'_value', '_bin'}
2✔
278
    assert set(df2.columns) == {'_value', '_bin'}
2✔
279
    assert len(df2) == len(Xunagg)
2✔
280
    assert df2['_bin'].nunique() == len(intersections2)
2✔
281

282

283
def test_include_empty_subsets():
2✔
284
    X = generate_counts(n_samples=2, n_categories=3)
2✔
285

286
    no_empty_upset = UpSet(X, include_empty_subsets=False)
2✔
287
    assert len(no_empty_upset.intersections) <= 2
2✔
288

289
    include_empty_upset = UpSet(X, include_empty_subsets=True)
2✔
290
    assert len(include_empty_upset.intersections) == 2 ** 3
2✔
291
    common_intersections = include_empty_upset.intersections.loc[
2✔
292
        no_empty_upset.intersections.index]
2✔
293
    assert_series_equal(no_empty_upset.intersections,
2✔
294
                        common_intersections)
2✔
295
    include_empty_upset.plot()  # smoke test
2✔
296

297

298
@pytest.mark.parametrize('kw', [{'sort_by': 'blah'},
2✔
299
                                {'sort_by': True},
2✔
300
                                {'sort_categories_by': 'blah'},
2✔
301
                                {'sort_categories_by': True}])
2✔
302
def test_param_validation(kw):
1✔
303
    X = generate_counts(n_samples=100)
2✔
304
    with pytest.raises(ValueError):
2✔
305
        UpSet(X, **kw)
2✔
306

307

308
@pytest.mark.parametrize('kw', [{},
2✔
309
                                {'element_size': None},
2✔
310
                                {'orientation': 'vertical'},
2✔
311
                                {'intersection_plot_elements': 0},
2✔
312
                                {'facecolor': 'red'},
2✔
313
                                {'shading_color': 'lightgrey',
2✔
314
                                 'other_dots_color': 'pink'}])
2✔
315
def test_plot_smoke_test(kw):
1✔
316
    fig = matplotlib.figure.Figure()
2✔
317
    X = generate_counts(n_samples=100)
2✔
318
    axes = plot(X, fig, **kw)
2✔
319
    fig.savefig(io.BytesIO(), format='png')
2✔
320

321
    attr = ('get_xlim'
1✔
322
            if kw.get('orientation', 'horizontal') == 'horizontal'
2✔
323
            else 'get_ylim')
2✔
324
    lim = getattr(axes['matrix'], attr)()
2✔
325
    expected_width = len(X)
2✔
326
    assert expected_width == lim[1] - lim[0]
2✔
327

328
    # Also check fig is optional
329
    n_nums = len(plt.get_fignums())
2✔
330
    plot(X, **kw)
2✔
331
    assert len(plt.get_fignums()) - n_nums == 1
2✔
332
    assert plt.gcf().axes
2✔
333

334

335
@pytest.mark.parametrize('set1',
2✔
336
                         itertools.product([False, True], repeat=2))
2✔
337
@pytest.mark.parametrize('set2',
2✔
338
                         itertools.product([False, True], repeat=2))
2✔
339
def test_two_sets(set1, set2):
1✔
340
    # we had a bug where processing failed if no items were in some set
341
    fig = matplotlib.figure.Figure()
2✔
342
    plot(pd.DataFrame({'val': [5, 7],
2✔
343
                       'set1': set1,
2✔
344
                       'set2': set2}).set_index(['set1', 'set2'])['val'],
2✔
345
         fig, subset_size='sum')
2✔
346

347

348
def test_vertical():
2✔
349
    X = generate_counts(n_samples=100)
2✔
350

351
    fig = matplotlib.figure.Figure()
2✔
352
    UpSet(X, orientation='horizontal').make_grid(fig)
2✔
353
    horz_height = fig.get_figheight()
2✔
354
    horz_width = fig.get_figwidth()
2✔
355
    assert horz_height < horz_width
2✔
356

357
    fig = matplotlib.figure.Figure()
2✔
358
    UpSet(X, orientation='vertical').make_grid(fig)
2✔
359
    vert_height = fig.get_figheight()
2✔
360
    vert_width = fig.get_figwidth()
2✔
361
    assert horz_width / horz_height > vert_width / vert_height
2✔
362

363
    # TODO: test axes positions, plot order, bar orientation
364
    pass
2✔
365

366

367
def test_element_size():
2✔
368
    X = generate_counts(n_samples=100)
2✔
369
    figsizes = []
2✔
370
    for element_size in range(10, 50, 5):
2✔
371
        fig = matplotlib.figure.Figure()
2✔
372
        UpSet(X, element_size=element_size).make_grid(fig)
2✔
373
        figsizes.append((fig.get_figwidth(), fig.get_figheight()))
2✔
374

375
    figwidths, figheights = zip(*figsizes)
2✔
376
    # Absolute width increases
377
    assert np.all(np.diff(figwidths) > 0)
2✔
378
    aspect = np.divide(figwidths, figheights)
2✔
379
    # Font size stays constant, so aspect ratio decreases
380
    assert np.all(np.diff(aspect) <= 1e-8)  # allow for near-equality
2✔
381
    assert np.any(np.diff(aspect) < 1e-4)  # require some significant decrease
2✔
382
    # But doesn't decrease by much
383
    assert np.all(aspect[:-1] / aspect[1:] < 1.1)
2✔
384

385
    fig = matplotlib.figure.Figure()
2✔
386
    figsize_before = fig.get_figwidth(), fig.get_figheight()
2✔
387
    UpSet(X, element_size=None).make_grid(fig)
2✔
388
    figsize_after = fig.get_figwidth(), fig.get_figheight()
2✔
389
    assert figsize_before == figsize_after
2✔
390

391
    # TODO: make sure axes are all within figure
392
    # TODO: make sure text does not overlap axes, even with element_size=None
393

394

395
def _walk_artists(el):
2✔
396
    children = el.get_children()
2✔
397
    yield el, children
2✔
398
    for ch in children:
2✔
399
        for x in _walk_artists(ch):
2✔
400
            yield x
2✔
401

402

403
def _count_descendants(el):
2✔
404
    return sum(len(children) for x, children in _walk_artists(el))
2✔
405

406

407
@pytest.mark.parametrize('orientation', ['horizontal', 'vertical'])
2✔
408
def test_show_counts(orientation):
1✔
409
    fig = matplotlib.figure.Figure()
2✔
410
    X = generate_counts(n_samples=10000)
2✔
411
    plot(X, fig, orientation=orientation)
2✔
412
    n_artists_no_sizes = _count_descendants(fig)
2✔
413

414
    fig = matplotlib.figure.Figure()
2✔
415
    plot(X, fig, orientation=orientation, show_counts=True)
2✔
416
    n_artists_yes_sizes = _count_descendants(fig)
2✔
417
    assert n_artists_yes_sizes - n_artists_no_sizes > 6
2✔
418
    assert '9547' in get_all_texts(fig)  # set size
2✔
419
    assert '283' in get_all_texts(fig)   # intersection size
2✔
420

421
    fig = matplotlib.figure.Figure()
2✔
422
    plot(X, fig, orientation=orientation, show_counts='%0.2g')
2✔
423
    assert n_artists_yes_sizes == _count_descendants(fig)
2✔
424
    assert '9.5e+03' in get_all_texts(fig)
2✔
425
    assert '2.8e+02' in get_all_texts(fig)
2✔
426

427
    fig = matplotlib.figure.Figure()
2✔
428
    plot(X, fig, orientation=orientation, show_counts='{:0.2g}')
2✔
429
    assert n_artists_yes_sizes == _count_descendants(fig)
2✔
430
    assert '9.5e+03' in get_all_texts(fig)
2✔
431
    assert '2.8e+02' in get_all_texts(fig)
2✔
432

433
    fig = matplotlib.figure.Figure()
2✔
434
    plot(X, fig, orientation=orientation, show_percentages=True)
2✔
435
    assert n_artists_yes_sizes == _count_descendants(fig)
2✔
436
    assert '95.5%' in get_all_texts(fig)
2✔
437
    assert '2.8%' in get_all_texts(fig)
2✔
438

439
    fig = matplotlib.figure.Figure()
2✔
440
    plot(X, fig, orientation=orientation, show_percentages='!{:0.2f}!')
2✔
441
    assert n_artists_yes_sizes == _count_descendants(fig)
2✔
442
    assert '!0.95!' in get_all_texts(fig)
2✔
443
    assert '!0.03!' in get_all_texts(fig)
2✔
444

445
    fig = matplotlib.figure.Figure()
2✔
446
    plot(X, fig, orientation=orientation, show_counts=True,
2✔
447
         show_percentages=True)
2✔
448
    assert n_artists_yes_sizes == _count_descendants(fig)
2✔
449
    if orientation == 'vertical':
2✔
450
        assert '9547\n(95.5%)' in get_all_texts(fig)
2✔
451
        assert '283 (2.8%)' in get_all_texts(fig)
2✔
UNCOV
452
    else:
×
453
        assert '9547 (95.5%)' in get_all_texts(fig)
2✔
454
        assert '283\n(2.8%)' in get_all_texts(fig)
2✔
455

456
    with pytest.raises(ValueError):
2✔
457
        fig = matplotlib.figure.Figure()
2✔
458
        plot(X, fig, orientation=orientation, show_counts='%0.2h')
2✔
459

460

461
def test_add_catplot():
2✔
462
    pytest.importorskip('seaborn')
2✔
UNCOV
463
    X = generate_counts(n_samples=100)
×
UNCOV
464
    upset = UpSet(X)
×
465
    # smoke test
UNCOV
466
    upset.add_catplot('violin')
×
UNCOV
467
    fig = matplotlib.figure.Figure()
×
UNCOV
468
    upset.plot(fig)
×
469

470
    # can't provide value with Series
UNCOV
471
    with pytest.raises(ValueError):
×
UNCOV
472
        upset.add_catplot('violin', value='foo')
×
473

474
    # check the above add_catplot did not break the state
UNCOV
475
    upset.plot(fig)
×
476

UNCOV
477
    X = generate_counts(n_samples=100)
×
UNCOV
478
    X.name = 'foo'
×
UNCOV
479
    X = X.to_frame()
×
UNCOV
480
    upset = UpSet(X, subset_size='count')
×
481
    # must provide value with DataFrame
UNCOV
482
    with pytest.raises(ValueError):
×
UNCOV
483
        upset.add_catplot('violin')
×
UNCOV
484
    upset.add_catplot('violin', value='foo')
×
UNCOV
485
    with pytest.raises(ValueError):
×
486
        # not a known column
UNCOV
487
        upset.add_catplot('violin', value='bar')
×
UNCOV
488
    upset.plot(fig)
×
489

490
    # invalid plot kind raises error when plotting
UNCOV
491
    upset.add_catplot('foobar', value='foo')
×
UNCOV
492
    with pytest.raises(AttributeError):
×
UNCOV
493
        upset.plot(fig)
×
494

495

496
def _get_patch_data(axes, is_vertical):
2✔
497
    out = [{"y": patch.get_y(), "x": patch.get_x(),
2✔
498
            "h": patch.get_height(), "w": patch.get_width(),
1✔
499
            "fc": patch.get_facecolor(),
1✔
500
            "ec": patch.get_edgecolor(),
1✔
501
            "lw": patch.get_linewidth(),
1✔
502
            "ls": patch.get_linestyle(),
1✔
503
            "hatch": patch.get_hatch(),
1✔
UNCOV
504
            }
×
505
           for patch in axes.patches]
2✔
506
    if is_vertical:
2✔
507
        out = [{"y": patch["x"], "x": 6.5 - patch["y"],
2✔
508
                "h": patch["w"], "w": patch["h"],
1✔
509
                "fc": patch["fc"],
1✔
510
                "ec": patch["ec"],
1✔
511
                "lw": patch["lw"],
1✔
512
                "ls": patch["ls"],
1✔
513
                "hatch": patch["hatch"],
1✔
UNCOV
514
                }
×
515
               for patch in out]
2✔
516
    return pd.DataFrame(out).sort_values("x").reset_index(drop=True)
2✔
517

518

519
def _get_color_to_label_from_legend(ax):
2✔
520
    handles, labels = ax.get_legend_handles_labels()
2✔
521
    color_to_label = {
2✔
522
        patches[0].get_facecolor(): label
1✔
523
        for patches, label in zip(handles, labels)
2✔
524
    }
525
    return color_to_label
2✔
526

527

528
@pytest.mark.parametrize('orientation', ['horizontal', 'vertical'])
2✔
529
@pytest.mark.parametrize('show_counts', [False, True])
2✔
530
def test_add_stacked_bars(orientation, show_counts):
1✔
531
    df = generate_samples()
2✔
532
    df["label"] = (pd.cut(generate_samples().value + np.random.rand() / 2, 3)
2✔
533
                   .cat.codes
1✔
534
                   .map({0: "foo", 1: "bar", 2: "baz"}))
2✔
535

536
    upset = UpSet(df, show_counts=show_counts, orientation=orientation)
2✔
537
    upset.add_stacked_bars(by="label")
2✔
538
    upset_axes = upset.plot()
2✔
539

540
    int_axes = upset_axes["intersections"]
2✔
541
    stacked_axes = upset_axes["extra1"]
2✔
542

543
    is_vertical = orientation == 'vertical'
2✔
544
    int_rects = _get_patch_data(int_axes, is_vertical)
2✔
545
    stacked_rects = _get_patch_data(stacked_axes, is_vertical)
2✔
546

547
    # check bar heights match between int_rects and stacked_rects
548
    assert_series_equal(int_rects.groupby("x")["h"].sum(),
2✔
549
                        stacked_rects.groupby("x")["h"].sum(),
2✔
550
                        check_dtype=False)
2✔
551
    # check count labels match (TODO: check coordinate)
552
    assert ([elem.get_text() for elem in int_axes.texts] ==
2✔
553
            [elem.get_text() for elem in stacked_axes.texts])
1✔
554

555
    color_to_label = _get_color_to_label_from_legend(stacked_axes)
2✔
556
    stacked_rects["label"] = stacked_rects["fc"].map(color_to_label)
2✔
557
    # check totals for each label
558
    assert_series_equal(stacked_rects.groupby("label")["h"].sum(),
2✔
559
                        df.groupby("label").size(),
2✔
560
                        check_dtype=False, check_names=False)
2✔
561

562
    label_order = [text_obj.get_text()
2✔
563
                   for text_obj in stacked_axes.get_legend().get_texts()]
2✔
564
    # label order should be lexicographic
565
    assert label_order == sorted(label_order)
2✔
566

567
    if orientation == "horizontal":
2✔
568
        # order of labels in legend should match stack, top to bottom
569
        for prev, curr in zip(label_order, label_order[1:]):
2✔
570
            assert (stacked_rects.query("label == @prev")
2✔
UNCOV
571
                    .sort_values("x")["y"].values >=
×
UNCOV
572
                    stacked_rects.query("label == @curr")
×
573
                    .sort_values("x")["y"].values).all()
1✔
UNCOV
574
    else:
×
575
        # order of labels in legend should match stack, left to right
576
        for prev, curr in zip(label_order, label_order[1:]):
2✔
577
            assert (stacked_rects.query("label == @prev")
2✔
UNCOV
578
                    .sort_values("x")["y"].values <=
×
UNCOV
579
                    stacked_rects.query("label == @curr")
×
580
                    .sort_values("x")["y"].values).all()
1✔
581

582

583
@pytest.mark.parametrize("colors, expected", [
2✔
584
    (["blue", "red", "green"], ["blue", "red", "green"]),
2✔
585
    ({"bar": "blue", "baz": "red", "foo": "green"}, ["blue", "red", "green"]),
2✔
586
    ("Pastel1", ["#fbb4ae", "#b3cde3", "#ccebc5"]),
2✔
587
    (cm.viridis, ["#440154", "#440256", "#450457"]),
2✔
588
    (lambda x: cm.Pastel1(x), ["#fbb4ae", "#b3cde3", "#ccebc5"]),
2✔
589
])
590
def test_add_stacked_bars_colors(colors, expected):
1✔
591
    df = generate_samples()
2✔
592
    df["label"] = (pd.cut(generate_samples().value + np.random.rand() / 2, 3)
2✔
593
                   .cat.codes
1✔
594
                   .map({0: "foo", 1: "bar", 2: "baz"}))
2✔
595

596
    upset = UpSet(df)
2✔
597
    upset.add_stacked_bars(by="label", colors=colors,
2✔
598
                           title="Count by gender")
2✔
599
    upset_axes = upset.plot()
2✔
600
    stacked_axes = upset_axes["extra1"]
2✔
601
    color_to_label = _get_color_to_label_from_legend(stacked_axes)
2✔
602
    label_to_color = {v: k for k, v in color_to_label.items()}
2✔
603
    actual = [to_hex(label_to_color[label]) for label in ["bar", "baz", "foo"]]
2✔
604
    expected = [to_hex(color) for color in expected]
2✔
605
    assert actual == expected
2✔
606

607

608
@pytest.mark.parametrize('int_sum_over', [False, True])
2✔
609
@pytest.mark.parametrize('stack_sum_over', [False, True])
2✔
610
@pytest.mark.parametrize('show_counts', [False, True])
2✔
611
def test_add_stacked_bars_sum_over(int_sum_over, stack_sum_over, show_counts):
1✔
612
    # A rough test of sum_over
613
    df = generate_samples()
2✔
614
    df["label"] = (pd.cut(generate_samples().value + np.random.rand() / 2, 3)
2✔
615
                   .cat.codes
1✔
616
                   .map({0: "foo", 1: "bar", 2: "baz"}))
2✔
617

618
    upset = UpSet(df, sum_over="value" if int_sum_over else None,
2✔
619
                  show_counts=show_counts)
2✔
620
    upset.add_stacked_bars(by="label",
2✔
621
                           sum_over="value" if stack_sum_over else None,
2✔
622
                           colors='Pastel1')
2✔
623
    upset_axes = upset.plot()
2✔
624

625
    int_axes = upset_axes["intersections"]
2✔
626
    stacked_axes = upset_axes["extra1"]
2✔
627

628
    int_rects = _get_patch_data(int_axes, is_vertical=False)
2✔
629
    stacked_rects = _get_patch_data(stacked_axes, is_vertical=False)
2✔
630

631
    if int_sum_over == stack_sum_over:
2✔
632
        # check bar heights match between int_rects and stacked_rects
633
        assert_series_equal(int_rects.groupby("x")["h"].sum(),
2✔
634
                            stacked_rects.groupby("x")["h"].sum(),
2✔
635
                            check_dtype=False)
2✔
636
        # and check labels match with show_counts
637
        assert ([elem.get_text() for elem in int_axes.texts] ==
2✔
638
                [elem.get_text() for elem in stacked_axes.texts])
1✔
UNCOV
639
    else:
×
640
        assert (int_rects.groupby("x")["h"].sum() !=
2✔
641
                stacked_rects.groupby("x")["h"].sum()).all()
1✔
642
        if show_counts:
2✔
643
            assert ([elem.get_text() for elem in int_axes.texts] !=
2✔
644
                    [elem.get_text() for elem in stacked_axes.texts])
1✔
645

646

647
@pytest.mark.parametrize('x', [
2✔
648
    generate_counts(),
2✔
649
])
650
def test_index_must_be_bool(x):
1✔
651
    # Truthy ints are okay
652
    x = x.reset_index()
2✔
653
    x[['cat0', 'cat2', 'cat2']] = x[['cat0', 'cat1', 'cat2']].astype(int)
2✔
654
    x = x.set_index(['cat0', 'cat1', 'cat2']).iloc[:, 0]
2✔
655

656
    UpSet(x)
2✔
657

658
    # other ints are not
659
    x = x.reset_index()
2✔
660
    x[['cat0', 'cat2', 'cat2']] = x[['cat0', 'cat1', 'cat2']] + 1
2✔
661
    x = x.set_index(['cat0', 'cat1', 'cat2']).iloc[:, 0]
2✔
662
    with pytest.raises(ValueError, match='not boolean'):
2✔
663
        UpSet(x)
2✔
664

665

666
@pytest.mark.parametrize(
2✔
667
    "filter_params, expected",
2✔
668
    [
1✔
669
        ({"min_subset_size": 623},
2✔
670
         {(True, False, False): 884,
2✔
671
          (True, True, False): 1547,
2✔
672
          (True, False, True): 623,
2✔
673
          (True, True, True): 990,
2✔
674
          }),
675
        ({"min_subset_size": 800, "max_subset_size": 990},
2✔
676
         {(True, False, False): 884,
2✔
677
          (True, True, True): 990,
2✔
678
          }),
679
        ({"min_degree": 2},
2✔
680
         {(True, True, False): 1547,
2✔
681
          (True, False, True): 623,
2✔
682
          (False, True, True): 258,
2✔
683
          (True, True, True): 990,
2✔
684
          }),
685
        ({"min_degree": 2, "max_degree": 2},
2✔
686
         {(True, True, False): 1547,
2✔
687
          (True, False, True): 623,
2✔
688
          (False, True, True): 258,
2✔
689
          }),
690
        ({"max_subset_size": 500, "max_degree": 2},
2✔
691
         {(False, False, False): 220,
2✔
692
          (False, True, False): 335,
2✔
693
          (False, False, True): 143,
2✔
694
          (False, True, True): 258,
2✔
695
          }),
696
    ]
697
)
698
@pytest.mark.parametrize('sort_by', ['cardinality', 'degree'])
2✔
699
def test_filter_subsets(filter_params, expected, sort_by):
1✔
700
    data = generate_samples(seed=0, n_samples=5000, n_categories=3)
2✔
701
    # data =
702
    #   cat1   cat0   cat2
703
    #   False  False  False     220
704
    #   True   False  False     884
705
    #   False  True   False     335
706
    #          False  True      143
707
    #   True   True   False    1547
708
    #          False  True      623
709
    #   False  True   True      258
710
    #   True   True   True      990
711
    upset_full = UpSet(data, subset_size='auto', sort_by=sort_by)
2✔
712
    upset_filtered = UpSet(data, subset_size='auto',
2✔
713
                           sort_by=sort_by,
2✔
714
                           **filter_params)
2✔
715
    intersections = upset_full.intersections
2✔
716
    df = upset_full._df
2✔
717
    # check integrity of expected, just to be sure
718
    for key, value in expected.items():
2✔
719
        assert intersections.loc[key] == value
2✔
720
    subset_intersections = intersections[
2✔
721
        intersections.index.isin(list(expected.keys()))]
2✔
722
    subset_df = df[df.index.isin(list(expected.keys()))]
2✔
723
    assert len(subset_intersections) < len(intersections)
2✔
724
    assert_series_equal(upset_filtered.intersections, subset_intersections)
2✔
725
    assert_frame_equal(upset_filtered._df.drop("_bin", axis=1),
2✔
726
                       subset_df.drop("_bin", axis=1))
2✔
727
    # category totals should not be affected
728
    assert_series_equal(upset_full.totals, upset_filtered.totals)
2✔
729

730

731
@pytest.mark.parametrize('x', [
2✔
732
    generate_counts(n_categories=3),
2✔
733
    generate_counts(n_categories=8),
2✔
734
    generate_counts(n_categories=15),
2✔
735
])
736
@pytest.mark.parametrize('orientation', [
2✔
737
    'horizontal',
2✔
738
    'vertical',
2✔
739
])
740
def test_matrix_plot_margins(x, orientation):
1✔
741
    """Non-regression test addressing a bug where there is are large whitespace
742
       margins around the matrix when the number of intersections is large"""
743
    axes = plot(x, orientation=orientation)
2✔
744

745
    # Expected behavior is that each matrix column takes up one unit on x-axis
746
    expected_width = len(x)
2✔
747
    attr = 'get_xlim' if orientation == 'horizontal' else 'get_ylim'
2✔
748
    lim = getattr(axes['matrix'], attr)()
2✔
749
    assert expected_width == lim[1] - lim[0]
2✔
750

751

752
def _make_facecolor_list(colors):
2✔
753
    return [{"facecolor": c} for c in colors]
2✔
754

755

756
CAT1_2_RED_STYLES = _make_facecolor_list(["blue", "blue", "blue", "blue",
2✔
757
                                          "red", "blue", "blue", "red"])
1✔
758
CAT1_RED_STYLES = _make_facecolor_list(["blue", "red", "blue", "blue",
2✔
759
                                        "red", "red", "blue", "red"])
1✔
760
CAT_NOT1_RED_STYLES = _make_facecolor_list(["red", "blue", "red", "red",
2✔
761
                                            "blue", "blue", "red", "blue"])
1✔
762
CAT1_NOT2_RED_STYLES = _make_facecolor_list(["blue", "red", "blue", "blue",
2✔
763
                                             "blue", "red", "blue", "blue"])
1✔
764
CAT_NOT1_2_RED_STYLES = _make_facecolor_list(["red", "blue", "blue", "red",
2✔
765
                                              "blue", "blue", "blue", "blue"])
1✔
766

767

768
@pytest.mark.parametrize(
2✔
769
    "kwarg_list,expected_subset_styles,expected_legend",
2✔
770
    [
1✔
771
        # Different forms of including two categories
772
        ([{"present": ["cat1", "cat2"], "facecolor": "red"}],
2✔
773
         CAT1_2_RED_STYLES, []),
2✔
774
        ([{"present": {"cat1", "cat2"}, "facecolor": "red"}],
2✔
775
         CAT1_2_RED_STYLES, []),
2✔
776
        ([{"present": ("cat1", "cat2"), "facecolor": "red"}],
2✔
777
         CAT1_2_RED_STYLES, []),
2✔
778
        # with legend
779
        ([{"present": ("cat1", "cat2"), "facecolor": "red", "label": "foo"}],
2✔
780
         CAT1_2_RED_STYLES, [({"facecolor": "red"}, "foo")]),
2✔
781
        # present only cat1
782
        ([{"present": ("cat1",), "facecolor": "red"}],
2✔
783
         CAT1_RED_STYLES, []),
2✔
784
        ([{"present": "cat1", "facecolor": "red"}],
2✔
785
         CAT1_RED_STYLES, []),
2✔
786
        # Some uses of absent
787
        ([{"absent": "cat1", "facecolor": "red"}],
2✔
788
         CAT_NOT1_RED_STYLES, []),
2✔
789
        ([{"present": "cat1", "absent": ["cat2"], "facecolor": "red"}],
2✔
790
         CAT1_NOT2_RED_STYLES, []),
2✔
791
        ([{"absent": ["cat2", "cat1"], "facecolor": "red"}],
2✔
792
         CAT_NOT1_2_RED_STYLES, []),
2✔
793
        # min/max args
794
        ([{"present": ["cat1", "cat2"], "min_degree": 3, "facecolor": "red"}],
2✔
795
         _make_facecolor_list(["blue"] * 7 + ["red"]), []),
2✔
796
        ([{"present": ["cat1", "cat2"], "max_subset_size": 3000,
2✔
797
           "facecolor": "red"}],
2✔
798
         _make_facecolor_list(["blue"] * 7 + ["red"]), []),
2✔
799
        ([{"present": ["cat1", "cat2"], "max_degree": 2, "facecolor": "red"}],
2✔
800
         _make_facecolor_list(["blue"] * 4 + ["red"] + ["blue"] * 3), []),
2✔
801
        ([{"present": ["cat1", "cat2"], "min_subset_size": 3000,
2✔
802
           "facecolor": "red"}],
2✔
803
         _make_facecolor_list(["blue"] * 4 + ["red"] + ["blue"] * 3), []),
2✔
804
        # cat1 _or_ cat2
805
        ([{"present": "cat1", "facecolor": "red"},
2✔
806
          {"present": "cat2", "facecolor": "red"}],
2✔
807
         _make_facecolor_list(["blue", "red", "red", "blue",
2✔
808
                               "red", "red", "red", "red"]), []),
2✔
809
        # With multiple uses of label
810
        ([{"present": "cat1", "facecolor": "red", "label": "foo"},
2✔
811
          {"present": "cat2", "facecolor": "red", "label": "bar"}],
2✔
812
         _make_facecolor_list(["blue", "red", "red", "blue",
2✔
813
                               "red", "red", "red", "red"]),
1✔
814
         [({"facecolor": "red"}, "foo; bar")]),
2✔
815
        ([{"present": "cat1", "facecolor": "red", "label": "foo"},
2✔
816
          {"present": "cat2", "facecolor": "red", "label": "foo"}],
2✔
817
         _make_facecolor_list(["blue", "red", "red", "blue",
2✔
818
                               "red", "red", "red", "red"]),
1✔
819
         [({"facecolor": "red"}, "foo")]),
2✔
820
        # With multiple colours, the latest overrides
821
        ([{"present": "cat1", "facecolor": "red", "label": "foo"},
2✔
822
          {"present": "cat2", "facecolor": "green", "label": "bar"}],
2✔
823
         _make_facecolor_list(["blue", "red", "green", "blue",
2✔
824
                               "green", "red", "green", "green"]),
1✔
825
         [({"facecolor": "red"}, "foo"),
2✔
826
          ({"facecolor": "green"}, "bar")]),
2✔
827
        # Combining multiple style properties
828
        ([{"present": "cat1", "facecolor": "red", "hatch": "//"},
2✔
829
          {"present": "cat2", "edgecolor": "green", "linestyle": "dotted"}],
2✔
830
         [{"facecolor": "blue"},
2✔
831
          {"facecolor": "red", "hatch": "//"},
2✔
832
          {"facecolor": "blue", "edgecolor": "green", "linestyle": "dotted"},
2✔
833
          {"facecolor": "blue"},
2✔
834
          {"facecolor": "red", "hatch": "//", "edgecolor": "green",
2✔
835
           "linestyle": "dotted"},
2✔
836
          {"facecolor": "red", "hatch": "//"},
2✔
837
          {"facecolor": "blue", "edgecolor": "green",
2✔
838
           "linestyle": "dotted"},
2✔
839
          {"facecolor": "red", "hatch": "//", "edgecolor": "green",
2✔
840
           "linestyle": "dotted"},
2✔
841
          ],
842
         []),
2✔
843
    ])
844
def test_style_subsets(kwarg_list, expected_subset_styles, expected_legend):
1✔
845
    data = generate_counts()
2✔
846
    upset = UpSet(data, facecolor="blue")
2✔
847
    for kw in kwarg_list:
2✔
848
        upset.style_subsets(**kw)
2✔
849
    actual_subset_styles = upset.subset_styles
2✔
850
    assert actual_subset_styles == expected_subset_styles
2✔
851
    assert upset.subset_legend == expected_legend
2✔
852

853

854
def _dots_to_dataframe(ax, is_vertical):
2✔
855
    matrix_path_collection = ax.collections[0]
2✔
856
    matrix_dots = pd.DataFrame(
2✔
857
        matrix_path_collection.get_offsets(), columns=["x", "y"]
2✔
858
    ).join(
1✔
859
        pd.DataFrame(matrix_path_collection.get_facecolors(),
2✔
860
                     columns=["fc_r", "fc_g", "fc_b", "fc_a"]),
2✔
861
    ).join(
1✔
862
        pd.DataFrame(matrix_path_collection.get_edgecolors(),
2✔
863
                     columns=["ec_r", "ec_g", "ec_b", "ec_a"]),
2✔
864
    ).assign(
1✔
865
        lw=matrix_path_collection.get_linewidths(),
2✔
866
        ls=matrix_path_collection.get_linestyles(),
2✔
867
        hatch=matrix_path_collection.get_hatch(),
2✔
868
    )
869

870
    matrix_dots["ls_offset"] = matrix_dots["ls"].map(
2✔
871
        lambda tup: tup[0]).astype(float)
2✔
872
    matrix_dots["ls_seq"] = matrix_dots["ls"].map(
2✔
873
        lambda tup: None if tup[1] is None else tuple(tup[1]))
2✔
874
    del matrix_dots["ls"]
2✔
875

876
    if is_vertical:
2✔
877
        matrix_dots[["x", "y"]] = matrix_dots[["y", "x"]]
2✔
878
        matrix_dots["x"] = 7 - matrix_dots["x"]
2✔
879
    return matrix_dots
2✔
880

881

882
@pytest.mark.parametrize('orientation', ['horizontal', 'vertical'])
2✔
883
def test_style_subsets_artists(orientation):
1✔
884
    # Check that subset_styles are all appropriately reflected in matplotlib
885
    # artists.
886
    # This may be a bit overkill, and too coupled with implementation details.
887
    is_vertical = orientation == 'vertical'
2✔
888
    data = generate_counts()
2✔
889
    upset = UpSet(data, orientation=orientation)
2✔
890
    subset_styles = [
1✔
891
        {"facecolor": "black"},
2✔
892
        {"facecolor": "red"},
2✔
893
        {"edgecolor": "red"},
2✔
894
        {"edgecolor": "red", "linewidth": 4},
2✔
895
        {"linestyle": "dotted"},
2✔
896
        {"edgecolor": "red", "facecolor": "blue", "hatch": "//"},
2✔
897
        {"facecolor": "blue"},
2✔
898
        {},
2✔
899
    ]
900

901
    if is_vertical:
2✔
902
        upset.subset_styles = subset_styles[::-1]
2✔
UNCOV
903
    else:
×
904
        upset.subset_styles = subset_styles
2✔
905

906
    upset_axes = upset.plot()
2✔
907

908
    int_rects = _get_patch_data(upset_axes["intersections"], is_vertical)
2✔
909
    int_rects[["fc_r", "fc_g", "fc_b", "fc_a"]] = (
1✔
910
        int_rects.pop("fc").apply(lambda x: pd.Series(x)))
2✔
911
    int_rects[["ec_r", "ec_g", "ec_b", "ec_a"]] = (
1✔
912
        int_rects.pop("ec").apply(lambda x: pd.Series(x)))
2✔
913
    int_rects["ls_is_solid"] = int_rects.pop("ls").map(
2✔
914
        lambda x: x == "solid" or pd.isna(x))
2✔
915
    expected = pd.DataFrame({
2✔
916
        "fc_r": [0, 1, 0, 0, 0, 0, 0, 0],
2✔
917
        "fc_g": [0, 0, 0, 0, 0, 0, 0, 0],
2✔
918
        "fc_b": [0, 0, 0, 0, 0, 1, 1, 0],
2✔
919
        "ec_r": [0, 1, 1, 1, 0, 1, 0, 0],
2✔
920
        "ec_g": [0, 0, 0, 0, 0, 0, 0, 0],
2✔
921
        "ec_b": [0, 0, 0, 0, 0, 0, 1, 0],
2✔
922
        "lw": [1, 1, 1, 4, 1, 1, 1, 1],
2✔
923
        "ls_is_solid": [True, True, True, True, False, True, True, True],
2✔
924
    })
925

926
    assert_frame_equal(expected, int_rects[expected.columns],
2✔
927
                       check_dtype=False)
2✔
928

929
    styled_dots = _dots_to_dataframe(upset_axes["matrix"], is_vertical)
2✔
930
    baseline_dots = _dots_to_dataframe(
2✔
931
        UpSet(data, orientation=orientation).plot()["matrix"],
2✔
932
        is_vertical
2✔
933
    )
934
    inactive_dot_mask = (baseline_dots[["fc_a"]] < 1).values.ravel()
2✔
935
    assert_frame_equal(baseline_dots.loc[inactive_dot_mask],
2✔
936
                       styled_dots.loc[inactive_dot_mask])
2✔
937

938
    styled_dots = styled_dots.loc[~inactive_dot_mask]
2✔
939

940
    styled_dots = styled_dots.drop(columns="y").groupby("x").apply(
2✔
941
        lambda df: df.drop_duplicates())
2✔
942
    styled_dots["ls_is_solid"] = styled_dots.pop("ls_seq").isna()
2✔
943
    assert_frame_equal(expected.iloc[1:].reset_index(drop=True),
2✔
944
                       styled_dots[expected.columns].reset_index(drop=True),
2✔
945
                       check_dtype=False)
2✔
946

947
    # TODO: check lines between dots
948
    # matrix_line_collection = upset_axes["matrix"].collections[1]
949

950

951
def test_many_categories():
2✔
952
    # Tests regressions against GH#193
953
    n_cats = 250
2✔
954
    index1 = [True, False] + [False] * (n_cats - 2)
2✔
955
    index2 = [False, True] + [False] * (n_cats - 2)
2✔
956
    columns = [chr(i + 33) for i in range(n_cats)]
2✔
957
    data = pd.DataFrame([index1, index2], columns=columns)
2✔
958
    data["value"] = 1
2✔
959
    data = data.set_index(columns)["value"]
2✔
960
    UpSet(data)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc