• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16589358505

29 Jul 2025 07:17AM UTC coverage: 81.227% (+0.005%) from 81.222%
16589358505

Pull #1880

github

web-flow
Merge 2d0af54e7 into 83063f920
Pull Request #1880: For load_dataset, use_cache default value is taken from settings

1557 of 1929 branches covered (80.72%)

Branch coverage included in aggregate %.

10597 of 13034 relevant lines covered (81.3%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.88
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
8
from datasets.exceptions import DatasetGenerationError
1✔
9

10
from .artifact import fetch_artifact
1✔
11
from .benchmark import Benchmark
1✔
12
from .card import TaskCard
1✔
13
from .dataclass import to_dict
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29
from .utils import lru_cache_decorator
1✔
30

31
logger = get_logger()
32
constants = get_constants()
1✔
33
settings = get_settings()
1✔
34

35

36
def short_hex_hash(value, length=8):
1✔
37
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
38
    return h[:length]
1✔
39

40

41
def _get_recipe_from_query(
1✔
42
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
43
) -> DatasetRecipe:
44
    try:
1✔
45
        dataset_stream, _ = fetch_artifact(
1✔
46
            dataset_query, overwrite_kwargs=overwrite_kwargs
47
        )
48
    except:
1✔
49
        dataset_stream = get_dataset_artifact(
1✔
50
            dataset_query, overwrite_kwargs=overwrite_kwargs
51
        )
52
    return dataset_stream
1✔
53

54

55
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
56
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
57
    for param in dataset_params.keys():
1✔
58
        assert param in recipe_attributes, (
1✔
59
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
60
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
61
        )
62
    return DatasetRecipe(**dataset_params)
1✔
63

64

65
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
66
    if dataset_query and dataset_args:
×
67
        raise ValueError(
68
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
69
            "If you want to load dataset from a card in local catalog, use query only. "
70
            "Otherwise, use key-worded arguments only to specify properties of dataset."
71
        )
72

73
    if dataset_query:
×
74
        if not isinstance(dataset_query, str):
×
75
            raise ValueError(
76
                f"If specified, 'dataset_query' must be a string, however, "
77
                f"'{dataset_query}' was provided instead, which is of type "
78
                f"'{type(dataset_query)}'."
79
            )
80

81
    if not dataset_query and not dataset_args:
×
82
        raise ValueError(
83
            "Either 'dataset_query' or key-worded arguments must be provided."
84
        )
85

86

87
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
88
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
89
        return dataset_query
×
90

91
    if dataset_query:
1✔
92
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
93

94
    elif kwargs:
1✔
95
        recipe = _get_recipe_from_dict(kwargs)
1✔
96

97
    else:
98
        raise UnitxtError(
×
99
            "Specify either dataset recipe string artifact name or recipe args."
100
        )
101

102
    return recipe
1✔
103

104

105
def create_dataset(
1✔
106
    task: Union[str, Task],
107
    test_set: List[Dict[Any, Any]],
108
    train_set: Optional[List[Dict[Any, Any]]] = None,
109
    validation_set: Optional[List[Dict[Any, Any]]] = None,
110
    split: Optional[str] = None,
111
    data_classification_policy: Optional[List[str]] = None,
112
    **kwargs,
113
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
114
    """Creates dataset from input data based on a specific task.
115

116
    Args:
117
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
118
        test_set : required list of instances
119
        train_set : optional train_set
120
        validation_set: optional validation set
121
        split: optional one split to choose
122
        data_classification_policy: data_classification_policy
123
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
124

125
    Returns:
126
        DatasetDict
127

128
    Example:
129
        template = Template(...)
130
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
131
    """
132
    data = {"test": test_set}
1✔
133
    if train_set is not None:
1✔
134
        data["train"] = train_set
1✔
135
    if validation_set is not None:
1✔
136
        data["validation"] = validation_set
1✔
137
    task, _ = fetch_artifact(task)
1✔
138

139
    if "template" not in kwargs and task.default_template is None:
1✔
140
        raise Exception(
141
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
142
        )
143

144
    card = TaskCard(
1✔
145
        loader=LoadFromDictionary(
146
            data=data, data_classification_policy=data_classification_policy
147
        ),
148
        task=task,
149
    )
150
    return load_dataset(card=card, split=split, **kwargs)
1✔
151

152

153
def object_to_str_without_addresses(obj):
1✔
154
    """Generates a string representation of a Python object while removing memory address references.
155

156
    This function is useful for creating consistent and comparable string representations of objects
157
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
158
    between executions. By stripping the memory address, the function ensures that the representation
159
    is stable and independent of the object's location in memory.
160

161
    Args:
162
        obj: Any Python object to be converted to a string representation.
163

164
    Returns:
165
        str: A string representation of the object with memory addresses removed if present.
166

167
    Example:
168
        ```python
169
        class MyClass:
170
            pass
171

172
        obj = MyClass()
173
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
174
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
175
        ```
176
    """
177
    obj_str = str(obj)
1✔
178
    if " at 0x" in obj_str:
1✔
179
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
180
    return obj_str
1✔
181

182

183
def _source_to_dataset(
1✔
184
    source: SourceOperator,
185
    split=None,
186
    use_cache=False,
187
    streaming=False,
188
):
189
    from .dataset import Dataset as UnitxtDataset
1✔
190

191
    # Generate a unique signature for the source
192
    source_signature = json.dumps(
1✔
193
        to_dict(source, object_to_str_without_addresses), sort_keys=True
194
    )
195
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
196
    # Obtain data stream from the source
197
    stream = source()
1✔
198

199
    try:
1✔
200
        ds_builder = UnitxtDataset(
1✔
201
            dataset_name="unitxt",
202
            config_name=config_name,  # Dictate the cache name
203
            version=constants.version,
204
        )
205
        if split is not None:
1✔
206
            stream = {split: stream[split]}
1✔
207
        ds_builder._generators = stream
1✔
208

209
        try:
1✔
210
            ds_builder.download_and_prepare(
1✔
211
                verification_mode="no_checks",
212
                download_mode=None if use_cache else "force_redownload",
213
            )
214
        except DatasetGenerationError as e:
×
215
            if e.__cause__:
×
216
                raise e.__cause__ from None
×
217
            if e.__context__:
×
218
                raise e.__context__ from None
×
219
            raise
×
220

221
        if streaming:
1✔
222
            return ds_builder.as_streaming_dataset(split=split)
×
223

224
        return ds_builder.as_dataset(
1✔
225
            split=split, run_post_process=False, verification_mode="no_checks"
226
        )
227

228
    except DatasetGenerationError as e:
×
229
        raise e.__cause__
×
230

231

232
def load_dataset(
1✔
233
    dataset_query: Optional[str] = None,
234
    split: Optional[str] = None,
235
    streaming: bool = False,
236
    use_cache: Optional[bool] = None,
237
    **kwargs,
238
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
239
    """Loads dataset.
240

241
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
242
    in local catalog based on parameters specified in the query.
243

244
    Alternatively, dataset is loaded from a provided card based on explicitly
245
    given parameters.
246

247
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
248

249
    Args:
250
        dataset_query (str, optional):
251
            A string query which specifies a dataset to load from
252
            local catalog or name of specific recipe or benchmark in the catalog. For
253
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
254
        streaming (bool, False):
255
            When True yields the data as a stream.
256
            This is useful when loading very large datasets.
257
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
258
        split (str, optional):
259
            The split of the data to load
260
        use_cache (bool, optional):
261
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
262
            If set to False, the returned dataset is not cached.
263
            If set to None, the value of this parameter will be determined by setting.dataset_cache_default (default is False).
264
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
265
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
266
        **kwargs:
267
            Arguments used to load dataset from provided card, which is not present in local catalog.
268

269
    Returns:
270
        DatasetDict
271

272
    :Example:
273

274
        .. code-block:: python
275

276
            dataset = load_dataset(
277
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
278
            )  # card and template must be present in local catalog
279

280
            # or built programmatically
281
            card = TaskCard(...)
282
            template = Template(...)
283
            loader_limit = 10
284
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
285

286
    """
287
    recipe = load_recipe(dataset_query, **kwargs)
1✔
288

289
    if use_cache is None:
1✔
290
        use_cache = settings.dataset_cache_default
1✔
291

292
    dataset = _source_to_dataset(
1✔
293
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
294
    )
295

296
    frame = inspect.currentframe()
1✔
297
    args, _, _, values = inspect.getargvalues(frame)
1✔
298
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
299
    all_kwargs.update(kwargs)
1✔
300
    metadata = fill_metadata(**all_kwargs)
1✔
301
    if isinstance(dataset, dict):
1✔
302
        for ds in dataset.values():
1✔
303
            ds.info.description = metadata.copy()
1✔
304
    else:
305
        dataset.info.description = metadata
1✔
306
    return dataset
1✔
307

308

309
def fill_metadata(**kwargs):
1✔
310
    metadata = kwargs.copy()
1✔
311
    metadata["unitxt_version"] = get_constants().version
1✔
312
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
313
    return metadata
1✔
314

315

316
def evaluate(
1✔
317
    predictions: Optional[List[str]] = None,
318
    dataset: Union[Dataset, IterableDataset] = None,
319
    data=None,
320
    calc_confidence_intervals: bool = True,
321
) -> EvaluationResults:
322
    if dataset is None and data is None:
1✔
323
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
324
    if data is not None:
1✔
325
        dataset = data  # for backward compatibility
1✔
326
    evaluation_result = _compute(
1✔
327
        predictions=predictions,
328
        references=dataset,
329
        calc_confidence_intervals=calc_confidence_intervals,
330
    )
331
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
332
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
333
    if hasattr(predictions, "metadata"):
1✔
334
        evaluation_result.metadata["predictions"] = predictions.metadata
×
335
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
336
        "%Y-%m-%d %H:%M:%S.%f"
337
    )[:-3]
338
    return evaluation_result
1✔
339

340

341
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
342
    return _inference_post_process(predictions=predictions, references=data)
1✔
343

344

345
@lru_cache_decorator(max_size=128)
1✔
346
def _get_recipe_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
347
    return load_recipe(dataset_query, **kwargs)
1✔
348

349

350
def produce(
1✔
351
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
352
) -> Union[Dataset, Dict[str, Any]]:
353
    is_list = isinstance(instance_or_instances, list)
1✔
354
    if not is_list:
1✔
355
        instance_or_instances = [instance_or_instances]
1✔
356
    dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs)
1✔
357
    result = dataset_recipe.produce(instance_or_instances)
1✔
358
    if not is_list:
1✔
359
        return result[0]
1✔
360
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
361

362

363
def infer(
1✔
364
    instance_or_instances,
365
    engine: InferenceEngine,
366
    dataset_query: Optional[str] = None,
367
    return_data: bool = False,
368
    return_log_probs: bool = False,
369
    return_meta_data: bool = False,
370
    previous_messages: Optional[List[Dict[str, str]]] = None,
371
    **kwargs,
372
):
373
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
374
    if previous_messages is not None:
1✔
375

376
        def add_previous_messages(example, index):
1✔
377
            example["source"] = previous_messages[index] + example["source"]
1✔
378
            return example
1✔
379

380
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
381
    engine, _ = fetch_artifact(engine)
1✔
382
    if return_log_probs:
1✔
383
        if not isinstance(engine, LogProbInferenceEngine):
×
384
            raise NotImplementedError(
×
385
                f"Error in infer: return_log_probs set to True but supplied engine "
386
                f"{engine.__class__.__name__} does not support logprobs."
387
            )
388
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
389
        raw_predictions = (
×
390
            [output.prediction for output in infer_outputs]
391
            if return_meta_data
392
            else infer_outputs
393
        )
394
        raw_predictions = [
×
395
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
396
        ]
397
    else:
398
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
399
        raw_predictions = (
1✔
400
            [output.prediction for output in infer_outputs]
401
            if return_meta_data
402
            else infer_outputs
403
        )
404
    predictions = post_process(raw_predictions, dataset)
1✔
405
    if return_data:
1✔
406
        if return_meta_data:
1✔
407
            infer_output_list = [
1✔
408
                infer_output.__dict__ for infer_output in infer_outputs
409
            ]
410
            for infer_output in infer_output_list:
1✔
411
                del infer_output["prediction"]
1✔
412
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
1✔
413
        dataset = dataset.add_column("prediction", predictions)
1✔
414
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
415
    return predictions
1✔
416

417

418
def select(
1✔
419
    instance_or_instances,
420
    engine: OptionSelectingByLogProbsInferenceEngine,
421
    dataset_query: Optional[str] = None,
422
    return_data: bool = False,
423
    previous_messages: Optional[List[Dict[str, str]]] = None,
424
    **kwargs,
425
):
426
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
427
    if previous_messages is not None:
×
428

429
        def add_previous_messages(example, index):
×
430
            example["source"] = previous_messages[index] + example["source"]
×
431
            return example
×
432

433
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
434
    engine, _ = fetch_artifact(engine)
×
435
    predictions = engine.select(dataset)
×
436
    # predictions = post_process(raw_predictions, dataset)
437
    if return_data:
×
438
        return dataset.add_column("prediction", predictions)
×
439
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc