• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15530158291

09 Jun 2025 08:23AM UTC coverage: 80.267% (+0.03%) from 80.242%
15530158291

Pull #1644

github

web-flow
Merge d45934119 into b3a894d7c
Pull Request #1644: Use elaborated cache key and use it for filelock semaphore

1696 of 2089 branches covered (81.19%)

Branch coverage included in aggregate %.

10511 of 13119 relevant lines covered (80.12%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.93
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
import os
1✔
5
import random
1✔
6
import time
1✔
7
from datetime import datetime
1✔
8
from functools import lru_cache
1✔
9
from typing import Any, Dict, List, Optional, Union
1✔
10

11
import filelock
1✔
12
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
13
from datasets.exceptions import DatasetGenerationError
1✔
14
from huggingface_hub import constants as hf_constants
1✔
15

16
from .artifact import fetch_artifact
1✔
17
from .benchmark import Benchmark
1✔
18
from .card import TaskCard
1✔
19
from .dataclass import to_dict
1✔
20
from .dataset_utils import get_dataset_artifact
1✔
21
from .error_utils import UnitxtError
1✔
22
from .inference import (
1✔
23
    InferenceEngine,
24
    LogProbInferenceEngine,
25
    OptionSelectingByLogProbsInferenceEngine,
26
)
27
from .loaders import LoadFromDictionary
1✔
28
from .logging_utils import get_logger
1✔
29
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
30
from .operator import SourceOperator
1✔
31
from .schema import loads_batch
1✔
32
from .settings_utils import get_constants, get_settings
1✔
33
from .standard import DatasetRecipe
1✔
34
from .task import Task
1✔
35

36
logger = get_logger()
1✔
37
constants = get_constants()
1✔
38
settings = get_settings()
1✔
39

40

41
def short_hex_hash(value, length=8):
1✔
42
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
43
    return h[:length]
1✔
44

45

46
def _get_recipe_from_query(
1✔
47
    dataset_query: str, overwrite_kwargs: Optional[Dict[str, Any]] = None
48
) -> DatasetRecipe:
49
    try:
1✔
50
        dataset_stream, _ = fetch_artifact(
1✔
51
            dataset_query, overwrite_kwargs=overwrite_kwargs
52
        )
53
    except:
1✔
54
        dataset_stream = get_dataset_artifact(
1✔
55
            dataset_query, overwrite_kwargs=overwrite_kwargs
56
        )
57
    return dataset_stream
1✔
58

59

60
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
61
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
62
    for param in dataset_params.keys():
1✔
63
        assert param in recipe_attributes, (
1✔
64
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
65
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
66
        )
67
    return DatasetRecipe(**dataset_params)
1✔
68

69

70
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
71
    if dataset_query and dataset_args:
×
72
        raise ValueError(
×
73
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
74
            "If you want to load dataset from a card in local catalog, use query only. "
75
            "Otherwise, use key-worded arguments only to specify properties of dataset."
76
        )
77

78
    if dataset_query:
×
79
        if not isinstance(dataset_query, str):
×
80
            raise ValueError(
×
81
                f"If specified, 'dataset_query' must be a string, however, "
82
                f"'{dataset_query}' was provided instead, which is of type "
83
                f"'{type(dataset_query)}'."
84
            )
85

86
    if not dataset_query and not dataset_args:
×
87
        raise ValueError(
×
88
            "Either 'dataset_query' or key-worded arguments must be provided."
89
        )
90

91

92
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
93
    if isinstance(dataset_query, (DatasetRecipe, Benchmark)):
1✔
94
        return dataset_query
×
95

96
    if dataset_query:
1✔
97
        recipe = _get_recipe_from_query(dataset_query, kwargs)
1✔
98

99
    elif kwargs:
1✔
100
        recipe = _get_recipe_from_dict(kwargs)
1✔
101

102
    else:
103
        raise UnitxtError(
×
104
            "Specify either dataset recipe string artifact name or recipe args."
105
        )
106

107
    return recipe
1✔
108

109

110
def create_dataset(
1✔
111
    task: Union[str, Task],
112
    test_set: List[Dict[Any, Any]],
113
    train_set: Optional[List[Dict[Any, Any]]] = None,
114
    validation_set: Optional[List[Dict[Any, Any]]] = None,
115
    split: Optional[str] = None,
116
    data_classification_policy: Optional[List[str]] = None,
117
    **kwargs,
118
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
119
    """Creates dataset from input data based on a specific task.
120

121
    Args:
122
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
123
        test_set : required list of instances
124
        train_set : optional train_set
125
        validation_set: optional validation set
126
        split: optional one split to choose
127
        data_classification_policy: data_classification_policy
128
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
129

130
    Returns:
131
        DatasetDict
132

133
    Example:
134
        template = Template(...)
135
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
136
    """
137
    data = {"test": test_set}
1✔
138
    if train_set is not None:
1✔
139
        data["train"] = train_set
1✔
140
    if validation_set is not None:
1✔
141
        data["validation"] = validation_set
1✔
142
    task, _ = fetch_artifact(task)
1✔
143

144
    if "template" not in kwargs and task.default_template is None:
1✔
145
        raise Exception(
×
146
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
147
        )
148

149
    card = TaskCard(
1✔
150
        loader=LoadFromDictionary(
151
            data=data, data_classification_policy=data_classification_policy
152
        ),
153
        task=task,
154
    )
155
    return load_dataset(card=card, split=split, **kwargs)
1✔
156

157
def object_to_str_without_addresses(obj):
1✔
158
    """Generates a string representation of a Python object while removing memory address references.
159

160
    This function is useful for creating consistent and comparable string representations of objects
161
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
162
    between executions. By stripping the memory address, the function ensures that the representation
163
    is stable and independent of the object's location in memory.
164

165
    Args:
166
        obj: Any Python object to be converted to a string representation.
167

168
    Returns:
169
        str: A string representation of the object with memory addresses removed if present.
170

171
    Example:
172
        ```python
173
        class MyClass:
174
            pass
175

176
        obj = MyClass()
177
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
178
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
179
        ```
180
    """
181
    obj_str = str(obj)
1✔
182
    if " at 0x" in obj_str:
1✔
183
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
184
    return obj_str
1✔
185

186
def _source_to_dataset(
1✔
187
    source: SourceOperator,
188
    split=None,
189
    use_cache=False,
190
    streaming=False,
191
    lock_timeout=60,  # Timeout in seconds for acquiring the lock
192
):
193
    from .dataset import Dataset as UnitxtDataset
1✔
194

195
    # Generate a unique signature for the source
196
    source_signature = json.dumps(to_dict(source, object_to_str_without_addresses), sort_keys=True)
1✔
197
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
198
    hf_cache_home = hf_constants.HF_HOME
1✔
199
    lock_dir = os.path.join(hf_cache_home, "locks")
1✔
200
    os.makedirs(lock_dir, exist_ok=True)
1✔
201

202
    # Create a lock file path based on the dataset configuration
203
    lock_file = os.path.join(lock_dir, f"unitxt_{config_name}.lock")
1✔
204

205
    # Add retry logic
206
    max_attempts = 5
1✔
207
    base_wait = 5  # seconds
1✔
208

209
    stream = source()
1✔
210

211
    try:
1✔
212
        ds_builder = UnitxtDataset(
1✔
213
            dataset_name="unitxt",
214
            config_name=config_name,
215
            version=constants.version,
216
        )
217

218
        if split is not None:
1✔
219
            stream = {split: stream[split]}
1✔
220

221
        ds_builder._generators = stream
1✔
222

223

224
        for attempt in range(max_attempts):
1✔
225
            # Create a file lock with appropriate timeout
226
            lock = filelock.FileLock(lock_file, timeout=300)  # 5 minutes
1✔
227

228
            try:
1✔
229
                with lock:
1✔
230
                    ds_builder.download_and_prepare(
1✔
231
                        verification_mode="no_checks",
232
                        download_mode=None if use_cache else "force_redownload",
233
                    )
234

235
                # If we reach here, the lock was successfully acquired and released
236
                if streaming:
1✔
237
                    return ds_builder.as_streaming_dataset(split=split)
×
238
                return ds_builder.as_dataset(
1✔
239
                    split=split, run_post_process=False, verification_mode="no_checks"
240
                )
241

242
            except filelock.Timeout:
×
243
                if attempt < max_attempts - 1:  # Not the last attempt
×
244
                    wait_time = base_wait * (2 ** attempt) + random.uniform(0, 1)
×
245
                    time.sleep(wait_time)
×
246
                else:
247
                    raise TimeoutError(f"Could not acquire lock for {config_name} after {max_attempts} attempts")
×
248

249
    except DatasetGenerationError as e:
×
250
        raise e.__cause__
×
251

252
def load_dataset(
1✔
253
    dataset_query: Optional[str] = None,
254
    split: Optional[str] = None,
255
    streaming: bool = False,
256
    use_cache: Optional[bool] = False,
257
    **kwargs,
258
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
259
    """Loads dataset.
260

261
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
262
    in local catalog based on parameters specified in the query.
263

264
    Alternatively, dataset is loaded from a provided card based on explicitly
265
    given parameters.
266

267
    If both are given, then the textual recipe is loaded with the key word args overriding the textual recipe args.
268

269
    Args:
270
        dataset_query (str, optional):
271
            A string query which specifies a dataset to load from
272
            local catalog or name of specific recipe or benchmark in the catalog. For
273
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
274
        streaming (bool, False):
275
            When True yields the data as a stream.
276
            This is useful when loading very large datasets.
277
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
278
        split (str, optional):
279
            The split of the data to load
280
        use_cache (bool, optional):
281
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
282
            If set to False (default), the returned dataset is not cached.
283
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
284
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
285
        **kwargs:
286
            Arguments used to load dataset from provided card, which is not present in local catalog.
287

288
    Returns:
289
        DatasetDict
290

291
    :Example:
292

293
        .. code-block:: python
294

295
            dataset = load_dataset(
296
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
297
            )  # card and template must be present in local catalog
298

299
            # or built programmatically
300
            card = TaskCard(...)
301
            template = Template(...)
302
            loader_limit = 10
303
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
304

305
    """
306
    recipe = load_recipe(dataset_query, **kwargs)
1✔
307

308
    dataset = _source_to_dataset(
1✔
309
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
310
    )
311

312
    frame = inspect.currentframe()
1✔
313
    args, _, _, values = inspect.getargvalues(frame)
1✔
314
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
315
    all_kwargs.update(kwargs)
1✔
316
    metadata = fill_metadata(**all_kwargs)
1✔
317
    if isinstance(dataset, dict):
1✔
318
        for ds in dataset.values():
1✔
319
            ds.info.description = metadata.copy()
1✔
320
    else:
321
        dataset.info.description = metadata
1✔
322
    return dataset
1✔
323

324

325
def fill_metadata(**kwargs):
1✔
326
    metadata = kwargs.copy()
1✔
327
    metadata["unitxt_version"] = get_constants().version
1✔
328
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
329
    return metadata
1✔
330

331

332
def evaluate(
1✔
333
    predictions,
334
    dataset: Union[Dataset, IterableDataset] = None,
335
    data=None,
336
    calc_confidence_intervals: bool = True,
337
) -> EvaluationResults:
338
    if dataset is None and data is None:
1✔
339
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
340
    if data is not None:
1✔
341
        dataset = data  # for backward compatibility
1✔
342
    evaluation_result = _compute(
1✔
343
        predictions=predictions,
344
        references=dataset,
345
        calc_confidence_intervals=calc_confidence_intervals,
346
    )
347
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
348
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
349
    if hasattr(predictions, "metadata"):
1✔
350
        evaluation_result.metadata["predictions"] = predictions.metadata
×
351
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
352
        "%Y-%m-%d %H:%M:%S.%f"
353
    )[:-3]
354
    return evaluation_result
1✔
355

356

357
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
358
    return _inference_post_process(predictions=predictions, references=data)
1✔
359

360

361
@lru_cache
1✔
362
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
363
    return load_recipe(dataset_query, **kwargs).produce
1✔
364

365

366
def produce(
1✔
367
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
368
) -> Union[Dataset, Dict[str, Any]]:
369
    is_list = isinstance(instance_or_instances, list)
1✔
370
    if not is_list:
1✔
371
        instance_or_instances = [instance_or_instances]
1✔
372
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
373
    if not is_list:
1✔
374
        return result[0]
1✔
375
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
376

377

378
def infer(
1✔
379
    instance_or_instances,
380
    engine: InferenceEngine,
381
    dataset_query: Optional[str] = None,
382
    return_data: bool = False,
383
    return_log_probs: bool = False,
384
    return_meta_data: bool = False,
385
    previous_messages: Optional[List[Dict[str, str]]] = None,
386
    **kwargs,
387
):
388
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
389
    if previous_messages is not None:
1✔
390

391
        def add_previous_messages(example, index):
1✔
392
            example["source"] = previous_messages[index] + example["source"]
1✔
393
            return example
1✔
394

395
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
396
    engine, _ = fetch_artifact(engine)
1✔
397
    if return_log_probs:
1✔
398
        if not isinstance(engine, LogProbInferenceEngine):
×
399
            raise NotImplementedError(
×
400
                f"Error in infer: return_log_probs set to True but supplied engine "
401
                f"{engine.__class__.__name__} does not support logprobs."
402
            )
403
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
404
        raw_predictions = (
×
405
            [output.prediction for output in infer_outputs]
406
            if return_meta_data
407
            else infer_outputs
408
        )
409
        raw_predictions = [
×
410
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
411
        ]
412
    else:
413
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
414
        raw_predictions = (
1✔
415
            [output.prediction for output in infer_outputs]
416
            if return_meta_data
417
            else infer_outputs
418
        )
419
    predictions = post_process(raw_predictions, dataset)
1✔
420
    if return_data:
1✔
421
        if return_meta_data:
1✔
422
            infer_output_list = [
×
423
                infer_output.__dict__ for infer_output in infer_outputs
424
            ]
425
            for infer_output in infer_output_list:
×
426
                del infer_output["prediction"]
×
427
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
428
        dataset = dataset.add_column("prediction", predictions)
1✔
429
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
430
    return predictions
1✔
431

432

433
def select(
1✔
434
    instance_or_instances,
435
    engine: OptionSelectingByLogProbsInferenceEngine,
436
    dataset_query: Optional[str] = None,
437
    return_data: bool = False,
438
    previous_messages: Optional[List[Dict[str, str]]] = None,
439
    **kwargs,
440
):
441
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
442
    if previous_messages is not None:
×
443

444
        def add_previous_messages(example, index):
×
445
            example["source"] = previous_messages[index] + example["source"]
×
446
            return example
×
447

448
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
449
    engine, _ = fetch_artifact(engine)
×
450
    predictions = engine.select(dataset)
×
451
    # predictions = post_process(raw_predictions, dataset)
452
    if return_data:
×
453
        return dataset.add_column("prediction", predictions)
×
454
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc