• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14955009337

11 May 2025 10:44AM UTC coverage: 80.026% (-0.05%) from 80.074%
14955009337

push

github

web-flow
Add support to mix args and textual query in load_dataset (#1778)

Signed-off-by: elronbandel <elronbandel@gmail.com>

1644 of 2040 branches covered (80.59%)

Branch coverage included in aggregate %.

10251 of 12824 relevant lines covered (79.94%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.44
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .benchmark import Benchmark
1✔
13
from .card import TaskCard
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_batch
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29

30
logger = get_logger()
1✔
31
constants = get_constants()
1✔
32
settings = get_settings()
1✔
33

34

35
def short_hex_hash(value, length=8):
1✔
36
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
37
    return h[:length]
1✔
38

39

40
def _get_recipe_from_query(dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]]=None) -> DatasetRecipe:
1✔
41
    try:
1✔
42
        dataset_stream, _ = fetch_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
1✔
43
    except:
1✔
44
        dataset_stream = get_dataset_artifact(dataset_query, overwrite_kwargs=overwrite_kwargs)
1✔
45
    return dataset_stream
1✔
46

47

48
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
49
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
50
    for param in dataset_params.keys():
1✔
51
        assert param in recipe_attributes, (
1✔
52
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
53
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
54
        )
55
    return DatasetRecipe(**dataset_params)
1✔
56

57

58
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
59
    if dataset_query and dataset_args:
×
60
        raise ValueError(
×
61
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
62
            "If you want to load dataset from a card in local catalog, use query only. "
63
            "Otherwise, use key-worded arguments only to specify properties of dataset."
64
        )
65

66
    if dataset_query:
×
67
        if not isinstance(dataset_query, str):
×
68
            raise ValueError(
×
69
                f"If specified, 'dataset_query' must be a string, however, "
70
                f"'{dataset_query}' was provided instead, which is of type "
71
                f"'{type(dataset_query)}'."
72
            )
73

74
    if not dataset_query and not dataset_args:
×
75
        raise ValueError(
×
76
            "Either 'dataset_query' or key-worded arguments must be provided."
77
        )
78

79

80
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
81
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
82
        return dataset_query
×
83

84
    if dataset_query:
1✔
85
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
86

87
    elif kwargs:
1✔
88
        recipe = _get_recipe_from_dict(kwargs)
1✔
89

90
    else:
91
        raise UnitxtError("Specify either dataset recipe string artifact name or recipe args.")
×
92

93
    return recipe
1✔
94

95

96
def create_dataset(
1✔
97
    task: Union[str, Task],
98
    test_set: List[Dict[Any, Any]],
99
    train_set: Optional[List[Dict[Any, Any]]] = None,
100
    validation_set: Optional[List[Dict[Any, Any]]] = None,
101
    split: Optional[str] = None,
102
    data_classification_policy:  Optional[List[str]] = None,
103
    **kwargs,
104
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
105
    """Creates dataset from input data based on a specific task.
106

107
    Args:
108
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
109
        test_set : required list of instances
110
        train_set : optional train_set
111
        validation_set: optional validation set
112
        split: optional one split to choose
113
        data_classification_policy: data_classification_policy
114
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
115

116
    Returns:
117
        DatasetDict
118

119
    Example:
120
        template = Template(...)
121
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
122
    """
123
    data = {"test": test_set}
1✔
124
    if train_set is not None:
1✔
125
        data["train"] = train_set
×
126
    if validation_set is not None:
1✔
127
        data["validation"] = validation_set
×
128
    task, _ = fetch_artifact(task)
1✔
129

130
    if "template" not in kwargs and task.default_template is None:
1✔
131
        raise Exception(
×
132
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
133
        )
134

135
    card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
1✔
136
    return load_dataset(card=card, split=split, **kwargs)
1✔
137

138

139
def _source_to_dataset(
1✔
140
    source: SourceOperator,
141
    split=None,
142
    use_cache=False,
143
    streaming=False,
144
):
145
    from .dataset import Dataset as UnitxtDataset
1✔
146

147
    stream = source()
1✔
148

149
    try:
1✔
150
        ds_builder = UnitxtDataset(
1✔
151
            dataset_name="unitxt",
152
            config_name="recipe-" + short_hex_hash(repr(source)),
153
            version=constants.version,
154
        )
155
        if split is not None:
1✔
156
            stream = {split: stream[split]}
1✔
157
        ds_builder._generators = stream
1✔
158

159
        ds_builder.download_and_prepare(
1✔
160
            verification_mode="no_checks",
161
            download_mode=None if use_cache else "force_redownload",
162
        )
163

164
        if streaming:
1✔
165
            return ds_builder.as_streaming_dataset(split=split)
×
166

167
        return ds_builder.as_dataset(
1✔
168
            split=split, run_post_process=False, verification_mode="no_checks"
169
        )
170

171
    except DatasetGenerationError as e:
×
172
        raise e.__cause__
×
173

174

175
def load_dataset(
1✔
176
    dataset_query: Optional[str] = None,
177
    split: Optional[str] = None,
178
    streaming: bool = False,
179
    use_cache: Optional[bool] = False,
180
    **kwargs,
181
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
182
    """Loads dataset.
183

184
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
185
    in local catalog based on parameters specified in the query.
186

187
    Alternatively, dataset is loaded from a provided card based on explicitly
188
    given parameters.
189

190
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
191

192
    Args:
193
        dataset_query (str, optional):
194
            A string query which specifies a dataset to load from
195
            local catalog or name of specific recipe or benchmark in the catalog. For
196
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
197
        streaming (bool, False):
198
            When True yields the data as a stream.
199
            This is useful when loading very large datasets.
200
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
201
        split (str, optional):
202
            The split of the data to load
203
        use_cache (bool, optional):
204
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
205
            If set to False (default), the returned dataset is not cached.
206
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
207
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
208
        **kwargs:
209
            Arguments used to load dataset from provided card, which is not present in local catalog.
210

211
    Returns:
212
        DatasetDict
213

214
    :Example:
215

216
        .. code-block:: python
217

218
            dataset = load_dataset(
219
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
220
            )  # card and template must be present in local catalog
221

222
            # or built programmatically
223
            card = TaskCard(...)
224
            template = Template(...)
225
            loader_limit = 10
226
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
227

228
    """
229
    recipe = load_recipe(dataset_query, **kwargs)
1✔
230

231
    dataset = _source_to_dataset(
1✔
232
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
233
    )
234

235
    frame = inspect.currentframe()
1✔
236
    args, _, _, values = inspect.getargvalues(frame)
1✔
237
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
238
    all_kwargs.update(kwargs)
1✔
239
    metadata = fill_metadata(**all_kwargs)
1✔
240
    if isinstance(dataset, dict):
1✔
241
        for ds in dataset.values():
1✔
242
            ds.info.description = metadata.copy()
1✔
243
    else:
244
        dataset.info.description = metadata
1✔
245
    return dataset
1✔
246

247

248
def fill_metadata(**kwargs):
1✔
249
    metadata = kwargs.copy()
1✔
250
    metadata["unitxt_version"] = get_constants().version
1✔
251
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
252
    return metadata
1✔
253

254

255
def evaluate(
1✔
256
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
257
) -> EvaluationResults:
258
    if dataset is None and data is None:
1✔
259
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
260
    if data is not None:
1✔
261
        dataset = data  # for backward compatibility
1✔
262
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
263
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
264
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
265
    if hasattr(predictions, "metadata"):
1✔
266
        evaluation_result.metadata["predictions"] = predictions.metadata
×
267
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
268
        "%Y-%m-%d %H:%M:%S.%f"
269
    )[:-3]
270
    return evaluation_result
1✔
271

272

273
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
274
    return _inference_post_process(predictions=predictions, references=data)
1✔
275

276

277
@lru_cache
1✔
278
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
279
    return load_recipe(dataset_query, **kwargs).produce
1✔
280

281

282
def produce(
1✔
283
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
284
) -> Union[Dataset, Dict[str, Any]]:
285
    is_list = isinstance(instance_or_instances, list)
1✔
286
    if not is_list:
1✔
287
        instance_or_instances = [instance_or_instances]
1✔
288
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
289
    if not is_list:
1✔
290
        return result[0]
1✔
291
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
292

293

294
def infer(
1✔
295
    instance_or_instances,
296
    engine: InferenceEngine,
297
    dataset_query: Optional[str] = None,
298
    return_data: bool = False,
299
    return_log_probs: bool = False,
300
    return_meta_data: bool = False,
301
    previous_messages: Optional[List[Dict[str, str]]] = None,
302
    **kwargs,
303
):
304
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
305
    if previous_messages is not None:
1✔
306

307
        def add_previous_messages(example, index):
1✔
308
            example["source"] = previous_messages[index] + example["source"]
1✔
309
            return example
1✔
310

311
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
312
    engine, _ = fetch_artifact(engine)
1✔
313
    if return_log_probs:
1✔
314
        if not isinstance(engine, LogProbInferenceEngine):
×
315
            raise NotImplementedError(
×
316
                f"Error in infer: return_log_probs set to True but supplied engine "
317
                f"{engine.__class__.__name__} does not support logprobs."
318
            )
319
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
320
        raw_predictions = (
×
321
            [output.prediction for output in infer_outputs]
322
            if return_meta_data
323
            else infer_outputs
324
        )
325
        raw_predictions = [
×
326
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
327
        ]
328
    else:
329
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
330
        raw_predictions = (
1✔
331
            [output.prediction for output in infer_outputs]
332
            if return_meta_data
333
            else infer_outputs
334
        )
335
    predictions = post_process(raw_predictions, dataset)
1✔
336
    if return_data:
1✔
337
        if return_meta_data:
1✔
338
            infer_output_list = [
×
339
                infer_output.__dict__ for infer_output in infer_outputs
340
            ]
341
            for infer_output in infer_output_list:
×
342
                del infer_output["prediction"]
×
343
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
344
        dataset = dataset.add_column("prediction", predictions)
1✔
345
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
346
    return predictions
1✔
347

348

349
def select(
1✔
350
    instance_or_instances,
351
    engine: OptionSelectingByLogProbsInferenceEngine,
352
    dataset_query: Optional[str] = None,
353
    return_data: bool = False,
354
    previous_messages: Optional[List[Dict[str, str]]] = None,
355
    **kwargs,
356
):
357
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
358
    if previous_messages is not None:
×
359

360
        def add_previous_messages(example, index):
×
361
            example["source"] = previous_messages[index] + example["source"]
×
362
            return example
×
363

364
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
365
    engine, _ = fetch_artifact(engine)
×
366
    predictions = engine.select(dataset)
×
367
    # predictions = post_process(raw_predictions, dataset)
368
    if return_data:
×
369
        return dataset.add_column("prediction", predictions)
×
370
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc