• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 17076945254

19 Aug 2025 05:20PM UTC coverage: 80.804% (-0.3%) from 81.081%
17076945254

Pull #1914

github

web-flow
Merge bd1e7d625 into 7a48aa9d3
Pull Request #1914: Refactor inference

1595 of 1991 branches covered (80.11%)

Branch coverage included in aggregate %.

10802 of 13351 relevant lines covered (80.91%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.09
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
8
from datasets.exceptions import DatasetGenerationError
1✔
9

10
from .artifact import fetch_artifact
1✔
11
from .benchmark import Benchmark
1✔
12
from .card import TaskCard
1✔
13
from .dataclass import to_dict
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    OptionSelectingByLogProbsInferenceEngine,
19
)
20
from .loaders import LoadFromDictionary
1✔
21
from .logging_utils import get_logger
1✔
22
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
23
from .operator import SourceOperator
1✔
24
from .schema import loads_batch
1✔
25
from .settings_utils import get_constants, get_settings
1✔
26
from .standard import DatasetRecipe
1✔
27
from .task import Task
1✔
28
from .utils import lru_cache_decorator
1✔
29

30
logger = get_logger()
31
constants = get_constants()
1✔
32
settings = get_settings()
1✔
33

34

35
def short_hex_hash(value, length=8):
1✔
36
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
37
    return h[:length]
1✔
38

39

40
def _get_recipe_from_query(
1✔
41
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
42
) -> DatasetRecipe:
43
    try:
1✔
44
        dataset_stream, _ = fetch_artifact(
1✔
45
            dataset_query, overwrite_kwargs=overwrite_kwargs
46
        )
47
    except:
1✔
48
        dataset_stream = get_dataset_artifact(
1✔
49
            dataset_query, overwrite_kwargs=overwrite_kwargs
50
        )
51
    return dataset_stream
1✔
52

53

54
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
55
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
56
    for param in dataset_params.keys():
1✔
57
        assert param in recipe_attributes, (
1✔
58
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
59
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
60
        )
61
    return DatasetRecipe(**dataset_params)
1✔
62

63

64
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
65
    if dataset_query and dataset_args:
×
66
        raise ValueError(
67
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
68
            "If you want to load dataset from a card in local catalog, use query only. "
69
            "Otherwise, use key-worded arguments only to specify properties of dataset."
70
        )
71

72
    if dataset_query:
×
73
        if not isinstance(dataset_query, str):
×
74
            raise ValueError(
75
                f"If specified, 'dataset_query' must be a string, however, "
76
                f"'{dataset_query}' was provided instead, which is of type "
77
                f"'{type(dataset_query)}'."
78
            )
79

80
    if not dataset_query and not dataset_args:
×
81
        raise ValueError(
82
            "Either 'dataset_query' or key-worded arguments must be provided."
83
        )
84

85

86
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
87
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
88
        return dataset_query
×
89

90
    if dataset_query:
1✔
91
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
92

93
    elif kwargs:
1✔
94
        recipe = _get_recipe_from_dict(kwargs)
1✔
95

96
    else:
97
        raise UnitxtError(
×
98
            "Specify either dataset recipe string artifact name or recipe args."
99
        )
100

101
    return recipe
1✔
102

103

104
def create_dataset(
1✔
105
    task: Union[str, Task],
106
    test_set: List[Dict[Any, Any]],
107
    train_set: Optional[List[Dict[Any, Any]]] = None,
108
    validation_set: Optional[List[Dict[Any, Any]]] = None,
109
    split: Optional[str] = None,
110
    data_classification_policy: Optional[List[str]] = None,
111
    **kwargs,
112
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
113
    """Creates dataset from input data based on a specific task.
114

115
    Args:
116
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
117
        test_set : required list of instances
118
        train_set : optional train_set
119
        validation_set: optional validation set
120
        split: optional one split to choose
121
        data_classification_policy: data_classification_policy
122
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
123

124
    Returns:
125
        DatasetDict
126

127
    Example:
128
        template = Template(...)
129
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
130
    """
131
    data = {"test": test_set}
1✔
132
    if train_set is not None:
1✔
133
        data["train"] = train_set
1✔
134
    if validation_set is not None:
1✔
135
        data["validation"] = validation_set
1✔
136
    task, _ = fetch_artifact(task)
1✔
137

138
    if "template" not in kwargs and task.default_template is None:
1✔
139
        raise Exception(
140
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
141
        )
142

143
    card = TaskCard(
1✔
144
        loader=LoadFromDictionary(
145
            data=data, data_classification_policy=data_classification_policy
146
        ),
147
        task=task,
148
    )
149
    return load_dataset(card=card, split=split, **kwargs)
1✔
150

151

152
def object_to_str_without_addresses(obj):
1✔
153
    """Generates a string representation of a Python object while removing memory address references.
154

155
    This function is useful for creating consistent and comparable string representations of objects
156
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
157
    between executions. By stripping the memory address, the function ensures that the representation
158
    is stable and independent of the object's location in memory.
159

160
    Args:
161
        obj: Any Python object to be converted to a string representation.
162

163
    Returns:
164
        str: A string representation of the object with memory addresses removed if present.
165

166
    Example:
167
        ```python
168
        class MyClass:
169
            pass
170

171
        obj = MyClass()
172
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
173
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
174
        ```
175
    """
176
    obj_str = str(obj)
1✔
177
    if " at 0x" in obj_str:
1✔
178
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
179
    return obj_str
1✔
180

181

182
def _source_to_dataset(
1✔
183
    source: SourceOperator,
184
    split=None,
185
    use_cache=False,
186
    streaming=False,
187
):
188
    from .dataset import Dataset as UnitxtDataset
1✔
189

190
    # Generate a unique signature for the source
191
    source_signature = json.dumps(
1✔
192
        to_dict(source, object_to_str_without_addresses), sort_keys=True
193
    )
194
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
195
    # Obtain data stream from the source
196
    stream = source()
1✔
197

198
    try:
1✔
199
        ds_builder = UnitxtDataset(
1✔
200
            dataset_name="unitxt",
201
            config_name=config_name,  # Dictate the cache name
202
            version=constants.version,
203
        )
204
        if split is not None:
1✔
205
            stream = {split: stream[split]}
1✔
206
        ds_builder._generators = stream
1✔
207

208
        try:
1✔
209
            ds_builder.download_and_prepare(
1✔
210
                verification_mode="no_checks",
211
                download_mode=None if use_cache else "force_redownload",
212
            )
213
        except DatasetGenerationError as e:
×
214
            if e.__cause__:
×
215
                raise e.__cause__ from None
×
216
            if e.__context__:
×
217
                raise e.__context__ from None
×
218
            raise
×
219

220
        if streaming:
1✔
221
            return ds_builder.as_streaming_dataset(split=split)
×
222

223
        return ds_builder.as_dataset(
1✔
224
            split=split, run_post_process=False, verification_mode="no_checks"
225
        )
226

227
    except DatasetGenerationError as e:
×
228
        raise e.__cause__
×
229

230

231
def load_dataset(
1✔
232
    dataset_query: Optional[str] = None,
233
    split: Optional[str] = None,
234
    streaming: bool = False,
235
    use_cache: Optional[bool] = None,
236
    **kwargs,
237
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
238
    """Loads dataset.
239

240
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
241
    in local catalog based on parameters specified in the query.
242

243
    Alternatively, dataset is loaded from a provided card based on explicitly
244
    given parameters.
245

246
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
247

248
    Args:
249
        dataset_query (str, optional):
250
            A string query which specifies a dataset to load from
251
            local catalog or name of specific recipe or benchmark in the catalog. For
252
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
253
        streaming (bool, False):
254
            When True yields the data as a stream.
255
            This is useful when loading very large datasets.
256
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
257
        split (str, optional):
258
            The split of the data to load
259
        use_cache (bool, optional):
260
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
261
            If set to False, the returned dataset is not cached.
262
            If set to None, the value of this parameter will be determined by setting.dataset_cache_default (default is False).
263
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
264
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
265
        **kwargs:
266
            Arguments used to load dataset from provided card, which is not present in local catalog.
267

268
    Returns:
269
        DatasetDict
270

271
    :Example:
272

273
        .. code-block:: python
274

275
            dataset = load_dataset(
276
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
277
            )  # card and template must be present in local catalog
278

279
            # or built programmatically
280
            card = TaskCard(...)
281
            template = Template(...)
282
            loader_limit = 10
283
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
284

285
    """
286
    recipe = load_recipe(dataset_query, **kwargs)
1✔
287

288
    if use_cache is None:
1✔
289
        use_cache = settings.dataset_cache_default
1✔
290

291
    dataset = _source_to_dataset(
1✔
292
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
293
    )
294

295
    frame = inspect.currentframe()
1✔
296
    args, _, _, values = inspect.getargvalues(frame)
1✔
297
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
298
    all_kwargs.update(kwargs)
1✔
299
    metadata = fill_metadata(**all_kwargs)
1✔
300
    if isinstance(dataset, dict):
1✔
301
        for ds in dataset.values():
1✔
302
            ds.info.description = metadata.copy()
1✔
303
    else:
304
        dataset.info.description = metadata
1✔
305
    return dataset
1✔
306

307

308
def fill_metadata(**kwargs):
1✔
309
    metadata = kwargs.copy()
1✔
310
    metadata["unitxt_version"] = get_constants().version
1✔
311
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
312
    return metadata
1✔
313

314

315
def evaluate(
1✔
316
    predictions: Optional[List[str]] = None,
317
    dataset: Union[Dataset, IterableDataset] = None,
318
    data=None,
319
    calc_confidence_intervals: bool = True,
320
) -> EvaluationResults:
321
    if dataset is None and data is None:
1✔
322
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
323
    if data is not None:
1✔
324
        dataset = data  # for backward compatibility
1✔
325
    evaluation_result = _compute(
1✔
326
        predictions=predictions,
327
        references=dataset,
328
        calc_confidence_intervals=calc_confidence_intervals,
329
    )
330
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
331
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
332
    if hasattr(predictions, "metadata"):
1✔
333
        evaluation_result.metadata["predictions"] = predictions.metadata
×
334
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
335
        "%Y-%m-%d %H:%M:%S.%f"
336
    )[:-3]
337
    return evaluation_result
1✔
338

339

340
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
341
    return _inference_post_process(predictions=predictions, references=data)
1✔
342

343

344
@lru_cache_decorator(max_size=128)
1✔
345
def _get_recipe_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
346
    return load_recipe(dataset_query, **kwargs)
1✔
347

348

349
def produce(
1✔
350
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
351
) -> Union[Dataset, Dict[str, Any]]:
352
    is_list = isinstance(instance_or_instances, list)
1✔
353
    if not is_list:
1✔
354
        instance_or_instances = [instance_or_instances]
1✔
355
    dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs)
1✔
356
    result = dataset_recipe.produce(instance_or_instances)
1✔
357
    if not is_list:
1✔
358
        return result[0]
1✔
359
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
360

361

362
def infer(
1✔
363
    instance_or_instances,
364
    engine: InferenceEngine,
365
    dataset_query: Optional[str] = None,
366
    return_data: bool = False,
367
    return_log_probs: bool = False,
368
    return_meta_data: bool = False,
369
    previous_messages: Optional[List[Dict[str, str]]] = None,
370
    **kwargs,
371
):
372
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
373
    if previous_messages is not None:
1✔
374

375
        def add_previous_messages(example, index):
1✔
376
            example["source"] = previous_messages[index] + example["source"]
1✔
377
            return example
1✔
378

379
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
380
    engine, _ = fetch_artifact(engine)
1✔
381
    if return_log_probs:
1✔
382
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
383
        raw_predictions = (
×
384
            [output.prediction for output in infer_outputs]
385
            if return_meta_data
386
            else infer_outputs
387
        )
388
        raw_predictions = [
×
389
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
390
        ]
391
    else:
392
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
393
        raw_predictions = (
1✔
394
            [output.prediction for output in infer_outputs]
395
            if return_meta_data
396
            else infer_outputs
397
        )
398
    predictions = post_process(raw_predictions, dataset)
1✔
399
    if return_data:
1✔
400
        if return_meta_data:
1✔
401
            infer_output_list = [
1✔
402
                infer_output.__dict__ for infer_output in infer_outputs
403
            ]
404
            for infer_output in infer_output_list:
1✔
405
                del infer_output["prediction"]
1✔
406
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
1✔
407
        dataset = dataset.add_column("prediction", predictions)
1✔
408
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
409
    return predictions
1✔
410

411

412
def select(
1✔
413
    instance_or_instances,
414
    engine: OptionSelectingByLogProbsInferenceEngine,
415
    dataset_query: Optional[str] = None,
416
    return_data: bool = False,
417
    previous_messages: Optional[List[Dict[str, str]]] = None,
418
    **kwargs,
419
):
420
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
421
    if previous_messages is not None:
×
422

423
        def add_previous_messages(example, index):
×
424
            example["source"] = previous_messages[index] + example["source"]
×
425
            return example
×
426

427
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
428
    engine, _ = fetch_artifact(engine)
×
429
    predictions = engine.select(dataset)
×
430
    # predictions = post_process(raw_predictions, dataset)
431
    if return_data:
×
432
        return dataset.add_column("prediction", predictions)
×
433
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc