• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13907679295

17 Mar 2025 07:00PM UTC coverage: 80.698% (-0.04%) from 80.737%
13907679295

Pull #1682

github

web-flow
Merge 125ad9c9e into 287a80176
Pull Request #1682: Fix some bugs in inference engine tests

1557 of 1923 branches covered (80.97%)

Branch coverage included in aggregate %.

9752 of 12091 relevant lines covered (80.66%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.03
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .card import TaskCard
1✔
13
from .dataset_utils import get_dataset_artifact
1✔
14
from .error_utils import UnitxtError
1✔
15
from .inference import (
1✔
16
    InferenceEngine,
17
    LogProbInferenceEngine,
18
    OptionSelectingByLogProbsInferenceEngine,
19
)
20
from .loaders import LoadFromDictionary
1✔
21
from .logging_utils import get_logger
1✔
22
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
23
from .operator import SourceOperator
1✔
24
from .schema import loads_batch
1✔
25
from .settings_utils import get_constants, get_settings
1✔
26
from .standard import DatasetRecipe
1✔
27
from .task import Task
1✔
28

29
logger = get_logger()
1✔
30
constants = get_constants()
1✔
31
settings = get_settings()
1✔
32

33

34
def short_hex_hash(value, length=8):
1✔
35
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
36
    return h[:length]
1✔
37

38

39
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
40
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
41
    try:
1✔
42
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
43
    except:
1✔
44
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
45
    return dataset_stream
1✔
46

47

48
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
49
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
50
    for param in dataset_params.keys():
1✔
51
        assert param in recipe_attributes, (
1✔
52
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
53
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
54
        )
55
    return DatasetRecipe(**dataset_params)
1✔
56

57

58
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
59
    if dataset_query and dataset_args:
1✔
60
        raise ValueError(
×
61
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
62
            "If you want to load dataset from a card in local catalog, use query only. "
63
            "Otherwise, use key-worded arguments only to specify properties of dataset."
64
        )
65

66
    if dataset_query:
1✔
67
        if not isinstance(dataset_query, str):
1✔
68
            raise ValueError(
×
69
                f"If specified, 'dataset_query' must be a string, however, "
70
                f"'{dataset_query}' was provided instead, which is of type "
71
                f"'{type(dataset_query)}'."
72
            )
73

74
    if not dataset_query and not dataset_args:
1✔
75
        raise ValueError(
×
76
            "Either 'dataset_query' or key-worded arguments must be provided."
77
        )
78

79

80
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
81
    if isinstance(dataset_query, DatasetRecipe):
1✔
82
        return dataset_query
×
83

84
    _verify_dataset_args(dataset_query, kwargs)
1✔
85

86
    if dataset_query:
1✔
87
        recipe = _get_recipe_from_query(dataset_query)
1✔
88

89
    if kwargs:
1✔
90
        recipe = _get_recipe_from_dict(kwargs)
1✔
91

92
    return recipe
1✔
93

94

95
def create_dataset(
1✔
96
    task: Union[str, Task],
97
    test_set: List[Dict[Any, Any]],
98
    train_set: Optional[List[Dict[Any, Any]]] = None,
99
    validation_set: Optional[List[Dict[Any, Any]]] = None,
100
    split: Optional[str] = None,
101
    data_classification_policy:  Optional[List[str]] = None,
102
    **kwargs,
103
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
104
    """Creates dataset from input data based on a specific task.
105

106
    Args:
107
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
108
        test_set : required list of instances
109
        train_set : optional train_set
110
        validation_set: optional validation set
111
        split: optional one split to choose
112
        data_classification_policy: data_classification_policy
113
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
114

115
    Returns:
116
        DatasetDict
117

118
    Example:
119
        template = Template(...)
120
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
121
    """
122
    data = {"test": test_set}
1✔
123
    if train_set is not None:
1✔
124
        data["train"] = train_set
×
125
    if validation_set is not None:
1✔
126
        data["validation"] = validation_set
×
127
    task, _ = fetch_artifact(task)
1✔
128

129
    if "template" not in kwargs and task.default_template is None:
1✔
130
        raise Exception(
×
131
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
132
        )
133

134
    args = {"data": data}
1✔
135
    if data_classification_policy is not None:
1✔
136
        args["default_data_classification_policy"] = data_classification_policy
×
137

138
    card = TaskCard(loader=LoadFromDictionary(**args), task=task)
1✔
139
    return load_dataset(card=card, split=split, **kwargs)
1✔
140

141

142
def _source_to_dataset(
1✔
143
    source: SourceOperator,
144
    split=None,
145
    use_cache=False,
146
    streaming=False,
147
):
148
    from .dataset import Dataset as UnitxtDataset
1✔
149

150
    stream = source()
1✔
151

152
    try:
1✔
153
        ds_builder = UnitxtDataset(
1✔
154
            dataset_name="unitxt",
155
            config_name="recipe-" + short_hex_hash(repr(source)),
156
            version=constants.version,
157
        )
158
        if split is not None:
1✔
159
            stream = {split: stream[split]}
1✔
160
        ds_builder._generators = stream
1✔
161

162
        ds_builder.download_and_prepare(
1✔
163
            verification_mode="no_checks",
164
            download_mode=None if use_cache else "force_redownload",
165
        )
166

167
        if streaming:
1✔
168
            return ds_builder.as_streaming_dataset(split=split)
×
169

170
        return ds_builder.as_dataset(
1✔
171
            split=split, run_post_process=False, verification_mode="no_checks"
172
        )
173

174
    except DatasetGenerationError as e:
×
175
        raise e.__cause__
×
176

177

178
def load_dataset(
1✔
179
    dataset_query: Optional[str] = None,
180
    split: Optional[str] = None,
181
    streaming: bool = False,
182
    use_cache: Optional[bool] = False,
183
    **kwargs,
184
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
185
    """Loads dataset.
186

187
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
188
    in local catalog based on parameters specified in the query.
189

190
    Alternatively, dataset is loaded from a provided card based on explicitly
191
    given parameters.
192

193
    Args:
194
        dataset_query (str, optional):
195
            A string query which specifies a dataset to load from
196
            local catalog or name of specific recipe or benchmark in the catalog. For
197
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
198
        streaming (bool, False):
199
            When True yields the data as a stream.
200
            This is useful when loading very large datasets.
201
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
202
        split (str, optional):
203
            The split of the data to load
204
        use_cache (bool, optional):
205
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
206
            If set to False (default), the returned dataset is not cached.
207
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
208
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
209
        **kwargs:
210
            Arguments used to load dataset from provided card, which is not present in local catalog.
211

212
    Returns:
213
        DatasetDict
214

215
    :Example:
216

217
        .. code-block:: python
218

219
            dataset = load_dataset(
220
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
221
            )  # card and template must be present in local catalog
222

223
            # or built programmatically
224
            card = TaskCard(...)
225
            template = Template(...)
226
            loader_limit = 10
227
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
228

229
    """
230
    recipe = load_recipe(dataset_query, **kwargs)
1✔
231

232
    dataset = _source_to_dataset(
1✔
233
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
234
    )
235

236
    frame = inspect.currentframe()
1✔
237
    args, _, _, values = inspect.getargvalues(frame)
1✔
238
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
239
    all_kwargs.update(kwargs)
1✔
240
    metadata = fill_metadata(**all_kwargs)
1✔
241
    if isinstance(dataset, dict):
1✔
242
        for ds in dataset.values():
1✔
243
            ds.info.description = metadata.copy()
1✔
244
    else:
245
        dataset.info.description = metadata
1✔
246
    return dataset
1✔
247

248

249
def fill_metadata(**kwargs):
1✔
250
    metadata = kwargs.copy()
1✔
251
    metadata["unitxt_version"] = get_constants().version
1✔
252
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
253
    return metadata
1✔
254

255

256
def evaluate(
1✔
257
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
258
) -> EvaluationResults:
259
    if dataset is None and data is None:
1✔
260
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
261
    if data is not None:
1✔
262
        dataset = data  # for backward compatibility
1✔
263
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
264
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
265
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
266
    if hasattr(predictions, "metadata"):
1✔
267
        evaluation_result.metadata["predictions"] = predictions.metadata
×
268
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
269
        "%Y-%m-%d %H:%M:%S.%f"
270
    )[:-3]
271
    return evaluation_result
1✔
272

273

274
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
275
    return _inference_post_process(predictions=predictions, references=data)
1✔
276

277

278
@lru_cache
1✔
279
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
280
    return load_recipe(dataset_query, **kwargs).produce
1✔
281

282

283
def produce(
1✔
284
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
285
) -> Union[Dataset, Dict[str, Any]]:
286
    is_list = isinstance(instance_or_instances, list)
1✔
287
    if not is_list:
1✔
288
        instance_or_instances = [instance_or_instances]
1✔
289
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
290
    if not is_list:
1✔
291
        return result[0]
1✔
292
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
293

294

295
def infer(
1✔
296
    instance_or_instances,
297
    engine: InferenceEngine,
298
    dataset_query: Optional[str] = None,
299
    return_data: bool = False,
300
    return_log_probs: bool = False,
301
    return_meta_data: bool = False,
302
    previous_messages: Optional[List[Dict[str, str]]] = None,
303
    **kwargs,
304
):
305
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
306
    if previous_messages is not None:
1✔
307

308
        def add_previous_messages(example, index):
1✔
309
            example["source"] = previous_messages[index] + example["source"]
1✔
310
            return example
1✔
311

312
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
313
    engine, _ = fetch_artifact(engine)
1✔
314
    if return_log_probs:
1✔
315
        if not isinstance(engine, LogProbInferenceEngine):
×
316
            raise NotImplementedError(
×
317
                f"Error in infer: return_log_probs set to True but supplied engine "
318
                f"{engine.__class__.__name__} does not support logprobs."
319
            )
320
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
321
        raw_predictions = (
×
322
            [output.prediction for output in infer_outputs]
323
            if return_meta_data
324
            else infer_outputs
325
        )
326
        raw_predictions = [
×
327
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
328
        ]
329
    else:
330
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
331
        raw_predictions = (
1✔
332
            [output.prediction for output in infer_outputs]
333
            if return_meta_data
334
            else infer_outputs
335
        )
336
    predictions = post_process(raw_predictions, dataset)
1✔
337
    if return_data:
1✔
338
        if return_meta_data:
1✔
339
            infer_output_list = [
×
340
                infer_output.__dict__ for infer_output in infer_outputs
341
            ]
342
            for infer_output in infer_output_list:
×
343
                del infer_output["prediction"]
×
344
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
345
        dataset = dataset.add_column("prediction", predictions)
1✔
346
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
347
    return predictions
1✔
348

349

350
def select(
1✔
351
    instance_or_instances,
352
    engine: OptionSelectingByLogProbsInferenceEngine,
353
    dataset_query: Optional[str] = None,
354
    return_data: bool = False,
355
    previous_messages: Optional[List[Dict[str, str]]] = None,
356
    **kwargs,
357
):
358
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
359
    if previous_messages is not None:
×
360

361
        def add_previous_messages(example, index):
×
362
            example["source"] = previous_messages[index] + example["source"]
×
363
            return example
×
364

365
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
366
    engine, _ = fetch_artifact(engine)
×
367
    predictions = engine.select(dataset)
×
368
    # predictions = post_process(raw_predictions, dataset)
369
    if return_data:
×
370
        return dataset.add_column("prediction", predictions)
×
371
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc