• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13101814348

02 Feb 2025 07:30PM UTC coverage: 79.304% (-0.02%) from 79.32%
13101814348

push

github

web-flow
Revisit huggingface cache policy (#1564)

* Revisit huggingface cache policy

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Enable streaming for LoadFromHFSpace and clean up commented code

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Disable Hugging Face datasets cache in CatalogPreparationTestCase

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Enable streaming for wiki_bio loader in TaskCard and update JSON configuration

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Add conditional test card execution for 'doqa_travel' subset in chat_rag_bench

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Enhance memory and performance logging in catalog preparation tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Return parallel execution to 1 and adjust modulo for deterministic test runs

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Try 1

Signed-off-by: elronbandel <elronbandel@gmail.com>

* try 1 fixed

Signed-off-by: elronbandel <elronbandel@gmail.com>

* trial 2

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Stop testing social iqa until problem resolved

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update social iqa card to use specific revision and enable testing

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Refactor translation card testing logic and remove unused dataset loading

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update head_qa card loader path and streamline dataset configuration

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Enable streaming for websrc card loader in configuration

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Add revision reference to Winogrande card loaders

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Add revision reference to PIQA card loader

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Another trial

Signed-off-by: elronbandel <e... (continued)

1451 of 1823 branches covered (79.59%)

Branch coverage included in aggregate %.

9163 of 11561 relevant lines covered (79.26%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.23
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
import tempfile
1✔
5
from datetime import datetime
1✔
6
from functools import lru_cache
1✔
7
from typing import Any, Dict, List, Optional, Union
1✔
8

9
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
10
from datasets.exceptions import DatasetGenerationError
1✔
11

12
from .artifact import fetch_artifact
1✔
13
from .card import TaskCard
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_instance
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29

30
logger = get_logger()
1✔
31
constants = get_constants()
1✔
32
settings = get_settings()
1✔
33

34

35
def short_hex_hash(value, length=8):
1✔
36
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
37
    return h[:length]
1✔
38

39

40
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
41
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
42
    try:
1✔
43
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
44
    except:
1✔
45
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
46
    return dataset_stream
1✔
47

48

49
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
50
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
51
    for param in dataset_params.keys():
1✔
52
        assert param in recipe_attributes, (
1✔
53
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
54
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
55
        )
56
    return DatasetRecipe(**dataset_params)
1✔
57

58

59
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
60
    if dataset_query and dataset_args:
1✔
61
        raise ValueError(
×
62
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
63
            "If you want to load dataset from a card in local catalog, use query only. "
64
            "Otherwise, use key-worded arguments only to specify properties of dataset."
65
        )
66

67
    if dataset_query:
1✔
68
        if not isinstance(dataset_query, str):
1✔
69
            raise ValueError(
×
70
                f"If specified, 'dataset_query' must be a string, however, "
71
                f"'{dataset_query}' was provided instead, which is of type "
72
                f"'{type(dataset_query)}'."
73
            )
74

75
    if not dataset_query and not dataset_args:
1✔
76
        raise ValueError(
×
77
            "Either 'dataset_query' or key-worded arguments must be provided."
78
        )
79

80

81
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
82
    if isinstance(dataset_query, DatasetRecipe):
1✔
83
        return dataset_query
×
84

85
    _verify_dataset_args(dataset_query, kwargs)
1✔
86

87
    if dataset_query:
1✔
88
        recipe = _get_recipe_from_query(dataset_query)
1✔
89

90
    if kwargs:
1✔
91
        recipe = _get_recipe_from_dict(kwargs)
1✔
92

93
    return recipe
1✔
94

95

96
def create_dataset(
1✔
97
    task: Union[str, Task],
98
    test_set: List[Dict[Any, Any]],
99
    train_set: Optional[List[Dict[Any, Any]]] = None,
100
    validation_set: Optional[List[Dict[Any, Any]]] = None,
101
    split: Optional[str] = None,
102
    **kwargs,
103
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
104
    """Creates dataset from input data based on a specific task.
105

106
    Args:
107
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
108
        test_set : required list of instances
109
        train_set : optional train_set
110
        validation_set: optional validation set
111
        split: optional one split to choose
112
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
113

114
    Returns:
115
        DatasetDict
116

117
    Example:
118
        template = Template(...)
119
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
120
    """
121
    data = {"test": test_set}
×
122
    if train_set is not None:
×
123
        data["train"] = train_set
×
124
    if validation_set is not None:
×
125
        data["validation"] = validation_set
×
126
    task, _ = fetch_artifact(task)
×
127

128
    if "template" not in kwargs and task.default_template is None:
×
129
        raise Exception(
×
130
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
131
        )
132

133
    card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
×
134
    return load_dataset(card=card, split=split, **kwargs)
×
135

136

137
def _source_to_dataset(
1✔
138
    source: SourceOperator, split=None, use_cache=False, streaming=False
139
):
140
    from .dataset import Dataset as UnitxtDataset
1✔
141

142
    stream = source()
1✔
143

144
    with tempfile.TemporaryDirectory() as dir_to_be_deleted:
1✔
145
        cache_dir = dir_to_be_deleted if not use_cache else None
1✔
146
        ds_builder = UnitxtDataset(
1✔
147
            dataset_name="unitxt",
148
            config_name="recipe-" + short_hex_hash(source.to_json()),
149
            hash=hash(source.to_json()),
150
            version=constants.version,
151
            cache_dir=cache_dir,
152
        )
153
        if split is not None:
1✔
154
            stream = {split: stream[split]}
1✔
155
        ds_builder._generators = stream
1✔
156

157
        try:
1✔
158
            ds_builder.download_and_prepare()
1✔
159

160
            if streaming:
1✔
161
                return ds_builder.as_streaming_dataset(split=split)
×
162

163
            return ds_builder.as_dataset(
1✔
164
                split=split, run_post_process=False, verification_mode="no_checks"
165
            )
166
        except DatasetGenerationError as e:
×
167
            raise e.__cause__
×
168

169

170
def load_dataset(
1✔
171
    dataset_query: Optional[str] = None,
172
    split: Optional[str] = None,
173
    streaming: bool = False,
174
    use_cache: Optional[bool] = False,
175
    **kwargs,
176
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
177
    """Loads dataset.
178

179
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
180
    in local catalog based on parameters specified in the query.
181

182
    Alternatively, dataset is loaded from a provided card based on explicitly
183
    given parameters.
184

185
    Args:
186
        dataset_query (str, optional):
187
            A string query which specifies a dataset to load from
188
            local catalog or name of specific recipe or benchmark in the catalog. For
189
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
190
        streaming (bool, False):
191
            When True yields the data as a stream.
192
            This is useful when loading very large datasets.
193
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
194
        split (str, optional):
195
            The split of the data to load
196
        use_cache (bool, optional):
197
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
198
            If set to False (default), the returned dataset is not cached.
199
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
200
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
201
        **kwargs:
202
            Arguments used to load dataset from provided card, which is not present in local catalog.
203

204
    Returns:
205
        DatasetDict
206

207
    :Example:
208

209
        .. code-block:: python
210

211
            dataset = load_dataset(
212
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
213
            )  # card and template must be present in local catalog
214

215
            # or built programmatically
216
            card = TaskCard(...)
217
            template = Template(...)
218
            loader_limit = 10
219
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
220

221
    """
222
    recipe = load_recipe(dataset_query, **kwargs)
1✔
223

224
    dataset = _source_to_dataset(
1✔
225
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
226
    )
227

228
    frame = inspect.currentframe()
1✔
229
    args, _, _, values = inspect.getargvalues(frame)
1✔
230
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
231
    all_kwargs.update(kwargs)
1✔
232
    metadata = fill_metadata(**all_kwargs)
1✔
233
    if isinstance(dataset, dict):
1✔
234
        for ds in dataset.values():
1✔
235
            ds.info.description = metadata.copy()
1✔
236
    else:
237
        dataset.info.description = metadata
1✔
238
    return dataset
1✔
239

240

241
def fill_metadata(**kwargs):
1✔
242
    metadata = kwargs.copy()
1✔
243
    metadata["unitxt_version"] = get_constants().version
1✔
244
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
245
    return metadata
1✔
246

247

248
def evaluate(
1✔
249
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
250
) -> EvaluationResults:
251
    if dataset is None and data is None:
1✔
252
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
253
    if data is not None:
1✔
254
        dataset = data  # for backward compatibility
1✔
255
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
256
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
257
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
258
    if hasattr(predictions, "metadata"):
1✔
259
        evaluation_result.metadata["predictions"] = predictions.metadata
×
260
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
261
        "%Y-%m-%d %H:%M:%S.%f"
262
    )[:-3]
263
    return evaluation_result
1✔
264

265

266
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
267
    return _inference_post_process(predictions=predictions, references=data)
1✔
268

269

270
@lru_cache
1✔
271
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
272
    return load_recipe(dataset_query, **kwargs).produce
1✔
273

274

275
def produce(
1✔
276
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
277
) -> Union[Dataset, Dict[str, Any]]:
278
    is_list = isinstance(instance_or_instances, list)
1✔
279
    if not is_list:
1✔
280
        instance_or_instances = [instance_or_instances]
1✔
281
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
282
    if not is_list:
1✔
283
        return result[0]
1✔
284
    return Dataset.from_list(result).with_transform(loads_instance)
1✔
285

286

287
def infer(
1✔
288
    instance_or_instances,
289
    engine: InferenceEngine,
290
    dataset_query: Optional[str] = None,
291
    return_data: bool = False,
292
    return_log_probs: bool = False,
293
    return_meta_data: bool = False,
294
    previous_messages: Optional[List[Dict[str, str]]] = None,
295
    **kwargs,
296
):
297
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
298
    if previous_messages is not None:
1✔
299

300
        def add_previous_messages(example, index):
×
301
            example["source"] = previous_messages[index] + example["source"]
×
302
            return example
×
303

304
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
305
    engine, _ = fetch_artifact(engine)
1✔
306
    if return_log_probs:
1✔
307
        if not isinstance(engine, LogProbInferenceEngine):
×
308
            raise NotImplementedError(
×
309
                f"Error in infer: return_log_probs set to True but supplied engine "
310
                f"{engine.__class__.__name__} does not support logprobs."
311
            )
312
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
313
        raw_predictions = (
×
314
            [output.prediction for output in infer_outputs]
315
            if return_meta_data
316
            else infer_outputs
317
        )
318
        raw_predictions = [
×
319
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
320
        ]
321
    else:
322
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
323
        raw_predictions = (
1✔
324
            [output.prediction for output in infer_outputs]
325
            if return_meta_data
326
            else infer_outputs
327
        )
328
    predictions = post_process(raw_predictions, dataset)
1✔
329
    if return_data:
1✔
330
        if return_meta_data:
1✔
331
            infer_output_list = [
×
332
                infer_output.__dict__ for infer_output in infer_outputs
333
            ]
334
            for infer_output in infer_output_list:
×
335
                del infer_output["prediction"]
×
336
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
337
        dataset = dataset.add_column("prediction", predictions)
1✔
338
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
339
    return predictions
1✔
340

341

342
def select(
1✔
343
    instance_or_instances,
344
    engine: OptionSelectingByLogProbsInferenceEngine,
345
    dataset_query: Optional[str] = None,
346
    return_data: bool = False,
347
    previous_messages: Optional[List[Dict[str, str]]] = None,
348
    **kwargs,
349
):
350
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
351
    if previous_messages is not None:
×
352

353
        def add_previous_messages(example, index):
×
354
            example["source"] = previous_messages[index] + example["source"]
×
355
            return example
×
356

357
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
358
    engine, _ = fetch_artifact(engine)
×
359
    predictions = engine.select(dataset)
×
360
    # predictions = post_process(raw_predictions, dataset)
361
    if return_data:
×
362
        return dataset.add_column("prediction", predictions)
×
363
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc