• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 26152825991

20 May 2026 09:08AM UTC coverage: 80.863% (-0.04%) from 80.903%
26152825991

Pull #1966

github

web-flow
Merge 914761465 into 37e11e0f6
Pull Request #1966: fix: CI compatibility fixes (HF_TOKEN, arena-hard migration, datasets 4.8.5)

1607 of 2007 branches covered (80.07%)

Branch coverage included in aggregate %.

10955 of 13528 relevant lines covered (80.98%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.88
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
8
from datasets.exceptions import DatasetGenerationError
1✔
9

10
from .artifact import fetch_artifact
1✔
11
from .benchmark import Benchmark
1✔
12
from .card import TaskCard
1✔
13
from .dataclass import to_dict
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29
from .utils import lru_cache_decorator
1✔
30

31
logger = get_logger()
32
constants = get_constants()
1✔
33
settings = get_settings()
1✔
34

35

36
def short_hex_hash(value, length=8):
1✔
37
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
38
    return h[:length]
1✔
39

40

41
def _get_recipe_from_query(
1✔
42
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
43
) -> DatasetRecipe:
44
    try:
1✔
45
        dataset_stream, _ = fetch_artifact(
1✔
46
            dataset_query, overwrite_kwargs=overwrite_kwargs
47
        )
48
    except:
1✔
49
        dataset_stream = get_dataset_artifact(
1✔
50
            dataset_query, overwrite_kwargs=overwrite_kwargs
51
        )
52
    return dataset_stream
1✔
53

54

55
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
56
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
57
    for param in dataset_params.keys():
1✔
58
        assert param in recipe_attributes, (
1✔
59
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
60
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
61
        )
62
    return DatasetRecipe(**dataset_params)
1✔
63

64

65
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
66
    if dataset_query and dataset_args:
×
67
        raise ValueError(
68
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
69
            "If you want to load dataset from a card in local catalog, use query only. "
70
            "Otherwise, use key-worded arguments only to specify properties of dataset."
71
        )
72

73
    if dataset_query:
×
74
        if not isinstance(dataset_query, str):
×
75
            raise ValueError(
76
                f"If specified, 'dataset_query' must be a string, however, "
77
                f"'{dataset_query}' was provided instead, which is of type "
78
                f"'{type(dataset_query)}'."
79
            )
80

81
    if not dataset_query and not dataset_args:
×
82
        raise ValueError(
83
            "Either 'dataset_query' or key-worded arguments must be provided."
84
        )
85

86

87
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
88
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
89
        return dataset_query
×
90

91
    if dataset_query:
1✔
92
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
93

94
    elif kwargs:
1✔
95
        recipe = _get_recipe_from_dict(kwargs)
1✔
96

97
    else:
98
        raise UnitxtError(
×
99
            "Specify either dataset recipe string artifact name or recipe args."
100
        )
101

102
    return recipe
1✔
103

104

105
def create_dataset(
1✔
106
    task: Union[str, Task],
107
    test_set: List[Dict[Any, Any]],
108
    train_set: Optional[List[Dict[Any, Any]]] = None,
109
    validation_set: Optional[List[Dict[Any, Any]]] = None,
110
    split: Optional[str] = None,
111
    data_classification_policy: Optional[List[str]] = None,
112
    **kwargs,
113
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
114
    """Creates dataset from input data based on a specific task.
115

116
    Args:
117
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
118
        test_set : required list of instances
119
        train_set : optional train_set
120
        validation_set: optional validation set
121
        split: optional one split to choose
122
        data_classification_policy: data_classification_policy
123
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
124

125
    Returns:
126
        DatasetDict
127

128
    Example:
129
        template = Template(...)
130
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
131
    """
132
    data = {"test": test_set}
1✔
133
    if train_set is not None:
1✔
134
        data["train"] = train_set
1✔
135
    if validation_set is not None:
1✔
136
        data["validation"] = validation_set
1✔
137
    task, _ = fetch_artifact(task)
1✔
138

139
    if "template" not in kwargs and task.default_template is None:
1✔
140
        raise Exception(
141
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
142
        )
143

144
    card = TaskCard(
1✔
145
        loader=LoadFromDictionary(
146
            data=data, data_classification_policy=data_classification_policy
147
        ),
148
        task=task,
149
    )
150
    return load_dataset(card=card, split=split, **kwargs)
1✔
151

152

153
def object_to_str_without_addresses(obj):
1✔
154
    """Generates a string representation of a Python object while removing memory address references.
155

156
    This function is useful for creating consistent and comparable string representations of objects
157
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
158
    between executions. By stripping the memory address, the function ensures that the representation
159
    is stable and independent of the object's location in memory.
160

161
    Args:
162
        obj: Any Python object to be converted to a string representation.
163

164
    Returns:
165
        str: A string representation of the object with memory addresses removed if present.
166

167
    Example:
168
        ```python
169
        class MyClass:
170
            pass
171

172
        obj = MyClass()
173
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
174
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
175
        ```
176
    """
177
    obj_str = str(obj)
1✔
178
    if " at 0x" in obj_str:
1✔
179
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
180
    return obj_str
1✔
181

182

183
def _source_to_dataset(
1✔
184
    source: SourceOperator,
185
    split=None,
186
    use_cache=False,
187
    streaming=False,
188
):
189
    from .dataset import Dataset as UnitxtDataset
1✔
190

191
    # Generate a unique signature for the source
192
    source_signature = json.dumps(
1✔
193
        to_dict(source, object_to_str_without_addresses), sort_keys=True
194
    )
195
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
196
    # Obtain data stream from the source
197
    stream = source()
1✔
198

199
    try:
1✔
200
        ds_builder = UnitxtDataset(
1✔
201
            dataset_name="unitxt",
202
            config_name=config_name,  # Dictate the cache name
203
            version=constants.version,
204
        )
205
        if split is not None:
1✔
206
            stream = {split: stream[split]}
1✔
207
        ds_builder._generators = stream
1✔
208

209
        try:
1✔
210
            ds_builder.download_and_prepare(
1✔
211
                verification_mode="no_checks",
212
                download_mode=None if use_cache else "force_redownload",
213
            )
214
        except DatasetGenerationError as e:
×
215
            if e.__cause__:
×
216
                raise e.__cause__ from None
×
217
            if e.__context__:
×
218
                raise e.__context__ from None
×
219
            raise
×
220

221
        if streaming:
1✔
222
            return ds_builder.as_streaming_dataset(split=split)
×
223

224
        return ds_builder.as_dataset(split=split)
1✔
225

226
    except DatasetGenerationError as e:
×
227
        raise e.__cause__
×
228

229

230
def load_dataset(
1✔
231
    dataset_query: Optional[str] = None,
232
    split: Optional[str] = None,
233
    streaming: bool = False,
234
    use_cache: Optional[bool] = None,
235
    **kwargs,
236
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
237
    """Loads dataset.
238

239
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
240
    in local catalog based on parameters specified in the query.
241

242
    Alternatively, dataset is loaded from a provided card based on explicitly
243
    given parameters.
244

245
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
246

247
    Args:
248
        dataset_query (str, optional):
249
            A string query which specifies a dataset to load from
250
            local catalog or name of specific recipe or benchmark in the catalog. For
251
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
252
        streaming (bool, False):
253
            When True yields the data as a stream.
254
            This is useful when loading very large datasets.
255
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
256
        split (str, optional):
257
            The split of the data to load
258
        use_cache (bool, optional):
259
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
260
            If set to False, the returned dataset is not cached.
261
            If set to None, the value of this parameter will be determined by setting.dataset_cache_default (default is False).
262
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
263
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
264
        **kwargs:
265
            Arguments used to load dataset from provided card, which is not present in local catalog.
266

267
    Returns:
268
        DatasetDict
269

270
    :Example:
271

272
        .. code-block:: python
273

274
            dataset = load_dataset(
275
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
276
            )  # card and template must be present in local catalog
277

278
            # or built programmatically
279
            card = TaskCard(...)
280
            template = Template(...)
281
            loader_limit = 10
282
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
283

284
    """
285
    recipe = load_recipe(dataset_query, **kwargs)
1✔
286

287
    if use_cache is None:
1✔
288
        use_cache = settings.dataset_cache_default
1✔
289

290
    dataset = _source_to_dataset(
1✔
291
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
292
    )
293

294
    frame = inspect.currentframe()
1✔
295
    args, _, _, values = inspect.getargvalues(frame)
1✔
296
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
297
    all_kwargs.update(kwargs)
1✔
298
    metadata = fill_metadata(**all_kwargs)
1✔
299
    if isinstance(dataset, dict):
1✔
300
        for ds in dataset.values():
1✔
301
            ds.info.description = metadata.copy()
1✔
302
    else:
303
        dataset.info.description = metadata
1✔
304
    return dataset
1✔
305

306

307
def fill_metadata(**kwargs):
1✔
308
    metadata = kwargs.copy()
1✔
309
    metadata["unitxt_version"] = get_constants().version
1✔
310
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
311
    return metadata
1✔
312

313

314
def evaluate(
1✔
315
    predictions: Optional[List[str]] = None,
316
    dataset: Union[Dataset, IterableDataset] = None,
317
    data=None,
318
    calc_confidence_intervals: bool = True,
319
) -> EvaluationResults:
320
    if dataset is None and data is None:
1✔
321
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
322
    if data is not None:
1✔
323
        dataset = data  # for backward compatibility
1✔
324
    evaluation_result = _compute(
1✔
325
        predictions=predictions,
326
        references=dataset,
327
        calc_confidence_intervals=calc_confidence_intervals,
328
    )
329
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
330
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
331
    if hasattr(predictions, "metadata"):
1✔
332
        evaluation_result.metadata["predictions"] = predictions.metadata
×
333
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
334
        "%Y-%m-%d %H:%M:%S.%f"
335
    )[:-3]
336
    return evaluation_result
1✔
337

338

339
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
340
    return _inference_post_process(predictions=predictions, references=data)
1✔
341

342

343
@lru_cache_decorator(max_size=128)
1✔
344
def _get_recipe_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
345
    return load_recipe(dataset_query, **kwargs)
1✔
346

347

348
def produce(
1✔
349
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
350
) -> Union[Dataset, Dict[str, Any]]:
351
    is_list = isinstance(instance_or_instances, list)
1✔
352
    if not is_list:
1✔
353
        instance_or_instances = [instance_or_instances]
1✔
354
    dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs)
1✔
355
    result = dataset_recipe.produce(instance_or_instances)
1✔
356
    if not is_list:
1✔
357
        return result[0]
1✔
358
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
359

360

361
def infer(
1✔
362
    instance_or_instances,
363
    engine: InferenceEngine,
364
    dataset_query: Optional[str] = None,
365
    return_data: bool = False,
366
    return_log_probs: bool = False,
367
    return_meta_data: bool = False,
368
    previous_messages: Optional[List[Dict[str, str]]] = None,
369
    **kwargs,
370
):
371
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
372
    if previous_messages is not None:
1✔
373

374
        def add_previous_messages(example, index):
1✔
375
            example["source"] = previous_messages[index] + example["source"]
1✔
376
            return example
1✔
377

378
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
379
    engine, _ = fetch_artifact(engine)
1✔
380
    if return_log_probs:
1✔
381
        if not isinstance(engine, LogProbInferenceEngine):
×
382
            raise NotImplementedError(
×
383
                f"Error in infer: return_log_probs set to True but supplied engine "
384
                f"{engine.__class__.__name__} does not support logprobs."
385
            )
386
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
387
        raw_predictions = (
×
388
            [output.prediction for output in infer_outputs]
389
            if return_meta_data
390
            else infer_outputs
391
        )
392
        raw_predictions = [
×
393
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
394
        ]
395
    else:
396
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
397
        raw_predictions = (
1✔
398
            [output.prediction for output in infer_outputs]
399
            if return_meta_data
400
            else infer_outputs
401
        )
402
    predictions = post_process(raw_predictions, dataset)
1✔
403
    if return_data:
1✔
404
        if return_meta_data:
1✔
405
            infer_output_list = [
1✔
406
                infer_output.__dict__ for infer_output in infer_outputs
407
            ]
408
            for infer_output in infer_output_list:
1✔
409
                del infer_output["prediction"]
1✔
410
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
1✔
411
        dataset = dataset.add_column("prediction", predictions)
1✔
412
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
413
    return predictions
1✔
414

415

416
def select(
1✔
417
    instance_or_instances,
418
    engine: OptionSelectingByLogProbsInferenceEngine,
419
    dataset_query: Optional[str] = None,
420
    return_data: bool = False,
421
    previous_messages: Optional[List[Dict[str, str]]] = None,
422
    **kwargs,
423
):
424
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
425
    if previous_messages is not None:
×
426

427
        def add_previous_messages(example, index):
×
428
            example["source"] = previous_messages[index] + example["source"]
×
429
            return example
×
430

431
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
432
    engine, _ = fetch_artifact(engine)
×
433
    predictions = engine.select(dataset)
×
434
    # predictions = post_process(raw_predictions, dataset)
435
    if return_data:
×
436
        return dataset.add_column("prediction", predictions)
×
437
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc