• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15374264096

01 Jun 2025 11:03AM UTC coverage: 80.41% (-0.03%) from 80.442%
15374264096

Pull #1799

github

web-flow
Merge edbf98228 into a4f161074
Pull Request #1799: Make api_key_env_var optional in LoadFromAPI

1687 of 2070 branches covered (81.5%)

Branch coverage included in aggregate %.

10389 of 12948 relevant lines covered (80.24%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.71
src/unitxt/loaders.py
1
"""This section describes unitxt loaders.
2

3
Loaders: Generators of Unitxt Multistreams from existing date sources
4
=====================================================================
5

6
Unitxt is all about readily preparing of any given data source for feeding into any given language model, and then,
7
post-processing the model's output, preparing it for any given evaluator.
8

9
Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
10
of various off-the-shelf operators (i.e., picked from Unitxt catalog), or operators easily implemented by inheriting.
11
The journey starts by a Unitxt Loader bearing a Multistream from the given datasource.
12
A loader, therefore, is the first item on any Unitxt Recipe.
13

14
Unitxt catalog contains several loaders for the most popular datasource formats.
15
All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource is
16
straightforward.
17

18
Available Loaders Overview:
19
    - :class:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
20
    - :class:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
    - :class:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
    - :class:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
    - :class:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
    - :class:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
    - :class:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
    - :class:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
27

28

29

30

31
------------------------
32
"""
33

34
import fnmatch
1✔
35
import itertools
1✔
36
import json
1✔
37
import os
1✔
38
import tempfile
1✔
39
import time
1✔
40
from abc import abstractmethod
1✔
41
from pathlib import Path
1✔
42
from tempfile import TemporaryDirectory
1✔
43
from typing import (
1✔
44
    Any,
45
    Dict,
46
    Generator,
47
    Iterable,
48
    List,
49
    Mapping,
50
    Optional,
51
    Sequence,
52
    Union,
53
)
54

55
import pandas as pd
1✔
56
import requests
1✔
57
from datasets import (
1✔
58
    DatasetDict,
59
    IterableDataset,
60
    IterableDatasetDict,
61
    get_dataset_split_names,
62
)
63
from datasets import load_dataset as _hf_load_dataset
1✔
64
from huggingface_hub import HfApi
1✔
65
from tqdm import tqdm
1✔
66

67
from .dataclass import NonPositionalField
1✔
68
from .dict_utils import dict_get
1✔
69
from .error_utils import Documentation, UnitxtError, UnitxtWarning
1✔
70
from .fusion import FixedFusion
1✔
71
from .logging_utils import get_logger
1✔
72
from .operator import SourceOperator
1✔
73
from .operators import Set
1✔
74
from .settings_utils import get_settings
1✔
75
from .stream import DynamicStream, MultiStream
1✔
76
from .type_utils import isoftype
1✔
77
from .utils import LRUCache, recursive_copy, retry_connection_with_exponential_backoff
1✔
78

79
logger = get_logger()
1✔
80
settings = get_settings()
1✔
81

82
class UnitxtUnverifiedCodeError(UnitxtError):
1✔
83
    def __init__(self, path):
1✔
84
        super().__init__(f"Loader cannot load and run remote code from {path} in huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE.", Documentation.SETTINGS)
×
85

86
@retry_connection_with_exponential_backoff(backoff_factor=2)
1✔
87
def hf_load_dataset(path: str, *args, **kwargs):
1✔
88
    if settings.hf_offline_datasets_path is not None:
1✔
89
        path = os.path.join(settings.hf_offline_datasets_path, path)
×
90
    try:
1✔
91
        return _hf_load_dataset(
1✔
92
            path,
93
            *args, **kwargs,
94
                verification_mode="no_checks",
95
                trust_remote_code=settings.allow_unverified_code,
96
                download_mode= "force_redownload" if settings.disable_hf_datasets_cache else "reuse_dataset_if_exists"
97
            )
98
    except ValueError as e:
1✔
99
        if "trust_remote_code" in str(e):
×
100
            raise UnitxtUnverifiedCodeError(path) from e
×
101
        raise e # Re raise
×
102

103

104
@retry_connection_with_exponential_backoff(backoff_factor=2)
1✔
105
def hf_get_dataset_splits(path: str, name: str, revision=None):
1✔
106
    try:
1✔
107
        return get_dataset_split_names(
1✔
108
            path=path,
109
            config_name=name,
110
            trust_remote_code=settings.allow_unverified_code,
111
            revision=revision,
112
        )
113
    except Exception as e:
1✔
114
        if "trust_remote_code" in str(e):
1✔
115
            raise UnitxtUnverifiedCodeError(path) from e
×
116

117
        if "Couldn't find cache" in str(e):
1✔
118
            raise FileNotFoundError(f"Dataset cache path={path}, name={name} was not found.") from e
×
119
        raise e # Re raise
1✔
120

121
class Loader(SourceOperator):
1✔
122
    """A base class for all loaders.
123

124
    A loader is the first component in the Unitxt Recipe,
125
    responsible for loading data from various sources and preparing it as a MultiStream for processing.
126
    The loader_limit is an optional parameter used to control the maximum number of instances to load from the data source.  It is applied for each split separately.
127
    It is usually provided to the loader via the recipe (see standard.py)
128
    The loader can use this value to limit the amount of data downloaded from the source
129
    to reduce loading time.  However, this may not always be possible, so the
130
    loader may ignore this.  In any case, the recipe, will limit the number of instances in the returned
131
    stream, after load is complete.
132

133
    Args:
134
        loader_limit: Optional integer to specify a limit on the number of records to load.
135
        streaming: Bool indicating if streaming should be used.
136
        num_proc: Optional integer to specify the number of processes to use for parallel dataset loading. Adjust the value according to the number of CPU cores available and the specific needs of your processing task.
137
    """
138

139
    loader_limit: int = None
1✔
140
    streaming: bool = False
1✔
141
    num_proc: int = None
1✔
142

143
    # class level shared cache:
144
    _loader_cache = LRUCache(max_size=settings.loader_cache_size)
1✔
145

146
    def get_limit(self) -> int:
1✔
147
        if settings.global_loader_limit is not None and self.loader_limit is not None:
1✔
148
            return min(int(settings.global_loader_limit), self.loader_limit)
1✔
149
        if settings.global_loader_limit is not None:
1✔
150
            return int(settings.global_loader_limit)
1✔
151
        return self.loader_limit
×
152

153
    def get_limiter(self):
1✔
154
        if settings.global_loader_limit is not None and self.loader_limit is not None:
1✔
155
            if int(settings.global_loader_limit) > self.loader_limit:
1✔
156
                return f"{self.__class__.__name__}.loader_limit"
1✔
157
            return "unitxt.settings.global_loader_limit"
1✔
158
        if settings.global_loader_limit is not None:
1✔
159
            return "unitxt.settings.global_loader_limit"
1✔
160
        return f"{self.__class__.__name__}.loader_limit"
×
161

162
    def log_limited_loading(self):
1✔
163
        if not hasattr(self, "_already_logged_limited_loading") or not self._already_logged_limited_loading:
1✔
164
            self._already_logged_limited_loading = True
1✔
165
            logger.info(
1✔
166
                f"\nLoading limited to {self.get_limit()} instances by setting {self.get_limiter()};"
167
            )
168

169
    def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
1✔
170
        if self.data_classification_policy is None:
1✔
171
            get_logger().warning(
×
172
                f"The {self.get_pretty_print_name()} loader does not set the `data_classification_policy`. "
173
                f"This may lead to sending of undesired data to external services.\n"
174
                f"Set it to a list of classification identifiers. \n"
175
                f"For example:\n"
176
                f"data_classification_policy = ['public']\n"
177
                f" or \n"
178
                f"data_classification_policy =['confidential','pii'])\n"
179
            )
180

181
        operator = Set(
1✔
182
            fields={"data_classification_policy": self.data_classification_policy}
183
        )
184
        return operator(multi_stream)
1✔
185

186
    def set_default_data_classification(
1✔
187
        self, default_data_classification_policy, additional_info
188
    ):
189
        if self.data_classification_policy is None:
1✔
190
            if additional_info is not None:
1✔
191
                logger.info(
1✔
192
                    f"{self.get_pretty_print_name()} sets 'data_classification_policy' to "
193
                    f"{default_data_classification_policy} by default {additional_info}.\n"
194
                    "To use a different value or remove this message, explicitly set the "
195
                    "`data_classification_policy` attribute of the loader.\n"
196
                )
197
            self.data_classification_policy = default_data_classification_policy
1✔
198

199
    @abstractmethod
1✔
200
    def load_iterables(self) -> Dict[str, Iterable]:
1✔
201
        pass
×
202

203
    def _maybe_set_classification_policy(self):
1✔
204
        pass
1✔
205

206
    def load_data(self) -> MultiStream:
1✔
207
        try:
1✔
208
            iterables = self.load_iterables()
1✔
209
        except Exception as e:
×
210
            raise UnitxtError(f"Error in loader:\n{self}") from e
×
211
        if isoftype(iterables, MultiStream):
1✔
212
            return iterables
1✔
213
        return MultiStream.from_iterables(iterables, copying=True)
1✔
214

215
    def process(self) -> MultiStream:
1✔
216
        self._maybe_set_classification_policy()
1✔
217
        return self.add_data_classification(self.load_data())
1✔
218

219
    def get_splits(self):
1✔
220
        return list(self().keys())
×
221

222

223
class LazyLoader(Loader):
1✔
224
    split: Optional[str] = NonPositionalField(default=None)
1✔
225

226
    @abstractmethod
1✔
227
    def get_splits(self) -> List[str]:
1✔
228
        pass
×
229

230
    @abstractmethod
1✔
231
    def split_generator(self, split: str) -> Generator:
1✔
232
        pass
×
233

234
    def load_iterables(self) -> Union[Dict[str, DynamicStream], IterableDatasetDict]:
1✔
235
        if self.split is not None:
1✔
236
            splits = [self.split]
1✔
237
        else:
238
            splits = self.get_splits()
1✔
239

240
        return MultiStream({
1✔
241
            split: DynamicStream(self.split_generator, gen_kwargs={"split": split})
242
            for split in splits
243
        })
244

245

246
class LoadHF(LazyLoader):
1✔
247
    """Loads datasets from the HuggingFace Hub.
248

249
    It supports loading with or without streaming,
250
    and it can filter datasets upon loading.
251

252
    Args:
253
        path:
254
            The path or identifier of the dataset on the HuggingFace Hub.
255
        name:
256
            An optional dataset name.
257
        data_dir:
258
            Optional directory to store downloaded data.
259
        split:
260
            Optional specification of which split to load.
261
        data_files:
262
            Optional specification of particular data files to load. When you provide a list of data_files to Hugging Face's load_dataset function without explicitly specifying the split argument, these files are automatically placed into the train split.
263
        revision:
264
            Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
265
        streaming (bool):
266
            indicating if streaming should be used.
267
        filtering_lambda (str, optional):
268
            A lambda function for filtering the data after loading.
269
        num_proc (int, optional):
270
            Specifies the number of processes to use for parallel dataset loading.
271

272
    Example:
273
        Loading glue's mrpc dataset
274

275
        .. code-block:: python
276

277
            load_hf = LoadHF(path='glue', name='mrpc')
278
    """
279

280
    path: str
1✔
281
    name: Optional[str] = None
1✔
282
    data_dir: Optional[str] = None
1✔
283
    split: Optional[str] = None
1✔
284
    data_files: Optional[
1✔
285
        Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
286
    ] = None
287
    revision: Optional[str] = None
1✔
288
    streaming: bool = None
1✔
289
    filtering_lambda: Optional[str] = None
1✔
290
    num_proc: Optional[int] = None
1✔
291
    splits: Optional[List[str]] = None
1✔
292

293
    def filter_load(self, dataset: DatasetDict):
1✔
294
        if not settings.allow_unverified_code:
1✔
295
            raise ValueError(
×
296
                f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
297
            )
298
        logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
1✔
299
        return dataset.filter(eval(self.filtering_lambda))
1✔
300

301
    def is_streaming(self) -> bool:
1✔
302
        if self.streaming is None:
1✔
303
            return settings.stream_hf_datasets_by_default
1✔
304
        return self.streaming
1✔
305

306
    def is_in_cache(self, split):
1✔
307
        dataset_id = str(self) + "_" + str(split)
1✔
308
        return dataset_id in self.__class__._loader_cache
1✔
309
    # returns Dict when split names are not known in advance, and just the the single split dataset - if known
310
    def load_dataset(
1✔
311
        self, split: str, streaming=None, disable_memory_caching=False
312
    ) -> Union[IterableDatasetDict, IterableDataset]:
313
        dataset_id = str(self) + "_" + str(split)
1✔
314
        dataset = self.__class__._loader_cache.get(dataset_id, None)
1✔
315
        if dataset is None:
1✔
316
            if streaming is None:
1✔
317
                streaming = self.is_streaming()
1✔
318

319
            dataset = hf_load_dataset(
1✔
320
                self.path,
321
                name=self.name,
322
                data_dir=self.data_dir,
323
                data_files=self.data_files,
324
                revision=self.revision,
325
                streaming=streaming,
326
                split=split,
327
                num_proc=self.num_proc,
328
            )
329

330
            if dataset is None:
1✔
331
                raise NotImplementedError() from None
×
332

333
            if not disable_memory_caching:
1✔
334
                self.__class__._loader_cache.max_size = settings.loader_cache_size
1✔
335
                self.__class__._loader_cache[dataset_id] = dataset
1✔
336
        self._already_logged_limited_loading = True
1✔
337

338
        return dataset
1✔
339

340
    def _maybe_set_classification_policy(self):
1✔
341
        if os.path.exists(self.path):
1✔
342
            self.set_default_data_classification(
1✔
343
                ["proprietary"], "when loading from local files"
344
            )
345
        else:
346
            self.set_default_data_classification(
1✔
347
                ["public"],
348
                None,  # No warning when loading from public hub
349
            )
350

351
    @retry_connection_with_exponential_backoff(max_retries=3, backoff_factor=2)
1✔
352
    def get_splits(self):
1✔
353
        if self.splits is not None:
1✔
354
            return self.splits
1✔
355
        if self.data_files is not None:
1✔
356
            if isinstance(self.data_files, dict):
1✔
357
                return list(self.data_files.keys())
1✔
358
            return ["train"]
1✔
359
        try:
1✔
360
            return hf_get_dataset_splits(
1✔
361
                path=self.path,
362
                name=self.name,
363
                revision=self.revision,
364
            )
365
        except Exception:
1✔
366
            UnitxtWarning(
1✔
367
                f'LoadHF(path="{self.path}", name="{self.name}") could not retrieve split names without loading the dataset. Consider defining "splits" in the LoadHF definition to improve loading time.'
368
            )
369
            try:
1✔
370
                dataset = self.load_dataset(
1✔
371
                    split=None, disable_memory_caching=True, streaming=True
372
                )
373
            except (
1✔
374
                NotImplementedError
375
            ):  # streaming is not supported for zipped files so we load without streaming
376
                dataset = self.load_dataset(split=None, streaming=False)
1✔
377

378
            if dataset is None:
1✔
379
                raise FileNotFoundError(f"Dataset path={self.path}, name={self.name} was not found.") from None
×
380

381
            return list(dataset.keys())
1✔
382

383
    def split_generator(self, split: str) -> Generator:
1✔
384
        if self.get_limit() is not None:
1✔
385
            if not self.is_in_cache(split):
1✔
386
                self.log_limited_loading()
1✔
387
        try:
1✔
388
            dataset = self.load_dataset(split=split)
1✔
389
        except (
×
390
            NotImplementedError
391
        ):  # streaming is not supported for zipped files so we load without streaming
392
            dataset = self.load_dataset(split=split, streaming=False)
×
393

394
        if self.filtering_lambda is not None:
1✔
395
            dataset = self.filter_load(dataset)
1✔
396

397
        limit = self.get_limit()
1✔
398
        if limit is None:
1✔
399
            yield from dataset
×
400
        else:
401
            for i, instance in enumerate(dataset):
1✔
402
                yield instance
1✔
403
                if i + 1 >= limit:
1✔
404
                    break
1✔
405

406
class LoadWithPandas(LazyLoader):
1✔
407
    """Utility base class for classes loading with pandas."""
408

409
    files: Dict[str, str]
1✔
410
    chunksize: int = 1000
1✔
411
    loader_limit: Optional[int] = None
1✔
412
    streaming: bool = True
1✔
413
    compression: Optional[str] = None
1✔
414

415
    def _maybe_set_classification_policy(self):
1✔
416
        self.set_default_data_classification(
1✔
417
            ["proprietary"], "when loading from local files"
418
        )
419

420
    def split_generator(self, split: str) -> Generator:
1✔
421
        dataset_id = str(self) + "_" + split
1✔
422
        dataset = self.__class__._loader_cache.get(dataset_id, None)
1✔
423
        if dataset is None:
1✔
424
            if self.get_limit() is not None:
1✔
425
                self.log_limited_loading()
1✔
426
            for attempt in range(settings.loaders_max_retries):
1✔
427
                try:
1✔
428
                    file = self.files[split]
1✔
429
                    if self.get_limit() is not None:
1✔
430
                        self.log_limited_loading()
1✔
431

432
                    try:
1✔
433
                        dataframe = self.read_dataframe(file)
1✔
434
                        break
1✔
435
                    except ValueError:
1✔
436
                        import fsspec
×
437

438
                        with fsspec.open(file, mode="rt") as file:
×
439
                            dataframe = self.read_dataframe(file)
×
440
                        break
×
441
                except Exception as e:
1✔
442
                    logger.warning(f"Attempt  load {attempt + 1} failed: {e}")
1✔
443
                    if attempt < settings.loaders_max_retries - 1:
1✔
444
                        time.sleep(2)
1✔
445
                    else:
446
                        raise e
1✔
447

448
            limit = self.get_limit()
1✔
449
            if limit is not None and len(dataframe) > limit:
1✔
450
                dataframe = dataframe.head(limit)
×
451

452
            dataset = dataframe.to_dict("records")
1✔
453

454
            self.__class__._loader_cache.max_size = settings.loader_cache_size
1✔
455
            self.__class__._loader_cache[dataset_id] = dataset
1✔
456

457
        for instance in self.__class__._loader_cache[dataset_id]:
1✔
458
            yield recursive_copy(instance)
1✔
459

460
    def get_splits(self) -> List[str]:
1✔
461
        return list(self.files.keys())
1✔
462

463

464
    def get_args(self) -> Dict[str, Any]:
1✔
465
        args = {}
1✔
466
        if self.compression is not None:
1✔
467
            args["compression"] = self.compression
×
468
        if self.get_limit() is not None:
1✔
469
            args["nrows"] = self.get_limit()
1✔
470
        return args
1✔
471

472
    @abstractmethod
1✔
473
    def read_dataframe(self, file) -> pd.DataFrame:
1✔
474
        ...
×
475

476
class LoadCSV(LoadWithPandas):
1✔
477
    """Loads data from CSV files.
478

479
    Supports streaming and can handle large files by loading them in chunks.
480

481
    Args:
482
        files (Dict[str, str]): A dictionary mapping names to file paths.
483
        chunksize : Size of the chunks to load at a time.
484
        loader_limit: Optional integer to specify a limit on the number of records to load.
485
        streaming: Bool indicating if streaming should be used.
486
        sep: String specifying the separator used in the CSV files.
487

488
    Example:
489
        Loading csv
490

491
        .. code-block:: python
492

493
            load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
494
    """
495

496
    sep: str = ","
1✔
497

498
    def read_dataframe(self, file) -> pd.DataFrame:
1✔
499
        return pd.read_csv(
1✔
500
            file,
501
            sep=self.sep,
502
            low_memory=self.streaming,
503
            **self.get_args()
504
        )
505

506

507
def read_file(source) -> bytes:
1✔
508

509
    if hasattr(source, "read"):
1✔
510
        return source.read()
×
511

512
    if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")):
1✔
513
        from urllib import request
×
514
        with request.urlopen(source) as response:
×
515
            return response.read()
×
516

517
    with open(source, "rb") as f:
1✔
518
        return f.read()
1✔
519

520
class LoadJsonFile(LoadWithPandas):
1✔
521
    """Loads data from JSON files.
522

523
    Supports streaming and can handle large files by loading them in chunks.
524

525
    Args:
526
        files (Dict[str, str]): A dictionary mapping names to file paths.
527
        chunksize : Size of the chunks to load at a time.
528
        loader_limit: Optional integer to specify a limit on the number of records to load.
529
        streaming: Bool indicating if streaming should be used.
530
        lines: Bool indicate if it is json lines file structure. Otherwise, assumes a single json object in the file.
531
        data_field: optional field within the json object, that contains the list of instances.
532

533
    Example:
534
        Loading json lines
535

536
        .. code-block:: python
537

538
            load_csv = LoadJsonFile(files={'train': 'path/to/train.jsonl'}, line=True, chunksize=100)
539
    """
540

541
    lines: bool = False
1✔
542
    data_field: Optional[str] = None
1✔
543

544
    def read_dataframe(self, file) -> pd.DataFrame:
1✔
545

546
        args =  self.get_args()
1✔
547
        if not self.lines:
1✔
548
            data = json.loads(read_file(file))
1✔
549
            if (self.data_field):
1✔
550
                instances = dict_get(data, self.data_field)
1✔
551
                if not isoftype(instances,List[Dict[str,Any]]):
1✔
552
                    raise UnitxtError(f"{self.data_field} of file {file} is not a list of dictionariess in LoadJsonFile loader")
×
553
            else:
554
                if isoftype(data,Dict[str,Any]):
1✔
555
                    instances = [data]
1✔
556
                elif isoftype(data,List[Dict[str,Any]]):
1✔
557
                    instances=data
1✔
558
                else:
559
                    raise UnitxtError(f"data of file {file} is not dictionary or a list of dictionaries in LoadJsonFile loader")
×
560
            dataframe = pd.DataFrame(instances)
1✔
561
        else:
562
            if self.data_field is not None:
1✔
563
                raise UnitxtError("Can not load from a specific 'data_field' when loading multiple lines (lines=True)")
×
564
            dataframe = pd.read_json(
1✔
565
                file,
566
                lines=self.lines,
567
                **args
568
            )
569
        return dataframe
1✔
570

571

572

573
class LoadFromSklearn(LazyLoader):
1✔
574
    """Loads datasets from the sklearn library.
575

576
    This loader does not support streaming and is intended for use with sklearn's dataset fetch functions.
577

578
    Args:
579
        dataset_name: The name of the sklearn dataset to fetch.
580
        splits: A list of data splits to load, e.g., ['train', 'test'].
581

582
    Example:
583
        Loading form sklearn
584

585
        .. code-block:: python
586

587
            load_sklearn = LoadFromSklearn(dataset_name='iris', splits=['train', 'test'])
588
    """
589

590
    dataset_name: str
1✔
591
    splits: List[str] = ["train", "test"]
1✔
592

593
    _requirements_list: List[str] = ["scikit-learn", "pandas"]
1✔
594

595
    data_classification_policy = ["public"]
1✔
596

597
    def verify(self):
1✔
598
        super().verify()
×
599

600
        if self.streaming:
×
601
            raise NotImplementedError("LoadFromSklearn cannot load with streaming.")
×
602

603
    def prepare(self):
1✔
604
        super().prepare()
×
605
        from sklearn import datasets as sklearn_datatasets
×
606

607
        self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
×
608

609
    def get_splits(self):
1✔
610
        return self.splits
×
611

612
    def split_generator(self, split: str) -> Generator:
1✔
613
        dataset_id = str(self) + "_" + split
×
614
        dataset = self.__class__._loader_cache.get(dataset_id, None)
×
615
        if dataset is None:
×
616
            split_data = self.downloader(subset=split)
×
617
            targets = [split_data["target_names"][t] for t in split_data["target"]]
×
618
            df = pd.DataFrame([split_data["data"], targets]).T
×
619
            df.columns = ["data", "target"]
×
620
            dataset = df.to_dict("records")
×
621
            self.__class__._loader_cache.max_size = settings.loader_cache_size
×
622
            self.__class__._loader_cache[dataset_id] = dataset
×
623
        for instance in self.__class__._loader_cache[dataset_id]:
×
624
            yield recursive_copy(instance)
×
625

626

627
class MissingKaggleCredentialsError(ValueError):
1✔
628
    pass
1✔
629

630

631
class LoadFromKaggle(Loader):
1✔
632
    """Loads datasets from Kaggle.
633

634
    Requires Kaggle API credentials and does not support streaming.
635

636
    Args:
637
        url: URL to the Kaggle dataset.
638

639
    Example:
640
        Loading from kaggle
641

642
        .. code-block:: python
643

644
            load_kaggle = LoadFromKaggle(url='kaggle.com/dataset/example')
645
    """
646

647
    url: str
1✔
648

649
    _requirements_list: List[str] = ["opendatasets"]
1✔
650
    data_classification_policy = ["public"]
1✔
651

652
    def verify(self):
1✔
653
        super().verify()
×
654
        if not os.path.isfile("kaggle.json"):
×
655
            raise MissingKaggleCredentialsError(
×
656
                "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
657
            )
658

659
        if self.streaming:
×
660
            raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
×
661

662
    def prepare(self):
1✔
663
        super().prepare()
×
664
        from opendatasets import download
×
665

666
        self.downloader = download
×
667

668
    def load_iterables(self):
1✔
669
        with TemporaryDirectory() as temp_directory:
×
670
            self.downloader(self.url, temp_directory)
×
671
            return hf_load_dataset(temp_directory, streaming=False)
×
672

673

674
class LoadFromIBMCloud(Loader):
1✔
675
    """Loads data from IBM Cloud Object Storage.
676

677
    Does not support streaming and requires AWS-style access keys.
678
    data_dir Can be either:
679
    1. a list of file names, the split of each file is determined by the file name pattern
680
    2. Mapping: split -> file_name, e.g. {"test" : "test.json", "train": "train.json"}
681
    3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
682

683
    Args:
684
        endpoint_url_env:
685
            Environment variable name for the IBM Cloud endpoint URL.
686
        aws_access_key_id_env:
687
            Environment variable name for the AWS access key ID.
688
        aws_secret_access_key_env:
689
            Environment variable name for the AWS secret access key.
690
        bucket_name:
691
            Name of the S3 bucket from which to load data.
692
        data_dir:
693
            Optional directory path within the bucket.
694
        data_files:
695
            Union type allowing either a list of file names or a mapping of splits to file names.
696
        data_field:
697
            The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
698
        caching (bool):
699
            indicating if caching is enabled to avoid re-downloading data.
700

701
    Example:
702
        Loading from IBM Cloud
703

704
        .. code-block:: python
705

706
            load_ibm_cloud = LoadFromIBMCloud(
707
                endpoint_url_env='IBM_CLOUD_ENDPOINT',
708
                aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
709
                aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY', # pragma: allowlist secret
710
                bucket_name='my-bucket'
711
            )
712
            multi_stream = load_ibm_cloud.process()
713
    """
714

715
    endpoint_url_env: str
1✔
716
    aws_access_key_id_env: str
1✔
717
    aws_secret_access_key_env: str
1✔
718
    bucket_name: str
1✔
719
    data_dir: str = None
1✔
720

721
    data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
1✔
722
    data_field: str = None
1✔
723
    caching: bool = True
1✔
724
    data_classification_policy = ["proprietary"]
1✔
725

726
    _requirements_list: List[str] = ["ibm-cos-sdk"]
1✔
727

728
    def _download_from_cos(self, cos, bucket_name, item_name, local_file):
1✔
729
        logger.info(f"Downloading {item_name} from {bucket_name} COS")
1✔
730
        try:
1✔
731
            response = cos.Object(bucket_name, item_name).get()
1✔
732
            size = response["ContentLength"]
1✔
733
            body = response["Body"]
1✔
734
        except Exception as e:
×
735
            raise Exception(
×
736
                f"Unabled to access {item_name} in {bucket_name} in COS", e
737
            ) from e
738

739
        if self.get_limit() is not None:
1✔
740
            if item_name.endswith(".jsonl"):
1✔
741
                first_lines = list(
1✔
742
                    itertools.islice(body.iter_lines(), self.get_limit())
743
                )
744
                with open(local_file, "wb") as downloaded_file:
1✔
745
                    for line in first_lines:
1✔
746
                        downloaded_file.write(line)
1✔
747
                        downloaded_file.write(b"\n")
1✔
748
                logger.info(
1✔
749
                    f"\nDownload successful limited to {self.get_limit()} lines"
750
                )
751
                return
1✔
752

753
        progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
×
754

755
        def upload_progress(chunk):
×
756
            progress_bar.update(chunk)
×
757

758
        try:
×
759
            cos.Bucket(bucket_name).download_file(
×
760
                item_name, local_file, Callback=upload_progress
761
            )
762
            logger.info("\nDownload Successful")
×
763
        except Exception as e:
×
764
            raise Exception(
×
765
                f"Unabled to download {item_name} in {bucket_name}", e
766
            ) from e
767

768
    def prepare(self):
1✔
769
        super().prepare()
1✔
770
        self.endpoint_url = os.getenv(self.endpoint_url_env)
1✔
771
        self.aws_access_key_id = os.getenv(self.aws_access_key_id_env)
1✔
772
        self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env)
1✔
773
        root_dir = os.getenv("UNITXT_IBM_COS_CACHE", None) or os.getcwd()
1✔
774
        self.cache_dir = os.path.join(root_dir, "ibmcos_datasets")
1✔
775

776
        if not os.path.exists(self.cache_dir):
1✔
777
            Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
1✔
778
        self.verified = False
1✔
779

780
    def lazy_verify(self):
1✔
781
        super().verify()
1✔
782
        assert (
1✔
783
            self.endpoint_url is not None
784
        ), f"Please set the {self.endpoint_url_env} environmental variable"
785
        assert (
1✔
786
            self.aws_access_key_id is not None
787
        ), f"Please set {self.aws_access_key_id_env} environmental variable"
788
        assert (
1✔
789
            self.aws_secret_access_key is not None
790
        ), f"Please set {self.aws_secret_access_key_env} environmental variable"
791
        if self.streaming:
1✔
792
            raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
×
793

794
    def _maybe_set_classification_policy(self):
1✔
795
        self.set_default_data_classification(
1✔
796
            ["proprietary"], "when loading from IBM COS"
797
        )
798

799
    def load_iterables(self):
1✔
800
        if not self.verified:
1✔
801
            self.lazy_verify()
1✔
802
            self.verified = True
1✔
803
        import ibm_boto3
1✔
804

805
        cos = ibm_boto3.resource(
1✔
806
            "s3",
807
            aws_access_key_id=self.aws_access_key_id,
808
            aws_secret_access_key=self.aws_secret_access_key,
809
            endpoint_url=self.endpoint_url,
810
        )
811
        local_dir = os.path.join(
1✔
812
            self.cache_dir,
813
            self.bucket_name,
814
            self.data_dir or "",  # data_dir can be None
815
            f"loader_limit_{self.get_limit()}",
816
        )
817
        if not os.path.exists(local_dir):
1✔
818
            Path(local_dir).mkdir(parents=True, exist_ok=True)
1✔
819
        if isinstance(self.data_files, Mapping):
1✔
820
            data_files_names = list(self.data_files.values())
1✔
821
            if not isinstance(data_files_names[0], str):
1✔
822
                data_files_names = list(itertools.chain(*data_files_names))
1✔
823
        else:
824
            data_files_names = self.data_files
1✔
825

826
        for data_file in data_files_names:
1✔
827
            local_file = os.path.join(local_dir, data_file)
1✔
828
            if not self.caching or not os.path.exists(local_file):
1✔
829
                # Build object key based on parameters. Slash character is not
830
                # allowed to be part of object key in IBM COS.
831
                object_key = (
1✔
832
                    self.data_dir + "/" + data_file
833
                    if self.data_dir is not None
834
                    else data_file
835
                )
836
                with tempfile.NamedTemporaryFile() as temp_file:
1✔
837
                    # Download to  a temporary file in same file partition, and then do an atomic move
838
                    self._download_from_cos(
1✔
839
                        cos,
840
                        self.bucket_name,
841
                        object_key,
842
                        local_dir + "/" + os.path.basename(temp_file.name),
843
                    )
844
                    os.renames(
1✔
845
                        local_dir + "/" + os.path.basename(temp_file.name),
846
                        local_dir + "/" + data_file,
847
                    )
848

849
        if isinstance(self.data_files, list):
1✔
850
            dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
1✔
851
        else:
852
            dataset = hf_load_dataset(
1✔
853
                local_dir,
854
                streaming=False,
855
                data_files=self.data_files,
856
                field=self.data_field,
857
            )
858

859
        return dataset
1✔
860

861

862
class MultipleSourceLoader(LazyLoader):
1✔
863
    """Allows loading data from multiple sources, potentially mixing different types of loaders.
864

865
    Args:
866
        sources: A list of loaders that will be combined to form a unified dataset.
867

868
    Examples:
869
        1) Loading the train split from a HuggingFace Hub and the test set from a local file:
870

871
        .. code-block:: python
872

873
            MultipleSourceLoader(sources = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
874

875

876

877
        2) Loading a test set combined from two files
878

879
        .. code-block:: python
880

881
            MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
882
    """
883

884
    sources: List[Loader]
1✔
885

886
    def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
1✔
887
        if self.data_classification_policy is None:
1✔
888
            return multi_stream
1✔
889
        return super().add_data_classification(multi_stream)
×
890

891
    def get_splits(self):
1✔
892
        splits = []
1✔
893
        for loader in self.sources:
1✔
894
            splits.extend(loader.get_splits())
1✔
895
        return list(set(splits))
1✔
896

897
    def split_generator(self, split: str) -> Generator[Any, None, None]:
1✔
898
        yield from FixedFusion(
1✔
899
            subsets=self.sources,
900
            max_instances_per_subset=self.get_limit(),
901
            include_splits=[split],
902
        )()[split]
903

904

905
class LoadFromDictionary(Loader):
1✔
906
    """Allows loading data from a dictionary of constants.
907

908
    The loader can be used, for example, when debugging or working with small datasets.
909

910
    Args:
911
        data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
912

913
    Example:
914
        Loading dictionary
915

916
        .. code-block:: python
917

918
            data = {
919
                "train": [{"input": "SomeInput1", "output": "SomeResult1"},
920
                          {"input": "SomeInput2", "output": "SomeResult2"}],
921
                "test":  [{"input": "SomeInput3", "output": "SomeResult3"},
922
                          {"input": "SomeInput4", "output": "SomeResult4"}]
923
            }
924
            loader = LoadFromDictionary(data=data)
925
    """
926

927
    data: Dict[str, List[Dict[str, Any]]]
1✔
928

929
    def verify(self):
1✔
930
        super().verify()
1✔
931
        if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
1✔
932
            raise ValueError(
1✔
933
                f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
934
                f"Expected data should map between split name and list of instances.\n"
935
                f"Received value: {self.data}\n"
936
            )
937
        for split in self.data.keys():
1✔
938
            if len(self.data[split]) == 0:
1✔
939
                raise ValueError(f"Split {split} has no instances.")
×
940
            first_instance = self.data[split][0]
1✔
941
            for instance in self.data[split]:
1✔
942
                if instance.keys() != first_instance.keys():
1✔
943
                    raise ValueError(
1✔
944
                        f"Not all instances in split '{split}' have the same fields.\n"
945
                        f"instance {instance} has different fields different from {first_instance}"
946
                    )
947

948
    def _maybe_set_classification_policy(self):
1✔
949
        self.set_default_data_classification(
1✔
950
            ["proprietary"], "when loading from python dictionary"
951
        )
952

953
    def load_iterables(self) -> MultiStream:
1✔
954
        return self.data
1✔
955

956

957
class LoadFromHFSpace(LazyLoader):
1✔
958
    """Used to load data from HuggingFace Spaces lazily.
959

960
    Args:
961
        space_name (str):
962
            Name of the HuggingFace Space to be accessed.
963
        data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
964
            Relative paths to files within a given repository. If given as a mapping,
965
            paths should be values, while keys should represent the type of respective files
966
            (training, testing etc.).
967
        path (str, optional):
968
            Absolute path to a directory where data should be downloaded.
969
        revision (str, optional):
970
            ID of a Git branch or commit to be used. By default, it is set to None,
971
            thus data is downloaded from the main branch of the accessed repository.
972
        use_token (bool, optional):
973
            Whether a token is used for authentication when accessing
974
            the HuggingFace Space. If necessary, the token is read from the HuggingFace
975
            config folder.
976
        token_env (str, optional):
977
            Key of an env variable which value will be used for
978
            authentication when accessing the HuggingFace Space - if necessary.
979
    """
980

981
    space_name: str
1✔
982
    data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
1✔
983
    path: Optional[str] = None
1✔
984
    revision: Optional[str] = None
1✔
985
    use_token: Optional[bool] = None
1✔
986
    token_env: Optional[str] = None
1✔
987
    requirements_list: List[str] = ["huggingface_hub"]
1✔
988

989
    streaming: bool = True
1✔
990

991
    def _get_token(self) -> Optional[Union[bool, str]]:
1✔
992
        if self.token_env:
1✔
993
            token = os.getenv(self.token_env)
×
994
            if not token:
×
995
                get_logger().warning(
×
996
                    f"The 'token_env' parameter was specified as '{self.token_env}', "
997
                    f"however, no environment variable under such a name was found. "
998
                    f"Therefore, the loader will not use any tokens for authentication."
999
                )
1000
            return token
×
1001
        return self.use_token
1✔
1002

1003
    @staticmethod
1✔
1004
    def _is_wildcard(path: str) -> bool:
1✔
1005
        wildcard_characters = ["*", "?", "[", "]"]
1✔
1006
        return any(char in path for char in wildcard_characters)
1✔
1007

1008

1009

1010
    def _get_repo_files(self):
1✔
1011
        if not hasattr(self, "_repo_files") or self._repo_files is None:
×
1012
            api = HfApi()
×
1013
            self._repo_files = api.list_repo_files(
×
1014
                self.space_name, repo_type="space", revision=self.revision
1015
            )
1016
        return self._repo_files
×
1017

1018
    def _get_sub_files(self, file: str) -> List[str]:
1✔
1019
        if self._is_wildcard(file):
1✔
1020
            return fnmatch.filter(self._get_repo_files(), file)
×
1021
        return [file]
1✔
1022

1023

1024
    def get_splits(self) -> List[str]:
1✔
1025
        if isinstance(self.data_files, Mapping):
1✔
1026
            return list(self.data_files.keys())
1✔
1027
        return ["train"]  # Default to 'train' if not specified
×
1028

1029
    def split_generator(self, split: str) -> Generator:
1✔
1030
        from huggingface_hub import hf_hub_download
1✔
1031
        from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
1✔
1032

1033
        token = self._get_token()
1✔
1034
        files = self.data_files.get(split, self.data_files) if isinstance(self.data_files, Mapping) else self.data_files
1✔
1035

1036
        if isinstance(files, str):
1✔
1037
            files = [files]
1✔
1038
        limit = self.get_limit()
1✔
1039

1040
        if limit is not None:
1✔
1041
            total = 0
1✔
1042
            self.log_limited_loading()
1✔
1043

1044
        for file in files:
1✔
1045
            for sub_file in self._get_sub_files(file):
1✔
1046
                try:
1✔
1047
                    file_path = hf_hub_download(
1✔
1048
                        repo_id=self.space_name,
1049
                        filename=sub_file,
1050
                        repo_type="space",
1051
                        token=token,
1052
                        revision=self.revision,
1053
                        local_dir=self.path,
1054
                    )
1055
                except EntryNotFoundError as e:
×
1056
                    raise ValueError(
×
1057
                        f"The file '{file}' was not found in the space '{self.space_name}'. "
1058
                        f"Please check if the filename is correct, or if it exists in that "
1059
                        f"Huggingface space."
1060
                    ) from e
1061
                except RepositoryNotFoundError as e:
×
1062
                    raise ValueError(
×
1063
                        f"The Huggingface space '{self.space_name}' was not found. "
1064
                        f"Please check if the name is correct and you have access to the space."
1065
                    ) from e
1066

1067
                with open(file_path, encoding="utf-8") as f:
1✔
1068
                    for line in f:
1✔
1069
                        yield json.loads(line.strip())
1✔
1070
                        if limit is not None:
1✔
1071
                            total += 1
1✔
1072
                            if total >= limit:
1✔
1073
                                return
1✔
1074

1075

1076

1077
class LoadFromAPI(Loader):
1✔
1078
    """Loads data from from API.
1079

1080
    This loader is designed to fetch data from an API endpoint,
1081
    handling authentication through an API key. It supports
1082
    customizable chunk sizes and limits for data retrieval.
1083

1084
    Args:
1085
        urls (Dict[str, str]):
1086
            A dictionary mapping split names to their respective API URLs.
1087
        chunksize (int, optional):
1088
            The size of data chunks to fetch in each request. Defaults to 100,000.
1089
        loader_limit (int, optional):
1090
            Limits the number of records to load. Applied per split. Defaults to None.
1091
        streaming (bool, optional):
1092
            Determines if data should be streamed. Defaults to False.
1093
        api_key_env_var (str, optional):
1094
            The name of the environment variable holding the API key.
1095
            Defaults to "SQL_API_KEY".
1096
        headers (Dict[str, Any], optional):
1097
            Additional headers to include in API requests. Defaults to None.
1098
        data_field (str, optional):
1099
            The name of the field in the API response that contains the data.
1100
            Defaults to "data".
1101
        method (str, optional):
1102
            The HTTP method to use for API requests. Defaults to "GET".
1103
        verify_cert (bool):
1104
            Apply verification of the SSL certificate
1105
            Defaults as True
1106
    """
1107

1108
    urls: Dict[str, str]
1✔
1109
    chunksize: int = 100000
1✔
1110
    loader_limit: Optional[int] = None
1✔
1111
    streaming: bool = False
1✔
1112
    api_key_env_var: Optional[str] = ""
1✔
1113
    headers: Optional[Dict[str, Any]] = None
1✔
1114
    data_field: str = "data"
1✔
1115
    method: str = "GET"
1✔
1116
    verify_cert: bool = True
1✔
1117

1118
    # class level shared cache:
1119
    _loader_cache = LRUCache(max_size=settings.loader_cache_size)
1✔
1120

1121
    def _maybe_set_classification_policy(self):
1✔
1122
        self.set_default_data_classification(["proprietary"], "when loading from API")
×
1123

1124
    def load_iterables(self) -> Dict[str, Iterable]:
1✔
1125
        if self.api_key_env_var is not None:
×
1126
            api_key = os.getenv(self.api_key_env_var, None)
×
1127
            if not api_key:
×
1128
                raise ValueError(
×
1129
                    f"The environment variable '{self.api_key_env_var}' must be set to use the LoadFromAPI loader."
1130
                )
1131
        else:
1132
            api_key = None
×
1133

1134
        base_headers = {
×
1135
            "Content-Type": "application/json",
1136
            "accept": "application/json",
1137
        }
1138

1139
        if api_key is not None:
×
1140
            base_headers["Authorization"] = f"Bearer {api_key}"
×
1141

1142
        if self.headers:
×
1143
            base_headers.update(self.headers)
×
1144

1145
        iterables = {}
×
1146
        for split_name, url in self.urls.items():
×
1147
            if self.get_limit() is not None:
×
1148
                self.log_limited_loading()
×
1149

1150
            if self.method == "GET":
×
1151
                response = requests.get(
×
1152
                    url,
1153
                    headers=base_headers,
1154
                    verify=self.verify_cert,
1155
                )
1156
            elif self.method == "POST":
×
1157
                response = requests.post(
×
1158
                    url,
1159
                    headers=base_headers,
1160
                    verify=self.verify_cert,
1161
                    json={},
1162
                )
1163
            else:
1164
                raise ValueError(f"Method {self.method} not supported")
×
1165

1166
            response.raise_for_status()
×
1167

1168
            data = json.loads(response.text)
×
1169

1170
            if self.data_field:
×
1171
                if self.data_field not in data:
×
1172
                    raise ValueError(
×
1173
                        f"Data field '{self.data_field}' not found in API response."
1174
                    )
1175
                data = data[self.data_field]
×
1176

1177
            if self.get_limit() is not None:
×
1178
                data = data[: self.get_limit()]
×
1179

1180
            iterables[split_name] = data
×
1181

1182
        return iterables
×
1183

1184
    def process(self) -> MultiStream:
1✔
1185
        self._maybe_set_classification_policy()
×
1186
        iterables = self.__class__._loader_cache.get(str(self), None)
×
1187
        if iterables is None:
×
1188
            iterables = self.load_iterables()
×
1189
            self.__class__._loader_cache.max_size = settings.loader_cache_size
×
1190
            self.__class__._loader_cache[str(self)] = iterables
×
1191
        return MultiStream.from_iterables(iterables, copying=True)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc