• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12535194121

29 Dec 2024 12:03PM UTC coverage: 80.228% (+0.2%) from 80.023%
12535194121

Pull #1459

github

web-flow
Merge 7067995c0 into def3e0ea1
Pull Request #1459: Add MapReduceMetric a new base class to integrate all metrics into

1365 of 1695 branches covered (80.53%)

Branch coverage included in aggregate %.

8629 of 10762 relevant lines covered (80.18%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

66.67
src/unitxt/api.py
1
import json
1✔
2
from functools import lru_cache
1✔
3
from typing import Any, Dict, List, Optional, Union
1✔
4

5
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
6

7
from .artifact import fetch_artifact
1✔
8
from .card import TaskCard
1✔
9
from .dataset_utils import get_dataset_artifact
1✔
10
from .error_utils import UnitxtError
1✔
11
from .inference import (
1✔
12
    InferenceEngine,
13
    LogProbInferenceEngine,
14
    OptionSelectingByLogProbsInferenceEngine,
15
)
16
from .loaders import LoadFromDictionary
1✔
17
from .logging_utils import get_logger
1✔
18
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
19
from .operator import SourceOperator
1✔
20
from .schema import UNITXT_DATASET_SCHEMA, loads_instance
1✔
21
from .settings_utils import get_constants, get_settings
1✔
22
from .standard import DatasetRecipe
1✔
23
from .task import Task
1✔
24

25
logger = get_logger()
1✔
26
constants = get_constants()
1✔
27
settings = get_settings()
1✔
28

29

30
def load(source: Union[SourceOperator, str]):
1✔
31
    assert isinstance(
×
32
        source, (SourceOperator, str)
33
    ), "source must be a SourceOperator or a string"
34
    if isinstance(source, str):
×
35
        source, _ = fetch_artifact(source)
×
36
    return source().to_dataset()
×
37

38

39
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
40
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
41
    try:
1✔
42
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
43
    except:
1✔
44
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
45
    return dataset_stream
1✔
46

47

48
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
49
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
50
    for param in dataset_params.keys():
1✔
51
        assert param in recipe_attributes, (
1✔
52
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
53
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
54
        )
55
    return DatasetRecipe(**dataset_params)
1✔
56

57

58
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
59
    if dataset_query and dataset_args:
1✔
60
        raise ValueError(
×
61
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
62
            "If you want to load dataset from a card in local catalog, use query only. "
63
            "Otherwise, use key-worded arguments only to specify properties of dataset."
64
        )
65

66
    if dataset_query:
1✔
67
        if not isinstance(dataset_query, str):
1✔
68
            raise ValueError(
×
69
                f"If specified, 'dataset_query' must be a string, however, "
70
                f"'{dataset_query}' was provided instead, which is of type "
71
                f"'{type(dataset_query)}'."
72
            )
73

74
    if not dataset_query and not dataset_args:
1✔
75
        raise ValueError(
×
76
            "Either 'dataset_query' or key-worded arguments must be provided."
77
        )
78

79

80
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
81
    if isinstance(dataset_query, DatasetRecipe):
1✔
82
        return dataset_query
×
83

84
    _verify_dataset_args(dataset_query, kwargs)
1✔
85

86
    if dataset_query:
1✔
87
        recipe = _get_recipe_from_query(dataset_query)
1✔
88

89
    if kwargs:
1✔
90
        recipe = _get_recipe_from_dict(kwargs)
1✔
91

92
    return recipe
1✔
93

94

95
def create_dataset(
1✔
96
    task: Union[str, Task],
97
    test_set: List[Dict[Any, Any]],
98
    train_set: Optional[List[Dict[Any, Any]]] = None,
99
    validation_set: Optional[List[Dict[Any, Any]]] = None,
100
    split: Optional[str] = None,
101
    **kwargs,
102
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
103
    """Creates dataset from input data based on a specific task.
104

105
    Args:
106
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
107
        test_set : required list of instances
108
        train_set : optional train_set
109
        validation_set: optional validation set
110
        split: optional one split to choose
111
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
112

113
    Returns:
114
        DatasetDict
115

116
    Example:
117
        template = Template(...)
118
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
119
    """
120
    data = {"test": test_set}
×
121
    if train_set is not None:
×
122
        data["train"] = train_set
×
123
    if validation_set is not None:
×
124
        data["validation"] = validation_set
×
125
    task, _ = fetch_artifact(task)
×
126

127
    if "template" not in kwargs and task.default_template is None:
×
128
        raise Exception(
×
129
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
130
        )
131

132
    card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
×
133
    return load_dataset(card=card, split=split, **kwargs)
×
134

135

136
def load_dataset(
1✔
137
    dataset_query: Optional[str] = None,
138
    split: Optional[str] = None,
139
    streaming: bool = False,
140
    disable_cache: Optional[bool] = None,
141
    **kwargs,
142
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
143
    """Loads dataset.
144

145
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
146
    in local catalog based on parameters specified in the query.
147

148
    Alternatively, dataset is loaded from a provided card based on explicitly
149
    given parameters.
150

151
    Args:
152
        dataset_query (str, optional):
153
            A string query which specifies a dataset to load from
154
            local catalog or name of specific recipe or benchmark in the catalog. For
155
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
156
        streaming (bool, False):
157
            When True yields the data as Unitxt streams dictionary
158
        split (str, optional):
159
            The split of the data to load
160
        disable_cache (str, optional):
161
            Disable caching process of the data
162
        **kwargs:
163
            Arguments used to load dataset from provided card, which is not present in local catalog.
164

165
    Returns:
166
        DatasetDict
167

168
    :Example:
169

170
        .. code-block:: python
171

172
            dataset = load_dataset(
173
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
174
            )  # card and template must be present in local catalog
175

176
            # or built programmatically
177
            card = TaskCard(...)
178
            template = Template(...)
179
            loader_limit = 10
180
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
181

182
    """
183
    recipe = load_recipe(dataset_query, **kwargs)
1✔
184

185
    stream = recipe()
1✔
186
    if split is not None:
1✔
187
        stream = stream[split]
×
188

189
    if disable_cache is None:
1✔
190
        disable_cache = settings.disable_hf_datasets_cache
1✔
191

192
    if streaming:
1✔
193
        return stream.to_iterable_dataset(
×
194
            features=UNITXT_DATASET_SCHEMA,
195
        ).map(loads_instance, batched=True)
196

197
    return stream.to_dataset(
1✔
198
        features=UNITXT_DATASET_SCHEMA, disable_cache=disable_cache
199
    ).with_transform(loads_instance)
200

201

202
def evaluate(
1✔
203
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
204
) -> EvaluationResults:
205
    if dataset is None and data is None:
1✔
206
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
207
    if data is not None:
1✔
208
        dataset = data  # for backward compatibility
1✔
209
    return _compute(predictions=predictions, references=dataset)
1✔
210

211

212
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
213
    return _inference_post_process(predictions=predictions, references=data)
1✔
214

215

216
@lru_cache
1✔
217
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
218
    return load_recipe(dataset_query, **kwargs).produce
1✔
219

220

221
def produce(
1✔
222
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
223
) -> Union[Dataset, Dict[str, Any]]:
224
    is_list = isinstance(instance_or_instances, list)
1✔
225
    if not is_list:
1✔
226
        instance_or_instances = [instance_or_instances]
1✔
227
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
228
    if not is_list:
1✔
229
        return result[0]
1✔
230
    return Dataset.from_list(result).with_transform(loads_instance)
1✔
231

232

233
def infer(
1✔
234
    instance_or_instances,
235
    engine: InferenceEngine,
236
    dataset_query: Optional[str] = None,
237
    return_data: bool = False,
238
    return_log_probs: bool = False,
239
    return_meta_data: bool = False,
240
    previous_messages: Optional[List[Dict[str, str]]] = None,
241
    **kwargs,
242
):
243
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
244
    if previous_messages is not None:
1✔
245

246
        def add_previous_messages(example, index):
×
247
            example["source"] = previous_messages[index] + example["source"]
×
248
            return example
×
249

250
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
251
    engine, _ = fetch_artifact(engine)
1✔
252
    if return_log_probs:
1✔
253
        if not isinstance(engine, LogProbInferenceEngine):
×
254
            raise NotImplementedError(
×
255
                f"Error in infer: return_log_probs set to True but supplied engine "
256
                f"{engine.__class__.__name__} does not support logprobs."
257
            )
258
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
259
        raw_predictions = (
×
260
            [output.prediction for output in infer_outputs]
261
            if return_meta_data
262
            else infer_outputs
263
        )
264
        raw_predictions = [
×
265
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
266
        ]
267
    else:
268
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
269
        raw_predictions = (
1✔
270
            [output.prediction for output in infer_outputs]
271
            if return_meta_data
272
            else infer_outputs
273
        )
274
    predictions = post_process(raw_predictions, dataset)
1✔
275
    if return_data:
1✔
276
        if return_meta_data:
1✔
277
            infer_output_list = [
×
278
                infer_output.__dict__ for infer_output in infer_outputs
279
            ]
280
            for infer_output in infer_output_list:
×
281
                del infer_output["prediction"]
×
282
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
283
        dataset = dataset.add_column("prediction", predictions)
1✔
284
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
285
    return predictions
1✔
286

287

288
def select(
1✔
289
    instance_or_instances,
290
    engine: OptionSelectingByLogProbsInferenceEngine,
291
    dataset_query: Optional[str] = None,
292
    return_data: bool = False,
293
    previous_messages: Optional[List[Dict[str, str]]] = None,
294
    **kwargs,
295
):
296
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
297
    if previous_messages is not None:
×
298

299
        def add_previous_messages(example, index):
×
300
            example["source"] = previous_messages[index] + example["source"]
×
301
            return example
×
302

303
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
304
    engine, _ = fetch_artifact(engine)
×
305
    predictions = engine.select(dataset)
×
306
    # predictions = post_process(raw_predictions, dataset)
307
    if return_data:
×
308
        return dataset.add_column("prediction", predictions)
×
309
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc