• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

f-dangel / backpack / 8116261751

01 Mar 2024 07:30PM UTC coverage: 98.375%. Remained the same
8116261751

Pull #323

github

web-flow
Merge 610195223 into e9b1dd361
Pull Request #323: [FIX | FMT] RTD build, apply latest `black` and `isort`

97 of 97 new or added lines in 97 files covered. (100.0%)

43 existing lines in 18 files now uncovered.

4420 of 4493 relevant lines covered (98.38%)

11.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/backpack/utils/examples.py
1
"""Utility functions for examples."""
2
from typing import Iterator, List, Tuple
12✔
3

4
from torch import Tensor, stack, zeros
12✔
5
from torch.nn import Module
12✔
6
from torch.nn.utils.convert_parameters import parameters_to_vector
12✔
7
from torch.utils.data import DataLoader, Dataset
12✔
8
from torchvision.datasets import MNIST
12✔
9
from torchvision.transforms import Compose, Normalize, ToTensor
12✔
10

11
from backpack.hessianfree.ggnvp import ggn_vector_product
12✔
12
from backpack.utils.convert_parameters import vector_to_parameter_list
12✔
13

14

15
def load_mnist_dataset() -> Dataset:
12✔
16
    """Download and normalize MNIST training data.
17

18
    Returns:
19
        Normalized MNIST dataset
20
    """
UNCOV
21
    return MNIST(
×
22
        root="./data",
23
        train=True,
24
        transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]),
25
        download=True,
26
    )
27

28

29
def get_mnist_dataloader(batch_size: int = 64, shuffle: bool = True) -> DataLoader:
12✔
30
    """Returns a dataloader for MNIST.
31

32
    Args:
33
        batch_size: Mini-batch size. Default: ``64``.
34
        shuffle: Randomly shuffle the data. Default: ``True``.
35

36
    Returns:
37
        MNIST dataloader
38
    """
UNCOV
39
    return DataLoader(load_mnist_dataset(), batch_size=batch_size, shuffle=shuffle)
×
40

41

42
def load_one_batch_mnist(
12✔
43
    batch_size: int = 64, shuffle: bool = True, flat: bool = False
44
) -> Tuple[Tensor, Tensor]:
45
    """Return a single mini-batch (inputs, labels) from MNIST.
46

47
    Args:
48
        batch_size: Mini-batch size. Default: ``64``.
49
        shuffle: Randomly shuffle the data. Default: ``True``.
50
        flat: Flatten chanel and returns a matrix ``[batch_size x 784]``
51

52
    Returns:
53
        A single batch (inputs, labels) from MNIST.
54
    """
UNCOV
55
    dataloader = get_mnist_dataloader(batch_size, shuffle)
×
56
    X, y = next(iter(dataloader))
×
57

UNCOV
58
    if flat:
×
59
        X = X.reshape(X.shape[0], -1)
×
60

UNCOV
61
    return X, y
×
62

63

64
def autograd_diag_ggn_exact(
12✔
65
    X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None
66
) -> Tensor:
67
    """Compute the generalized Gauss-Newton diagonal with ``torch.autograd``.
68

69
    Args:
70
        X: Input to the model.
71
        y: Labels.
72
        model: The neural network.
73
        loss_function: Loss function module.
74
        idx: Indices for which the diagonal entries are computed. Default value ``None``
75
            computes the full diagonal.
76

77
    Returns:
78
        Exact GGN diagonal (flattened and concatenated).
79
    """
80
    diag_elements = [
12✔
81
        col[col_idx]
82
        for col_idx, col in _autograd_ggn_exact_columns(
83
            X, y, model, loss_function, idx=idx
84
        )
85
    ]
86

87
    return stack(diag_elements)
12✔
88

89

90
def _autograd_ggn_exact_columns(
12✔
91
    X: Tensor, y: Tensor, model: Module, loss_function: Module, idx: List[int] = None
92
) -> Iterator[Tuple[int, Tensor]]:
93
    """Yield exact generalized Gauss-Newton's columns computed with ``torch.autograd``.
94

95
    Args:
96
        X: Input to the model.
97
        y: Labels.
98
        model: The neural network.
99
        loss_function: Loss function module.
100
        idx: Indices of columns that are computed. Default value ``None`` computes all
101
            columns.
102

103
    Yields:
104
        Tuple of column index and respective GGN column (flattened and concatenated).
105
    """
106
    trainable_parameters = [p for p in model.parameters() if p.requires_grad]
12✔
107
    D = sum(p.numel() for p in trainable_parameters)
12✔
108

109
    outputs = model(X)
12✔
110
    loss = loss_function(outputs, y)
12✔
111

112
    idx = idx if idx is not None else list(range(D))
12✔
113

114
    for d in idx:
12✔
115
        e_d = zeros(D, device=loss.device, dtype=loss.dtype)
12✔
116
        e_d[d] = 1.0
12✔
117
        e_d_list = vector_to_parameter_list(e_d, trainable_parameters)
12✔
118

119
        ggn_d_list = ggn_vector_product(loss, outputs, model, e_d_list)
12✔
120
        ggn_d_list = [t.contiguous() for t in ggn_d_list]
12✔
121

122
        yield d, parameters_to_vector(ggn_d_list)
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc