• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 8098020118

29 Feb 2024 02:57PM UTC coverage: 51.806% (-12.4%) from 64.17%
8098020118

push

github

web-flow
Update test-coverage-coveralls.yml

14756 of 28483 relevant lines covered (51.81%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

24.0
/tests/training/test_regularization.py
1
import unittest
1✔
2

3
import torch
1✔
4
from torch.utils.data import DataLoader
1✔
5

6
from avalanche.models import SimpleMLP, MTSimpleMLP
1✔
7
from avalanche.models.utils import avalanche_model_adaptation
1✔
8
from avalanche.training.regularization import LearningWithoutForgetting
1✔
9
from tests.unit_tests_utils import get_fast_benchmark
1✔
10
import numpy as np
1✔
11
import random
1✔
12

13

14
class TestLwF(unittest.TestCase):
1✔
15
    def test_lwf(self):
1✔
16
        seed = 0
×
17
        torch.manual_seed(seed)
×
18
        np.random.seed(seed)
×
19
        random.seed(seed)
×
20
        torch.use_deterministic_algorithms(True)
×
21
        lwf = LearningWithoutForgetting()
×
22
        bm = get_fast_benchmark()
×
23

24
        teacher = SimpleMLP(input_size=6)
×
25
        model = SimpleMLP(input_size=6)
×
26
        for exp in bm.train_stream:
×
27
            mb_x, mb_y, mb_tl = list(DataLoader(exp.dataset))[0]
×
28
            mb_pred = model(mb_x)
×
29
            loss = lwf(mb_x, mb_pred, model)
×
30

31
            # non-zero loss after first task
32
            if lwf.expcount == 0:
×
33
                assert loss == 0
×
34
            else:
35
                assert loss > 0.0
×
36
            lwf.update(exp, teacher)
×
37

38
        lwf = LearningWithoutForgetting()
×
39
        teacher = MTSimpleMLP(input_size=6)
×
40
        model = MTSimpleMLP(input_size=6)
×
41
        for exp in bm.train_stream:
×
42
            avalanche_model_adaptation(teacher, exp)
×
43
            avalanche_model_adaptation(model, exp)
×
44
            mb_x, mb_y, mb_tl = list(DataLoader(exp.dataset))[0]
×
45
            mb_pred = model(mb_x, task_labels=mb_tl)
×
46
            loss = lwf(mb_x, mb_pred, model)
×
47

48
            # non-zero loss after first task
49
            if lwf.expcount == 0:
×
50
                assert loss == 0
×
51
            else:
52
                assert loss > 0.0
×
53

54
                # non-zero loss for all the previous heads
55
                loss.backward()
×
56
                for tid in lwf.prev_classes_by_task.keys():
×
57
                    head = model.classifier.classifiers[str(tid)]
×
58
                    weight = head.classifier.weight
×
59
                    assert weight.grad is not None
×
60
                    assert torch.norm(weight.grad) > 0
×
61
                model.zero_grad()
×
62

63
            lwf.update(exp, teacher)
×
64

65

66
if __name__ == "__main__":
1✔
67
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc