• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14293170121

06 Apr 2025 01:46PM UTC coverage: 80.205% (-0.01%) from 80.217%
14293170121

push

github

web-flow
Update version to 1.22.0 (#1717)

Signed-off-by: elronbandel <elronbandel@gmail.com>

1582 of 1966 branches covered (80.47%)

Branch coverage included in aggregate %.

9905 of 12356 relevant lines covered (80.16%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.25
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .benchmark import Benchmark
1✔
13
from .card import TaskCard
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29

30
logger = get_logger()
1✔
31
constants = get_constants()
1✔
32
settings = get_settings()
1✔
33

34

35
def short_hex_hash(value, length=8):
1✔
36
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
37
    return h[:length]
1✔
38

39

40
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
41
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
42
    try:
1✔
43
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
44
    except:
1✔
45
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
46
    return dataset_stream
1✔
47

48

49
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
50
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
51
    for param in dataset_params.keys():
1✔
52
        assert param in recipe_attributes, (
1✔
53
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
54
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
55
        )
56
    return DatasetRecipe(**dataset_params)
1✔
57

58

59
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
60
    if dataset_query and dataset_args:
1✔
61
        raise ValueError(
×
62
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
63
            "If you want to load dataset from a card in local catalog, use query only. "
64
            "Otherwise, use key-worded arguments only to specify properties of dataset."
65
        )
66

67
    if dataset_query:
1✔
68
        if not isinstance(dataset_query, str):
1✔
69
            raise ValueError(
×
70
                f"If specified, 'dataset_query' must be a string, however, "
71
                f"'{dataset_query}' was provided instead, which is of type "
72
                f"'{type(dataset_query)}'."
73
            )
74

75
    if not dataset_query and not dataset_args:
1✔
76
        raise ValueError(
×
77
            "Either 'dataset_query' or key-worded arguments must be provided."
78
        )
79

80

81
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
82
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
83
        return dataset_query
×
84

85
    _verify_dataset_args(dataset_query, kwargs)
1✔
86

87
    if dataset_query:
1✔
88
        recipe = _get_recipe_from_query(dataset_query)
1✔
89

90
    if kwargs:
1✔
91
        recipe = _get_recipe_from_dict(kwargs)
1✔
92

93
    return recipe
1✔
94

95

96
def create_dataset(
1✔
97
    task: Union[str, Task],
98
    test_set: List[Dict[Any, Any]],
99
    train_set: Optional[List[Dict[Any, Any]]] = None,
100
    validation_set: Optional[List[Dict[Any, Any]]] = None,
101
    split: Optional[str] = None,
102
    data_classification_policy:  Optional[List[str]] = None,
103
    **kwargs,
104
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
105
    """Creates dataset from input data based on a specific task.
106

107
    Args:
108
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
109
        test_set : required list of instances
110
        train_set : optional train_set
111
        validation_set: optional validation set
112
        split: optional one split to choose
113
        data_classification_policy: data_classification_policy
114
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
115

116
    Returns:
117
        DatasetDict
118

119
    Example:
120
        template = Template(...)
121
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
122
    """
123
    data = {"test": test_set}
1✔
124
    if train_set is not None:
1✔
125
        data["train"] = train_set
×
126
    if validation_set is not None:
1✔
127
        data["validation"] = validation_set
×
128
    task, _ = fetch_artifact(task)
1✔
129

130
    if "template" not in kwargs and task.default_template is None:
1✔
131
        raise Exception(
×
132
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
133
        )
134

135
    card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
1✔
136
    return load_dataset(card=card, split=split, **kwargs)
1✔
137

138

139
def _source_to_dataset(
1✔
140
    source: SourceOperator,
141
    split=None,
142
    use_cache=False,
143
    streaming=False,
144
):
145
    from .dataset import Dataset as UnitxtDataset
1✔
146

147
    stream = source()
1✔
148

149
    try:
1✔
150
        ds_builder = UnitxtDataset(
1✔
151
            dataset_name="unitxt",
152
            config_name="recipe-" + short_hex_hash(repr(source)),
153
            version=constants.version,
154
        )
155
        if split is not None:
1✔
156
            stream = {split: stream[split]}
1✔
157
        ds_builder._generators = stream
1✔
158

159
        ds_builder.download_and_prepare(
1✔
160
            verification_mode="no_checks",
161
            download_mode=None if use_cache else "force_redownload",
162
        )
163

164
        if streaming:
1✔
165
            return ds_builder.as_streaming_dataset(split=split)
×
166

167
        return ds_builder.as_dataset(
1✔
168
            split=split, run_post_process=False, verification_mode="no_checks"
169
        )
170

171
    except DatasetGenerationError as e:
×
172
        raise e.__cause__
×
173

174

175
def load_dataset(
1✔
176
    dataset_query: Optional[str] = None,
177
    split: Optional[str] = None,
178
    streaming: bool = False,
179
    use_cache: Optional[bool] = False,
180
    **kwargs,
181
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
182
    """Loads dataset.
183

184
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
185
    in local catalog based on parameters specified in the query.
186

187
    Alternatively, dataset is loaded from a provided card based on explicitly
188
    given parameters.
189

190
    Args:
191
        dataset_query (str, optional):
192
            A string query which specifies a dataset to load from
193
            local catalog or name of specific recipe or benchmark in the catalog. For
194
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
195
        streaming (bool, False):
196
            When True yields the data as a stream.
197
            This is useful when loading very large datasets.
198
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
199
        split (str, optional):
200
            The split of the data to load
201
        use_cache (bool, optional):
202
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
203
            If set to False (default), the returned dataset is not cached.
204
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
205
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
206
        **kwargs:
207
            Arguments used to load dataset from provided card, which is not present in local catalog.
208

209
    Returns:
210
        DatasetDict
211

212
    :Example:
213

214
        .. code-block:: python
215

216
            dataset = load_dataset(
217
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
218
            )  # card and template must be present in local catalog
219

220
            # or built programmatically
221
            card = TaskCard(...)
222
            template = Template(...)
223
            loader_limit = 10
224
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
225

226
    """
227
    recipe = load_recipe(dataset_query, **kwargs)
1✔
228

229
    dataset = _source_to_dataset(
1✔
230
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
231
    )
232

233
    frame = inspect.currentframe()
1✔
234
    args, _, _, values = inspect.getargvalues(frame)
1✔
235
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
236
    all_kwargs.update(kwargs)
1✔
237
    metadata = fill_metadata(**all_kwargs)
1✔
238
    if isinstance(dataset, dict):
1✔
239
        for ds in dataset.values():
1✔
240
            ds.info.description = metadata.copy()
1✔
241
    else:
242
        dataset.info.description = metadata
1✔
243
    return dataset
1✔
244

245

246
def fill_metadata(**kwargs):
1✔
247
    metadata = kwargs.copy()
1✔
248
    metadata["unitxt_version"] = get_constants().version
1✔
249
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
250
    return metadata
1✔
251

252

253
def evaluate(
1✔
254
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
255
) -> EvaluationResults:
256
    if dataset is None and data is None:
1✔
257
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
258
    if data is not None:
1✔
259
        dataset = data  # for backward compatibility
1✔
260
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
261
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
262
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
263
    if hasattr(predictions, "metadata"):
1✔
264
        evaluation_result.metadata["predictions"] = predictions.metadata
×
265
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
266
        "%Y-%m-%d %H:%M:%S.%f"
267
    )[:-3]
268
    return evaluation_result
1✔
269

270

271
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
272
    return _inference_post_process(predictions=predictions, references=data)
1✔
273

274

275
@lru_cache
1✔
276
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
277
    return load_recipe(dataset_query, **kwargs).produce
1✔
278

279

280
def produce(
1✔
281
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
282
) -> Union[Dataset, Dict[str, Any]]:
283
    is_list = isinstance(instance_or_instances, list)
1✔
284
    if not is_list:
1✔
285
        instance_or_instances = [instance_or_instances]
1✔
286
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
287
    if not is_list:
1✔
288
        return result[0]
1✔
289
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
290

291

292
def infer(
1✔
293
    instance_or_instances,
294
    engine: InferenceEngine,
295
    dataset_query: Optional[str] = None,
296
    return_data: bool = False,
297
    return_log_probs: bool = False,
298
    return_meta_data: bool = False,
299
    previous_messages: Optional[List[Dict[str, str]]] = None,
300
    **kwargs,
301
):
302
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
303
    if previous_messages is not None:
1✔
304

305
        def add_previous_messages(example, index):
1✔
306
            example["source"] = previous_messages[index] + example["source"]
1✔
307
            return example
1✔
308

309
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
310
    engine, _ = fetch_artifact(engine)
1✔
311
    if return_log_probs:
1✔
312
        if not isinstance(engine, LogProbInferenceEngine):
×
313
            raise NotImplementedError(
×
314
                f"Error in infer: return_log_probs set to True but supplied engine "
315
                f"{engine.__class__.__name__} does not support logprobs."
316
            )
317
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
318
        raw_predictions = (
×
319
            [output.prediction for output in infer_outputs]
320
            if return_meta_data
321
            else infer_outputs
322
        )
323
        raw_predictions = [
×
324
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
325
        ]
326
    else:
327
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
328
        raw_predictions = (
1✔
329
            [output.prediction for output in infer_outputs]
330
            if return_meta_data
331
            else infer_outputs
332
        )
333
    predictions = post_process(raw_predictions, dataset)
1✔
334
    if return_data:
1✔
335
        if return_meta_data:
1✔
336
            infer_output_list = [
×
337
                infer_output.__dict__ for infer_output in infer_outputs
338
            ]
339
            for infer_output in infer_output_list:
×
340
                del infer_output["prediction"]
×
341
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
342
        dataset = dataset.add_column("prediction", predictions)
1✔
343
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
344
    return predictions
1✔
345

346

347
def select(
1✔
348
    instance_or_instances,
349
    engine: OptionSelectingByLogProbsInferenceEngine,
350
    dataset_query: Optional[str] = None,
351
    return_data: bool = False,
352
    previous_messages: Optional[List[Dict[str, str]]] = None,
353
    **kwargs,
354
):
355
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
356
    if previous_messages is not None:
×
357

358
        def add_previous_messages(example, index):
×
359
            example["source"] = previous_messages[index] + example["source"]
×
360
            return example
×
361

362
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
363
    engine, _ = fetch_artifact(engine)
×
364
    predictions = engine.select(dataset)
×
365
    # predictions = post_process(raw_predictions, dataset)
366
    if return_data:
×
367
        return dataset.add_column("prediction", predictions)
×
368
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc