• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14356631610

09 Apr 2025 12:03PM UTC coverage: 80.221% (+0.02%) from 80.205%
14356631610

Pull #1724

github

web-flow
Merge 6c839c369 into 3c9a29c48
Pull Request #1724: Add full data_files support in HFLoader + tests

1584 of 1968 branches covered (80.49%)

Branch coverage included in aggregate %.

9910 of 12360 relevant lines covered (80.18%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.86
src/unitxt/loaders.py
1
"""This section describes unitxt loaders.
2

3
Loaders: Generators of Unitxt Multistreams from existing date sources
4
=====================================================================
5

6
Unitxt is all about readily preparing of any given data source for feeding into any given language model, and then,
7
post-processing the model's output, preparing it for any given evaluator.
8

9
Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
10
of various off-the-shelf operators (i.e., picked from Unitxt catalog), or operators easily implemented by inheriting.
11
The journey starts by a Unitxt Loader bearing a Multistream from the given datasource.
12
A loader, therefore, is the first item on any Unitxt Recipe.
13

14
Unitxt catalog contains several loaders for the most popular datasource formats.
15
All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource is
16
straightforward.
17

18
Available Loaders Overview:
19
    - :class:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
20
    - :class:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
    - :class:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
    - :class:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
    - :class:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
    - :class:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
    - :class:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
    - :class:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
27

28

29

30

31
------------------------
32
"""
33

34
import fnmatch
1✔
35
import itertools
1✔
36
import json
1✔
37
import os
1✔
38
import tempfile
1✔
39
import time
1✔
40
from abc import abstractmethod
1✔
41
from pathlib import Path
1✔
42
from tempfile import TemporaryDirectory
1✔
43
from typing import (
1✔
44
    Any,
45
    Dict,
46
    Generator,
47
    Iterable,
48
    List,
49
    Literal,
50
    Mapping,
51
    Optional,
52
    Sequence,
53
    Union,
54
)
55

56
import pandas as pd
1✔
57
import requests
1✔
58
from datasets import (
1✔
59
    DatasetDict,
60
    IterableDataset,
61
    IterableDatasetDict,
62
    get_dataset_split_names,
63
)
64
from datasets import load_dataset as _hf_load_dataset
1✔
65
from huggingface_hub import HfApi
1✔
66
from tqdm import tqdm
1✔
67

68
from .dataclass import NonPositionalField
1✔
69
from .error_utils import Documentation, UnitxtError, UnitxtWarning
1✔
70
from .fusion import FixedFusion
1✔
71
from .logging_utils import get_logger
1✔
72
from .operator import SourceOperator
1✔
73
from .operators import Set
1✔
74
from .settings_utils import get_settings
1✔
75
from .stream import DynamicStream, MultiStream
1✔
76
from .type_utils import isoftype
1✔
77
from .utils import LRUCache, recursive_copy, retry_connection_with_exponential_backoff
1✔
78

79
logger = get_logger()
1✔
80
settings = get_settings()
1✔
81

82
class UnitxtUnverifiedCodeError(UnitxtError):
1✔
83
    def __init__(self, path):
1✔
84
        super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS)
×
85

86
@retry_connection_with_exponential_backoff(backoff_factor=2)
1✔
87
def hf_load_dataset(path: str, *args, **kwargs):
1✔
88
    if settings.hf_offline_datasets_path is not None:
1✔
89
        path = os.path.join(settings.hf_offline_datasets_path, path)
×
90
    try:
1✔
91
        return _hf_load_dataset(
1✔
92
            path,
93
            *args, **kwargs,
94
                verification_mode="no_checks",
95
                trust_remote_code=settings.allow_unverified_code,
96
                download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
97
            )
98
    except ValueError as e:
1✔
99
        if "trust_remote_code" in str(e):
×
100
            raise UnitxtUnverifiedCodeError(path) from e
×
101
        raise e # Re raise
×
102

103

104
@retry_connection_with_exponential_backoff(backoff_factor=2)
1✔
105
def hf_get_dataset_splits(path: str, name: str):
1✔
106
    try:
1✔
107
        return get_dataset_split_names(
1✔
108
            path=path,
109
            config_name=name,
110
            trust_remote_code=settings.allow_unverified_code,
111
        )
112
    except Exception as e:
1✔
113
        if "trust_remote_code" in str(e):
1✔
114
            raise UnitxtUnverifiedCodeError(path) from e
×
115

116
        if "Couldn't find cache" in str(e):
1✔
117
            raise FileNotFoundError(f"Dataset cache path={path}, name={name} was not found.") from e
×
118
        raise e # Re raise
1✔
119

120
class Loader(SourceOperator):
1✔
121
    """A base class for all loaders.
122

123
    A loader is the first component in the Unitxt Recipe,
124
    responsible for loading data from various sources and preparing it as a MultiStream for processing.
125
    The loader_limit is an optional parameter used to control the maximum number of instances to load from the data source.  It is applied for each split separately.
126
    It is usually provided to the loader via the recipe (see standard.py)
127
    The loader can use this value to limit the amount of data downloaded from the source
128
    to reduce loading time.  However, this may not always be possible, so the
129
    loader may ignore this.  In any case, the recipe, will limit the number of instances in the returned
130
    stream, after load is complete.
131

132
    Args:
133
        loader_limit: Optional integer to specify a limit on the number of records to load.
134
        streaming: Bool indicating if streaming should be used.
135
        num_proc: Optional integer to specify the number of processes to use for parallel dataset loading. Adjust the value according to the number of CPU cores available and the specific needs of your processing task.
136
    """
137

138
    loader_limit: int = None
1✔
139
    streaming: bool = False
1✔
140
    num_proc: int = None
1✔
141

142
    # class level shared cache:
143
    _loader_cache = LRUCache(max_size=settings.loader_cache_size)
1✔
144

145
    def get_limit(self) -> int:
1✔
146
        if settings.global_loader_limit is not None and self.loader_limit is not None:
1✔
147
            return min(int(settings.global_loader_limit), self.loader_limit)
1✔
148
        if settings.global_loader_limit is not None:
1✔
149
            return int(settings.global_loader_limit)
1✔
150
        return self.loader_limit
×
151

152
    def get_limiter(self):
1✔
153
        if settings.global_loader_limit is not None and self.loader_limit is not None:
1✔
154
            if int(settings.global_loader_limit) > self.loader_limit:
1✔
155
                return f"{self.__class__.__name__}.loader_limit"
1✔
156
            return "unitxt.settings.global_loader_limit"
1✔
157
        if settings.global_loader_limit is not None:
1✔
158
            return "unitxt.settings.global_loader_limit"
1✔
159
        return f"{self.__class__.__name__}.loader_limit"
×
160

161
    def log_limited_loading(self):
1✔
162
        if not hasattr(self, "_already_logged_limited_loading") or not self._already_logged_limited_loading:
1✔
163
            self._already_logged_limited_loading = True
1✔
164
            logger.info(
1✔
165
                f"\nLoading limited to {self.get_limit()} instances by setting {self.get_limiter()};"
166
            )
167

168
    def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
1✔
169
        if self.data_classification_policy is None:
1✔
170
            get_logger().warning(
×
171
                f"The {self.get_pretty_print_name()} loader does not set the `data_classification_policy`. "
172
                f"This may lead to sending of undesired data to external services.\n"
173
                f"Set it to a list of classification identifiers. \n"
174
                f"For example:\n"
175
                f"data_classification_policy = ['public']\n"
176
                f" or \n"
177
                f"data_classification_policy =['confidential','pii'])\n"
178
            )
179

180
        operator = Set(
1✔
181
            fields={"data_classification_policy": self.data_classification_policy}
182
        )
183
        return operator(multi_stream)
1✔
184

185
    def set_default_data_classification(
1✔
186
        self, default_data_classification_policy, additional_info
187
    ):
188
        if self.data_classification_policy is None:
1✔
189
            if additional_info is not None:
1✔
190
                logger.info(
1✔
191
                    f"{self.get_pretty_print_name()} sets 'data_classification_policy' to "
192
                    f"{default_data_classification_policy} by default {additional_info}.\n"
193
                    "To use a different value or remove this message, explicitly set the "
194
                    "`data_classification_policy` attribute of the loader.\n"
195
                )
196
            self.data_classification_policy = default_data_classification_policy
1✔
197

198
    @abstractmethod
1✔
199
    def load_iterables(self) -> Dict[str, Iterable]:
1✔
200
        pass
×
201

202
    def _maybe_set_classification_policy(self):
1✔
203
        pass
1✔
204

205
    def load_data(self) -> MultiStream:
1✔
206
        try:
1✔
207
            iterables = self.load_iterables()
1✔
208
        except Exception as e:
×
209
            raise UnitxtError(f"Error in loader:\n{self}") from e
×
210
        if isoftype(iterables, MultiStream):
1✔
211
            return iterables
1✔
212
        return MultiStream.from_iterables(iterables, copying=True)
1✔
213

214
    def process(self) -> MultiStream:
1✔
215
        self._maybe_set_classification_policy()
1✔
216
        return self.add_data_classification(self.load_data())
1✔
217

218
    def get_splits(self):
1✔
219
        return list(self().keys())
×
220

221

222
class LazyLoader(Loader):
1✔
223
    split: Optional[str] = NonPositionalField(default=None)
1✔
224

225
    @abstractmethod
1✔
226
    def get_splits(self) -> List[str]:
1✔
227
        pass
×
228

229
    @abstractmethod
1✔
230
    def split_generator(self, split: str) -> Generator:
1✔
231
        pass
×
232

233
    def load_iterables(self) -> Union[Dict[str, DynamicStream], IterableDatasetDict]:
1✔
234
        if self.split is not None:
1✔
235
            splits = [self.split]
1✔
236
        else:
237
            splits = self.get_splits()
1✔
238

239
        return MultiStream({
1✔
240
            split: DynamicStream(self.split_generator, gen_kwargs={"split": split})
241
            for split in splits
242
        })
243

244

245
class LoadHF(LazyLoader):
1✔
246
    """Loads datasets from the HuggingFace Hub.
247

248
    It supports loading with or without streaming,
249
    and it can filter datasets upon loading.
250

251
    Args:
252
        path:
253
            The path or identifier of the dataset on the HuggingFace Hub.
254
        name:
255
            An optional dataset name.
256
        data_dir:
257
            Optional directory to store downloaded data.
258
        split:
259
            Optional specification of which split to load.
260
        data_files:
261
            Optional specification of particular data files to load. When you provide a list of data_files to Hugging Face's load_dataset function without explicitly specifying the split argument, these files are automatically placed into the train split.
262
        revision:
263
            Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
264
        streaming (bool):
265
            indicating if streaming should be used.
266
        filtering_lambda (str, optional):
267
            A lambda function for filtering the data after loading.
268
        num_proc (int, optional):
269
            Specifies the number of processes to use for parallel dataset loading.
270

271
    Example:
272
        Loading glue's mrpc dataset
273

274
        .. code-block:: python
275

276
            load_hf = LoadHF(path='glue', name='mrpc')
277
    """
278

279
    path: str
1✔
280
    name: Optional[str] = None
1✔
281
    data_dir: Optional[str] = None
1✔
282
    split: Optional[str] = None
1✔
283
    data_files: Optional[
1✔
284
        Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
285
    ] = None
286
    revision: Optional[str] = None
1✔
287
    streaming: bool = None
1✔
288
    filtering_lambda: Optional[str] = None
1✔
289
    num_proc: Optional[int] = None
1✔
290
    splits: Optional[List[str]] = None
1✔
291

292
    def filter_load(self, dataset: DatasetDict):
1✔
293
        if not settings.allow_unverified_code:
1✔
294
            raise ValueError(
×
295
                f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
296
            )
297
        logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
1✔
298
        return dataset.filter(eval(self.filtering_lambda))
1✔
299

300
    def is_streaming(self) -> bool:
1✔
301
        if self.streaming is None:
1✔
302
            return settings.stream_hf_datasets_by_default
1✔
303
        return self.streaming
1✔
304

305
    def is_in_cache(self, split):
1✔
306
        dataset_id = str(self) + "_" + str(split)
1✔
307
        return dataset_id in self.__class__._loader_cache
1✔
308
    # returns Dict when split names are not known in advance, and just the the single split dataset - if known
309
    def load_dataset(
1✔
310
        self, split: str, streaming=None, disable_memory_caching=False
311
    ) -> Union[IterableDatasetDict, IterableDataset]:
312
        dataset_id = str(self) + "_" + str(split)
1✔
313
        dataset = self.__class__._loader_cache.get(dataset_id, None)
1✔
314
        if dataset is None:
1✔
315
            if streaming is None:
1✔
316
                streaming = self.is_streaming()
1✔
317

318
            dataset = hf_load_dataset(
1✔
319
                self.path,
320
                name=self.name,
321
                data_dir=self.data_dir,
322
                data_files=self.data_files,
323
                revision=self.revision,
324
                streaming=streaming,
325
                split=split,
326
                num_proc=self.num_proc,
327
            )
328

329
            if dataset is None:
1✔
330
                raise NotImplementedError() from None
×
331

332
            if not disable_memory_caching:
1✔
333
                self.__class__._loader_cache.max_size = settings.loader_cache_size
1✔
334
                self.__class__._loader_cache[dataset_id] = dataset
1✔
335
        self._already_logged_limited_loading = True
1✔
336

337
        return dataset
1✔
338

339
    def _maybe_set_classification_policy(self):
1✔
340
        if os.path.exists(self.path):
1✔
341
            self.set_default_data_classification(
1✔
342
                ["proprietary"], "when loading from local files"
343
            )
344
        else:
345
            self.set_default_data_classification(
1✔
346
                ["public"],
347
                None,  # No warning when loading from public hub
348
            )
349

350
    @retry_connection_with_exponential_backoff(max_retries=3, backoff_factor=2)
1✔
351
    def get_splits(self):
1✔
352
        if self.splits is not None:
1✔
353
            return self.splits
1✔
354
        if self.data_files is not None:
1✔
355
            if isinstance(self.data_files, dict):
1✔
356
                return list(self.data_files.keys())
1✔
357
            return ["train"]
1✔
358
        try:
1✔
359
            return hf_get_dataset_splits(
1✔
360
                path=self.path,
361
                name=self.name,
362
            )
363
        except Exception:
1✔
364
            UnitxtWarning(
1✔
365
                f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
366
            )
367
            try:
1✔
368
                dataset = self.load_dataset(
1✔
369
                    split=None, disable_memory_caching=True, streaming=True
370
                )
371
            except (
1✔
372
                NotImplementedError
373
            ):  # streaming is not supported for zipped files so we load without streaming
374
                dataset = self.load_dataset(split=None, streaming=False)
1✔
375

376
            if dataset is None:
1✔
377
                raise FileNotFoundError(f"Dataset path={self.path}, name={self.name} was not found.") from None
×
378

379
            return list(dataset.keys())
1✔
380

381
    def split_generator(self, split: str) -> Generator:
1✔
382
        if self.get_limit() is not None:
1✔
383
            if not self.is_in_cache(split):
1✔
384
                self.log_limited_loading()
1✔
385
        try:
1✔
386
            dataset = self.load_dataset(split=split)
1✔
387
        except (
×
388
            NotImplementedError
389
        ):  # streaming is not supported for zipped files so we load without streaming
390
            dataset = self.load_dataset(split=split, streaming=False)
×
391

392
        if self.filtering_lambda is not None:
1✔
393
            dataset = self.filter_load(dataset)
1✔
394

395
        limit = self.get_limit()
1✔
396
        if limit is None:
1✔
397
            yield from dataset
×
398
        else:
399
            for i, instance in enumerate(dataset):
1✔
400
                yield instance
1✔
401
                if i + 1 >= limit:
1✔
402
                    break
1✔
403

404

405
class LoadCSV(LazyLoader):
1✔
406
    """Loads data from CSV files.
407

408
    Supports streaming and can handle large files by loading them in chunks.
409

410
    Args:
411
        files (Dict[str, str]): A dictionary mapping names to file paths.
412
        chunksize : Size of the chunks to load at a time.
413
        loader_limit: Optional integer to specify a limit on the number of records to load.
414
        streaming: Bool indicating if streaming should be used.
415
        sep: String specifying the separator used in the CSV files.
416

417
    Example:
418
        Loading csv
419

420
        .. code-block:: python
421

422
            load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
423
    """
424

425
    files: Dict[str, str]
1✔
426
    chunksize: int = 1000
1✔
427
    loader_limit: Optional[int] = None
1✔
428
    streaming: bool = True
1✔
429
    sep: str = ","
1✔
430
    compression: Optional[str] = None
1✔
431
    lines: Optional[bool] = None
1✔
432
    file_type: Literal["csv", "json"] = "csv"
1✔
433

434
    def _maybe_set_classification_policy(self):
1✔
435
        self.set_default_data_classification(
1✔
436
            ["proprietary"], "when loading from local files"
437
        )
438

439
    def get_reader(self):
1✔
440
        if self.file_type == "csv":
1✔
441
            return pd.read_csv
1✔
442
        if self.file_type == "json":
×
443
            return pd.read_json
×
444
        raise ValueError()
×
445

446
    def get_args(self):
1✔
447
        args = {}
1✔
448
        if self.file_type == "csv":
1✔
449
            args["sep"] = self.sep
1✔
450
            args["low_memory"] = self.streaming
1✔
451
        if self.compression is not None:
1✔
452
            args["compression"] = self.compression
×
453
        if self.lines is not None:
1✔
454
            args["lines"] = self.lines
×
455
        if self.get_limit() is not None:
1✔
456
            args["nrows"] = self.get_limit()
1✔
457
        return args
1✔
458

459
    def get_splits(self) -> List[str]:
1✔
460
        return list(self.files.keys())
1✔
461

462
    def split_generator(self, split: str) -> Generator:
1✔
463
        dataset_id = str(self) + "_" + split
1✔
464
        dataset = self.__class__._loader_cache.get(dataset_id, None)
1✔
465
        if dataset is None:
1✔
466
            if self.get_limit() is not None:
1✔
467
                self.log_limited_loading()
1✔
468
            for attempt in range(settings.loaders_max_retries):
1✔
469
                try:
1✔
470
                    reader = self.get_reader()
1✔
471
                    if self.get_limit() is not None:
1✔
472
                        self.log_limited_loading()
1✔
473

474
                    try:
1✔
475
                        dataset = reader(self.files[split], **self.get_args()).to_dict(
1✔
476
                            "records"
477
                        )
478
                        break
1✔
479
                    except ValueError:
1✔
480
                        import fsspec
×
481

482
                        with fsspec.open(self.files[split], mode="rt") as f:
×
483
                            dataset = reader(f, **self.get_args()).to_dict("records")
×
484
                        break
×
485
                except Exception as e:
1✔
486
                    logger.debug(f"Attempt csv load {attempt + 1} failed: {e}")
1✔
487
                    if attempt < settings.loaders_max_retries - 1:
1✔
488
                        time.sleep(2)
1✔
489
                    else:
490
                        raise e
1✔
491
            self.__class__._loader_cache.max_size = settings.loader_cache_size
1✔
492
            self.__class__._loader_cache[dataset_id] = dataset
1✔
493

494
        for instance in self.__class__._loader_cache[dataset_id]:
1✔
495
            yield recursive_copy(instance)
1✔
496

497

498
class LoadFromSklearn(LazyLoader):
1✔
499
    """Loads datasets from the sklearn library.
500

501
    This loader does not support streaming and is intended for use with sklearn's dataset fetch functions.
502

503
    Args:
504
        dataset_name: The name of the sklearn dataset to fetch.
505
        splits: A list of data splits to load, e.g., ['train', 'test'].
506

507
    Example:
508
        Loading form sklearn
509

510
        .. code-block:: python
511

512
            load_sklearn = LoadFromSklearn(dataset_name='iris', splits=['train', 'test'])
513
    """
514

515
    dataset_name: str
1✔
516
    splits: List[str] = ["train", "test"]
1✔
517

518
    _requirements_list: List[str] = ["scikit-learn", "pandas"]
1✔
519

520
    data_classification_policy = ["public"]
1✔
521

522
    def verify(self):
1✔
523
        super().verify()
×
524

525
        if self.streaming:
×
526
            raise NotImplementedError("LoadFromSklearn cannot load with streaming.")
×
527

528
    def prepare(self):
1✔
529
        super().prepare()
×
530
        from sklearn import datasets as sklearn_datatasets
×
531

532
        self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
×
533

534
    def get_splits(self):
1✔
535
        return self.splits
×
536

537
    def split_generator(self, split: str) -> Generator:
1✔
538
        dataset_id = str(self) + "_" + split
×
539
        dataset = self.__class__._loader_cache.get(dataset_id, None)
×
540
        if dataset is None:
×
541
            split_data = self.downloader(subset=split)
×
542
            targets = [split_data["target_names"][t] for t in split_data["target"]]
×
543
            df = pd.DataFrame([split_data["data"], targets]).T
×
544
            df.columns = ["data", "target"]
×
545
            dataset = df.to_dict("records")
×
546
            self.__class__._loader_cache.max_size = settings.loader_cache_size
×
547
            self.__class__._loader_cache[dataset_id] = dataset
×
548
        for instance in self.__class__._loader_cache[dataset_id]:
×
549
            yield recursive_copy(instance)
×
550

551

552
class MissingKaggleCredentialsError(ValueError):
1✔
553
    pass
1✔
554

555

556
class LoadFromKaggle(Loader):
1✔
557
    """Loads datasets from Kaggle.
558

559
    Requires Kaggle API credentials and does not support streaming.
560

561
    Args:
562
        url: URL to the Kaggle dataset.
563

564
    Example:
565
        Loading from kaggle
566

567
        .. code-block:: python
568

569
            load_kaggle = LoadFromKaggle(url='kaggle.com/dataset/example')
570
    """
571

572
    url: str
1✔
573

574
    _requirements_list: List[str] = ["opendatasets"]
1✔
575
    data_classification_policy = ["public"]
1✔
576

577
    def verify(self):
1✔
578
        super().verify()
×
579
        if not os.path.isfile("kaggle.json"):
×
580
            raise MissingKaggleCredentialsError(
×
581
                "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
582
            )
583

584
        if self.streaming:
×
585
            raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
×
586

587
    def prepare(self):
1✔
588
        super().prepare()
×
589
        from opendatasets import download
×
590

591
        self.downloader = download
×
592

593
    def load_iterables(self):
1✔
594
        with TemporaryDirectory() as temp_directory:
×
595
            self.downloader(self.url, temp_directory)
×
596
            return hf_load_dataset(temp_directory, streaming=False)
×
597

598

599
class LoadFromIBMCloud(Loader):
1✔
600
    """Loads data from IBM Cloud Object Storage.
601

602
    Does not support streaming and requires AWS-style access keys.
603
    data_dir Can be either:
604
    1. a list of file names, the split of each file is determined by the file name pattern
605
    2. Mapping: split -> file_name, e.g. {"test" : "test.json", "train": "train.json"}
606
    3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
607

608
    Args:
609
        endpoint_url_env:
610
            Environment variable name for the IBM Cloud endpoint URL.
611
        aws_access_key_id_env:
612
            Environment variable name for the AWS access key ID.
613
        aws_secret_access_key_env:
614
            Environment variable name for the AWS secret access key.
615
        bucket_name:
616
            Name of the S3 bucket from which to load data.
617
        data_dir:
618
            Optional directory path within the bucket.
619
        data_files:
620
            Union type allowing either a list of file names or a mapping of splits to file names.
621
        data_field:
622
            The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
623
        caching (bool):
624
            indicating if caching is enabled to avoid re-downloading data.
625

626
    Example:
627
        Loading from IBM Cloud
628

629
        .. code-block:: python
630

631
            load_ibm_cloud = LoadFromIBMCloud(
632
                endpoint_url_env='IBM_CLOUD_ENDPOINT',
633
                aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
634
                aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY', # pragma: allowlist secret
635
                bucket_name='my-bucket'
636
            )
637
            multi_stream = load_ibm_cloud.process()
638
    """
639

640
    endpoint_url_env: str
1✔
641
    aws_access_key_id_env: str
1✔
642
    aws_secret_access_key_env: str
1✔
643
    bucket_name: str
1✔
644
    data_dir: str = None
1✔
645

646
    data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
1✔
647
    data_field: str = None
1✔
648
    caching: bool = True
1✔
649
    data_classification_policy = ["proprietary"]
1✔
650

651
    _requirements_list: List[str] = ["ibm-cos-sdk"]
1✔
652

653
    def _download_from_cos(self, cos, bucket_name, item_name, local_file):
1✔
654
        logger.info(f"Downloading {item_name} from {bucket_name} COS")
1✔
655
        try:
1✔
656
            response = cos.Object(bucket_name, item_name).get()
1✔
657
            size = response["ContentLength"]
1✔
658
            body = response["Body"]
1✔
659
        except Exception as e:
×
660
            raise Exception(
×
661
                f"Unabled to access {item_name} in {bucket_name} in COS", e
662
            ) from e
663

664
        if self.get_limit() is not None:
1✔
665
            if item_name.endswith(".jsonl"):
1✔
666
                first_lines = list(
1✔
667
                    itertools.islice(body.iter_lines(), self.get_limit())
668
                )
669
                with open(local_file, "wb") as downloaded_file:
1✔
670
                    for line in first_lines:
1✔
671
                        downloaded_file.write(line)
1✔
672
                        downloaded_file.write(b"\n")
1✔
673
                logger.info(
1✔
674
                    f"\nDownload successful limited to {self.get_limit()} lines"
675
                )
676
                return
1✔
677

678
        progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
×
679

680
        def upload_progress(chunk):
×
681
            progress_bar.update(chunk)
×
682

683
        try:
×
684
            cos.Bucket(bucket_name).download_file(
×
685
                item_name, local_file, Callback=upload_progress
686
            )
687
            logger.info("\nDownload Successful")
×
688
        except Exception as e:
×
689
            raise Exception(
×
690
                f"Unabled to download {item_name} in {bucket_name}", e
691
            ) from e
692

693
    def prepare(self):
1✔
694
        super().prepare()
1✔
695
        self.endpoint_url = os.getenv(self.endpoint_url_env)
1✔
696
        self.aws_access_key_id = os.getenv(self.aws_access_key_id_env)
1✔
697
        self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env)
1✔
698
        root_dir = os.getenv("UNITXT_IBM_COS_CACHE", None) or os.getcwd()
1✔
699
        self.cache_dir = os.path.join(root_dir, "ibmcos_datasets")
1✔
700

701
        if not os.path.exists(self.cache_dir):
1✔
702
            Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
1✔
703
        self.verified = False
1✔
704

705
    def lazy_verify(self):
1✔
706
        super().verify()
1✔
707
        assert (
1✔
708
            self.endpoint_url is not None
709
        ), f"Please set the {self.endpoint_url_env} environmental variable"
710
        assert (
1✔
711
            self.aws_access_key_id is not None
712
        ), f"Please set {self.aws_access_key_id_env} environmental variable"
713
        assert (
1✔
714
            self.aws_secret_access_key is not None
715
        ), f"Please set {self.aws_secret_access_key_env} environmental variable"
716
        if self.streaming:
1✔
717
            raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
×
718

719
    def _maybe_set_classification_policy(self):
1✔
720
        self.set_default_data_classification(
1✔
721
            ["proprietary"], "when loading from IBM COS"
722
        )
723

724
    def load_iterables(self):
1✔
725
        if not self.verified:
1✔
726
            self.lazy_verify()
1✔
727
            self.verified = True
1✔
728
        import ibm_boto3
1✔
729

730
        cos = ibm_boto3.resource(
1✔
731
            "s3",
732
            aws_access_key_id=self.aws_access_key_id,
733
            aws_secret_access_key=self.aws_secret_access_key,
734
            endpoint_url=self.endpoint_url,
735
        )
736
        local_dir = os.path.join(
1✔
737
            self.cache_dir,
738
            self.bucket_name,
739
            self.data_dir or "",  # data_dir can be None
740
            f"loader_limit_{self.get_limit()}",
741
        )
742
        if not os.path.exists(local_dir):
1✔
743
            Path(local_dir).mkdir(parents=True, exist_ok=True)
1✔
744
        if isinstance(self.data_files, Mapping):
1✔
745
            data_files_names = list(self.data_files.values())
1✔
746
            if not isinstance(data_files_names[0], str):
1✔
747
                data_files_names = list(itertools.chain(*data_files_names))
1✔
748
        else:
749
            data_files_names = self.data_files
1✔
750

751
        for data_file in data_files_names:
1✔
752
            local_file = os.path.join(local_dir, data_file)
1✔
753
            if not self.caching or not os.path.exists(local_file):
1✔
754
                # Build object key based on parameters. Slash character is not
755
                # allowed to be part of object key in IBM COS.
756
                object_key = (
1✔
757
                    self.data_dir + "/" + data_file
758
                    if self.data_dir is not None
759
                    else data_file
760
                )
761
                with tempfile.NamedTemporaryFile() as temp_file:
1✔
762
                    # Download to  a temporary file in same file partition, and then do an atomic move
763
                    self._download_from_cos(
1✔
764
                        cos,
765
                        self.bucket_name,
766
                        object_key,
767
                        local_dir + "/" + os.path.basename(temp_file.name),
768
                    )
769
                    os.renames(
1✔
770
                        local_dir + "/" + os.path.basename(temp_file.name),
771
                        local_dir + "/" + data_file,
772
                    )
773

774
        if isinstance(self.data_files, list):
1✔
775
            dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
1✔
776
        else:
777
            dataset = hf_load_dataset(
1✔
778
                local_dir,
779
                streaming=False,
780
                data_files=self.data_files,
781
                field=self.data_field,
782
            )
783

784
        return dataset
1✔
785

786

787
class MultipleSourceLoader(LazyLoader):
1✔
788
    """Allows loading data from multiple sources, potentially mixing different types of loaders.
789

790
    Args:
791
        sources: A list of loaders that will be combined to form a unified dataset.
792

793
    Examples:
794
        1) Loading the train split from a HuggingFace Hub and the test set from a local file:
795

796
        .. code-block:: python
797

798
            MultipleSourceLoader(sources = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
799

800

801

802
        2) Loading a test set combined from two files
803

804
        .. code-block:: python
805

806
            MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
807
    """
808

809
    sources: List[Loader]
1✔
810

811
    def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
1✔
812
        if self.data_classification_policy is None:
1✔
813
            return multi_stream
1✔
814
        return super().add_data_classification(multi_stream)
×
815

816
    def get_splits(self):
1✔
817
        splits = []
1✔
818
        for loader in self.sources:
1✔
819
            splits.extend(loader.get_splits())
1✔
820
        return list(set(splits))
1✔
821

822
    def split_generator(self, split: str) -> Generator[Any, None, None]:
1✔
823
        yield from FixedFusion(
1✔
824
            subsets=self.sources,
825
            max_instances_per_subset=self.get_limit(),
826
            include_splits=[split],
827
        )()[split]
828

829

830
class LoadFromDictionary(Loader):
1✔
831
    """Allows loading data from a dictionary of constants.
832

833
    The loader can be used, for example, when debugging or working with small datasets.
834

835
    Args:
836
        data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
837

838
    Example:
839
        Loading dictionary
840

841
        .. code-block:: python
842

843
            data = {
844
                "train": [{"input": "SomeInput1", "output": "SomeResult1"},
845
                          {"input": "SomeInput2", "output": "SomeResult2"}],
846
                "test":  [{"input": "SomeInput3", "output": "SomeResult3"},
847
                          {"input": "SomeInput4", "output": "SomeResult4"}]
848
            }
849
            loader = LoadFromDictionary(data=data)
850
    """
851

852
    data: Dict[str, List[Dict[str, Any]]]
1✔
853

854
    def verify(self):
1✔
855
        super().verify()
1✔
856
        if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
1✔
857
            raise ValueError(
1✔
858
                f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
859
                f"Expected data should map between split name and list of instances.\n"
860
                f"Received value: {self.data}\n"
861
            )
862
        for split in self.data.keys():
1✔
863
            if len(self.data[split]) == 0:
1✔
864
                raise ValueError(f"Split {split} has no instances.")
×
865
            first_instance = self.data[split][0]
1✔
866
            for instance in self.data[split]:
1✔
867
                if instance.keys() != first_instance.keys():
1✔
868
                    raise ValueError(
1✔
869
                        f"Not all instances in split '{split}' have the same fields.\n"
870
                        f"instance {instance} has different fields different from {first_instance}"
871
                    )
872

873
    def _maybe_set_classification_policy(self):
1✔
874
        self.set_default_data_classification(
1✔
875
            ["proprietary"], "when loading from python dictionary"
876
        )
877

878
    def load_iterables(self) -> MultiStream:
1✔
879
        return self.data
1✔
880

881

882
class LoadFromHFSpace(LazyLoader):
1✔
883
    """Used to load data from HuggingFace Spaces lazily.
884

885
    Args:
886
        space_name (str):
887
            Name of the HuggingFace Space to be accessed.
888
        data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
889
            Relative paths to files within a given repository. If given as a mapping,
890
            paths should be values, while keys should represent the type of respective files
891
            (training, testing etc.).
892
        path (str, optional):
893
            Absolute path to a directory where data should be downloaded.
894
        revision (str, optional):
895
            ID of a Git branch or commit to be used. By default, it is set to None,
896
            thus data is downloaded from the main branch of the accessed repository.
897
        use_token (bool, optional):
898
            Whether a token is used for authentication when accessing
899
            the HuggingFace Space. If necessary, the token is read from the HuggingFace
900
            config folder.
901
        token_env (str, optional):
902
            Key of an env variable which value will be used for
903
            authentication when accessing the HuggingFace Space - if necessary.
904
    """
905

906
    space_name: str
1✔
907
    data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
1✔
908
    path: Optional[str] = None
1✔
909
    revision: Optional[str] = None
1✔
910
    use_token: Optional[bool] = None
1✔
911
    token_env: Optional[str] = None
1✔
912
    requirements_list: List[str] = ["huggingface_hub"]
1✔
913

914
    streaming: bool = True
1✔
915

916
    def _get_token(self) -> Optional[Union[bool, str]]:
1✔
917
        if self.token_env:
1✔
918
            token = os.getenv(self.token_env)
×
919
            if not token:
×
920
                get_logger().warning(
×
921
                    f"The 'token_env' parameter was specified as '{self.token_env}', "
922
                    f"however, no environment variable under such a name was found. "
923
                    f"Therefore, the loader will not use any tokens for authentication."
924
                )
925
            return token
×
926
        return self.use_token
1✔
927

928
    @staticmethod
1✔
929
    def _is_wildcard(path: str) -> bool:
1✔
930
        wildcard_characters = ["*", "?", "[", "]"]
1✔
931
        return any(char in path for char in wildcard_characters)
1✔
932

933

934

935
    def _get_repo_files(self):
1✔
936
        if not hasattr(self, "_repo_files") or self._repo_files is None:
×
937
            api = HfApi()
×
938
            self._repo_files = api.list_repo_files(
×
939
                self.space_name, repo_type="space", revision=self.revision
940
            )
941
        return self._repo_files
×
942

943
    def _get_sub_files(self, file: str) -> List[str]:
1✔
944
        if self._is_wildcard(file):
1✔
945
            return fnmatch.filter(self._get_repo_files(), file)
×
946
        return [file]
1✔
947

948

949
    def get_splits(self) -> List[str]:
1✔
950
        if isinstance(self.data_files, Mapping):
1✔
951
            return list(self.data_files.keys())
1✔
952
        return ["train"]  # Default to 'train' if not specified
×
953

954
    def split_generator(self, split: str) -> Generator:
1✔
955
        from huggingface_hub import hf_hub_download
1✔
956
        from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
1✔
957

958
        token = self._get_token()
1✔
959
        files = self.data_files.get(split, self.data_files) if isinstance(self.data_files, Mapping) else self.data_files
1✔
960

961
        if isinstance(files, str):
1✔
962
            files = [files]
1✔
963
        limit = self.get_limit()
1✔
964

965
        if limit is not None:
1✔
966
            total = 0
1✔
967
            self.log_limited_loading()
1✔
968

969
        for file in files:
1✔
970
            for sub_file in self._get_sub_files(file):
1✔
971
                try:
1✔
972
                    file_path = hf_hub_download(
1✔
973
                        repo_id=self.space_name,
974
                        filename=sub_file,
975
                        repo_type="space",
976
                        token=token,
977
                        revision=self.revision,
978
                        local_dir=self.path,
979
                    )
980
                except EntryNotFoundError as e:
×
981
                    raise ValueError(
×
982
                        f"The file '{file}' was not found in the space '{self.space_name}'. "
983
                        f"Please check if the filename is correct, or if it exists in that "
984
                        f"Huggingface space."
985
                    ) from e
986
                except RepositoryNotFoundError as e:
×
987
                    raise ValueError(
×
988
                        f"The Huggingface space '{self.space_name}' was not found. "
989
                        f"Please check if the name is correct and you have access to the space."
990
                    ) from e
991

992
                with open(file_path, encoding="utf-8") as f:
1✔
993
                    for line in f:
1✔
994
                        yield json.loads(line.strip())
1✔
995
                        if limit is not None:
1✔
996
                            total += 1
1✔
997
                            if total >= limit:
1✔
998
                                return
1✔
999

1000

1001

1002
class LoadFromAPI(Loader):
1✔
1003
    """Loads data from from API.
1004

1005
    This loader is designed to fetch data from an API endpoint,
1006
    handling authentication through an API key. It supports
1007
    customizable chunk sizes and limits for data retrieval.
1008

1009
    Args:
1010
        urls (Dict[str, str]):
1011
            A dictionary mapping split names to their respective API URLs.
1012
        chunksize (int, optional):
1013
            The size of data chunks to fetch in each request. Defaults to 100,000.
1014
        loader_limit (int, optional):
1015
            Limits the number of records to load. Applied per split. Defaults to None.
1016
        streaming (bool, optional):
1017
            Determines if data should be streamed. Defaults to False.
1018
        api_key_env_var (str, optional):
1019
            The name of the environment variable holding the API key.
1020
            Defaults to "SQL_API_KEY".
1021
        headers (Dict[str, Any], optional):
1022
            Additional headers to include in API requests. Defaults to None.
1023
        data_field (str, optional):
1024
            The name of the field in the API response that contains the data.
1025
            Defaults to "data".
1026
        method (str, optional):
1027
            The HTTP method to use for API requests. Defaults to "GET".
1028
        verify_cert (bool):
1029
            Apply verification of the SSL certificate
1030
            Defaults as True
1031
    """
1032

1033
    urls: Dict[str, str]
1✔
1034
    chunksize: int = 100000
1✔
1035
    loader_limit: Optional[int] = None
1✔
1036
    streaming: bool = False
1✔
1037
    api_key_env_var: str = "SQL_API_KEY"
1✔
1038
    headers: Optional[Dict[str, Any]] = None
1✔
1039
    data_field: str = "data"
1✔
1040
    method: str = "GET"
1✔
1041
    verify_cert: bool = True
1✔
1042

1043
    # class level shared cache:
1044
    _loader_cache = LRUCache(max_size=settings.loader_cache_size)
1✔
1045

1046
    def _maybe_set_classification_policy(self):
1✔
1047
        self.set_default_data_classification(["proprietary"], "when loading from API")
×
1048

1049
    def load_iterables(self) -> Dict[str, Iterable]:
1✔
1050
        api_key = os.getenv(self.api_key_env_var, None)
×
1051
        if not api_key:
×
1052
            raise ValueError(
×
1053
                f"The environment variable '{self.api_key_env_var}' must be set to use the LoadFromAPI loader."
1054
            )
1055

1056
        base_headers = {
×
1057
            "Content-Type": "application/json",
1058
            "accept": "application/json",
1059
            "Authorization": f"Bearer {api_key}",
1060
        }
1061
        if self.headers:
×
1062
            base_headers.update(self.headers)
×
1063

1064
        iterables = {}
×
1065
        for split_name, url in self.urls.items():
×
1066
            if self.get_limit() is not None:
×
1067
                self.log_limited_loading()
×
1068

1069
            if self.method == "GET":
×
1070
                response = requests.get(
×
1071
                    url,
1072
                    headers=base_headers,
1073
                    verify=self.verify_cert,
1074
                )
1075
            elif self.method == "POST":
×
1076
                response = requests.post(
×
1077
                    url,
1078
                    headers=base_headers,
1079
                    verify=self.verify_cert,
1080
                    json={},
1081
                )
1082
            else:
1083
                raise ValueError(f"Method {self.method} not supported")
×
1084

1085
            response.raise_for_status()
×
1086

1087
            data = json.loads(response.text)
×
1088

1089
            if self.data_field:
×
1090
                if self.data_field not in data:
×
1091
                    raise ValueError(
×
1092
                        f"Data field '{self.data_field}' not found in API response."
1093
                    )
1094
                data = data[self.data_field]
×
1095

1096
            if self.get_limit() is not None:
×
1097
                data = data[: self.get_limit()]
×
1098

1099
            iterables[split_name] = data
×
1100

1101
        return iterables
×
1102

1103
    def process(self) -> MultiStream:
1✔
1104
        self._maybe_set_classification_policy()
×
1105
        iterables = self.__class__._loader_cache.get(str(self), None)
×
1106
        if iterables is None:
×
1107
            iterables = self.load_iterables()
×
1108
            self.__class__._loader_cache.max_size = settings.loader_cache_size
×
1109
            self.__class__._loader_cache[str(self)] = iterables
×
1110
        return MultiStream.from_iterables(iterables, copying=True)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc