• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13930739237

18 Mar 2025 06:26PM UTC coverage: 80.698% (-0.04%) from 80.737%
13930739237

push

github

web-flow
Fix some bugs in inference engine tests (#1682)

* Fix some bugs in inference engine tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix some bugs in inference engine tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix bug in name conversion in rits

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Add engine id

Signed-off-by: elronbandel <elronbandel@gmail.com>

* fix

Signed-off-by: elronbandel <elronbandel@gmail.com>

* fix

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Use greedy decoding and remove redundant cache

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix hf-auto model test

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Touch up watsonx tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix inference tests.

1. Use local inference engine on CPU when test inference engine, for reproducability.
2. In cache maechanisim, don't assum that infer on empty list yields empty list.

Signed-off-by: Elad Venezian <eladv@il.ibm.com>

* Fix setting of data classification policy

Signed-off-by: elronbandel <elronbandel@gmail.com>

---------

Signed-off-by: elronbandel <elronbandel@gmail.com>
Signed-off-by: Elad Venezian <eladv@il.ibm.com>
Co-authored-by: Elad Venezian <eladv@il.ibm.com>

1556 of 1922 branches covered (80.96%)

Branch coverage included in aggregate %.

9749 of 12087 relevant lines covered (80.66%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.15
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .card import TaskCard
1✔
13
from .dataset_utils import get_dataset_artifact
1✔
14
from .error_utils import UnitxtError
1✔
15
from .inference import (
1✔
16
    InferenceEngine,
17
    LogProbInferenceEngine,
18
    OptionSelectingByLogProbsInferenceEngine,
19
)
20
from .loaders import LoadFromDictionary
1✔
21
from .logging_utils import get_logger
1✔
22
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
23
from .operator import SourceOperator
1✔
24
from .schema import loads_batch
1✔
25
from .settings_utils import get_constants, get_settings
1✔
26
from .standard import DatasetRecipe
1✔
27
from .task import Task
1✔
28

29
logger = get_logger()
1✔
30
constants = get_constants()
1✔
31
settings = get_settings()
1✔
32

33

34
def short_hex_hash(value, length=8):
1✔
35
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
36
    return h[:length]
1✔
37

38

39
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
40
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
41
    try:
1✔
42
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
43
    except:
1✔
44
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
45
    return dataset_stream
1✔
46

47

48
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
49
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
50
    for param in dataset_params.keys():
1✔
51
        assert param in recipe_attributes, (
1✔
52
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
53
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
54
        )
55
    return DatasetRecipe(**dataset_params)
1✔
56

57

58
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
59
    if dataset_query and dataset_args:
1✔
60
        raise ValueError(
×
61
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
62
            "If you want to load dataset from a card in local catalog, use query only. "
63
            "Otherwise, use key-worded arguments only to specify properties of dataset."
64
        )
65

66
    if dataset_query:
1✔
67
        if not isinstance(dataset_query, str):
1✔
68
            raise ValueError(
×
69
                f"If specified, 'dataset_query' must be a string, however, "
70
                f"'{dataset_query}' was provided instead, which is of type "
71
                f"'{type(dataset_query)}'."
72
            )
73

74
    if not dataset_query and not dataset_args:
1✔
75
        raise ValueError(
×
76
            "Either 'dataset_query' or key-worded arguments must be provided."
77
        )
78

79

80
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
81
    if isinstance(dataset_query, DatasetRecipe):
1✔
82
        return dataset_query
×
83

84
    _verify_dataset_args(dataset_query, kwargs)
1✔
85

86
    if dataset_query:
1✔
87
        recipe = _get_recipe_from_query(dataset_query)
1✔
88

89
    if kwargs:
1✔
90
        recipe = _get_recipe_from_dict(kwargs)
1✔
91

92
    return recipe
1✔
93

94

95
def create_dataset(
1✔
96
    task: Union[str, Task],
97
    test_set: List[Dict[Any, Any]],
98
    train_set: Optional[List[Dict[Any, Any]]] = None,
99
    validation_set: Optional[List[Dict[Any, Any]]] = None,
100
    split: Optional[str] = None,
101
    data_classification_policy:  Optional[List[str]] = None,
102
    **kwargs,
103
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
104
    """Creates dataset from input data based on a specific task.
105

106
    Args:
107
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
108
        test_set : required list of instances
109
        train_set : optional train_set
110
        validation_set: optional validation set
111
        split: optional one split to choose
112
        data_classification_policy: data_classification_policy
113
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
114

115
    Returns:
116
        DatasetDict
117

118
    Example:
119
        template = Template(...)
120
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
121
    """
122
    data = {"test": test_set}
1✔
123
    if train_set is not None:
1✔
124
        data["train"] = train_set
×
125
    if validation_set is not None:
1✔
126
        data["validation"] = validation_set
×
127
    task, _ = fetch_artifact(task)
1✔
128

129
    if "template" not in kwargs and task.default_template is None:
1✔
130
        raise Exception(
×
131
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
132
        )
133

134
    card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
1✔
135
    return load_dataset(card=card, split=split, **kwargs)
1✔
136

137

138
def _source_to_dataset(
1✔
139
    source: SourceOperator,
140
    split=None,
141
    use_cache=False,
142
    streaming=False,
143
):
144
    from .dataset import Dataset as UnitxtDataset
1✔
145

146
    stream = source()
1✔
147

148
    try:
1✔
149
        ds_builder = UnitxtDataset(
1✔
150
            dataset_name="unitxt",
151
            config_name="recipe-" + short_hex_hash(repr(source)),
152
            version=constants.version,
153
        )
154
        if split is not None:
1✔
155
            stream = {split: stream[split]}
1✔
156
        ds_builder._generators = stream
1✔
157

158
        ds_builder.download_and_prepare(
1✔
159
            verification_mode="no_checks",
160
            download_mode=None if use_cache else "force_redownload",
161
        )
162

163
        if streaming:
1✔
164
            return ds_builder.as_streaming_dataset(split=split)
×
165

166
        return ds_builder.as_dataset(
1✔
167
            split=split, run_post_process=False, verification_mode="no_checks"
168
        )
169

170
    except DatasetGenerationError as e:
×
171
        raise e.__cause__
×
172

173

174
def load_dataset(
1✔
175
    dataset_query: Optional[str] = None,
176
    split: Optional[str] = None,
177
    streaming: bool = False,
178
    use_cache: Optional[bool] = False,
179
    **kwargs,
180
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
181
    """Loads dataset.
182

183
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
184
    in local catalog based on parameters specified in the query.
185

186
    Alternatively, dataset is loaded from a provided card based on explicitly
187
    given parameters.
188

189
    Args:
190
        dataset_query (str, optional):
191
            A string query which specifies a dataset to load from
192
            local catalog or name of specific recipe or benchmark in the catalog. For
193
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
194
        streaming (bool, False):
195
            When True yields the data as a stream.
196
            This is useful when loading very large datasets.
197
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
198
        split (str, optional):
199
            The split of the data to load
200
        use_cache (bool, optional):
201
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
202
            If set to False (default), the returned dataset is not cached.
203
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
204
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
205
        **kwargs:
206
            Arguments used to load dataset from provided card, which is not present in local catalog.
207

208
    Returns:
209
        DatasetDict
210

211
    :Example:
212

213
        .. code-block:: python
214

215
            dataset = load_dataset(
216
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
217
            )  # card and template must be present in local catalog
218

219
            # or built programmatically
220
            card = TaskCard(...)
221
            template = Template(...)
222
            loader_limit = 10
223
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
224

225
    """
226
    recipe = load_recipe(dataset_query, **kwargs)
1✔
227

228
    dataset = _source_to_dataset(
1✔
229
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
230
    )
231

232
    frame = inspect.currentframe()
1✔
233
    args, _, _, values = inspect.getargvalues(frame)
1✔
234
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
235
    all_kwargs.update(kwargs)
1✔
236
    metadata = fill_metadata(**all_kwargs)
1✔
237
    if isinstance(dataset, dict):
1✔
238
        for ds in dataset.values():
1✔
239
            ds.info.description = metadata.copy()
1✔
240
    else:
241
        dataset.info.description = metadata
1✔
242
    return dataset
1✔
243

244

245
def fill_metadata(**kwargs):
1✔
246
    metadata = kwargs.copy()
1✔
247
    metadata["unitxt_version"] = get_constants().version
1✔
248
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
249
    return metadata
1✔
250

251

252
def evaluate(
1✔
253
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
254
) -> EvaluationResults:
255
    if dataset is None and data is None:
1✔
256
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
257
    if data is not None:
1✔
258
        dataset = data  # for backward compatibility
1✔
259
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
260
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
261
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
262
    if hasattr(predictions, "metadata"):
1✔
263
        evaluation_result.metadata["predictions"] = predictions.metadata
×
264
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
265
        "%Y-%m-%d %H:%M:%S.%f"
266
    )[:-3]
267
    return evaluation_result
1✔
268

269

270
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
271
    return _inference_post_process(predictions=predictions, references=data)
1✔
272

273

274
@lru_cache
1✔
275
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
276
    return load_recipe(dataset_query, **kwargs).produce
1✔
277

278

279
def produce(
1✔
280
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
281
) -> Union[Dataset, Dict[str, Any]]:
282
    is_list = isinstance(instance_or_instances, list)
1✔
283
    if not is_list:
1✔
284
        instance_or_instances = [instance_or_instances]
1✔
285
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
286
    if not is_list:
1✔
287
        return result[0]
1✔
288
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
289

290

291
def infer(
1✔
292
    instance_or_instances,
293
    engine: InferenceEngine,
294
    dataset_query: Optional[str] = None,
295
    return_data: bool = False,
296
    return_log_probs: bool = False,
297
    return_meta_data: bool = False,
298
    previous_messages: Optional[List[Dict[str, str]]] = None,
299
    **kwargs,
300
):
301
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
302
    if previous_messages is not None:
1✔
303

304
        def add_previous_messages(example, index):
1✔
305
            example["source"] = previous_messages[index] + example["source"]
1✔
306
            return example
1✔
307

308
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
309
    engine, _ = fetch_artifact(engine)
1✔
310
    if return_log_probs:
1✔
311
        if not isinstance(engine, LogProbInferenceEngine):
×
312
            raise NotImplementedError(
×
313
                f"Error in infer: return_log_probs set to True but supplied engine "
314
                f"{engine.__class__.__name__} does not support logprobs."
315
            )
316
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
317
        raw_predictions = (
×
318
            [output.prediction for output in infer_outputs]
319
            if return_meta_data
320
            else infer_outputs
321
        )
322
        raw_predictions = [
×
323
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
324
        ]
325
    else:
326
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
327
        raw_predictions = (
1✔
328
            [output.prediction for output in infer_outputs]
329
            if return_meta_data
330
            else infer_outputs
331
        )
332
    predictions = post_process(raw_predictions, dataset)
1✔
333
    if return_data:
1✔
334
        if return_meta_data:
1✔
335
            infer_output_list = [
×
336
                infer_output.__dict__ for infer_output in infer_outputs
337
            ]
338
            for infer_output in infer_output_list:
×
339
                del infer_output["prediction"]
×
340
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
341
        dataset = dataset.add_column("prediction", predictions)
1✔
342
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
343
    return predictions
1✔
344

345

346
def select(
1✔
347
    instance_or_instances,
348
    engine: OptionSelectingByLogProbsInferenceEngine,
349
    dataset_query: Optional[str] = None,
350
    return_data: bool = False,
351
    previous_messages: Optional[List[Dict[str, str]]] = None,
352
    **kwargs,
353
):
354
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
355
    if previous_messages is not None:
×
356

357
        def add_previous_messages(example, index):
×
358
            example["source"] = previous_messages[index] + example["source"]
×
359
            return example
×
360

361
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
362
    engine, _ = fetch_artifact(engine)
×
363
    predictions = engine.select(dataset)
×
364
    # predictions = post_process(raw_predictions, dataset)
365
    if return_data:
×
366
        return dataset.add_column("prediction", predictions)
×
367
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc