• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8098020118

29 Feb 2024 02:57PM UTC coverage: 51.806% (-12.4%) from 64.17%
8098020118

push

github

web-flow
Update test-coverage-coveralls.yml

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

30.19
/tests/benchmarks/utils/test_data_attribute.py
1
import unittest
1✔
2

3
import numpy as np
1✔
4
import torch
1✔
5

6
from avalanche.benchmarks.utils import (
1✔
7
    _taskaware_classification_subset,
8
    make_avalanche_dataset,
9
)
10
from avalanche.benchmarks.utils.data_attribute import DataAttribute, TensorDataAttribute
1✔
11

12

13
class DataAttributeTests(unittest.TestCase):
1✔
14
    def test_tensor_uniques(self):
1✔
15
        """Test that uniques are correctly computed for tensors."""
16
        t = torch.zeros(10)
×
17
        da = DataAttribute(t, "task_labels")
×
18
        self.assertEqual(da.uniques, {0.0})
×
19

20
    def test_count(self):
1✔
21
        """Test that count is correctly computed."""
22
        t0 = torch.zeros(10, dtype=torch.int)
×
23
        t1 = torch.ones(10, dtype=torch.int)
×
24
        da = DataAttribute(torch.cat([t0, t1]), "task_labels")
×
25
        self.assertEqual(da.count, {0: 10, 1: 10})
×
26

27
    def test_val_to_idx(self):
1✔
28
        """Test that val_to_idx is correctly computed."""
29
        t0 = torch.zeros(10, dtype=torch.int)
×
30
        t1 = torch.ones(10, dtype=torch.int)
×
31
        da = DataAttribute(torch.cat([t0, t1]), "task_labels")
×
32
        self.assertEqual(da.val_to_idx, {0: list(range(10)), 1: list(range(10, 20))})
×
33

34
    def test_subset(self):
1✔
35
        """Test that subset is correctly computed."""
36
        t0 = torch.zeros(10, dtype=torch.int)
×
37
        t1 = torch.ones(10, dtype=torch.int)
×
38
        da = DataAttribute(torch.cat([t0, t1]), "task_labels")
×
39
        self.assertEqual(list(da.subset(range(10)).data), list(t0))
×
40
        self.assertEqual(list(da.subset(range(10, 20)).data), list(t1))
×
41

42
    def test_concat(self):
1✔
43
        """Test that concat is correctly computed."""
44
        t0 = torch.zeros(10, dtype=torch.int)
×
45
        t1 = torch.ones(10, dtype=torch.int)
×
46
        da = DataAttribute(torch.cat([t0, t1]), "task_labels")
×
47
        self.assertEqual(list(da.concat(da).data), list(torch.cat([t0, t1, t0, t1])))
×
48

49

50
class TensorDataAttributeTests(unittest.TestCase):
1✔
51
    def test_subset(self):
1✔
52
        """Test that subset is correctly computed."""
53
        t0 = torch.zeros(10)
×
54
        t1 = torch.ones(10)
×
55
        da = TensorDataAttribute(torch.cat([t0, t1]), "logit")
×
56
        self.assertEqual(list(da.subset(range(10)).data), list(t0))
×
57
        self.assertEqual(list(da.subset(range(10, 20)).data), list(t1))
×
58

59
    def test_concat(self):
1✔
60
        """Test that concat is correctly computed."""
61
        t0 = torch.zeros(10)
×
62
        t1 = torch.ones(10)
×
63
        da = DataAttribute(torch.cat([t0, t1]), "logits")
×
64
        self.assertEqual(list(da.concat(da).data), list(torch.cat([t0, t1, t0, t1])))
×
65

66
    def test_swap(self):
1✔
67
        """Test that data attributes are
68
        always returned in the same order"""
69
        # Fake x, y
70
        t1 = list(zip(np.arange(10), np.arange(10)))
×
71
        t2 = torch.ones(10).tolist()
×
72
        t3 = (torch.ones(10) * 2).tolist()
×
73
        t4 = (torch.ones(10) * 3).tolist()
×
74

75
        dataset = make_avalanche_dataset(
×
76
            t1,
77
            data_attributes=[
78
                TensorDataAttribute(t2, name="logits", use_in_getitem=True),
79
                TensorDataAttribute(t3, name="logits2", use_in_getitem=True),
80
            ],
81
        )
82

83
        # Now add another attribute
84
        dataset = make_avalanche_dataset(
×
85
            dataset,
86
            data_attributes=[
87
                TensorDataAttribute(t4, name="logits0", use_in_getitem=True),
88
            ],
89
        )
90

91
        self.assertSequenceEqual([0.0, 0.0, 1.0, 2.0, 3.0], dataset[0])
×
92

93

94
if __name__ == "__main__":
1✔
95
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc