• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16704320175

03 Aug 2025 11:05AM UTC coverage: 80.829% (-0.4%) from 81.213%
16704320175

Pull #1845

github

web-flow
Merge 59428aa88 into 5372aa6df
Pull Request #1845: Allow using python functions instead of operators (e.g in pre-processing pipeline)

1576 of 1970 branches covered (80.0%)

Branch coverage included in aggregate %.

10685 of 13199 relevant lines covered (80.95%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.94
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
8
from datasets.exceptions import DatasetGenerationError
1✔
9

10
from .artifact import fetch_artifact
1✔
11
from .benchmark import Benchmark
1✔
12
from .card import TaskCard
1✔
13
from .dataclass import to_dict
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29
from .utils import json_dump, lru_cache_decorator
1✔
30

31
logger = get_logger()
32
constants = get_constants()
1✔
33
settings = get_settings()
1✔
34

35

36
def short_hex_hash(value, length=8):
1✔
37
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
38
    return h[:length]
1✔
39

40

41
def _get_recipe_from_query(
1✔
42
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
43
) -> DatasetRecipe:
44
    try:
1✔
45
        dataset_stream, _ = fetch_artifact(
1✔
46
            dataset_query, overwrite_kwargs=overwrite_kwargs
47
        )
48
    except:
1✔
49
        dataset_stream = get_dataset_artifact(
1✔
50
            dataset_query, overwrite_kwargs=overwrite_kwargs
51
        )
52
    return dataset_stream
1✔
53

54

55
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
56
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
57
    for param in dataset_params.keys():
1✔
58
        assert param in recipe_attributes, (
1✔
59
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
60
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
61
        )
62
    return DatasetRecipe(**dataset_params)
1✔
63

64

65
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
66
    if dataset_query and dataset_args:
×
67
        raise ValueError(
68
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
69
            "If you want to load dataset from a card in local catalog, use query only. "
70
            "Otherwise, use key-worded arguments only to specify properties of dataset."
71
        )
72

73
    if dataset_query:
×
74
        if not isinstance(dataset_query, str):
×
75
            raise ValueError(
76
                f"If specified, 'dataset_query' must be a string, however, "
77
                f"'{dataset_query}' was provided instead, which is of type "
78
                f"'{type(dataset_query)}'."
79
            )
80

81
    if not dataset_query and not dataset_args:
×
82
        raise ValueError(
83
            "Either 'dataset_query' or key-worded arguments must be provided."
84
        )
85

86

87
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
88
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
89
        return dataset_query
×
90

91
    if dataset_query:
1✔
92
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
93

94
    elif kwargs:
1✔
95
        recipe = _get_recipe_from_dict(kwargs)
1✔
96

97
    else:
98
        raise UnitxtError(
×
99
            "Specify either dataset recipe string artifact name or recipe args."
100
        )
101

102
    return recipe
1✔
103

104

105
def create_dataset(
1✔
106
    task: Union[str, Task],
107
    test_set: List[Dict[Any, Any]],
108
    train_set: Optional[List[Dict[Any, Any]]] = None,
109
    validation_set: Optional[List[Dict[Any, Any]]] = None,
110
    split: Optional[str] = None,
111
    data_classification_policy: Optional[List[str]] = None,
112
    **kwargs,
113
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
114
    """Creates dataset from input data based on a specific task.
115

116
    Args:
117
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
118
        test_set : required list of instances
119
        train_set : optional train_set
120
        validation_set: optional validation set
121
        split: optional one split to choose
122
        data_classification_policy: data_classification_policy
123
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
124

125
    Returns:
126
        DatasetDict
127

128
    Example:
129
        template = Template(...)
130
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
131
    """
132
    data = {"test": test_set}
1✔
133
    if train_set is not None:
1✔
134
        data["train"] = train_set
1✔
135
    if validation_set is not None:
1✔
136
        data["validation"] = validation_set
1✔
137
    task, _ = fetch_artifact(task)
1✔
138

139
    if "template" not in kwargs and task.default_template is None:
1✔
140
        raise Exception(
141
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
142
        )
143

144
    card = TaskCard(
1✔
145
        loader=LoadFromDictionary(
146
            data=data, data_classification_policy=data_classification_policy
147
        ),
148
        task=task,
149
    )
150
    return load_dataset(card=card, split=split, **kwargs)
1✔
151

152

153
def object_to_str_without_addresses(obj):
1✔
154
    """Generates a string representation of a Python object while removing memory address references.
155

156
    This function is useful for creating consistent and comparable string representations of objects
157
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
158
    between executions. By stripping the memory address, the function ensures that the representation
159
    is stable and independent of the object's location in memory.
160

161
    Args:
162
        obj: Any Python object to be converted to a string representation.
163

164
    Returns:
165
        str: A string representation of the object with memory addresses removed if present.
166

167
    Example:
168
        ```python
169
        class MyClass:
170
            pass
171

172
        obj = MyClass()
173
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
174
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
175
        ```
176
    """
177
    obj_str = str(obj)
1✔
178
    if " at 0x" in obj_str:
1✔
179
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
180
    return obj_str
1✔
181

182

183
def _remove_id_keys(obj):
1✔
184
    if isinstance(obj, dict):
1✔
185
        return {k: _remove_id_keys(v) for k, v in obj.items() if k != "__id__"}
1✔
186
    if isinstance(obj, list):
1✔
187
        return [_remove_id_keys(item) for item in obj]
1✔
188
    return obj
1✔
189

190

191
def _artifact_string_repr(artifact):
1✔
192
    artifact_dict = to_dict(artifact, object_to_str_without_addresses)
1✔
193
    artifact_dict_without_ids = _remove_id_keys(artifact_dict)
1✔
194
    return json_dump(artifact_dict_without_ids)
1✔
195

196

197
def _source_to_dataset(
1✔
198
    source: SourceOperator,
199
    split=None,
200
    use_cache=False,
201
    streaming=False,
202
):
203
    from .dataset import Dataset as UnitxtDataset
1✔
204

205
    # Generate a unique signature for the source
206
    source_signature = _artifact_string_repr(source)
1✔
207
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
208
    # Obtain data stream from the source
209
    stream = source()
1✔
210

211
    try:
1✔
212
        ds_builder = UnitxtDataset(
1✔
213
            dataset_name="unitxt",
214
            config_name=config_name,  # Dictate the cache name
215
            version=constants.version,
216
        )
217
        if split is not None:
1✔
218
            stream = {split: stream[split]}
1✔
219
        ds_builder._generators = stream
1✔
220

221
        try:
1✔
222
            ds_builder.download_and_prepare(
1✔
223
                verification_mode="no_checks",
224
                download_mode=None if use_cache else "force_redownload",
225
            )
226
        except DatasetGenerationError as e:
×
227
            if e.__cause__:
×
228
                raise e.__cause__ from None
×
229
            if e.__context__:
×
230
                raise e.__context__ from None
×
231
            raise
×
232

233
        if streaming:
1✔
234
            return ds_builder.as_streaming_dataset(split=split)
×
235

236
        return ds_builder.as_dataset(
1✔
237
            split=split, run_post_process=False, verification_mode="no_checks"
238
        )
239

240
    except DatasetGenerationError as e:
×
241
        raise e.__cause__
×
242

243

244
def load_dataset(
1✔
245
    dataset_query: Optional[str] = None,
246
    split: Optional[str] = None,
247
    streaming: bool = False,
248
    use_cache: Optional[bool] = None,
249
    **kwargs,
250
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
251
    """Loads dataset.
252

253
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
254
    in local catalog based on parameters specified in the query.
255

256
    Alternatively, dataset is loaded from a provided card based on explicitly
257
    given parameters.
258

259
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
260

261
    Args:
262
        dataset_query (str, optional):
263
            A string query which specifies a dataset to load from
264
            local catalog or name of specific recipe or benchmark in the catalog. For
265
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
266
        streaming (bool, False):
267
            When True yields the data as a stream.
268
            This is useful when loading very large datasets.
269
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
270
        split (str, optional):
271
            The split of the data to load
272
        use_cache (bool, optional):
273
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
274
            If set to False, the returned dataset is not cached.
275
            If set to None, the value of this parameter will be determined by setting.dataset_cache_default (default is False).
276
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
277
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
278
        **kwargs:
279
            Arguments used to load dataset from provided card, which is not present in local catalog.
280

281
    Returns:
282
        DatasetDict
283

284
    :Example:
285

286
        .. code-block:: python
287

288
            dataset = load_dataset(
289
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
290
            )  # card and template must be present in local catalog
291

292
            # or built programmatically
293
            card = TaskCard(...)
294
            template = Template(...)
295
            loader_limit = 10
296
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
297

298
    """
299
    recipe = load_recipe(dataset_query, **kwargs)
1✔
300

301
    if use_cache is None:
1✔
302
        use_cache = settings.dataset_cache_default
1✔
303

304
    dataset = _source_to_dataset(
1✔
305
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
306
    )
307

308
    frame = inspect.currentframe()
1✔
309
    args, _, _, values = inspect.getargvalues(frame)
1✔
310
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
311
    all_kwargs.update(kwargs)
1✔
312
    metadata = fill_metadata(**all_kwargs)
1✔
313
    if isinstance(dataset, dict):
1✔
314
        for ds in dataset.values():
1✔
315
            ds.info.description = metadata.copy()
1✔
316
    else:
317
        dataset.info.description = metadata
1✔
318
    return dataset
1✔
319

320

321
def fill_metadata(**kwargs):
1✔
322
    metadata = kwargs.copy()
1✔
323
    metadata["unitxt_version"] = get_constants().version
1✔
324
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
325
    return metadata
1✔
326

327

328
def evaluate(
1✔
329
    predictions: Optional[List[str]] = None,
330
    dataset: Union[Dataset, IterableDataset] = None,
331
    data=None,
332
    calc_confidence_intervals: bool = True,
333
) -> EvaluationResults:
334
    if dataset is None and data is None:
1✔
335
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
336
    if data is not None:
1✔
337
        dataset = data  # for backward compatibility
1✔
338
    evaluation_result = _compute(
1✔
339
        predictions=predictions,
340
        references=dataset,
341
        calc_confidence_intervals=calc_confidence_intervals,
342
    )
343
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
344
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
345
    if hasattr(predictions, "metadata"):
1✔
346
        evaluation_result.metadata["predictions"] = predictions.metadata
×
347
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
348
        "%Y-%m-%d %H:%M:%S.%f"
349
    )[:-3]
350
    return evaluation_result
1✔
351

352

353
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
354
    return _inference_post_process(predictions=predictions, references=data)
1✔
355

356

357
@lru_cache_decorator(max_size=128)
1✔
358
def _get_recipe_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
359
    return load_recipe(dataset_query, **kwargs)
1✔
360

361

362
def produce(
1✔
363
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
364
) -> Union[Dataset, Dict[str, Any]]:
365
    is_list = isinstance(instance_or_instances, list)
1✔
366
    if not is_list:
1✔
367
        instance_or_instances = [instance_or_instances]
1✔
368
    dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs)
1✔
369
    result = dataset_recipe.produce(instance_or_instances)
1✔
370
    if not is_list:
1✔
371
        return result[0]
1✔
372
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
373

374

375
def infer(
1✔
376
    instance_or_instances,
377
    engine: InferenceEngine,
378
    dataset_query: Optional[str] = None,
379
    return_data: bool = False,
380
    return_log_probs: bool = False,
381
    return_meta_data: bool = False,
382
    previous_messages: Optional[List[Dict[str, str]]] = None,
383
    **kwargs,
384
):
385
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
386
    if previous_messages is not None:
1✔
387

388
        def add_previous_messages(example, index):
1✔
389
            example["source"] = previous_messages[index] + example["source"]
1✔
390
            return example
1✔
391

392
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
393
    engine, _ = fetch_artifact(engine)
1✔
394
    if return_log_probs:
1✔
395
        if not isinstance(engine, LogProbInferenceEngine):
×
396
            raise NotImplementedError(
×
397
                f"Error in infer: return_log_probs set to True but supplied engine "
398
                f"{engine.__class__.__name__} does not support logprobs."
399
            )
400
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
401
        raw_predictions = (
×
402
            [output.prediction for output in infer_outputs]
403
            if return_meta_data
404
            else infer_outputs
405
        )
406
        raw_predictions = [
×
407
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
408
        ]
409
    else:
410
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
411
        raw_predictions = (
1✔
412
            [output.prediction for output in infer_outputs]
413
            if return_meta_data
414
            else infer_outputs
415
        )
416
    predictions = post_process(raw_predictions, dataset)
1✔
417
    if return_data:
1✔
418
        if return_meta_data:
1✔
419
            infer_output_list = [
1✔
420
                infer_output.__dict__ for infer_output in infer_outputs
421
            ]
422
            for infer_output in infer_output_list:
1✔
423
                del infer_output["prediction"]
1✔
424
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
1✔
425
        dataset = dataset.add_column("prediction", predictions)
1✔
426
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
427
    return predictions
1✔
428

429

430
def select(
1✔
431
    instance_or_instances,
432
    engine: OptionSelectingByLogProbsInferenceEngine,
433
    dataset_query: Optional[str] = None,
434
    return_data: bool = False,
435
    previous_messages: Optional[List[Dict[str, str]]] = None,
436
    **kwargs,
437
):
438
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
439
    if previous_messages is not None:
×
440

441
        def add_previous_messages(example, index):
×
442
            example["source"] = previous_messages[index] + example["source"]
×
443
            return example
×
444

445
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
446
    engine, _ = fetch_artifact(engine)
×
447
    predictions = engine.select(dataset)
×
448
    # predictions = post_process(raw_predictions, dataset)
449
    if return_data:
×
450
        return dataset.add_column("prediction", predictions)
×
451
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc