• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15851684646

24 Jun 2025 01:24PM UTC coverage: 80.21% (+0.01%) from 80.199%
15851684646

Pull #1838

github

web-flow
Merge 81822bb33 into 6f9c0666b
Pull Request #1838: Improved error messages

1705 of 2106 branches covered (80.96%)

Branch coverage included in aggregate %.

10588 of 13220 relevant lines covered (80.09%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.5
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .benchmark import Benchmark
1✔
13
from .card import TaskCard
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29

30
logger = get_logger()
1✔
31
constants = get_constants()
1✔
32
settings = get_settings()
1✔
33

34

35
def short_hex_hash(value, length=8):
1✔
36
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
37
    return h[:length]
1✔
38

39

40
def _get_recipe_from_query(
1✔
41
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
42
) -> DatasetRecipe:
43
    try:
1✔
44
        dataset_stream, _ = fetch_artifact(
1✔
45
            dataset_query, overwrite_kwargs=overwrite_kwargs
46
        )
47
    except:
1✔
48
        dataset_stream = get_dataset_artifact(
1✔
49
            dataset_query, overwrite_kwargs=overwrite_kwargs
50
        )
51
    return dataset_stream
1✔
52

53

54
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
55
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
56
    for param in dataset_params.keys():
1✔
57
        assert param in recipe_attributes, (
1✔
58
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
59
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
60
        )
61
    return DatasetRecipe(**dataset_params)
1✔
62

63

64
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
65
    if dataset_query and dataset_args:
×
66
        raise ValueError(
×
67
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
68
            "If you want to load dataset from a card in local catalog, use query only. "
69
            "Otherwise, use key-worded arguments only to specify properties of dataset."
70
        )
71

72
    if dataset_query:
×
73
        if not isinstance(dataset_query, str):
×
74
            raise ValueError(
×
75
                f"If specified, 'dataset_query' must be a string, however, "
76
                f"'{dataset_query}' was provided instead, which is of type "
77
                f"'{type(dataset_query)}'."
78
            )
79

80
    if not dataset_query and not dataset_args:
×
81
        raise ValueError(
×
82
            "Either 'dataset_query' or key-worded arguments must be provided."
83
        )
84

85

86
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
87
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
88
        return dataset_query
×
89

90
    if dataset_query:
1✔
91
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
92

93
    elif kwargs:
1✔
94
        recipe = _get_recipe_from_dict(kwargs)
1✔
95

96
    else:
97
        raise UnitxtError(
×
98
            "Specify either dataset recipe string artifact name or recipe args."
99
        )
100

101
    return recipe
1✔
102

103

104
def create_dataset(
1✔
105
    task: Union[str, Task],
106
    test_set: List[Dict[Any, Any]],
107
    train_set: Optional[List[Dict[Any, Any]]] = None,
108
    validation_set: Optional[List[Dict[Any, Any]]] = None,
109
    split: Optional[str] = None,
110
    data_classification_policy: Optional[List[str]] = None,
111
    **kwargs,
112
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
113
    """Creates dataset from input data based on a specific task.
114

115
    Args:
116
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
117
        test_set : required list of instances
118
        train_set : optional train_set
119
        validation_set: optional validation set
120
        split: optional one split to choose
121
        data_classification_policy: data_classification_policy
122
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
123

124
    Returns:
125
        DatasetDict
126

127
    Example:
128
        template = Template(...)
129
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
130
    """
131
    data = {"test": test_set}
1✔
132
    if train_set is not None:
1✔
133
        data["train"] = train_set
1✔
134
    if validation_set is not None:
1✔
135
        data["validation"] = validation_set
1✔
136
    task, _ = fetch_artifact(task)
1✔
137

138
    if "template" not in kwargs and task.default_template is None:
1✔
139
        raise Exception(
×
140
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
141
        )
142

143
    card = TaskCard(
1✔
144
        loader=LoadFromDictionary(
145
            data=data, data_classification_policy=data_classification_policy
146
        ),
147
        task=task,
148
    )
149
    return load_dataset(card=card, split=split, **kwargs)
1✔
150

151

152
def _source_to_dataset(
1✔
153
    source: SourceOperator,
154
    split=None,
155
    use_cache=False,
156
    streaming=False,
157
):
158
    from .dataset import Dataset as UnitxtDataset
1✔
159

160
    stream = source()
1✔
161

162
    try:
1✔
163
        ds_builder = UnitxtDataset(
1✔
164
            dataset_name="unitxt",
165
            config_name="recipe-" + short_hex_hash(repr(source)),
166
            version=constants.version,
167
        )
168
        if split is not None:
1✔
169
            stream = {split: stream[split]}
1✔
170
        ds_builder._generators = stream
1✔
171

172
        try:
1✔
173
            ds_builder.download_and_prepare(
1✔
174
                verification_mode="no_checks",
175
                download_mode=None if use_cache else "force_redownload",
176
            )
177
        except DatasetGenerationError as e:
×
178
            if e.__cause__:
×
179
                raise e.__cause__ from None
×
180
            if e.__context__:
×
181
                raise e.__context__ from None
×
182
            raise
×
183

184
        if streaming:
1✔
185
            return ds_builder.as_streaming_dataset(split=split)
×
186

187
        return ds_builder.as_dataset(
1✔
188
            split=split, run_post_process=False, verification_mode="no_checks"
189
        )
190

191
    except DatasetGenerationError as e:
×
192
        raise e.__cause__
×
193

194

195
def load_dataset(
1✔
196
    dataset_query: Optional[str] = None,
197
    split: Optional[str] = None,
198
    streaming: bool = False,
199
    use_cache: Optional[bool] = False,
200
    **kwargs,
201
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
202
    """Loads dataset.
203

204
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
205
    in local catalog based on parameters specified in the query.
206

207
    Alternatively, dataset is loaded from a provided card based on explicitly
208
    given parameters.
209

210
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
211

212
    Args:
213
        dataset_query (str, optional):
214
            A string query which specifies a dataset to load from
215
            local catalog or name of specific recipe or benchmark in the catalog. For
216
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
217
        streaming (bool, False):
218
            When True yields the data as a stream.
219
            This is useful when loading very large datasets.
220
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
221
        split (str, optional):
222
            The split of the data to load
223
        use_cache (bool, optional):
224
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
225
            If set to False (default), the returned dataset is not cached.
226
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
227
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
228
        **kwargs:
229
            Arguments used to load dataset from provided card, which is not present in local catalog.
230

231
    Returns:
232
        DatasetDict
233

234
    :Example:
235

236
        .. code-block:: python
237

238
            dataset = load_dataset(
239
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
240
            )  # card and template must be present in local catalog
241

242
            # or built programmatically
243
            card = TaskCard(...)
244
            template = Template(...)
245
            loader_limit = 10
246
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
247

248
    """
249
    recipe = load_recipe(dataset_query, **kwargs)
1✔
250

251
    dataset = _source_to_dataset(
1✔
252
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
253
    )
254

255
    frame = inspect.currentframe()
1✔
256
    args, _, _, values = inspect.getargvalues(frame)
1✔
257
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
258
    all_kwargs.update(kwargs)
1✔
259
    metadata = fill_metadata(**all_kwargs)
1✔
260
    if isinstance(dataset, dict):
1✔
261
        for ds in dataset.values():
1✔
262
            ds.info.description = metadata.copy()
1✔
263
    else:
264
        dataset.info.description = metadata
1✔
265
    return dataset
1✔
266

267

268
def fill_metadata(**kwargs):
1✔
269
    metadata = kwargs.copy()
1✔
270
    metadata["unitxt_version"] = get_constants().version
1✔
271
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
272
    return metadata
1✔
273

274

275
def evaluate(
1✔
276
    predictions,
277
    dataset: Union[Dataset, IterableDataset] = None,
278
    data=None,
279
    calc_confidence_intervals: bool = True,
280
) -> EvaluationResults:
281
    if dataset is None and data is None:
1✔
282
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
283
    if data is not None:
1✔
284
        dataset = data  # for backward compatibility
1✔
285
    evaluation_result = _compute(
1✔
286
        predictions=predictions,
287
        references=dataset,
288
        calc_confidence_intervals=calc_confidence_intervals,
289
    )
290
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
291
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
292
    if hasattr(predictions, "metadata"):
1✔
293
        evaluation_result.metadata["predictions"] = predictions.metadata
×
294
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
295
        "%Y-%m-%d %H:%M:%S.%f"
296
    )[:-3]
297
    return evaluation_result
1✔
298

299

300
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
301
    return _inference_post_process(predictions=predictions, references=data)
1✔
302

303

304
@lru_cache
1✔
305
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
306
    return load_recipe(dataset_query, **kwargs).produce
1✔
307

308

309
def produce(
1✔
310
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
311
) -> Union[Dataset, Dict[str, Any]]:
312
    is_list = isinstance(instance_or_instances, list)
1✔
313
    if not is_list:
1✔
314
        instance_or_instances = [instance_or_instances]
1✔
315
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
316
    if not is_list:
1✔
317
        return result[0]
1✔
318
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
319

320

321
def infer(
1✔
322
    instance_or_instances,
323
    engine: InferenceEngine,
324
    dataset_query: Optional[str] = None,
325
    return_data: bool = False,
326
    return_log_probs: bool = False,
327
    return_meta_data: bool = False,
328
    previous_messages: Optional[List[Dict[str, str]]] = None,
329
    **kwargs,
330
):
331
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
332
    if previous_messages is not None:
1✔
333

334
        def add_previous_messages(example, index):
1✔
335
            example["source"] = previous_messages[index] + example["source"]
1✔
336
            return example
1✔
337

338
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
339
    engine, _ = fetch_artifact(engine)
1✔
340
    if return_log_probs:
1✔
341
        if not isinstance(engine, LogProbInferenceEngine):
×
342
            raise NotImplementedError(
×
343
                f"Error in infer: return_log_probs set to True but supplied engine "
344
                f"{engine.__class__.__name__} does not support logprobs."
345
            )
346
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
347
        raw_predictions = (
×
348
            [output.prediction for output in infer_outputs]
349
            if return_meta_data
350
            else infer_outputs
351
        )
352
        raw_predictions = [
×
353
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
354
        ]
355
    else:
356
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
357
        raw_predictions = (
1✔
358
            [output.prediction for output in infer_outputs]
359
            if return_meta_data
360
            else infer_outputs
361
        )
362
    predictions = post_process(raw_predictions, dataset)
1✔
363
    if return_data:
1✔
364
        if return_meta_data:
1✔
365
            infer_output_list = [
×
366
                infer_output.__dict__ for infer_output in infer_outputs
367
            ]
368
            for infer_output in infer_output_list:
×
369
                del infer_output["prediction"]
×
370
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
371
        dataset = dataset.add_column("prediction", predictions)
1✔
372
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
373
    return predictions
1✔
374

375

376
def select(
1✔
377
    instance_or_instances,
378
    engine: OptionSelectingByLogProbsInferenceEngine,
379
    dataset_query: Optional[str] = None,
380
    return_data: bool = False,
381
    previous_messages: Optional[List[Dict[str, str]]] = None,
382
    **kwargs,
383
):
384
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
385
    if previous_messages is not None:
×
386

387
        def add_previous_messages(example, index):
×
388
            example["source"] = previous_messages[index] + example["source"]
×
389
            return example
×
390

391
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
392
    engine, _ = fetch_artifact(engine)
×
393
    predictions = engine.select(dataset)
×
394
    # predictions = post_process(raw_predictions, dataset)
395
    if return_data:
×
396
        return dataset.add_column("prediction", predictions)
×
397
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc