• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Project-OSmOSE / OSEkit / 19831033427

01 Dec 2025 05:08PM UTC coverage: 96.981% (+0.03%) from 96.955%
19831033427

Pull #307

github

web-flow
Merge 92d3f5588 into 27d210a34
Pull Request #307: [DRAFT] Relative paths

41 of 42 new or added lines in 6 files covered. (97.62%)

13 existing lines in 5 files now uncovered.

3951 of 4074 relevant lines covered (96.98%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.17
/src/osekit/core_api/base_dataset.py
1
"""BaseDataset: Base class for the Dataset objects.
2

3
Datasets are collections of Data, with methods
4
that simplify repeated operations on the data.
5
"""
6

7
from __future__ import annotations
1✔
8

9
import os
1✔
10
from bisect import bisect
1✔
11
from pathlib import Path
1✔
12
from typing import TYPE_CHECKING, Generic, Literal, TypeVar
1✔
13

14
from pandas import Timedelta, Timestamp, date_range
1✔
15
from soundfile import LibsndfileError
1✔
16
from tqdm import tqdm
1✔
17

18
from osekit.config import TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED
1✔
19
from osekit.config import global_logging_context as glc
1✔
20
from osekit.core_api.base_data import BaseData
1✔
21
from osekit.core_api.base_file import BaseFile
1✔
22
from osekit.core_api.event import Event
1✔
23
from osekit.core_api.json_serializer import deserialize_json, serialize_json
1✔
24
from osekit.utils.timestamp_utils import last_window_end
1✔
25

26
if TYPE_CHECKING:
27
    import pytz
28

29
TData = TypeVar("TData", bound=BaseData)
1✔
30
TFile = TypeVar("TFile", bound=BaseFile)
1✔
31

32

33
class BaseDataset(Generic[TData, TFile], Event):
1✔
34
    """Base class for Dataset objects.
35

36
    Datasets are collections of Data, with methods
37
    that simplify repeated operations on the data.
38
    """
39

40
    def __init__(
1✔
41
        self,
42
        data: list[TData],
43
        name: str | None = None,
44
        suffix: str = "",
45
        folder: Path | None = None,
46
    ) -> None:
47
        """Instantiate a Dataset object from the Data objects."""
48
        self.data = data
1✔
49
        self._name = name
1✔
50
        self._has_default_name = name is None
1✔
51
        self._suffix = suffix
1✔
52
        self._folder = folder
1✔
53

54
    def __str__(self) -> str:
1✔
55
        """Overwrite __str__."""
56
        return self.name
1✔
57

58
    def __eq__(self, other: BaseDataset) -> bool:
1✔
59
        """Overwrite __eq__."""
60
        return sorted(self.data, key=lambda e: (e.begin, e.end)) == sorted(
1✔
61
            other.data,
62
            key=lambda e: (e.begin, e.end),
63
        )
64

65
    @property
1✔
66
    def base_name(self) -> str:
1✔
67
        """Name of the dataset without suffix."""
68
        return (
1✔
69
            self.begin.strftime(TIMESTAMP_FORMAT_EXPORTED_FILES_UNLOCALIZED)
70
            if self._name is None
71
            else self._name
72
        )
73

74
    @base_name.setter
1✔
75
    def base_name(self, name: str) -> None:
1✔
76
        self._name = name
1✔
77

78
    @property
1✔
79
    def name(self) -> str:
1✔
80
        """Name of the dataset with suffix."""
81
        return self.base_name if not self.suffix else f"{self.base_name}_{self.suffix}"
1✔
82

83
    @name.setter
1✔
84
    def name(self, name: str | None) -> None:
1✔
85
        self._name = name
1✔
86

87
    @property
1✔
88
    def suffix(self) -> str:
1✔
89
        """Suffix that is applied to the name of the ads.
90

91
        This is used by the public API, for suffixing multiple core_api datasets
92
        that are created simultaneously and share the same namewith their specific type,
93
         e.g. _audio or _spectro.
94
        """
95
        return self._suffix
1✔
96

97
    @suffix.setter
1✔
98
    def suffix(self, suffix: str | None) -> None:
1✔
99
        self._suffix = suffix
1✔
100

101
    @property
1✔
102
    def has_default_name(self) -> bool:
1✔
103
        """Return True if the dataset has a default name, False if it has a given name."""
104
        return self._has_default_name
1✔
105

106
    @property
1✔
107
    def begin(self) -> Timestamp:
1✔
108
        """Begin of the first data object."""
109
        return min(data.begin for data in self.data)
1✔
110

111
    @property
1✔
112
    def end(self) -> Timestamp:
1✔
113
        """End of the last data object."""
114
        return max(data.end for data in self.data)
1✔
115

116
    @property
1✔
117
    def files(self) -> set[TFile]:
1✔
118
        """All files referred to by the Dataset."""
119
        return {file for data in self.data for file in data.files}
1✔
120

121
    @property
1✔
122
    def folder(self) -> Path:
1✔
123
        """Folder in which the dataset files are located or to be written."""
124
        return (
1✔
125
            self._folder
126
            if self._folder is not None
127
            else next(iter(file.path.parent for file in self.files), None)
128
        )
129

130
    @folder.setter
1✔
131
    def folder(self, folder: Path) -> None:
1✔
132
        """Set the folder in which the dataset files might be written.
133

134
        Parameters
135
        ----------
136
        folder: Path
137
            The folder in which the dataset files might be written.
138

139
        """
140
        self._folder = folder
1✔
141

142
    def move_files(self, folder: Path) -> None:
1✔
143
        """Move the dataset files to the destination folder.
144

145
        Parameters
146
        ----------
147
        folder: Path
148
            Destination folder in which the dataset files will be moved.
149

150
        """
151
        for file in tqdm(self.files, disable=os.environ.get("DISABLE_TQDM", "")):
1✔
152
            file.move(folder)
1✔
153
        self._folder = folder
1✔
154

155
    @property
1✔
156
    def data_duration(self) -> Timedelta:
1✔
157
        """Return the most frequent duration among durations of the data of this dataset, rounded to the nearest second."""
158
        data_durations = [
1✔
159
            Timedelta(data.duration).round(freq="1s") for data in self.data
160
        ]
161
        return max(set(data_durations), key=data_durations.count)
1✔
162

163
    def write(
1✔
164
        self,
165
        folder: Path,
166
        link: bool = False,
167
        first: int = 0,
168
        last: int | None = None,
169
    ) -> None:
170
        """Write all data objects in the specified folder.
171

172
        Parameters
173
        ----------
174
        folder: Path
175
            Folder in which to write the data.
176
        link: bool
177
            If True, the Data will be bound to the written file.
178
            Its items will be replaced with a single item, which will match the whole
179
            new File.
180
        first: int
181
            Index of the first data object to write.
182
        last: int | None
183
            Index after the last data object to write.
184

185
        """
186
        last = len(self.data) if last is None else last
1✔
187
        for data in tqdm(
1✔
188
            self.data[first:last],
189
            disable=os.environ.get("DISABLE_TQDM", ""),
190
        ):
191
            data.write(folder=folder, link=link)
1✔
192

193
    def set_files_as_relative(self) -> None:
1✔
194
        """Set the dataset folder as relative root of the files."""
195
        for file in self.files:
1✔
196
            file.relative_root = self.folder
1✔
197

198
    def to_dict(self) -> dict:
1✔
199
        """Serialize a BaseDataset to a dictionary.
200

201
        Returns
202
        -------
203
        dict:
204
            The serialized dictionary representing the BaseDataset.
205

206
        """
207
        self.set_files_as_relative()
1✔
208
        return {
1✔
209
            "data": {str(d): d.to_dict() for d in self.data},
210
            "name": self._name,
211
            "suffix": self.suffix,
212
            "folder": str(self.folder),
213
        }
214

215
    @classmethod
1✔
216
    def from_dict(cls, dictionary: dict, root_path: Path | None = None) -> BaseDataset:
1✔
217
        """Deserialize a BaseDataset from a dictionary.
218

219
        Parameters
220
        ----------
221
        dictionary: dict
222
            The serialized dictionary representing the BaseData.
223
        root_path: Path | None
224
            Path according to which the "files" values are expressed.
225
            If None, "files" values should be absolute.
226

227
        Returns
228
        -------
229
        AudioData
230
            The deserialized BaseDataset.
231

232
        """
UNCOV
233
        return cls(
×
234
            [
235
                BaseData.from_dict(dictionary=d, root_path=root_path)
236
                for d in dictionary["data"].values()
237
            ],
238
            name=dictionary["name"],
239
            suffix=dictionary["suffix"],
240
            folder=Path(dictionary["folder"]),
241
        )
242

243
    def write_json(self, folder: Path) -> None:
1✔
244
        """Write a serialized BaseDataset to a JSON file."""
245
        serialize_json(folder / f"{self.name}.json", self.to_dict())
1✔
246

247
    @classmethod
1✔
248
    def from_json(cls, file: Path) -> BaseDataset:
1✔
249
        """Deserialize a BaseDataset from a JSON file.
250

251
        Parameters
252
        ----------
253
        file: Path
254
            Path to the serialized JSON file representing the BaseDataset.
255

256
        Returns
257
        -------
258
        BaseDataset
259
            The deserialized BaseDataset.
260

261
        """
NEW
UNCOV
262
        return cls.from_dict(dictionary=deserialize_json(file), root_path=file.parent)
×
263

264
    @classmethod
1✔
265
    def from_files(  # noqa: PLR0913
1✔
266
        cls,
267
        files: list[TFile],
268
        begin: Timestamp | None = None,
269
        end: Timestamp | None = None,
270
        mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total",
271
        data_duration: Timedelta | None = None,
272
        overlap: float = 0.0,
273
        name: str | None = None,
274
    ) -> BaseDataset:
275
        """Return a base BaseDataset object from a list of Files.
276

277
        Parameters
278
        ----------
279
        files: list[TFile]
280
            The list of files contained in the Dataset.
281
        begin: Timestamp | None
282
            Begin of the first data object.
283
            Defaulted to the begin of the first file.
284
        end: Timestamp | None
285
            End of the last data object.
286
            Defaulted to the end of the last file.
287
        mode: Literal["files", "timedelta_total", "timedelta_file"]
288
            Mode of creation of the dataset data from the original files.
289
            "files": one data will be created for each file.
290
            "timedelta_total": data objects of duration equal to data_duration will
291
            be created from the begin timestamp to the end timestamp.
292
            "timedelta_file": data objects of duration equal to data_duration will
293
            be created from the beginning of the first file that the begin timestamp is into, until it would resume
294
            in a data beginning between two files. Then, the next data object will be created from the
295
            beginning of the next original file and so on.
296
        data_duration: Timedelta | None
297
            Duration of the data objects.
298
            If mode is set to "files", this parameter has no effect.
299
            If provided, data will be evenly distributed between begin and end.
300
            Else, one data object will cover the whole time period.
301
        overlap: float
302
            Overlap percentage between consecutive data.
303
        name: str|None
304
            Name of the dataset.
305

306
        Returns
307
        -------
308
        BaseDataset[TItem, TFile]:
309
        The DataBase object.
310

311
        """
312
        if mode == "files":
1✔
313
            data_base = [BaseData.from_files([f]) for f in files]
1✔
314
            data_base = BaseData.remove_overlaps(data_base)
1✔
315
            return cls(data=data_base, name=name)
1✔
316

317
        if not begin:
1✔
318
            begin = min(file.begin for file in files)
1✔
319
        if not end:
1✔
320
            end = max(file.end for file in files)
1✔
321
        if data_duration:
1✔
322
            data_base = (
1✔
323
                cls._get_base_data_from_files_timedelta_total(
324
                    begin=begin,
325
                    end=end,
326
                    data_duration=data_duration,
327
                    files=files,
328
                    overlap=overlap,
329
                )
330
                if mode == "timedelta_total"
331
                else cls._get_base_data_from_files_timedelta_file(
332
                    begin=begin,
333
                    end=end,
334
                    data_duration=data_duration,
335
                    files=files,
336
                    overlap=overlap,
337
                )
338
            )
339
        else:
340
            data_base = [BaseData.from_files(files, begin=begin, end=end)]
1✔
341
        return cls(data_base, name=name)
1✔
342

343
    @classmethod
1✔
344
    def _get_base_data_from_files_timedelta_total(
1✔
345
        cls,
346
        begin: Timestamp,
347
        end: Timestamp,
348
        data_duration: Timedelta,
349
        files: list[TFile],
350
        overlap: float = 0,
351
    ) -> list[BaseData]:
352
        if not 0 <= overlap < 1:
1✔
353
            msg = f"Overlap ({overlap}) must be between 0 and 1."
1✔
354
            raise ValueError(msg)
1✔
355

356
        active_file_index = 0
1✔
357
        output = []
1✔
358
        files = sorted(files, key=lambda f: f.begin)
1✔
359
        freq = data_duration * (1 - overlap)
1✔
360

361
        for data_begin in tqdm(
1✔
362
            date_range(begin, end, freq=freq, inclusive="left"),
363
            disable=os.environ.get("DISABLE_TQDM", ""),
364
        ):
365
            data_end = Timestamp(data_begin + data_duration)
1✔
366
            while (
1✔
367
                active_file_index < len(files)
368
                and files[active_file_index].end < data_begin
369
            ):
370
                active_file_index += 1
1✔
371
            last_active_file_index = active_file_index
1✔
372
            while (
1✔
373
                last_active_file_index < len(files)
374
                and files[last_active_file_index].begin < data_end
375
            ):
376
                last_active_file_index += 1
1✔
377
            output.append(
1✔
378
                BaseData.from_files(
379
                    files[active_file_index:last_active_file_index],
380
                    data_begin,
381
                    data_end,
382
                ),
383
            )
384

385
        return output
1✔
386

387
    @classmethod
1✔
388
    def _get_base_data_from_files_timedelta_file(
1✔
389
        cls,
390
        begin: Timestamp,
391
        end: Timestamp,
392
        data_duration: Timedelta,
393
        files: list[TFile],
394
        overlap: float = 0,
395
    ) -> list[BaseData]:
396
        if not 0 <= overlap < 1:
1✔
397
            msg = f"Overlap ({overlap}) must be between 0 and 1."
1✔
398
            raise ValueError(msg)
1✔
399

400
        files = sorted(files, key=lambda file: file.begin)
1✔
401
        first = max(0, bisect(files, begin, key=lambda f: f.begin) - 1)
1✔
402
        last = bisect(files, end, key=lambda f: f.begin)
1✔
403

404
        data_hop = data_duration * (1 - overlap)
1✔
405

406
        output = []
1✔
407
        files_chunk = []
1✔
408
        for idx, file in tqdm(
1✔
409
            enumerate(files[first:last]),
410
            disable=os.environ.get("DISABLE_TQDM", ""),
411
        ):
412
            if file in files_chunk:
1✔
413
                continue
1✔
414
            files_chunk = [file]
1✔
415

416
            for next_file in files[idx + 1 :]:
1✔
417
                upper_data_limit = last_window_end(
1✔
418
                    begin=file.begin,
419
                    end=files_chunk[-1].end,
420
                    window_hop=data_hop,
421
                    window_duration=data_duration,
422
                )
423
                if upper_data_limit < next_file.begin:
1✔
424
                    break
1✔
425
                files_chunk.append(next_file)
1✔
426

427
            output.extend(
1✔
428
                BaseData.from_files(files, data_begin, data_begin + data_duration)
429
                for data_begin in date_range(
430
                    file.begin,
431
                    files_chunk[-1].end,
432
                    freq=data_hop,
433
                    inclusive="left",
434
                )
435
            )
436

437
        return output
1✔
438

439
    @classmethod
1✔
440
    def from_folder(  # noqa: PLR0913
1✔
441
        cls,
442
        folder: Path,
443
        strptime_format: str,
444
        file_class: type[TFile] = BaseFile,
445
        supported_file_extensions: list[str] | None = None,
446
        begin: Timestamp | None = None,
447
        end: Timestamp | None = None,
448
        timezone: str | pytz.timezone | None = None,
449
        mode: Literal["files", "timedelta_total", "timedelta_file"] = "timedelta_total",
450
        overlap: float = 0.0,
451
        data_duration: Timedelta | None = None,
452
        name: str | None = None,
453
    ) -> BaseDataset:
454
        """Return a BaseDataset from a folder containing the base files.
455

456
        Parameters
457
        ----------
458
        folder: Path
459
            The folder containing the files.
460
        strptime_format: str
461
            The strptime format of the timestamps in the file names.
462
        file_class: type[Tfile]
463
            Derived type of BaseFile used to instantiate the dataset.
464
        supported_file_extensions: list[str]
465
            List of supported file extensions for parsing TFiles.
466
        begin: Timestamp | None
467
            The begin of the dataset.
468
            Defaulted to the begin of the first file.
469
        end: Timestamp | None
470
            The end of the dataset.
471
            Defaulted to the end of the last file.
472
        timezone: str | pytz.timezone | None
473
            The timezone in which the file should be localized.
474
            If None, the file begin/end will be tz-naive.
475
            If different from a timezone parsed from the filename, the timestamps'
476
            timezone will be converted from the parsed timezone
477
            to the specified timezone.
478
        mode: Literal["files", "timedelta_total", "timedelta_file"]
479
            Mode of creation of the dataset data from the original files.
480
            "files": one data will be created for each file.
481
            "timedelta_total": data objects of duration equal to data_duration will
482
            be created from the begin timestamp to the end timestamp.
483
            "timedelta_file": data objects of duration equal to data_duration will
484
            be created from the beginning of the first file that the begin timestamp is into, until it would resume
485
            in a data beginning between two files. Then, the next data object will be created from the
486
            beginning of the next original file and so on.
487
        overlap: float
488
            Overlap percentage between consecutive data.
489
        data_duration: Timedelta | None
490
            Duration of the data objects.
491
            If mode is set to "files", this parameter has no effect.
492
            If provided, data will be evenly distributed between begin and end.
493
            Else, one object will cover the whole time period.
494
        name: str|None
495
            Name of the dataset.
496

497
        Returns
498
        -------
499
        Basedataset:
500
            The base dataset.
501

502
        """
503
        if supported_file_extensions is None:
1✔
UNCOV
504
            supported_file_extensions = []
×
505
        valid_files = []
1✔
506
        rejected_files = []
1✔
507
        for file in tqdm(folder.iterdir(), disable=os.environ.get("DISABLE_TQDM", "")):
1✔
508
            if file.suffix.lower() not in supported_file_extensions:
1✔
509
                continue
1✔
510
            try:
1✔
511
                f = file_class(file, strptime_format=strptime_format, timezone=timezone)
1✔
512
                valid_files.append(f)
1✔
513
            except (ValueError, LibsndfileError):
1✔
514
                rejected_files.append(file)
1✔
515

516
        if rejected_files:
1✔
517
            rejected_files = "\n\t".join(f.name for f in rejected_files)
1✔
518
            glc.logger.warning(
1✔
519
                f"The following files couldn't be parsed:\n\t{rejected_files}",
520
            )
521

522
        if not valid_files:
1✔
523
            raise FileNotFoundError(f"No valid file found in {folder}.")
1✔
524

525
        return BaseDataset.from_files(
1✔
526
            files=valid_files,
527
            begin=begin,
528
            end=end,
529
            mode=mode,
530
            overlap=overlap,
531
            data_duration=data_duration,
532
            name=name,
533
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc