• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12765217246

14 Jan 2025 09:58AM UTC coverage: 79.393% (+0.02%) from 79.372%
12765217246

push

github

web-flow
Add Tables Understanding Benchmark (#1506)

* init commit bench

Signed-off-by: ShirApp <shirashury@gmail.com>

* merge updates

* Make tables benchmark

Signed-off-by: elronbandel <elronbandel@gmail.com>

* modify prompts (instruction once)

* modify prompts (instruction once) in generation template

* change llm as judge metric for scigen (Yifan's code)

* updated recipes

* add table augmenter

* update table benchmark files

* delete some files from branch

* fix typo of augmeter list in benchmark code + update recipes to include loader limit

* fix typos

* drop personal scripts

* create updated json cards (tab fact+turl)

* updated cards (tab fact+turl)

* add tablebench visualization json file

* delete old file

* update df serializer test

* drop table bench visualization since it is not a part of the benchmark, and we are not sure about its evaluation metric

---------

Signed-off-by: ShirApp <shirashury@gmail.com>
Signed-off-by: elronbandel <elronbandel@gmail.com>
Co-authored-by: elronbandel <elronbandel@gmail.com>

1387 of 1735 branches covered (79.94%)

Branch coverage included in aggregate %.

8742 of 11023 relevant lines covered (79.31%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.02
src/unitxt/loaders.py
1
"""This section describes unitxt loaders.
2

3
Loaders: Generators of Unitxt Multistreams from existing date sources
4
=====================================================================
5

6
Unitxt is all about readily preparing of any given data source for feeding into any given language model, and then,
7
post-processing the model's output, preparing it for any given evaluator.
8

9
Through that journey, the data advances in the form of Unitxt Multistream, undergoing a sequential application
10
of various off-the-shelf operators (i.e., picked from Unitxt catalog), or operators easily implemented by inheriting.
11
The journey starts by a Unitxt Loader bearing a Multistream from the given datasource.
12
A loader, therefore, is the first item on any Unitxt Recipe.
13

14
Unitxt catalog contains several loaders for the most popular datasource formats.
15
All these loaders inherit from Loader, and hence, implementing a loader to expand over a new type of datasource is
16
straightforward.
17

18
Available Loaders Overview:
19
    - :class:`LoadHF <unitxt.loaders.LoadHF>` - Loads data from HuggingFace Datasets.
20
    - :class:`LoadCSV <unitxt.loaders.LoadCSV>` - Imports data from CSV (Comma-Separated Values) files.
21
    - :class:`LoadFromKaggle <unitxt.loaders.LoadFromKaggle>` - Retrieves datasets from the Kaggle community site.
22
    - :class:`LoadFromIBMCloud <unitxt.loaders.LoadFromIBMCloud>` - Fetches datasets hosted on IBM Cloud.
23
    - :class:`LoadFromSklearn <unitxt.loaders.LoadFromSklearn>` - Loads datasets available through the sklearn library.
24
    - :class:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
25
    - :class:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
26
    - :class:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
27

28

29

30

31
------------------------
32
"""
33

34
import fnmatch
1✔
35
import itertools
1✔
36
import os
1✔
37
import tempfile
1✔
38
from abc import abstractmethod
1✔
39
from pathlib import Path
1✔
40
from tempfile import TemporaryDirectory
1✔
41
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union
1✔
42

43
import pandas as pd
1✔
44
from datasets import IterableDatasetDict
1✔
45
from datasets import load_dataset as hf_load_dataset
1✔
46
from huggingface_hub import HfApi
1✔
47
from tqdm import tqdm
1✔
48

49
from .dataclass import OptionalField
1✔
50
from .fusion import FixedFusion
1✔
51
from .logging_utils import get_logger
1✔
52
from .operator import SourceOperator
1✔
53
from .operators import Set
1✔
54
from .settings_utils import get_settings
1✔
55
from .stream import MultiStream
1✔
56
from .type_utils import isoftype
1✔
57
from .utils import LRUCache
1✔
58

59
logger = get_logger()
1✔
60
settings = get_settings()
1✔
61

62

63
class Loader(SourceOperator):
1✔
64
    """A base class for all loaders.
65

66
    A loader is the first component in the Unitxt Recipe,
67
    responsible for loading data from various sources and preparing it as a MultiStream for processing.
68
    The loader_limit is an optional parameter used to control the maximum number of instances to load from the data source.  It is applied for each split separately.
69
    It is usually provided to the loader via the recipe (see standard.py)
70
    The loader can use this value to limit the amount of data downloaded from the source
71
    to reduce loading time.  However, this may not always be possible, so the
72
    loader may ignore this.  In any case, the recipe, will limit the number of instances in the returned
73
    stream, after load is complete.
74

75
    Args:
76
        loader_limit: Optional integer to specify a limit on the number of records to load.
77
        streaming: Bool indicating if streaming should be used.
78
        num_proc: Optional integer to specify the number of processes to use for parallel dataset loading. Adjust the value according to the number of CPU cores available and the specific needs of your processing task.
79
    """
80

81
    loader_limit: int = None
1✔
82
    streaming: bool = False
1✔
83
    num_proc: int = None
1✔
84

85
    # class level shared cache:
86
    _loader_cache = LRUCache(max_size=settings.loader_cache_size)
1✔
87

88
    def get_limit(self) -> int:
1✔
89
        if settings.global_loader_limit is not None and self.loader_limit is not None:
1✔
90
            return min(int(settings.global_loader_limit), self.loader_limit)
1✔
91
        if settings.global_loader_limit is not None:
1✔
92
            return int(settings.global_loader_limit)
1✔
93
        return self.loader_limit
×
94

95
    def get_limiter(self):
1✔
96
        if settings.global_loader_limit is not None and self.loader_limit is not None:
1✔
97
            if int(settings.global_loader_limit) > self.loader_limit:
1✔
98
                return f"{self.__class__.__name__}.loader_limit"
1✔
99
            return "unitxt.settings.global_loader_limit"
1✔
100
        if settings.global_loader_limit is not None:
1✔
101
            return "unitxt.settings.global_loader_limit"
1✔
102
        return f"{self.__class__.__name__}.loader_limit"
×
103

104
    def log_limited_loading(self):
1✔
105
        logger.info(
1✔
106
            f"\nLoading limited to {self.get_limit()} instances by setting {self.get_limiter()};"
107
        )
108

109
    def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
1✔
110
        if self.data_classification_policy is None:
1✔
111
            get_logger().warning(
×
112
                f"The {self.get_pretty_print_name()} loader does not set the `data_classification_policy`. "
113
                f"This may lead to sending of undesired data to external services.\n"
114
                f"Set it to a list of classification identifiers. \n"
115
                f"For example:\n"
116
                f"data_classification_policy = ['public']\n"
117
                f" or \n"
118
                f"data_classification_policy =['confidential','pii'])\n"
119
            )
120

121
        operator = Set(
1✔
122
            fields={"data_classification_policy": self.data_classification_policy}
123
        )
124
        return operator(multi_stream)
1✔
125

126
    def set_default_data_classification(
1✔
127
        self, default_data_classification_policy, additional_info
128
    ):
129
        if self.data_classification_policy is None:
1✔
130
            if additional_info is not None:
1✔
131
                logger.info(
1✔
132
                    f"{self.get_pretty_print_name()} sets 'data_classification_policy' to "
133
                    f"{default_data_classification_policy} by default {additional_info}.\n"
134
                    "To use a different value or remove this message, explicitly set the "
135
                    "`data_classification_policy` attribute of the loader.\n"
136
                )
137
            self.data_classification_policy = default_data_classification_policy
1✔
138

139
    @abstractmethod
1✔
140
    def load_iterables(self) -> Dict[str, Iterable]:
1✔
141
        pass
×
142

143
    def _maybe_set_classification_policy(self):
1✔
144
        pass
1✔
145

146
    def load_data(self) -> MultiStream:
1✔
147
        iterables = self.__class__._loader_cache.get(str(self), None)
1✔
148
        if iterables is None:
1✔
149
            iterables = self.load_iterables()
1✔
150
            self.__class__._loader_cache.max_size = settings.loader_cache_size
1✔
151
            self.__class__._loader_cache[str(self)] = iterables
1✔
152
        return MultiStream.from_iterables(iterables, copying=True)
1✔
153

154
    def process(self) -> MultiStream:
1✔
155
        self._maybe_set_classification_policy()
1✔
156
        return self.add_data_classification(self.load_data())
1✔
157

158

159
class LoadHF(Loader):
1✔
160
    """Loads datasets from the HuggingFace Hub.
161

162
    It supports loading with or without streaming,
163
    and it can filter datasets upon loading.
164

165
    Args:
166
        path:
167
            The path or identifier of the dataset on the HuggingFace Hub.
168
        name:
169
            An optional dataset name.
170
        data_dir:
171
            Optional directory to store downloaded data.
172
        split:
173
            Optional specification of which split to load.
174
        data_files:
175
            Optional specification of particular data files to load.
176
        revision:
177
            Optional. The revision of the dataset. Often the commit id. Use in case you want to set the dataset version.
178
        streaming (bool):
179
            indicating if streaming should be used.
180
        filtering_lambda (str, optional):
181
            A lambda function for filtering the data after loading.
182
        num_proc (int, optional):
183
            Specifies the number of processes to use for parallel dataset loading.
184

185
    Example:
186
        Loading glue's mrpc dataset
187

188
        .. code-block:: python
189

190
            load_hf = LoadHF(path='glue', name='mrpc')
191
    """
192

193
    path: str
1✔
194
    name: Optional[str] = None
1✔
195
    data_dir: Optional[str] = None
1✔
196
    split: Optional[str] = None
1✔
197
    data_files: Optional[
1✔
198
        Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
199
    ] = None
200
    revision: Optional[str] = None
1✔
201
    streaming: bool = True
1✔
202
    filtering_lambda: Optional[str] = None
1✔
203
    num_proc: Optional[int] = None
1✔
204
    requirements_list: List[str] = OptionalField(default_factory=list)
1✔
205

206
    def verify(self):
1✔
207
        for requirement in self.requirements_list:
1✔
208
            if requirement not in self._requirements_list:
1✔
209
                self._requirements_list.append(requirement)
1✔
210
        super().verify()
1✔
211

212
    def filter_load(self, dataset):
1✔
213
        if not settings.allow_unverified_code:
1✔
214
            raise ValueError(
×
215
                f"{self.__class__.__name__} cannot run use filtering_lambda expression without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE=True."
216
            )
217
        logger.info(f"\nLoading filtered by: {self.filtering_lambda};")
1✔
218
        return dataset.filter(eval(self.filtering_lambda))
1✔
219

220
    def stream_dataset(self):
1✔
221
        with tempfile.TemporaryDirectory() as dir_to_be_deleted:
1✔
222
            if settings.disable_hf_datasets_cache and not self.streaming:
1✔
223
                cache_dir = dir_to_be_deleted
1✔
224
            else:
225
                cache_dir = None
1✔
226
            try:
1✔
227
                dataset = hf_load_dataset(
1✔
228
                    self.path,
229
                    name=self.name,
230
                    data_dir=self.data_dir,
231
                    data_files=self.data_files,
232
                    revision=self.revision,
233
                    streaming=self.streaming,
234
                    cache_dir=cache_dir,
235
                    split=self.split,
236
                    trust_remote_code=settings.allow_unverified_code,
237
                    num_proc=self.num_proc,
238
                )
239
            except ValueError as e:
1✔
240
                if "trust_remote_code" in str(e):
×
241
                    raise ValueError(
×
242
                        f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
243
                    ) from e
244
                raise e
×
245

246
        if self.split is not None:
1✔
247
            dataset = {self.split: dataset}
1✔
248

249
        if self.filtering_lambda is not None:
1✔
250
            dataset = self.filter_load(dataset)
1✔
251

252
        return dataset
1✔
253

254
    def load_dataset(self):
1✔
255
        with tempfile.TemporaryDirectory() as dir_to_be_deleted:
1✔
256
            if settings.disable_hf_datasets_cache:
1✔
257
                cache_dir = dir_to_be_deleted
1✔
258
            else:
259
                cache_dir = None
×
260
            try:
1✔
261
                dataset = hf_load_dataset(
1✔
262
                    self.path,
263
                    name=self.name,
264
                    data_dir=self.data_dir,
265
                    data_files=self.data_files,
266
                    streaming=False,
267
                    keep_in_memory=True,
268
                    cache_dir=cache_dir,
269
                    split=self.split,
270
                    trust_remote_code=settings.allow_unverified_code,
271
                    num_proc=self.num_proc,
272
                )
273
            except ValueError as e:
×
274
                if "trust_remote_code" in str(e):
×
275
                    raise ValueError(
×
276
                        f"{self.__class__.__name__} cannot run remote code from huggingface without setting unitxt.settings.allow_unverified_code=True or by setting environment variable: UNITXT_ALLOW_UNVERIFIED_CODE."
277
                    ) from e
278

279
        if self.split is None:
1✔
280
            for split in dataset.keys():
1✔
281
                dataset[split] = dataset[split].to_iterable_dataset()
1✔
282
        else:
283
            dataset = {self.split: dataset.to_iterable_dataset()}
1✔
284

285
        return dataset
1✔
286

287
    def _maybe_set_classification_policy(self):
1✔
288
        if os.path.exists(self.path):
1✔
289
            self.set_default_data_classification(
×
290
                ["proprietary"], "when loading from local files"
291
            )
292
        else:
293
            self.set_default_data_classification(
1✔
294
                ["public"],
295
                None,  # No warning when loading from public hub
296
            )
297

298
    def load_iterables(self) -> IterableDatasetDict:
1✔
299
        try:
1✔
300
            dataset = self.stream_dataset()
1✔
301
        except (
1✔
302
            NotImplementedError
303
        ):  # streaming is not supported for zipped files so we load without streaming
304
            dataset = self.load_dataset()
1✔
305

306
        if self.filtering_lambda is not None:
1✔
307
            dataset = self.filter_load(dataset)
1✔
308

309
        limit = self.get_limit()
1✔
310
        if limit is not None:
1✔
311
            self.log_limited_loading()
1✔
312
            result = {}
1✔
313
            for split_name in dataset:
1✔
314
                try:
1✔
315
                    split_limit = min(limit, len(dataset[split_name]))
1✔
316
                except:
1✔
317
                    split_limit = limit
1✔
318
                result[split_name] = dataset[split_name].take(split_limit)
1✔
319

320
            return result
1✔
321

322
        return dataset
×
323

324

325
class LoadCSV(Loader):
1✔
326
    """Loads data from CSV files.
327

328
    Supports streaming and can handle large files by loading them in chunks.
329

330
    Args:
331
        files (Dict[str, str]): A dictionary mapping names to file paths.
332
        chunksize : Size of the chunks to load at a time.
333
        loader_limit: Optional integer to specify a limit on the number of records to load.
334
        streaming: Bool indicating if streaming should be used.
335
        sep: String specifying the separator used in the CSV files.
336

337
    Example:
338
        Loading csv
339

340
        .. code-block:: python
341

342
            load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
343
    """
344

345
    files: Dict[str, str]
1✔
346
    chunksize: int = 1000
1✔
347
    loader_limit: Optional[int] = None
1✔
348
    streaming: bool = True
1✔
349
    sep: str = ","
1✔
350

351
    def _maybe_set_classification_policy(self):
1✔
352
        self.set_default_data_classification(
1✔
353
            ["proprietary"], "when loading from local files"
354
        )
355

356
    def load_iterables(self):
1✔
357
        iterables = {}
1✔
358
        for split_name, file_path in self.files.items():
1✔
359
            if self.get_limit() is not None:
1✔
360
                self.log_limited_loading()
1✔
361
                iterables[split_name] = pd.read_csv(
1✔
362
                    file_path, nrows=self.get_limit(), sep=self.sep
363
                ).to_dict("records")
364
            else:
365
                iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict(
×
366
                    "records"
367
                )
368
        return iterables
1✔
369

370

371
class LoadFromSklearn(Loader):
1✔
372
    """Loads datasets from the sklearn library.
373

374
    This loader does not support streaming and is intended for use with sklearn's dataset fetch functions.
375

376
    Args:
377
        dataset_name: The name of the sklearn dataset to fetch.
378
        splits: A list of data splits to load, e.g., ['train', 'test'].
379

380
    Example:
381
        Loading form sklearn
382

383
        .. code-block:: python
384

385
            load_sklearn = LoadFromSklearn(dataset_name='iris', splits=['train', 'test'])
386
    """
387

388
    dataset_name: str
1✔
389
    splits: List[str] = ["train", "test"]
1✔
390

391
    _requirements_list: List[str] = ["scikit-learn", "pandas"]
1✔
392

393
    data_classification_policy = ["public"]
1✔
394

395
    def verify(self):
1✔
396
        super().verify()
×
397

398
        if self.streaming:
×
399
            raise NotImplementedError("LoadFromSklearn cannot load with streaming.")
×
400

401
    def prepare(self):
1✔
402
        super().prepare()
×
403
        from sklearn import datasets as sklearn_datatasets
×
404

405
        self.downloader = getattr(sklearn_datatasets, f"fetch_{self.dataset_name}")
×
406

407
    def load_iterables(self):
1✔
408
        with TemporaryDirectory() as temp_directory:
×
409
            for split in self.splits:
×
410
                split_data = self.downloader(subset=split)
×
411
                targets = [split_data["target_names"][t] for t in split_data["target"]]
×
412
                df = pd.DataFrame([split_data["data"], targets]).T
×
413
                df.columns = ["data", "target"]
×
414
                df.to_csv(os.path.join(temp_directory, f"{split}.csv"), index=None)
×
415
            return hf_load_dataset(temp_directory, streaming=False)
×
416

417

418
class MissingKaggleCredentialsError(ValueError):
1✔
419
    pass
1✔
420

421

422
class LoadFromKaggle(Loader):
1✔
423
    """Loads datasets from Kaggle.
424

425
    Requires Kaggle API credentials and does not support streaming.
426

427
    Args:
428
        url: URL to the Kaggle dataset.
429

430
    Example:
431
        Loading from kaggle
432

433
        .. code-block:: python
434

435
            load_kaggle = LoadFromKaggle(url='kaggle.com/dataset/example')
436
    """
437

438
    url: str
1✔
439

440
    _requirements_list: List[str] = ["opendatasets"]
1✔
441
    data_classification_policy = ["public"]
1✔
442

443
    def verify(self):
1✔
444
        super().verify()
×
445
        if not os.path.isfile("kaggle.json"):
×
446
            raise MissingKaggleCredentialsError(
×
447
                "Please obtain kaggle credentials https://christianjmills.com/posts/kaggle-obtain-api-key-tutorial/ and save them to local ./kaggle.json file"
448
            )
449

450
        if self.streaming:
×
451
            raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
×
452

453
    def prepare(self):
1✔
454
        super().prepare()
×
455
        from opendatasets import download
×
456

457
        self.downloader = download
×
458

459
    def load_iterables(self):
1✔
460
        with TemporaryDirectory() as temp_directory:
×
461
            self.downloader(self.url, temp_directory)
×
462
            return hf_load_dataset(temp_directory, streaming=False)
×
463

464

465
class LoadFromIBMCloud(Loader):
1✔
466
    """Loads data from IBM Cloud Object Storage.
467

468
    Does not support streaming and requires AWS-style access keys.
469
    data_dir Can be either:
470
    1. a list of file names, the split of each file is determined by the file name pattern
471
    2. Mapping: split -> file_name, e.g. {"test" : "test.json", "train": "train.json"}
472
    3. Mapping: split -> file_names, e.g. {"test" : ["test1.json", "test2.json"], "train": ["train.json"]}
473

474
    Args:
475
        endpoint_url_env:
476
            Environment variable name for the IBM Cloud endpoint URL.
477
        aws_access_key_id_env:
478
            Environment variable name for the AWS access key ID.
479
        aws_secret_access_key_env:
480
            Environment variable name for the AWS secret access key.
481
        bucket_name:
482
            Name of the S3 bucket from which to load data.
483
        data_dir:
484
            Optional directory path within the bucket.
485
        data_files:
486
            Union type allowing either a list of file names or a mapping of splits to file names.
487
        data_field:
488
            The dataset key for nested JSON file, i.e. when multiple datasets are nested in the same file
489
        caching (bool):
490
            indicating if caching is enabled to avoid re-downloading data.
491

492
    Example:
493
        Loading from IBM Cloud
494

495
        .. code-block:: python
496

497
            load_ibm_cloud = LoadFromIBMCloud(
498
                endpoint_url_env='IBM_CLOUD_ENDPOINT',
499
                aws_access_key_id_env='IBM_AWS_ACCESS_KEY_ID',
500
                aws_secret_access_key_env='IBM_AWS_SECRET_ACCESS_KEY',
501
                bucket_name='my-bucket'
502
            )
503
            multi_stream = load_ibm_cloud.process()
504
    """
505

506
    endpoint_url_env: str
1✔
507
    aws_access_key_id_env: str
1✔
508
    aws_secret_access_key_env: str
1✔
509
    bucket_name: str
1✔
510
    data_dir: str = None
1✔
511

512
    data_files: Union[Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
1✔
513
    data_field: str = None
1✔
514
    caching: bool = True
1✔
515
    data_classification_policy = ["proprietary"]
1✔
516

517
    _requirements_list: List[str] = ["ibm-cos-sdk"]
1✔
518

519
    def _download_from_cos(self, cos, bucket_name, item_name, local_file):
1✔
520
        logger.info(f"Downloading {item_name} from {bucket_name} COS")
1✔
521
        try:
1✔
522
            response = cos.Object(bucket_name, item_name).get()
1✔
523
            size = response["ContentLength"]
1✔
524
            body = response["Body"]
1✔
525
        except Exception as e:
×
526
            raise Exception(
×
527
                f"Unabled to access {item_name} in {bucket_name} in COS", e
528
            ) from e
529

530
        if self.get_limit() is not None:
1✔
531
            if item_name.endswith(".jsonl"):
1✔
532
                first_lines = list(
1✔
533
                    itertools.islice(body.iter_lines(), self.get_limit())
534
                )
535
                with open(local_file, "wb") as downloaded_file:
1✔
536
                    for line in first_lines:
1✔
537
                        downloaded_file.write(line)
1✔
538
                        downloaded_file.write(b"\n")
1✔
539
                logger.info(
1✔
540
                    f"\nDownload successful limited to {self.get_limit()} lines"
541
                )
542
                return
1✔
543

544
        progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
×
545

546
        def upload_progress(chunk):
×
547
            progress_bar.update(chunk)
×
548

549
        try:
×
550
            cos.Bucket(bucket_name).download_file(
×
551
                item_name, local_file, Callback=upload_progress
552
            )
553
            logger.info("\nDownload Successful")
×
554
        except Exception as e:
×
555
            raise Exception(
×
556
                f"Unabled to download {item_name} in {bucket_name}", e
557
            ) from e
558

559
    def prepare(self):
1✔
560
        super().prepare()
1✔
561
        self.endpoint_url = os.getenv(self.endpoint_url_env)
1✔
562
        self.aws_access_key_id = os.getenv(self.aws_access_key_id_env)
1✔
563
        self.aws_secret_access_key = os.getenv(self.aws_secret_access_key_env)
1✔
564
        root_dir = os.getenv("UNITXT_IBM_COS_CACHE", None) or os.getcwd()
1✔
565
        self.cache_dir = os.path.join(root_dir, "ibmcos_datasets")
1✔
566

567
        if not os.path.exists(self.cache_dir):
1✔
568
            Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
1✔
569
        self.verified = False
1✔
570

571
    def lazy_verify(self):
1✔
572
        super().verify()
1✔
573
        assert (
1✔
574
            self.endpoint_url is not None
575
        ), f"Please set the {self.endpoint_url_env} environmental variable"
576
        assert (
1✔
577
            self.aws_access_key_id is not None
578
        ), f"Please set {self.aws_access_key_id_env} environmental variable"
579
        assert (
1✔
580
            self.aws_secret_access_key is not None
581
        ), f"Please set {self.aws_secret_access_key_env} environmental variable"
582
        if self.streaming:
1✔
583
            raise NotImplementedError("LoadFromKaggle cannot load with streaming.")
×
584

585
    def _maybe_set_classification_policy(self):
1✔
586
        self.set_default_data_classification(
1✔
587
            ["proprietary"], "when loading from IBM COS"
588
        )
589

590
    def load_iterables(self):
1✔
591
        if not self.verified:
1✔
592
            self.lazy_verify()
1✔
593
            self.verified = True
1✔
594
        import ibm_boto3
1✔
595

596
        cos = ibm_boto3.resource(
1✔
597
            "s3",
598
            aws_access_key_id=self.aws_access_key_id,
599
            aws_secret_access_key=self.aws_secret_access_key,
600
            endpoint_url=self.endpoint_url,
601
        )
602
        local_dir = os.path.join(
1✔
603
            self.cache_dir,
604
            self.bucket_name,
605
            self.data_dir or "",  # data_dir can be None
606
            f"loader_limit_{self.get_limit()}",
607
        )
608
        if not os.path.exists(local_dir):
1✔
609
            Path(local_dir).mkdir(parents=True, exist_ok=True)
1✔
610
        if isinstance(self.data_files, Mapping):
1✔
611
            data_files_names = list(self.data_files.values())
1✔
612
            if not isinstance(data_files_names[0], str):
1✔
613
                data_files_names = list(itertools.chain(*data_files_names))
1✔
614
        else:
615
            data_files_names = self.data_files
1✔
616

617
        for data_file in data_files_names:
1✔
618
            local_file = os.path.join(local_dir, data_file)
1✔
619
            if not self.caching or not os.path.exists(local_file):
1✔
620
                # Build object key based on parameters. Slash character is not
621
                # allowed to be part of object key in IBM COS.
622
                object_key = (
1✔
623
                    self.data_dir + "/" + data_file
624
                    if self.data_dir is not None
625
                    else data_file
626
                )
627
                with tempfile.NamedTemporaryFile() as temp_file:
1✔
628
                    # Download to  a temporary file in same file partition, and then do an atomic move
629
                    self._download_from_cos(
1✔
630
                        cos,
631
                        self.bucket_name,
632
                        object_key,
633
                        local_dir + "/" + os.path.basename(temp_file.name),
634
                    )
635
                    os.renames(
1✔
636
                        local_dir + "/" + os.path.basename(temp_file.name),
637
                        local_dir + "/" + data_file,
638
                    )
639

640
        if isinstance(self.data_files, list):
1✔
641
            dataset = hf_load_dataset(local_dir, streaming=False, field=self.data_field)
1✔
642
        else:
643
            dataset = hf_load_dataset(
1✔
644
                local_dir,
645
                streaming=False,
646
                data_files=self.data_files,
647
                field=self.data_field,
648
            )
649

650
        return dataset
1✔
651

652

653
class MultipleSourceLoader(Loader):
1✔
654
    """Allows loading data from multiple sources, potentially mixing different types of loaders.
655

656
    Args:
657
        sources: A list of loaders that will be combined to form a unified dataset.
658

659
    Examples:
660
        1) Loading the train split from a HuggingFace Hub and the test set from a local file:
661

662
        .. code-block:: python
663

664
            MultipleSourceLoader(sources = [ LoadHF(path="public/data",split="train"), LoadCSV({"test": "mytest.csv"}) ])
665

666

667

668
        2) Loading a test set combined from two files
669

670
        .. code-block:: python
671

672
            MultipleSourceLoader(sources = [ LoadCSV({"test": "mytest1.csv"}, LoadCSV({"test": "mytest2.csv"}) ])
673
    """
674

675
    sources: List[Loader]
1✔
676

677
    # MultipleSourceLoaders uses the the data classification from source loaders,
678
    # so only need to add it, if explicitly requested to override.
679
    def add_data_classification(self, multi_stream: MultiStream) -> MultiStream:
1✔
680
        if self.data_classification_policy is None:
1✔
681
            return multi_stream
1✔
682
        return super().add_data_classification(multi_stream)
×
683

684
    def load_iterables(self):
1✔
685
        pass
×
686

687
    def load_data(self):
1✔
688
        return FixedFusion(
1✔
689
            subsets=self.sources, max_instances_per_subset=self.get_limit()
690
        ).process()
691

692

693
class LoadFromDictionary(Loader):
1✔
694
    """Allows loading data from a dictionary of constants.
695

696
    The loader can be used, for example, when debugging or working with small datasets.
697

698
    Args:
699
        data (Dict[str, List[Dict[str, Any]]]): a dictionary of constants from which the data will be loaded
700

701
    Example:
702
        Loading dictionary
703

704
        .. code-block:: python
705

706
            data = {
707
                "train": [{"input": "SomeInput1", "output": "SomeResult1"},
708
                          {"input": "SomeInput2", "output": "SomeResult2"}],
709
                "test":  [{"input": "SomeInput3", "output": "SomeResult3"},
710
                          {"input": "SomeInput4", "output": "SomeResult4"}]
711
            }
712
            loader = LoadFromDictionary(data=data)
713
    """
714

715
    data: Dict[str, List[Dict[str, Any]]]
1✔
716

717
    def verify(self):
1✔
718
        super().verify()
1✔
719
        if not isoftype(self.data, Dict[str, List[Dict[str, Any]]]):
1✔
720
            raise ValueError(
1✔
721
                f"Passed data to LoadFromDictionary is not of type Dict[str, List[Dict[str, Any]]].\n"
722
                f"Expected data should map between split name and list of instances.\n"
723
                f"Received value: {self.data}\n"
724
            )
725
        for split in self.data.keys():
1✔
726
            if len(self.data[split]) == 0:
1✔
727
                raise ValueError(f"Split {split} has no instances.")
×
728
            first_instance = self.data[split][0]
1✔
729
            for instance in self.data[split]:
1✔
730
                if instance.keys() != first_instance.keys():
1✔
731
                    raise ValueError(
1✔
732
                        f"Not all instances in split '{split}' have the same fields.\n"
733
                        f"instance {instance} has different fields different from {first_instance}"
734
                    )
735

736
    def _maybe_set_classification_policy(self):
1✔
737
        self.set_default_data_classification(
1✔
738
            ["proprietary"], "when loading from python dictionary"
739
        )
740

741
    def load_iterables(self) -> MultiStream:
1✔
742
        return self.data
1✔
743

744

745
class LoadFromHFSpace(LoadHF):
1✔
746
    """Used to load data from HuggingFace Spaces.
747

748
    Loaders firstly tries to download all files specified in the 'data_files' parameter
749
    from the given space and then reads them as a HuggingFace Dataset.
750

751
    Args:
752
        space_name (str):
753
            Name of the HuggingFace Space to be accessed.
754
        data_files (str | Sequence[str] | Mapping[str, str | Sequence[str]]):
755
            Relative paths to files within a given repository. If given as a mapping,
756
            paths should be values, while keys should represent the type of respective files
757
            (training, testing etc.).
758
        path (str, optional):
759
            Absolute path to a directory where data should be downloaded.
760
        revision (str, optional):
761
            ID of a Git branch or commit to be used. By default, it is set to None,
762
            thus data is downloaded from the main branch of the accessed repository.
763
        use_token (bool, optional):
764
            Whether a token is used for authentication when accessing
765
            the HuggingFace Space. If necessary, the token is read from the HuggingFace
766
            config folder.
767
        token_env (str, optional):
768
            Key of an env variable which value will be used for
769
            authentication when accessing the HuggingFace Space - if necessary.
770

771
    Example:
772
        Loading from a HuggingFace Space
773

774
        .. code-block:: python
775

776
            loader = LoadFromHFSpace(
777
                space_name="lmsys/mt-bench",
778
                data_files={
779
                    "train": [
780
                        "data/mt_bench/model_answer/gpt-3.5-turbo.jsonl",
781
                        "data/mt_bench/model_answer/gpt-4.jsonl",
782
                    ],
783
                    "test": "data/mt_bench/model_answer/tulu-30b.jsonl",
784
                },
785
            )
786
    """
787

788
    space_name: str
1✔
789
    data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
1✔
790
    path: Optional[str] = None
1✔
791
    revision: Optional[str] = None
1✔
792
    use_token: Optional[bool] = None
1✔
793
    token_env: Optional[str] = None
1✔
794
    requirements_list: List[str] = ["huggingface_hub"]
1✔
795

796
    def _get_token(self) -> Optional[Union[bool, str]]:
1✔
797
        if self.token_env:
1✔
798
            token = os.getenv(self.token_env)
×
799
            if not token:
×
800
                get_logger().warning(
×
801
                    f"The 'token_env' parameter was specified as '{self.token_env}', "
802
                    f"however, no environment variable under such a name was found. "
803
                    f"Therefore, the loader will not use any tokens for authentication."
804
                )
805
            return token
×
806
        return self.use_token
1✔
807

808
    def _download_file_from_space(self, filename: str) -> str:
1✔
809
        from huggingface_hub import hf_hub_download
1✔
810
        from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
1✔
811

812
        token = self._get_token()
1✔
813

814
        try:
1✔
815
            file_path = hf_hub_download(
1✔
816
                repo_id=self.space_name,
817
                filename=filename,
818
                repo_type="space",
819
                token=token,
820
                revision=self.revision,
821
                local_dir=self.path,
822
            )
823
        except EntryNotFoundError as e:
×
824
            raise ValueError(
×
825
                f"The file '{filename}' was not found in the space '{self.space_name}'. "
826
                f"Please check if the filename is correct, or if it exists in that "
827
                f"Huggingface space."
828
            ) from e
829
        except RepositoryNotFoundError as e:
×
830
            raise ValueError(
×
831
                f"The Huggingface space '{self.space_name}' was not found. "
832
                f"Please check if the name is correct and you have access to the space."
833
            ) from e
834

835
        return file_path
1✔
836

837
    def _download_data(self) -> str:
1✔
838
        if isinstance(self.data_files, str):
1✔
839
            data_files = [self.data_files]
×
840
        elif isinstance(self.data_files, Mapping):
1✔
841
            data_files = list(self.data_files.values())
1✔
842
        else:
843
            data_files = self.data_files
×
844

845
        dir_paths_list = []
1✔
846
        for files in data_files:
1✔
847
            if isinstance(files, str):
1✔
848
                files = [files]
×
849

850
            paths = [self._download_file_from_space(file) for file in files]
1✔
851
            dir_paths = [
1✔
852
                path.replace(file_url, "") for path, file_url in zip(paths, files)
853
            ]
854
            dir_paths_list.extend(dir_paths)
1✔
855

856
        # All files - within the same space - are downloaded into the same base directory:
857
        assert len(set(dir_paths_list)) == 1
1✔
858

859
        return f"{dir_paths_list.pop()}"
1✔
860

861
    @staticmethod
1✔
862
    def _is_wildcard(path: str) -> bool:
1✔
863
        wildcard_characters = ["*", "?", "[", "]"]
1✔
864
        return any(char in path for char in wildcard_characters)
1✔
865

866
    def _get_file_list_from_wildcard_path(
1✔
867
        self, pattern: str, repo_files: List
868
    ) -> List[str]:
869
        if self._is_wildcard(pattern):
1✔
870
            return fnmatch.filter(repo_files, pattern)
×
871
        return [pattern]
1✔
872

873
    def _map_wildcard_path_to_full_paths(self):
1✔
874
        api = HfApi()
1✔
875
        repo_files = api.list_repo_files(
1✔
876
            self.space_name, repo_type="space", revision=self.revision
877
        )
878
        if isinstance(self.data_files, str):
1✔
879
            self.data_files = self._get_file_list_from_wildcard_path(
×
880
                self.data_files, repo_files
881
            )
882
        elif isinstance(self.data_files, Mapping):
1✔
883
            new_mapping = {}
1✔
884
            for k, v in self.data_files.items():
1✔
885
                if isinstance(v, list):
1✔
886
                    assert all(isinstance(s, str) for s in v)
1✔
887
                    new_mapping[k] = [
1✔
888
                        file
889
                        for p in v
890
                        for file in self._get_file_list_from_wildcard_path(
891
                            p, repo_files
892
                        )
893
                    ]
894
                elif isinstance(v, str):
1✔
895
                    new_mapping[k] = self._get_file_list_from_wildcard_path(
1✔
896
                        v, repo_files
897
                    )
898
                else:
899
                    raise NotImplementedError(
×
900
                        f"Loader does not support input 'data_files' of type Mapping[{type(v)}]"
901
                    )
902

903
            self.data_files = new_mapping
1✔
904
        elif isinstance(self.data_files, list):
×
905
            assert all(isinstance(s, str) for s in self.data_files)
×
906
            self.data_files = [
×
907
                file
908
                for p in self.data_files
909
                for file in self._get_file_list_from_wildcard_path(p, repo_files)
910
            ]
911
        else:
912
            raise NotImplementedError(
×
913
                f"Loader does not support input 'data_files' of type {type(self.data_files)}"
914
            )
915

916
    def _maybe_set_classification_policy(self):
1✔
917
        self.set_default_data_classification(
1✔
918
            ["public"], "when loading from Huggingface spaces"
919
        )
920

921
    def load_data(self):
1✔
922
        self._map_wildcard_path_to_full_paths()
1✔
923
        self.path = self._download_data()
1✔
924
        return super().load_data()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc