• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13181104872

06 Feb 2025 02:33PM UTC coverage: 79.34% (-0.005%) from 79.345%
13181104872

Pull #1582

github

web-flow
Merge f60e212fd into 85195fae9
Pull Request #1582: Fix attempt to missing arrow dataset

1454 of 1823 branches covered (79.76%)

Branch coverage included in aggregate %.

9176 of 11575 relevant lines covered (79.27%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.06
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .card import TaskCard
1✔
13
from .dataset_utils import get_dataset_artifact
1✔
14
from .error_utils import UnitxtError
1✔
15
from .inference import (
1✔
16
    InferenceEngine,
17
    LogProbInferenceEngine,
18
    OptionSelectingByLogProbsInferenceEngine,
19
)
20
from .loaders import LoadFromDictionary
1✔
21
from .logging_utils import get_logger
1✔
22
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
23
from .operator import SourceOperator
1✔
24
from .schema import loads_instance
1✔
25
from .settings_utils import get_constants, get_settings
1✔
26
from .standard import DatasetRecipe
1✔
27
from .task import Task
1✔
28

29
logger = get_logger()
1✔
30
constants = get_constants()
1✔
31
settings = get_settings()
1✔
32

33

34
def short_hex_hash(value, length=8):
1✔
35
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
36
    return h[:length]
1✔
37

38

39
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
40
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
41
    try:
1✔
42
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
43
    except:
1✔
44
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
45
    return dataset_stream
1✔
46

47

48
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
49
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
50
    for param in dataset_params.keys():
1✔
51
        assert param in recipe_attributes, (
1✔
52
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
53
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
54
        )
55
    return DatasetRecipe(**dataset_params)
1✔
56

57

58
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
59
    if dataset_query and dataset_args:
1✔
60
        raise ValueError(
×
61
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
62
            "If you want to load dataset from a card in local catalog, use query only. "
63
            "Otherwise, use key-worded arguments only to specify properties of dataset."
64
        )
65

66
    if dataset_query:
1✔
67
        if not isinstance(dataset_query, str):
1✔
68
            raise ValueError(
×
69
                f"If specified, 'dataset_query' must be a string, however, "
70
                f"'{dataset_query}' was provided instead, which is of type "
71
                f"'{type(dataset_query)}'."
72
            )
73

74
    if not dataset_query and not dataset_args:
1✔
75
        raise ValueError(
×
76
            "Either 'dataset_query' or key-worded arguments must be provided."
77
        )
78

79

80
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
81
    if isinstance(dataset_query, DatasetRecipe):
1✔
82
        return dataset_query
×
83

84
    _verify_dataset_args(dataset_query, kwargs)
1✔
85

86
    if dataset_query:
1✔
87
        recipe = _get_recipe_from_query(dataset_query)
1✔
88

89
    if kwargs:
1✔
90
        recipe = _get_recipe_from_dict(kwargs)
1✔
91

92
    return recipe
1✔
93

94

95
def create_dataset(
1✔
96
    task: Union[str, Task],
97
    test_set: List[Dict[Any, Any]],
98
    train_set: Optional[List[Dict[Any, Any]]] = None,
99
    validation_set: Optional[List[Dict[Any, Any]]] = None,
100
    split: Optional[str] = None,
101
    **kwargs,
102
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
103
    """Creates dataset from input data based on a specific task.
104

105
    Args:
106
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
107
        test_set : required list of instances
108
        train_set : optional train_set
109
        validation_set: optional validation set
110
        split: optional one split to choose
111
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
112

113
    Returns:
114
        DatasetDict
115

116
    Example:
117
        template = Template(...)
118
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
119
    """
120
    data = {"test": test_set}
1✔
121
    if train_set is not None:
1✔
122
        data["train"] = train_set
×
123
    if validation_set is not None:
1✔
124
        data["validation"] = validation_set
×
125
    task, _ = fetch_artifact(task)
1✔
126

127
    if "template" not in kwargs and task.default_template is None:
1✔
128
        raise Exception(
×
129
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
130
        )
131

132
    card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
1✔
133
    return load_dataset(card=card, split=split, **kwargs)
1✔
134

135

136
def _source_to_dataset(
1✔
137
    source: SourceOperator,
138
    split=None,
139
    use_cache=False,
140
    streaming=False,
141
):
142
    from .dataset import Dataset as UnitxtDataset
1✔
143

144
    stream = source()
1✔
145

146
    try:
1✔
147
        ds_builder = UnitxtDataset(
1✔
148
            dataset_name="unitxt",
149
            config_name="recipe-" + short_hex_hash(repr(source)),
150
            version=constants.version,
151
        )
152
        if split is not None:
1✔
153
            stream = {split: stream[split]}
1✔
154
        ds_builder._generators = stream
1✔
155

156
        ds_builder.download_and_prepare(
1✔
157
            verification_mode="no_checks",
158
            download_mode=None if use_cache else "force_redownload",
159
        )
160

161
        if streaming:
1✔
162
            return ds_builder.as_streaming_dataset(split=split)
×
163

164
        return ds_builder.as_dataset(
1✔
165
            split=split, run_post_process=False, verification_mode="no_checks"
166
        )
167

168
    except DatasetGenerationError as e:
×
169
        raise e.__cause__
×
170

171

172
def load_dataset(
1✔
173
    dataset_query: Optional[str] = None,
174
    split: Optional[str] = None,
175
    streaming: bool = False,
176
    use_cache: Optional[bool] = False,
177
    **kwargs,
178
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
179
    """Loads dataset.
180

181
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
182
    in local catalog based on parameters specified in the query.
183

184
    Alternatively, dataset is loaded from a provided card based on explicitly
185
    given parameters.
186

187
    Args:
188
        dataset_query (str, optional):
189
            A string query which specifies a dataset to load from
190
            local catalog or name of specific recipe or benchmark in the catalog. For
191
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
192
        streaming (bool, False):
193
            When True yields the data as a stream.
194
            This is useful when loading very large datasets.
195
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
196
        split (str, optional):
197
            The split of the data to load
198
        use_cache (bool, optional):
199
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
200
            If set to False (default), the returned dataset is not cached.
201
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
202
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
203
        **kwargs:
204
            Arguments used to load dataset from provided card, which is not present in local catalog.
205

206
    Returns:
207
        DatasetDict
208

209
    :Example:
210

211
        .. code-block:: python
212

213
            dataset = load_dataset(
214
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
215
            )  # card and template must be present in local catalog
216

217
            # or built programmatically
218
            card = TaskCard(...)
219
            template = Template(...)
220
            loader_limit = 10
221
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
222

223
    """
224
    recipe = load_recipe(dataset_query, **kwargs)
1✔
225

226
    dataset = _source_to_dataset(
1✔
227
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
228
    )
229

230
    frame = inspect.currentframe()
1✔
231
    args, _, _, values = inspect.getargvalues(frame)
1✔
232
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
233
    all_kwargs.update(kwargs)
1✔
234
    metadata = fill_metadata(**all_kwargs)
1✔
235
    if isinstance(dataset, dict):
1✔
236
        for ds in dataset.values():
1✔
237
            ds.info.description = metadata.copy()
1✔
238
    else:
239
        dataset.info.description = metadata
1✔
240
    return dataset
1✔
241

242

243
def fill_metadata(**kwargs):
1✔
244
    metadata = kwargs.copy()
1✔
245
    metadata["unitxt_version"] = get_constants().version
1✔
246
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
247
    return metadata
1✔
248

249

250
def evaluate(
1✔
251
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
252
) -> EvaluationResults:
253
    if dataset is None and data is None:
1✔
254
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
255
    if data is not None:
1✔
256
        dataset = data  # for backward compatibility
1✔
257
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
258
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
259
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
260
    if hasattr(predictions, "metadata"):
1✔
261
        evaluation_result.metadata["predictions"] = predictions.metadata
×
262
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
263
        "%Y-%m-%d %H:%M:%S.%f"
264
    )[:-3]
265
    return evaluation_result
1✔
266

267

268
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
269
    return _inference_post_process(predictions=predictions, references=data)
1✔
270

271

272
@lru_cache
1✔
273
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
274
    return load_recipe(dataset_query, **kwargs).produce
1✔
275

276

277
def produce(
1✔
278
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
279
) -> Union[Dataset, Dict[str, Any]]:
280
    is_list = isinstance(instance_or_instances, list)
1✔
281
    if not is_list:
1✔
282
        instance_or_instances = [instance_or_instances]
1✔
283
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
284
    if not is_list:
1✔
285
        return result[0]
1✔
286
    return Dataset.from_list(result).with_transform(loads_instance)
1✔
287

288

289
def infer(
1✔
290
    instance_or_instances,
291
    engine: InferenceEngine,
292
    dataset_query: Optional[str] = None,
293
    return_data: bool = False,
294
    return_log_probs: bool = False,
295
    return_meta_data: bool = False,
296
    previous_messages: Optional[List[Dict[str, str]]] = None,
297
    **kwargs,
298
):
299
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
300
    if previous_messages is not None:
1✔
301

302
        def add_previous_messages(example, index):
×
303
            example["source"] = previous_messages[index] + example["source"]
×
304
            return example
×
305

306
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
307
    engine, _ = fetch_artifact(engine)
1✔
308
    if return_log_probs:
1✔
309
        if not isinstance(engine, LogProbInferenceEngine):
×
310
            raise NotImplementedError(
×
311
                f"Error in infer: return_log_probs set to True but supplied engine "
312
                f"{engine.__class__.__name__} does not support logprobs."
313
            )
314
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
315
        raw_predictions = (
×
316
            [output.prediction for output in infer_outputs]
317
            if return_meta_data
318
            else infer_outputs
319
        )
320
        raw_predictions = [
×
321
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
322
        ]
323
    else:
324
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
325
        raw_predictions = (
1✔
326
            [output.prediction for output in infer_outputs]
327
            if return_meta_data
328
            else infer_outputs
329
        )
330
    predictions = post_process(raw_predictions, dataset)
1✔
331
    if return_data:
1✔
332
        if return_meta_data:
1✔
333
            infer_output_list = [
×
334
                infer_output.__dict__ for infer_output in infer_outputs
335
            ]
336
            for infer_output in infer_output_list:
×
337
                del infer_output["prediction"]
×
338
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
339
        dataset = dataset.add_column("prediction", predictions)
1✔
340
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
341
    return predictions
1✔
342

343

344
def select(
1✔
345
    instance_or_instances,
346
    engine: OptionSelectingByLogProbsInferenceEngine,
347
    dataset_query: Optional[str] = None,
348
    return_data: bool = False,
349
    previous_messages: Optional[List[Dict[str, str]]] = None,
350
    **kwargs,
351
):
352
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
353
    if previous_messages is not None:
×
354

355
        def add_previous_messages(example, index):
×
356
            example["source"] = previous_messages[index] + example["source"]
×
357
            return example
×
358

359
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
360
    engine, _ = fetch_artifact(engine)
×
361
    predictions = engine.select(dataset)
×
362
    # predictions = post_process(raw_predictions, dataset)
363
    if return_data:
×
364
        return dataset.add_column("prediction", predictions)
×
365
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc