• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12787870510

15 Jan 2025 12:02PM UTC coverage: 79.421% (+0.03%) from 79.393%
12787870510

Pull #1512

github

web-flow
Merge 918569791 into 614bb1224
Pull Request #1512: Keep metadata over main unitxt stages

1391 of 1739 branches covered (79.99%)

Branch coverage included in aggregate %.

8767 of 11051 relevant lines covered (79.33%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.74
src/unitxt/api.py
1
import inspect
1✔
2
import json
1✔
3
from datetime import datetime
1✔
4
from functools import lru_cache
1✔
5
from typing import Any, Dict, List, Optional, Union
1✔
6

7
import pkg_resources
1✔
8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9

10
from .artifact import fetch_artifact
1✔
11
from .card import TaskCard
1✔
12
from .dataset_utils import get_dataset_artifact
1✔
13
from .error_utils import UnitxtError
1✔
14
from .inference import (
1✔
15
    InferenceEngine,
16
    LogProbInferenceEngine,
17
    OptionSelectingByLogProbsInferenceEngine,
18
)
19
from .loaders import LoadFromDictionary
1✔
20
from .logging_utils import get_logger
1✔
21
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
22
from .operator import SourceOperator
1✔
23
from .schema import UNITXT_DATASET_SCHEMA, loads_instance
1✔
24
from .settings_utils import get_constants, get_settings
1✔
25
from .standard import DatasetRecipe
1✔
26
from .task import Task
1✔
27

28
logger = get_logger()
1✔
29
constants = get_constants()
1✔
30
settings = get_settings()
1✔
31

32

33
def load(source: Union[SourceOperator, str]):
1✔
34
    assert isinstance(
×
35
        source, (SourceOperator, str)
36
    ), "source must be a SourceOperator or a string"
37
    if isinstance(source, str):
×
38
        source, _ = fetch_artifact(source)
×
39
    return source().to_dataset()
×
40

41

42
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
43
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
44
    try:
1✔
45
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
46
    except:
1✔
47
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
48
    return dataset_stream
1✔
49

50

51
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
52
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
53
    for param in dataset_params.keys():
1✔
54
        assert param in recipe_attributes, (
1✔
55
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
56
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
57
        )
58
    return DatasetRecipe(**dataset_params)
1✔
59

60

61
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
62
    if dataset_query and dataset_args:
1✔
63
        raise ValueError(
×
64
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
65
            "If you want to load dataset from a card in local catalog, use query only. "
66
            "Otherwise, use key-worded arguments only to specify properties of dataset."
67
        )
68

69
    if dataset_query:
1✔
70
        if not isinstance(dataset_query, str):
1✔
71
            raise ValueError(
×
72
                f"If specified, 'dataset_query' must be a string, however, "
73
                f"'{dataset_query}' was provided instead, which is of type "
74
                f"'{type(dataset_query)}'."
75
            )
76

77
    if not dataset_query and not dataset_args:
1✔
78
        raise ValueError(
×
79
            "Either 'dataset_query' or key-worded arguments must be provided."
80
        )
81

82

83
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
84
    if isinstance(dataset_query, DatasetRecipe):
1✔
85
        return dataset_query
×
86

87
    _verify_dataset_args(dataset_query, kwargs)
1✔
88

89
    if dataset_query:
1✔
90
        recipe = _get_recipe_from_query(dataset_query)
1✔
91

92
    if kwargs:
1✔
93
        recipe = _get_recipe_from_dict(kwargs)
1✔
94

95
    return recipe
1✔
96

97

98
def create_dataset(
1✔
99
    task: Union[str, Task],
100
    test_set: List[Dict[Any, Any]],
101
    train_set: Optional[List[Dict[Any, Any]]] = None,
102
    validation_set: Optional[List[Dict[Any, Any]]] = None,
103
    split: Optional[str] = None,
104
    **kwargs,
105
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
106
    """Creates dataset from input data based on a specific task.
107

108
    Args:
109
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
110
        test_set : required list of instances
111
        train_set : optional train_set
112
        validation_set: optional validation set
113
        split: optional one split to choose
114
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
115

116
    Returns:
117
        DatasetDict
118

119
    Example:
120
        template = Template(...)
121
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
122
    """
123
    data = {"test": test_set}
×
124
    if train_set is not None:
×
125
        data["train"] = train_set
×
126
    if validation_set is not None:
×
127
        data["validation"] = validation_set
×
128
    task, _ = fetch_artifact(task)
×
129

130
    if "template" not in kwargs and task.default_template is None:
×
131
        raise Exception(
×
132
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
133
        )
134

135
    card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
×
136
    return load_dataset(card=card, split=split, **kwargs)
×
137

138

139
def load_dataset(
1✔
140
    dataset_query: Optional[str] = None,
141
    split: Optional[str] = None,
142
    streaming: bool = False,
143
    disable_cache: Optional[bool] = None,
144
    **kwargs,
145
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
146
    """Loads dataset.
147

148
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
149
    in local catalog based on parameters specified in the query.
150

151
    Alternatively, dataset is loaded from a provided card based on explicitly
152
    given parameters.
153

154
    Args:
155
        dataset_query (str, optional):
156
            A string query which specifies a dataset to load from
157
            local catalog or name of specific recipe or benchmark in the catalog. For
158
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
159
        streaming (bool, False):
160
            When True yields the data as Unitxt streams dictionary
161
        split (str, optional):
162
            The split of the data to load
163
        disable_cache (str, optional):
164
            Disable caching process of the data
165
        **kwargs:
166
            Arguments used to load dataset from provided card, which is not present in local catalog.
167

168
    Returns:
169
        DatasetDict
170

171
    :Example:
172

173
        .. code-block:: python
174

175
            dataset = load_dataset(
176
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
177
            )  # card and template must be present in local catalog
178

179
            # or built programmatically
180
            card = TaskCard(...)
181
            template = Template(...)
182
            loader_limit = 10
183
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
184

185
    """
186
    recipe = load_recipe(dataset_query, **kwargs)
1✔
187

188
    stream = recipe()
1✔
189
    if split is not None:
1✔
190
        stream = stream[split]
×
191

192
    if disable_cache is None:
1✔
193
        disable_cache = settings.disable_hf_datasets_cache
1✔
194

195
    if streaming:
1✔
196
        dataset = stream.to_iterable_dataset(
×
197
            features=UNITXT_DATASET_SCHEMA,
198
        ).map(loads_instance, batched=True)
199
    else:
200
        dataset = stream.to_dataset(
1✔
201
            features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
202
        ).with_transform(loads_instance)
203

204
    frame = inspect.currentframe()
1✔
205
    args, _, _, values = inspect.getargvalues(frame)
1✔
206
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
207
    all_kwargs.update(kwargs)
1✔
208
    metadata = fill_metadata(**all_kwargs)
1✔
209
    if isinstance(dataset, dict):
1✔
210
        for ds in dataset.values():
1✔
211
            ds.info.description = metadata.copy()
1✔
212
    else:
213
        dataset.info.description = metadata
×
214
    return dataset
1✔
215

216

217
def fill_metadata(**kwargs):
1✔
218
    metadata = kwargs.copy()
1✔
219
    metadata["unitxt_version"] = get_constants().version
1✔
220
    metadata["packages_freeze"] = {
1✔
221
        d.project_name: d.version for d in pkg_resources.working_set
222
    }
223
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
224
    return metadata
1✔
225

226

227
def evaluate(
1✔
228
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
229
) -> EvaluationResults:
230
    if dataset is None and data is None:
1✔
231
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
232
    if data is not None:
1✔
233
        dataset = data  # for backward compatibility
1✔
234
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
235
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
236
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
237
    if hasattr(predictions, "metadata"):
1✔
238
        evaluation_result.metadata["predictions"] = predictions.metadata
×
239
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
240
        "%Y-%m-%d %H:%M:%S.%f"
241
    )[:-3]
242
    return evaluation_result
1✔
243

244

245
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
246
    return _inference_post_process(predictions=predictions, references=data)
1✔
247

248

249
@lru_cache
1✔
250
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
251
    return load_recipe(dataset_query, **kwargs).produce
1✔
252

253

254
def produce(
1✔
255
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
256
) -> Union[Dataset, Dict[str, Any]]:
257
    is_list = isinstance(instance_or_instances, list)
1✔
258
    if not is_list:
1✔
259
        instance_or_instances = [instance_or_instances]
1✔
260
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
261
    if not is_list:
1✔
262
        return result[0]
1✔
263
    return Dataset.from_list(result).with_transform(loads_instance)
1✔
264

265

266
def infer(
1✔
267
    instance_or_instances,
268
    engine: InferenceEngine,
269
    dataset_query: Optional[str] = None,
270
    return_data: bool = False,
271
    return_log_probs: bool = False,
272
    return_meta_data: bool = False,
273
    previous_messages: Optional[List[Dict[str, str]]] = None,
274
    **kwargs,
275
):
276
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
277
    if previous_messages is not None:
1✔
278

279
        def add_previous_messages(example, index):
×
280
            example["source"] = previous_messages[index] + example["source"]
×
281
            return example
×
282

283
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
284
    engine, _ = fetch_artifact(engine)
1✔
285
    if return_log_probs:
1✔
286
        if not isinstance(engine, LogProbInferenceEngine):
×
287
            raise NotImplementedError(
×
288
                f"Error in infer: return_log_probs set to True but supplied engine "
289
                f"{engine.__class__.__name__} does not support logprobs."
290
            )
291
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
292
        raw_predictions = (
×
293
            [output.prediction for output in infer_outputs]
294
            if return_meta_data
295
            else infer_outputs
296
        )
297
        raw_predictions = [
×
298
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
299
        ]
300
    else:
301
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
302
        raw_predictions = (
1✔
303
            [output.prediction for output in infer_outputs]
304
            if return_meta_data
305
            else infer_outputs
306
        )
307
    predictions = post_process(raw_predictions, dataset)
1✔
308
    if return_data:
1✔
309
        if return_meta_data:
1✔
310
            infer_output_list = [
×
311
                infer_output.__dict__ for infer_output in infer_outputs
312
            ]
313
            for infer_output in infer_output_list:
×
314
                del infer_output["prediction"]
×
315
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
316
        dataset = dataset.add_column("prediction", predictions)
1✔
317
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
318
    return predictions
1✔
319

320

321
def select(
1✔
322
    instance_or_instances,
323
    engine: OptionSelectingByLogProbsInferenceEngine,
324
    dataset_query: Optional[str] = None,
325
    return_data: bool = False,
326
    previous_messages: Optional[List[Dict[str, str]]] = None,
327
    **kwargs,
328
):
329
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
330
    if previous_messages is not None:
×
331

332
        def add_previous_messages(example, index):
×
333
            example["source"] = previous_messages[index] + example["source"]
×
334
            return example
×
335

336
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
337
    engine, _ = fetch_artifact(engine)
×
338
    predictions = engine.select(dataset)
×
339
    # predictions = post_process(raw_predictions, dataset)
340
    if return_data:
×
341
        return dataset.add_column("prediction", predictions)
×
342
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc