• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ContinualAI / avalanche / 9268694310

28 May 2024 11:45AM UTC coverage: 51.799% (-0.004%) from 51.803%
9268694310

Pull #1647

github

web-flow
Merge 2995a272c into 8f0e61f23
Pull Request #1647: Improved Efficiency of the DiskUsage Metric

18 of 30 new or added lines in 2 files covered. (60.0%)

2 existing lines in 2 files now uncovered.

15091 of 29134 relevant lines covered (51.8%)

0.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/avalanche/benchmarks/datasets/torchaudio_wrapper.py
1
################################################################################
2
# Copyright (c) 2022 ContinualAI.                                              #
3
# Copyrights licensed under the MIT License.                                   #
4
# See the accompanying LICENSE file for terms.                                 #
5
#                                                                              #
6
# Author(s): Andrea Cossu                                                      #
7
# E-mail: contact@continualai.org                                              #
8
# Website: www.continualai.org                                                 #
9
################################################################################
10

11
""" This module conveniently wraps TorchAudio Datasets for using a clean and
12
comprehensive Avalanche API."""
13
import os
×
14

15
try:
×
16
    import torchaudio
×
17
    from torchaudio.datasets import SPEECHCOMMANDS
×
18
except ImportError:
×
19
    import warnings
×
20

21
    warnings.warn(
×
22
        "TorchAudio package is required to load SpeechCommands. "
23
        "You can install it as extra dependency with "
24
        "`pip install avalanche-lib[extra]`"
25
    )
26
    SPEECHCOMMANDS = object
×
27

28
from avalanche.benchmarks.utils import _make_taskaware_classification_dataset
×
29
from avalanche.benchmarks.datasets import default_dataset_location
×
30
import torch
×
31

32

33
def speech_commands_collate(batch):
×
34
    tensors, targets, t_labels = [], [], []
×
35
    for waveform, label, rate, sid, uid, t_label in batch:
×
36
        tensors += [waveform]
×
37
        targets += [torch.tensor(label)]
×
38
        t_labels += [torch.tensor(t_label)]
×
39
    tensors = [item.t() for item in tensors]
×
40
    tensors_padded = torch.nn.utils.rnn.pad_sequence(
×
41
        tensors, batch_first=True, padding_value=0.0
42
    )
43

44
    if len(tensors_padded.size()) == 2:  # no MFCC, add feature dimension
×
45
        tensors_padded = tensors_padded.unsqueeze(-1)
×
46
    targets = torch.stack(targets)
×
47
    t_labels = torch.stack(t_labels)
×
48
    return [tensors_padded, targets, t_labels]
×
49

50

51
class SpeechCommandsData(SPEECHCOMMANDS):
×
52
    def __init__(self, root, url, download, subset, mfcc_preprocessing):
×
53
        os.makedirs(root, exist_ok=True)
×
54
        super().__init__(root=root, download=download, subset=subset, url=url)
×
55
        self.labels_names = [
×
56
            "backward",
57
            "bed",
58
            "bird",
59
            "cat",
60
            "dog",
61
            "down",
62
            "eight",
63
            "five",
64
            "follow",
65
            "forward",
66
            "four",
67
            "go",
68
            "happy",
69
            "house",
70
            "learn",
71
            "left",
72
            "marvin",
73
            "nine",
74
            "no",
75
            "off",
76
            "on",
77
            "one",
78
            "right",
79
            "seven",
80
            "sheila",
81
            "six",
82
            "stop",
83
            "three",
84
            "tree",
85
            "two",
86
            "up",
87
            "visual",
88
            "wow",
89
            "yes",
90
            "zero",
91
        ]
92
        self.mfcc_preprocessing = mfcc_preprocessing
×
93

94
    def __getitem__(self, item):
×
95
        wave, rate, label, speaker_id, ut_number = super().__getitem__(item)
×
96
        label = self.labels_names.index(label)
×
97
        wave = wave.squeeze(0)  # (T,)
×
98
        if self.mfcc_preprocessing is not None:
×
99
            assert rate == self.mfcc_preprocessing.sample_rate
×
100
            # (T, MFCC)
101
            wave = self.mfcc_preprocessing(wave).permute(1, 0)
×
102
        return wave, label, rate, speaker_id, ut_number
×
103

104

105
def SpeechCommands(
×
106
    root=default_dataset_location("speech_commands"),
107
    url="speech_commands_v0.02",
108
    download=True,
109
    subset=None,
110
    mfcc_preprocessing=None,
111
):
112
    """
113
    root: dataset root location
114
    url: version name of the dataset
115
    download: automatically download the dataset, if not present
116
    subset: one of 'training', 'validation', 'testing'
117
    mfcc_preprocessing: an optional torchaudio.transforms.MFCC instance
118
        to preprocess each audio. Warning: this may slow down the execution
119
        since preprocessing is applied on-the-fly each time a sample is
120
        retrieved from the dataset.
121
    """
122
    dataset = SpeechCommandsData(
×
123
        root=root,
124
        download=download,
125
        subset=subset,
126
        url=url,
127
        mfcc_preprocessing=mfcc_preprocessing,
128
    )
129
    labels = [datapoint[1] for datapoint in dataset]
×
130
    return _make_taskaware_classification_dataset(
×
131
        dataset, collate_fn=speech_commands_collate, targets=labels
132
    )
133

134

135
__all__ = ["SpeechCommands"]
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc