• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16150034228

08 Jul 2025 05:21PM UTC coverage: 81.255% (+0.2%) from 81.077%
16150034228

push

github

web-flow
Add comprehensive multi threading support and tests (#1853)

1553 of 1922 branches covered (80.8%)

Branch coverage included in aggregate %.

10550 of 12973 relevant lines covered (81.32%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.6
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
8
from datasets.exceptions import DatasetGenerationError
1✔
9

10
from .artifact import fetch_artifact
1✔
11
from .benchmark import Benchmark
1✔
12
from .card import TaskCard
1✔
13
from .dataclass import to_dict
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29
from .utils import lru_cache_decorator
1✔
30

31
logger = get_logger()
32
constants = get_constants()
1✔
33
settings = get_settings()
1✔
34

35

36
def short_hex_hash(value, length=8):
1✔
37
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
38
    return h[:length]
1✔
39

40

41
def _get_recipe_from_query(
1✔
42
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
43
) -> DatasetRecipe:
44
    try:
1✔
45
        dataset_stream, _ = fetch_artifact(
1✔
46
            dataset_query, overwrite_kwargs=overwrite_kwargs
47
        )
48
    except:
1✔
49
        dataset_stream = get_dataset_artifact(
1✔
50
            dataset_query, overwrite_kwargs=overwrite_kwargs
51
        )
52
    return dataset_stream
1✔
53

54

55
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
56
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
57
    for param in dataset_params.keys():
1✔
58
        assert param in recipe_attributes, (
1✔
59
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
60
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
61
        )
62
    return DatasetRecipe(**dataset_params)
1✔
63

64

65
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
66
    if dataset_query and dataset_args:
×
67
        raise ValueError(
68
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
69
            "If you want to load dataset from a card in local catalog, use query only. "
70
            "Otherwise, use key-worded arguments only to specify properties of dataset."
71
        )
72

73
    if dataset_query:
×
74
        if not isinstance(dataset_query, str):
×
75
            raise ValueError(
76
                f"If specified, 'dataset_query' must be a string, however, "
77
                f"'{dataset_query}' was provided instead, which is of type "
78
                f"'{type(dataset_query)}'."
79
            )
80

81
    if not dataset_query and not dataset_args:
×
82
        raise ValueError(
83
            "Either 'dataset_query' or key-worded arguments must be provided."
84
        )
85

86

87
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
88
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
89
        return dataset_query
×
90

91
    if dataset_query:
1✔
92
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
93

94
    elif kwargs:
1✔
95
        recipe = _get_recipe_from_dict(kwargs)
1✔
96

97
    else:
98
        raise UnitxtError(
×
99
            "Specify either dataset recipe string artifact name or recipe args."
100
        )
101

102
    return recipe
1✔
103

104

105
def create_dataset(
1✔
106
    task: Union[str, Task],
107
    test_set: List[Dict[Any, Any]],
108
    train_set: Optional[List[Dict[Any, Any]]] = None,
109
    validation_set: Optional[List[Dict[Any, Any]]] = None,
110
    split: Optional[str] = None,
111
    data_classification_policy: Optional[List[str]] = None,
112
    **kwargs,
113
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
114
    """Creates dataset from input data based on a specific task.
115

116
    Args:
117
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
118
        test_set : required list of instances
119
        train_set : optional train_set
120
        validation_set: optional validation set
121
        split: optional one split to choose
122
        data_classification_policy: data_classification_policy
123
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
124

125
    Returns:
126
        DatasetDict
127

128
    Example:
129
        template = Template(...)
130
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
131
    """
132
    data = {"test": test_set}
1✔
133
    if train_set is not None:
1✔
134
        data["train"] = train_set
1✔
135
    if validation_set is not None:
1✔
136
        data["validation"] = validation_set
1✔
137
    task, _ = fetch_artifact(task)
1✔
138

139
    if "template" not in kwargs and task.default_template is None:
1✔
140
        raise Exception(
141
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
142
        )
143

144
    card = TaskCard(
1✔
145
        loader=LoadFromDictionary(
146
            data=data, data_classification_policy=data_classification_policy
147
        ),
148
        task=task,
149
    )
150
    return load_dataset(card=card, split=split, **kwargs)
1✔
151

152

153
def object_to_str_without_addresses(obj):
1✔
154
    """Generates a string representation of a Python object while removing memory address references.
155

156
    This function is useful for creating consistent and comparable string representations of objects
157
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
158
    between executions. By stripping the memory address, the function ensures that the representation
159
    is stable and independent of the object's location in memory.
160

161
    Args:
162
        obj: Any Python object to be converted to a string representation.
163

164
    Returns:
165
        str: A string representation of the object with memory addresses removed if present.
166

167
    Example:
168
        ```python
169
        class MyClass:
170
            pass
171

172
        obj = MyClass()
173
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
174
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
175
        ```
176
    """
177
    obj_str = str(obj)
1✔
178
    if " at 0x" in obj_str:
1✔
179
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
180
    return obj_str
1✔
181

182

183
def _source_to_dataset(
1✔
184
    source: SourceOperator,
185
    split=None,
186
    use_cache=False,
187
    streaming=False,
188
):
189
    from .dataset import Dataset as UnitxtDataset
1✔
190

191
    # Generate a unique signature for the source
192
    source_signature = json.dumps(
1✔
193
        to_dict(source, object_to_str_without_addresses), sort_keys=True
194
    )
195
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
196
    # Obtain data stream from the source
197
    stream = source()
1✔
198

199
    try:
1✔
200
        ds_builder = UnitxtDataset(
1✔
201
            dataset_name="unitxt",
202
            config_name=config_name,  # Dictate the cache name
203
            version=constants.version,
204
        )
205
        if split is not None:
1✔
206
            stream = {split: stream[split]}
1✔
207
        ds_builder._generators = stream
1✔
208

209
        try:
1✔
210
            ds_builder.download_and_prepare(
1✔
211
                verification_mode="no_checks",
212
                download_mode=None if use_cache else "force_redownload",
213
            )
214
        except DatasetGenerationError as e:
×
215
            if e.__cause__:
×
216
                raise e.__cause__ from None
×
217
            if e.__context__:
×
218
                raise e.__context__ from None
×
219
            raise
×
220

221
        if streaming:
1✔
222
            return ds_builder.as_streaming_dataset(split=split)
×
223

224
        return ds_builder.as_dataset(
1✔
225
            split=split, run_post_process=False, verification_mode="no_checks"
226
        )
227

228
    except DatasetGenerationError as e:
×
229
        raise e.__cause__
×
230

231

232
def load_dataset(
1✔
233
    dataset_query: Optional[str] = None,
234
    split: Optional[str] = None,
235
    streaming: bool = False,
236
    use_cache: Optional[bool] = False,
237
    **kwargs,
238
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
239
    """Loads dataset.
240

241
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
242
    in local catalog based on parameters specified in the query.
243

244
    Alternatively, dataset is loaded from a provided card based on explicitly
245
    given parameters.
246

247
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
248

249
    Args:
250
        dataset_query (str, optional):
251
            A string query which specifies a dataset to load from
252
            local catalog or name of specific recipe or benchmark in the catalog. For
253
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
254
        streaming (bool, False):
255
            When True yields the data as a stream.
256
            This is useful when loading very large datasets.
257
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
258
        split (str, optional):
259
            The split of the data to load
260
        use_cache (bool, optional):
261
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
262
            If set to False (default), the returned dataset is not cached.
263
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
264
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
265
        **kwargs:
266
            Arguments used to load dataset from provided card, which is not present in local catalog.
267

268
    Returns:
269
        DatasetDict
270

271
    :Example:
272

273
        .. code-block:: python
274

275
            dataset = load_dataset(
276
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
277
            )  # card and template must be present in local catalog
278

279
            # or built programmatically
280
            card = TaskCard(...)
281
            template = Template(...)
282
            loader_limit = 10
283
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
284

285
    """
286
    recipe = load_recipe(dataset_query, **kwargs)
1✔
287

288
    dataset = _source_to_dataset(
1✔
289
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
290
    )
291

292
    frame = inspect.currentframe()
1✔
293
    args, _, _, values = inspect.getargvalues(frame)
1✔
294
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
295
    all_kwargs.update(kwargs)
1✔
296
    metadata = fill_metadata(**all_kwargs)
1✔
297
    if isinstance(dataset, dict):
1✔
298
        for ds in dataset.values():
1✔
299
            ds.info.description = metadata.copy()
1✔
300
    else:
301
        dataset.info.description = metadata
1✔
302
    return dataset
1✔
303

304

305
def fill_metadata(**kwargs):
1✔
306
    metadata = kwargs.copy()
1✔
307
    metadata["unitxt_version"] = get_constants().version
1✔
308
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
309
    return metadata
1✔
310

311

312
def evaluate(
1✔
313
    predictions,
314
    dataset: Union[Dataset, IterableDataset] = None,
315
    data=None,
316
    calc_confidence_intervals: bool = True,
317
) -> EvaluationResults:
318
    if dataset is None and data is None:
1✔
319
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
320
    if data is not None:
1✔
321
        dataset = data  # for backward compatibility
1✔
322
    evaluation_result = _compute(
1✔
323
        predictions=predictions,
324
        references=dataset,
325
        calc_confidence_intervals=calc_confidence_intervals,
326
    )
327
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
328
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
329
    if hasattr(predictions, "metadata"):
1✔
330
        evaluation_result.metadata["predictions"] = predictions.metadata
×
331
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
332
        "%Y-%m-%d %H:%M:%S.%f"
333
    )[:-3]
334
    return evaluation_result
1✔
335

336

337
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
338
    return _inference_post_process(predictions=predictions, references=data)
1✔
339

340

341
@lru_cache_decorator(max_size=128)
1✔
342
def _get_recipe_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
343
    return load_recipe(dataset_query, **kwargs)
1✔
344

345

346
def produce(
1✔
347
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
348
) -> Union[Dataset, Dict[str, Any]]:
349
    is_list = isinstance(instance_or_instances, list)
1✔
350
    if not is_list:
1✔
351
        instance_or_instances = [instance_or_instances]
1✔
352
    dataset_recipe = _get_recipe_with_cache(dataset_query, **kwargs)
1✔
353
    result = dataset_recipe.produce(instance_or_instances)
1✔
354
    if not is_list:
1✔
355
        return result[0]
1✔
356
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
357

358

359
def infer(
1✔
360
    instance_or_instances,
361
    engine: InferenceEngine,
362
    dataset_query: Optional[str] = None,
363
    return_data: bool = False,
364
    return_log_probs: bool = False,
365
    return_meta_data: bool = False,
366
    previous_messages: Optional[List[Dict[str, str]]] = None,
367
    **kwargs,
368
):
369
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
370
    if previous_messages is not None:
1✔
371

372
        def add_previous_messages(example, index):
1✔
373
            example["source"] = previous_messages[index] + example["source"]
1✔
374
            return example
1✔
375

376
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
377
    engine, _ = fetch_artifact(engine)
1✔
378
    if return_log_probs:
1✔
379
        if not isinstance(engine, LogProbInferenceEngine):
×
380
            raise NotImplementedError(
×
381
                f"Error in infer: return_log_probs set to True but supplied engine "
382
                f"{engine.__class__.__name__} does not support logprobs."
383
            )
384
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
385
        raw_predictions = (
×
386
            [output.prediction for output in infer_outputs]
387
            if return_meta_data
388
            else infer_outputs
389
        )
390
        raw_predictions = [
×
391
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
392
        ]
393
    else:
394
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
395
        raw_predictions = (
1✔
396
            [output.prediction for output in infer_outputs]
397
            if return_meta_data
398
            else infer_outputs
399
        )
400
    predictions = post_process(raw_predictions, dataset)
1✔
401
    if return_data:
1✔
402
        if return_meta_data:
1✔
403
            infer_output_list = [
1✔
404
                infer_output.__dict__ for infer_output in infer_outputs
405
            ]
406
            for infer_output in infer_output_list:
1✔
407
                del infer_output["prediction"]
1✔
408
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
1✔
409
        dataset = dataset.add_column("prediction", predictions)
1✔
410
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
411
    return predictions
1✔
412

413

414
def select(
1✔
415
    instance_or_instances,
416
    engine: OptionSelectingByLogProbsInferenceEngine,
417
    dataset_query: Optional[str] = None,
418
    return_data: bool = False,
419
    previous_messages: Optional[List[Dict[str, str]]] = None,
420
    **kwargs,
421
):
422
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
423
    if previous_messages is not None:
×
424

425
        def add_previous_messages(example, index):
×
426
            example["source"] = previous_messages[index] + example["source"]
×
427
            return example
×
428

429
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
430
    engine, _ = fetch_artifact(engine)
×
431
    predictions = engine.select(dataset)
×
432
    # predictions = post_process(raw_predictions, dataset)
433
    if return_data:
×
434
        return dataset.add_column("prediction", predictions)
×
435
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc