• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14010321715

22 Mar 2025 04:21PM UTC coverage: 80.26% (+0.03%) from 80.233%
14010321715

Pull #1644

github

web-flow
Merge 99e69969e into 4aad1e02f
Pull Request #1644: Use elaborated cache key and use it for filelock semaphore

1578 of 1957 branches covered (80.63%)

Branch coverage included in aggregate %.

9855 of 12288 relevant lines covered (80.2%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.08
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
import os
1✔
5
import random
1✔
6
import time
1✔
7
from datetime import datetime
1✔
8
from functools import lru_cache
1✔
9
from typing import Any, Dict, List, Optional, Union
1✔
10

11
import filelock
1✔
12
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
13
from datasets.exceptions import DatasetGenerationError
1✔
14
from huggingface_hub import constants as hf_constants
1✔
15

16
from .artifact import fetch_artifact
1✔
17
from .card import TaskCard
1✔
18
from .dataclass import to_dict
1✔
19
from .dataset_utils import get_dataset_artifact
1✔
20
from .error_utils import UnitxtError
1✔
21
from .inference import (
1✔
22
    InferenceEngine,
23
    LogProbInferenceEngine,
24
    OptionSelectingByLogProbsInferenceEngine,
25
)
26
from .loaders import LoadFromDictionary
1✔
27
from .logging_utils import get_logger
1✔
28
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
29
from .operator import SourceOperator
1✔
30
from .schema import loads_batch
1✔
31
from .settings_utils import get_constants, get_settings
1✔
32
from .standard import DatasetRecipe
1✔
33
from .task import Task
1✔
34

35
logger = get_logger()
1✔
36
constants = get_constants()
1✔
37
settings = get_settings()
1✔
38

39

40
def short_hex_hash(value, length=8):
1✔
41
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
42
    return h[:length]
1✔
43

44

45
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
46
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
47
    try:
1✔
48
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
49
    except:
1✔
50
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
51
    return dataset_stream
1✔
52

53

54
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
55
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
56
    for param in dataset_params.keys():
1✔
57
        assert param in recipe_attributes, (
1✔
58
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
59
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
60
        )
61
    return DatasetRecipe(**dataset_params)
1✔
62

63

64
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
65
    if dataset_query and dataset_args:
1✔
66
        raise ValueError(
×
67
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
68
            "If you want to load dataset from a card in local catalog, use query only. "
69
            "Otherwise, use key-worded arguments only to specify properties of dataset."
70
        )
71

72
    if dataset_query:
1✔
73
        if not isinstance(dataset_query, str):
1✔
74
            raise ValueError(
×
75
                f"If specified, 'dataset_query' must be a string, however, "
76
                f"'{dataset_query}' was provided instead, which is of type "
77
                f"'{type(dataset_query)}'."
78
            )
79

80
    if not dataset_query and not dataset_args:
1✔
81
        raise ValueError(
×
82
            "Either 'dataset_query' or key-worded arguments must be provided."
83
        )
84

85

86
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
87
    if isinstance(dataset_query, DatasetRecipe):
1✔
88
        return dataset_query
×
89

90
    _verify_dataset_args(dataset_query, kwargs)
1✔
91

92
    if dataset_query:
1✔
93
        recipe = _get_recipe_from_query(dataset_query)
1✔
94

95
    if kwargs:
1✔
96
        recipe = _get_recipe_from_dict(kwargs)
1✔
97

98
    return recipe
1✔
99

100

101
def create_dataset(
1✔
102
    task: Union[str, Task],
103
    test_set: List[Dict[Any, Any]],
104
    train_set: Optional[List[Dict[Any, Any]]] = None,
105
    validation_set: Optional[List[Dict[Any, Any]]] = None,
106
    split: Optional[str] = None,
107
    data_classification_policy:  Optional[List[str]] = None,
108
    **kwargs,
109
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
110
    """Creates dataset from input data based on a specific task.
111

112
    Args:
113
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
114
        test_set : required list of instances
115
        train_set : optional train_set
116
        validation_set: optional validation set
117
        split: optional one split to choose
118
        data_classification_policy: data_classification_policy
119
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
120

121
    Returns:
122
        DatasetDict
123

124
    Example:
125
        template = Template(...)
126
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
127
    """
128
    data = {"test": test_set}
1✔
129
    if train_set is not None:
1✔
130
        data["train"] = train_set
×
131
    if validation_set is not None:
1✔
132
        data["validation"] = validation_set
×
133
    task, _ = fetch_artifact(task)
1✔
134

135
    if "template" not in kwargs and task.default_template is None:
1✔
136
        raise Exception(
×
137
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
138
        )
139

140
    card = TaskCard(loader=LoadFromDictionary(data=data, data_classification_policy=data_classification_policy), task=task)
1✔
141
    return load_dataset(card=card, split=split, **kwargs)
1✔
142

143
def object_to_str_without_addresses(obj):
1✔
144
    """Generates a string representation of a Python object while removing memory address references.
145

146
    This function is useful for creating consistent and comparable string representations of objects
147
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
148
    between executions. By stripping the memory address, the function ensures that the representation
149
    is stable and independent of the object's location in memory.
150

151
    Args:
152
        obj: Any Python object to be converted to a string representation.
153

154
    Returns:
155
        str: A string representation of the object with memory addresses removed if present.
156

157
    Example:
158
        ```python
159
        class MyClass:
160
            pass
161

162
        obj = MyClass()
163
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
164
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
165
        ```
166
    """
167
    obj_str = str(obj)
1✔
168
    if " at 0x" in obj_str:
1✔
169
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
170
    return obj_str
1✔
171

172
def _source_to_dataset(
1✔
173
    source: SourceOperator,
174
    split=None,
175
    use_cache=False,
176
    streaming=False,
177
    lock_timeout=60,  # Timeout in seconds for acquiring the lock
178
):
179
    from .dataset import Dataset as UnitxtDataset
1✔
180

181
    # Generate a unique signature for the source
182
    source_signature = json.dumps(to_dict(source, object_to_str_without_addresses), sort_keys=True)
1✔
183
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
184
    hf_cache_home = hf_constants.HF_HOME
1✔
185
    lock_dir = os.path.join(hf_cache_home, "locks")
1✔
186
    os.makedirs(lock_dir, exist_ok=True)
1✔
187

188
    # Create a lock file path based on the dataset configuration
189
    lock_file = os.path.join(lock_dir, f"unitxt_{config_name}.lock")
1✔
190

191
    # Add retry logic
192
    max_attempts = 5
1✔
193
    base_wait = 5  # seconds
1✔
194

195
    stream = source()
1✔
196

197
    try:
1✔
198
        ds_builder = UnitxtDataset(
1✔
199
            dataset_name="unitxt",
200
            config_name=config_name,
201
            version=constants.version,
202
        )
203

204
        if split is not None:
1✔
205
            stream = {split: stream[split]}
1✔
206

207
        ds_builder._generators = stream
1✔
208

209

210
        for attempt in range(max_attempts):
1✔
211
            # Create a file lock with appropriate timeout
212
            lock = filelock.FileLock(lock_file, timeout=300)  # 5 minutes
1✔
213

214
            try:
1✔
215
                with lock:
1✔
216
                    ds_builder.download_and_prepare(
1✔
217
                        verification_mode="no_checks",
218
                        download_mode=None if use_cache else "force_redownload",
219
                    )
220

221
                # If we reach here, the lock was successfully acquired and released
222
                if streaming:
1✔
223
                    return ds_builder.as_streaming_dataset(split=split)
×
224
                return ds_builder.as_dataset(
1✔
225
                    split=split, run_post_process=False, verification_mode="no_checks"
226
                )
227

228
            except filelock.Timeout:
×
229
                if attempt < max_attempts - 1:  # Not the last attempt
×
230
                    wait_time = base_wait * (2 ** attempt) + random.uniform(0, 1)
×
231
                    time.sleep(wait_time)
×
232
                else:
233
                    raise TimeoutError(f"Could not acquire lock for {config_name} after {max_attempts} attempts")
×
234

235
    except DatasetGenerationError as e:
×
236
        raise e.__cause__
×
237

238
def load_dataset(
1✔
239
    dataset_query: Optional[str] = None,
240
    split: Optional[str] = None,
241
    streaming: bool = False,
242
    use_cache: Optional[bool] = False,
243
    **kwargs,
244
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
245
    """Loads dataset.
246

247
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
248
    in local catalog based on parameters specified in the query.
249

250
    Alternatively, dataset is loaded from a provided card based on explicitly
251
    given parameters.
252

253
    Args:
254
        dataset_query (str, optional):
255
            A string query which specifies a dataset to load from
256
            local catalog or name of specific recipe or benchmark in the catalog. For
257
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
258
        streaming (bool, False):
259
            When True yields the data as a stream.
260
            This is useful when loading very large datasets.
261
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
262
        split (str, optional):
263
            The split of the data to load
264
        use_cache (bool, optional):
265
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
266
            If set to False (default), the returned dataset is not cached.
267
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
268
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
269
        **kwargs:
270
            Arguments used to load dataset from provided card, which is not present in local catalog.
271

272
    Returns:
273
        DatasetDict
274

275
    :Example:
276

277
        .. code-block:: python
278

279
            dataset = load_dataset(
280
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
281
            )  # card and template must be present in local catalog
282

283
            # or built programmatically
284
            card = TaskCard(...)
285
            template = Template(...)
286
            loader_limit = 10
287
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
288

289
    """
290
    recipe = load_recipe(dataset_query, **kwargs)
1✔
291

292
    dataset = _source_to_dataset(
1✔
293
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
294
    )
295

296
    frame = inspect.currentframe()
1✔
297
    args, _, _, values = inspect.getargvalues(frame)
1✔
298
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
299
    all_kwargs.update(kwargs)
1✔
300
    metadata = fill_metadata(**all_kwargs)
1✔
301
    if isinstance(dataset, dict):
1✔
302
        for ds in dataset.values():
1✔
303
            ds.info.description = metadata.copy()
1✔
304
    else:
305
        dataset.info.description = metadata
1✔
306
    return dataset
1✔
307

308

309
def fill_metadata(**kwargs):
1✔
310
    metadata = kwargs.copy()
1✔
311
    metadata["unitxt_version"] = get_constants().version
1✔
312
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
313
    return metadata
1✔
314

315

316
def evaluate(
1✔
317
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
318
) -> EvaluationResults:
319
    if dataset is None and data is None:
1✔
320
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
321
    if data is not None:
1✔
322
        dataset = data  # for backward compatibility
1✔
323
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
324
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
325
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
326
    if hasattr(predictions, "metadata"):
1✔
327
        evaluation_result.metadata["predictions"] = predictions.metadata
×
328
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
329
        "%Y-%m-%d %H:%M:%S.%f"
330
    )[:-3]
331
    return evaluation_result
1✔
332

333

334
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
335
    return _inference_post_process(predictions=predictions, references=data)
1✔
336

337

338
@lru_cache
1✔
339
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
340
    return load_recipe(dataset_query, **kwargs).produce
1✔
341

342

343
def produce(
1✔
344
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
345
) -> Union[Dataset, Dict[str, Any]]:
346
    is_list = isinstance(instance_or_instances, list)
1✔
347
    if not is_list:
1✔
348
        instance_or_instances = [instance_or_instances]
1✔
349
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
350
    if not is_list:
1✔
351
        return result[0]
1✔
352
    return Dataset.from_list(result).with_transform(loads_batch)
1✔
353

354

355
def infer(
1✔
356
    instance_or_instances,
357
    engine: InferenceEngine,
358
    dataset_query: Optional[str] = None,
359
    return_data: bool = False,
360
    return_log_probs: bool = False,
361
    return_meta_data: bool = False,
362
    previous_messages: Optional[List[Dict[str, str]]] = None,
363
    **kwargs,
364
):
365
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
366
    if previous_messages is not None:
1✔
367

368
        def add_previous_messages(example, index):
1✔
369
            example["source"] = previous_messages[index] + example["source"]
1✔
370
            return example
1✔
371

372
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
373
    engine, _ = fetch_artifact(engine)
1✔
374
    if return_log_probs:
1✔
375
        if not isinstance(engine, LogProbInferenceEngine):
×
376
            raise NotImplementedError(
×
377
                f"Error in infer: return_log_probs set to True but supplied engine "
378
                f"{engine.__class__.__name__} does not support logprobs."
379
            )
380
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
381
        raw_predictions = (
×
382
            [output.prediction for output in infer_outputs]
383
            if return_meta_data
384
            else infer_outputs
385
        )
386
        raw_predictions = [
×
387
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
388
        ]
389
    else:
390
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
391
        raw_predictions = (
1✔
392
            [output.prediction for output in infer_outputs]
393
            if return_meta_data
394
            else infer_outputs
395
        )
396
    predictions = post_process(raw_predictions, dataset)
1✔
397
    if return_data:
1✔
398
        if return_meta_data:
1✔
399
            infer_output_list = [
×
400
                infer_output.__dict__ for infer_output in infer_outputs
401
            ]
402
            for infer_output in infer_output_list:
×
403
                del infer_output["prediction"]
×
404
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
405
        dataset = dataset.add_column("prediction", predictions)
1✔
406
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
407
    return predictions
1✔
408

409

410
def select(
1✔
411
    instance_or_instances,
412
    engine: OptionSelectingByLogProbsInferenceEngine,
413
    dataset_query: Optional[str] = None,
414
    return_data: bool = False,
415
    previous_messages: Optional[List[Dict[str, str]]] = None,
416
    **kwargs,
417
):
418
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
419
    if previous_messages is not None:
×
420

421
        def add_previous_messages(example, index):
×
422
            example["source"] = previous_messages[index] + example["source"]
×
423
            return example
×
424

425
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
426
    engine, _ = fetch_artifact(engine)
×
427
    predictions = engine.select(dataset)
×
428
    # predictions = post_process(raw_predictions, dataset)
429
    if return_data:
×
430
        return dataset.add_column("prediction", predictions)
×
431
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc