• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

f-dangel / backpack / 8116261751

01 Mar 2024 07:30PM UTC coverage: 98.375%. Remained the same
8116261751

Pull #323

github

web-flow
Merge 610195223 into e9b1dd361
Pull Request #323: [FIX | FMT] RTD build, apply latest `black` and `isort`

97 of 97 new or added lines in 97 files covered. (100.0%)

43 existing lines in 18 files now uncovered.

4420 of 4493 relevant lines covered (98.38%)

11.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.48
/backpack/context.py
1
"""Context class for BackPACK."""
2
from typing import Callable, Iterable, List, Tuple, Type
12✔
3

4
from torch.nn import Module
12✔
5
from torch.utils.hooks import RemovableHandle
12✔
6

7
from backpack.extensions.backprop_extension import BackpropExtension
12✔
8
from backpack.utils.hooks import no_op
12✔
9

10

11
class CTX:
12✔
12
    """Global Class holding the configuration of the backward pass."""
13

14
    active_exts: Tuple[BackpropExtension] = tuple()
12✔
15
    debug: bool = False
12✔
16
    extension_hook: Callable[[Module], None] = no_op
12✔
17
    hook_handles: List[RemovableHandle] = []
12✔
18
    retain_graph: bool = False
12✔
19

20
    @staticmethod
12✔
21
    def set_active_exts(active_exts: Iterable[BackpropExtension]) -> None:
12✔
22
        """Set the active backpack extensions.
23

24
        Args:
25
            active_exts: the extensions
26
        """
27
        CTX.active_exts = tuple(active_exts)
12✔
28

29
    @staticmethod
12✔
30
    def get_active_exts() -> Tuple[BackpropExtension]:
12✔
31
        """Get the currently active extensions.
32

33
        Returns:
34
            active extensions
35
        """
36
        return CTX.active_exts
12✔
37

38
    @staticmethod
12✔
39
    def add_hook_handle(hook_handle: RemovableHandle) -> None:
12✔
40
        """Add the hook handle to internal variable hook_handles.
41

42
        Args:
43
            hook_handle: the removable handle
44
        """
45
        CTX.hook_handles.append(hook_handle)
12✔
46

47
    @staticmethod
12✔
48
    def remove_hooks() -> None:
12✔
49
        """Remove all hooks."""
UNCOV
50
        for handle in CTX.hook_handles:
×
51
            handle.remove()
×
52
        CTX.hook_handles = []
×
53

54
    @staticmethod
12✔
55
    def is_extension_active(*extension_classes: Type[BackpropExtension]) -> bool:
12✔
56
        """Returns whether the specified class is currently active.
57

58
        Args:
59
            *extension_classes: classes to test
60

61
        Returns:
62
            whether at least one of the specified extensions is active
63
        """
64
        return any(isinstance(ext, extension_classes) for ext in CTX.get_active_exts())
12✔
65

66
    @staticmethod
12✔
67
    def get_debug() -> bool:
12✔
68
        """Whether debug mode is active.
69

70
        Returns:
71
            whether debug mode is active
72
        """
73
        return CTX.debug
12✔
74

75
    @staticmethod
12✔
76
    def set_debug(debug: bool) -> None:
12✔
77
        """Set debug mode.
78

79
        Args:
80
            debug: the mode to set
81
        """
82
        CTX.debug = debug
12✔
83

84
    @staticmethod
12✔
85
    def get_extension_hook() -> Callable[[Module], None]:
12✔
86
        """Return the current extension hook to be run after all other extensions.
87

88
        Returns:
89
            current extension hook
90
        """
91
        return CTX.extension_hook
12✔
92

93
    @staticmethod
12✔
94
    def set_extension_hook(extension_hook: Callable[[Module], None]) -> None:
12✔
95
        """Set the current extension hook.
96

97
        Args:
98
            extension_hook: the extension hook to run after all other extensions
99
        """
100
        CTX.extension_hook = extension_hook
12✔
101

102
    @staticmethod
12✔
103
    def set_retain_graph(retain_graph: bool) -> None:
12✔
104
        """Set retain_graph.
105

106
        Args:
107
            retain_graph: new value for retain_graph
108
        """
109
        CTX.retain_graph = retain_graph
12✔
110

111
    @staticmethod
12✔
112
    def get_retain_graph() -> bool:
12✔
113
        """Get retain_graph.
114

115
        Returns:
116
            retain_graph
117
        """
118
        return CTX.retain_graph
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc