• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15883725135

25 Jun 2025 06:03PM UTC coverage: 81.096% (+1.3%) from 79.803%
15883725135

push

github

web-flow
Use full artifact representation as the cache key for the dataset (#1644)

* Use elaborated cache key and use it for filelock

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Add documentation

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Another try

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Remove file lock

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Format

Signed-off-by: elronbandel <elronbandel@gmail.com>

---------

Signed-off-by: elronbandel <elronbandel@gmail.com>

1538 of 1909 branches covered (80.57%)

Branch coverage included in aggregate %.

10461 of 12887 relevant lines covered (81.17%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.0
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .benchmark import Benchmark
1✔
13
from .card import TaskCard
1✔
14
from .dataclass import to_dict
1✔
15
from .dataset_utils import get_dataset_artifact
1✔
16
from .error_utils import UnitxtError
1✔
17
from .inference import (
1✔
18
    InferenceEngine,
19
    LogProbInferenceEngine,
20
    OptionSelectingByLogProbsInferenceEngine,
21
)
22
from .loaders import LoadFromDictionary
1✔
23
from .logging_utils import get_logger
1✔
24
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
25
from .operator import SourceOperator
1✔
26
from .schema import loads_batch
1✔
27
from .settings_utils import get_constants, get_settings
1✔
28
from .standard import DatasetRecipe
1✔
29
from .task import Task
1✔
30

31
logger = get_logger()
32
constants = get_constants()
1✔
33
settings = get_settings()
1✔
34

35

36
def short_hex_hash(value, length=8):
1✔
37
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
38
    return h[:length]
1✔
39

40

41
def _get_recipe_from_query(
1✔
42
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
43
) -> DatasetRecipe:
44
    try:
1✔
45
        dataset_stream, _ = fetch_artifact(
1✔
46
            dataset_query, overwrite_kwargs=overwrite_kwargs
47
        )
48
    except:
1✔
49
        dataset_stream = get_dataset_artifact(
1✔
50
            dataset_query, overwrite_kwargs=overwrite_kwargs
51
        )
52
    return dataset_stream
1✔
53

54

55
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
56
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
57
    for param in dataset_params.keys():
1✔
58
        assert param in recipe_attributes, (
1✔
59
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
60
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
61
        )
62
    return DatasetRecipe(**dataset_params)
1✔
63

64

65
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
66
    if dataset_query and dataset_args:
×
67
        raise ValueError(
68
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
69
            "If you want to load dataset from a card in local catalog, use query only. "
70
            "Otherwise, use key-worded arguments only to specify properties of dataset."
71
        )
72

73
    if dataset_query:
×
74
        if not isinstance(dataset_query, str):
×
75
            raise ValueError(
76
                f"If specified, 'dataset_query' must be a string, however, "
77
                f"'{dataset_query}' was provided instead, which is of type "
78
                f"'{type(dataset_query)}'."
79
            )
80

81
    if not dataset_query and not dataset_args:
×
82
        raise ValueError(
83
            "Either 'dataset_query' or key-worded arguments must be provided."
84
        )
85

86

87
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
88
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
89
        return dataset_query
×
90

91
    if dataset_query:
1✔
92
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
93

94
    elif kwargs:
1✔
95
        recipe = _get_recipe_from_dict(kwargs)
1✔
96

97
    else:
98
        raise UnitxtError(
×
99
            "Specify either dataset recipe string artifact name or recipe args."
100
        )
101

102
    return recipe
1✔
103

104

105
def create_dataset(
1✔
106
    task: Union[str, Task],
107
    test_set: List[Dict[Any, Any]],
108
    train_set: Optional[List[Dict[Any, Any]]] = None,
109
    validation_set: Optional[List[Dict[Any, Any]]] = None,
110
    split: Optional[str] = None,
111
    data_classification_policy: Optional[List[str]] = None,
112
    **kwargs,
113
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
114
    """Creates dataset from input data based on a specific task.
115

116
    Args:
117
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
118
        test_set : required list of instances
119
        train_set : optional train_set
120
        validation_set: optional validation set
121
        split: optional one split to choose
122
        data_classification_policy: data_classification_policy
123
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
124

125
    Returns:
126
        DatasetDict
127

128
    Example:
129
        template = Template(...)
130
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
131
    """
132
    data = {"test": test_set}
1✔
133
    if train_set is not None:
1✔
134
        data["train"] = train_set
1✔
135
    if validation_set is not None:
1✔
136
        data["validation"] = validation_set
1✔
137
    task, _ = fetch_artifact(task)
1✔
138

139
    if "template" not in kwargs and task.default_template is None:
1✔
140
        raise Exception(
141
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
142
        )
143

144
    card = TaskCard(
1✔
145
        loader=LoadFromDictionary(
146
            data=data, data_classification_policy=data_classification_policy
147
        ),
148
        task=task,
149
    )
150
    return load_dataset(card=card, split=split, **kwargs)
1✔
151

152

153
def object_to_str_without_addresses(obj):
1✔
154
    """Generates a string representation of a Python object while removing memory address references.
155

156
    This function is useful for creating consistent and comparable string representations of objects
157
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
158
    between executions. By stripping the memory address, the function ensures that the representation
159
    is stable and independent of the object's location in memory.
160

161
    Args:
162
        obj: Any Python object to be converted to a string representation.
163

164
    Returns:
165
        str: A string representation of the object with memory addresses removed if present.
166

167
    Example:
168
        ```python
169
        class MyClass:
170
            pass
171

172
        obj = MyClass()
173
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
174
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
175
        ```
176
    """
177
    obj_str = str(obj)
1✔
178
    if " at 0x" in obj_str:
1✔
179
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
180
    return obj_str
1✔
181

182

183
def _source_to_dataset(
1✔
184
    source: SourceOperator,
185
    split=None,
186
    use_cache=False,
187
    streaming=False,
188
):
189
    from .dataset import Dataset as UnitxtDataset
1✔
190

191
    # Generate a unique signature for the source
192
    source_signature = json.dumps(
1✔
193
        to_dict(source, object_to_str_without_addresses), sort_keys=True
194
    )
195
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
196
    # Obtain data stream from the source
197
    stream = source()
1✔
198

199
    try:
1✔
200
        ds_builder = UnitxtDataset(
1✔
201
            dataset_name="unitxt",
202
            config_name=config_name,  # Dictate the cache name
203
            version=constants.version,
204
        )
205
        if split is not None:
1✔
206
            stream = {split: stream[split]}
1✔
207
        ds_builder._generators = stream
1✔
208

209
        try:
1✔
210
            ds_builder.download_and_prepare(
1✔
211
                verification_mode="no_checks",
212
                download_mode=None if use_cache else "force_redownload",
213
            )
214
        except DatasetGenerationError as e:
×
215
            if e.__cause__:
×
216
                raise e.__cause__ from None
×
217
            if e.__context__:
×
218
                raise e.__context__ from None
×
219
            raise
×
220

221
        if streaming:
1✔
222
            return ds_builder.as_streaming_dataset(split=split)
×
223

224
        return ds_builder.as_dataset(
1✔
225
            split=split, run_post_process=False, verification_mode="no_checks"
226
        )
227

228
    except DatasetGenerationError as e:
×
229
        raise e.__cause__
×
230

231

232
def load_dataset(
1✔
233
    dataset_query: Optional[str] = None,
234
    split: Optional[str] = None,
235
    streaming: bool = False,
236
    use_cache: Optional[bool] = False,
237
    **kwargs,
238
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
239
    """Loads dataset.
240

241
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
242
    in local catalog based on parameters specified in the query.
243

244
    Alternatively, dataset is loaded from a provided card based on explicitly
245
    given parameters.
246

247
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
248

249
    Args:
250
        dataset_query (str, optional):
251
            A string query which specifies a dataset to load from
252
            local catalog or name of specific recipe or benchmark in the catalog. For
253
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
254
        streaming (bool, False):
255
            When True yields the data as a stream.
256
            This is useful when loading very large datasets.
257
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
258
        split (str, optional):
259
            The split of the data to load
260
        use_cache (bool, optional):
261
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
262
            If set to False (default), the returned dataset is not cached.
263
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
264
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
265
        **kwargs:
266
            Arguments used to load dataset from provided card, which is not present in local catalog.
267

268
    Returns:
269
        DatasetDict
270

271
    :Example:
272

273
        .. code-block:: python
274

275
            dataset = load_dataset(
276
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
277
            )  # card and template must be present in local catalog
278

279
            # or built programmatically
280
            card = TaskCard(...)
281
            template = Template(...)
282
            loader_limit = 10
283
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
284

285
    """
286
    recipe = load_recipe(dataset_query, **kwargs)
1✔
287

288
    dataset = _source_to_dataset(
1✔
289
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
290
    )
291

292
    frame = inspect.currentframe()
1✔
293
    args, _, _, values = inspect.getargvalues(frame)
1✔
294
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
295
    all_kwargs.update(kwargs)
1✔
296
    metadata = fill_metadata(**all_kwargs)
1✔
297
    if isinstance(dataset, dict):
1✔
298
        for ds in dataset.values():
1✔
299
            ds.info.description = metadata.copy()
1✔
300
    else:
301
        dataset.info.description = metadata
1✔
302
    return dataset
1✔
303

304

305
def fill_metadata(**kwargs):
1✔
306
    metadata = kwargs.copy()
1✔
307
    metadata["unitxt_version"] = get_constants().version
1✔
308
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
309
    return metadata
1✔
310

311

312
def evaluate(
1✔
313
    predictions,
314
    dataset: Union[Dataset, IterableDataset] = None,
315
    data=None,
316
    calc_confidence_intervals: bool = True,
317
) -> EvaluationResults:
318
    if dataset is None and data is None:
1✔
319
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
320
    if data is not None:
1✔
321
        dataset = data  # for backward compatibility
1✔
322
    evaluation_result = _compute(
1✔
323
        predictions=predictions,
324
        references=dataset,
325
        calc_confidence_intervals=calc_confidence_intervals,
326
    )
327
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
328
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
329
    if hasattr(predictions, "metadata"):
1✔
330
        evaluation_result.metadata["predictions"] = predictions.metadata
×
331
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
332
        "%Y-%m-%d %H:%M:%S.%f"
333
    )[:-3]
334
    return evaluation_result
1✔
335

336

337
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
338
    return _inference_post_process(predictions=predictions, references=data)
1✔
339

340

341
@lru_cache
1✔
342
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
343
    return load_recipe(dataset_query, **kwargs).produce
1✔
344

345

346
def produce(
1✔
347
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
348
) -> Union[Dataset, Dict[str, Any]]:
349
    is_list = isinstance(instance_or_instances, list)
1✔
350
    if not is_list:
1✔
351
        instance_or_instances = [instance_or_instances]
1✔
352
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
353
    if not is_list:
1✔
354
        return result[0]
1✔
355
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
356

357

358
def infer(
1✔
359
    instance_or_instances,
360
    engine: InferenceEngine,
361
    dataset_query: Optional[str] = None,
362
    return_data: bool = False,
363
    return_log_probs: bool = False,
364
    return_meta_data: bool = False,
365
    previous_messages: Optional[List[Dict[str, str]]] = None,
366
    **kwargs,
367
):
368
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
369
    if previous_messages is not None:
1✔
370

371
        def add_previous_messages(example, index):
1✔
372
            example["source"] = previous_messages[index] + example["source"]
1✔
373
            return example
1✔
374

375
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
376
    engine, _ = fetch_artifact(engine)
1✔
377
    if return_log_probs:
1✔
378
        if not isinstance(engine, LogProbInferenceEngine):
×
379
            raise NotImplementedError(
×
380
                f"Error in infer: return_log_probs set to True but supplied engine "
381
                f"{engine.__class__.__name__} does not support logprobs."
382
            )
383
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
384
        raw_predictions = (
×
385
            [output.prediction for output in infer_outputs]
386
            if return_meta_data
387
            else infer_outputs
388
        )
389
        raw_predictions = [
×
390
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
391
        ]
392
    else:
393
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
394
        raw_predictions = (
1✔
395
            [output.prediction for output in infer_outputs]
396
            if return_meta_data
397
            else infer_outputs
398
        )
399
    predictions = post_process(raw_predictions, dataset)
1✔
400
    if return_data:
1✔
401
        if return_meta_data:
1✔
402
            infer_output_list = [
×
403
                infer_output.__dict__ for infer_output in infer_outputs
404
            ]
405
            for infer_output in infer_output_list:
×
406
                del infer_output["prediction"]
×
407
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
408
        dataset = dataset.add_column("prediction", predictions)
1✔
409
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
410
    return predictions
1✔
411

412

413
def select(
1✔
414
    instance_or_instances,
415
    engine: OptionSelectingByLogProbsInferenceEngine,
416
    dataset_query: Optional[str] = None,
417
    return_data: bool = False,
418
    previous_messages: Optional[List[Dict[str, str]]] = None,
419
    **kwargs,
420
):
421
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
422
    if previous_messages is not None:
×
423

424
        def add_previous_messages(example, index):
×
425
            example["source"] = previous_messages[index] + example["source"]
×
426
            return example
×
427

428
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
429
    engine, _ = fetch_artifact(engine)
×
430
    predictions = engine.select(dataset)
×
431
    # predictions = post_process(raw_predictions, dataset)
432
    if return_data:
×
433
        return dataset.add_column("prediction", predictions)
×
434
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc