• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13674045039

05 Mar 2025 10:38AM UTC coverage: 80.442% (+0.04%) from 80.404%
13674045039

Pull #1644

github

web-flow
Merge 88875e9f7 into 2ab381920
Pull Request #1644: Use elaborated cache key and use it for filelock semaphore

1551 of 1928 branches covered (80.45%)

Branch coverage included in aggregate %.

9723 of 12087 relevant lines covered (80.44%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.9
src/unitxt/api.py
1
import hashlib
1✔
2
import inspect
1✔
3
import json
1✔
4
from datetime import datetime
1✔
5
from functools import lru_cache
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1✔
9
from datasets.exceptions import DatasetGenerationError
1✔
10

11
from .artifact import fetch_artifact
1✔
12
from .card import TaskCard
1✔
13
from .dataclass import to_dict
1✔
14
from .dataset_utils import get_dataset_artifact
1✔
15
from .error_utils import UnitxtError
1✔
16
from .inference import (
1✔
17
    InferenceEngine,
18
    LogProbInferenceEngine,
19
    OptionSelectingByLogProbsInferenceEngine,
20
)
21
from .loaders import LoadFromDictionary
1✔
22
from .logging_utils import get_logger
1✔
23
from .metric_utils import EvaluationResults, _compute, _inference_post_process
1✔
24
from .operator import SourceOperator
1✔
25
from .schema import loads_instance
1✔
26
from .settings_utils import get_constants, get_settings
1✔
27
from .standard import DatasetRecipe
1✔
28
from .task import Task
1✔
29

30
logger = get_logger()
1✔
31
constants = get_constants()
1✔
32
settings = get_settings()
1✔
33

34

35
def short_hex_hash(value, length=8):
1✔
36
    h = hashlib.sha256(value.encode()).hexdigest()  # Full 64-character hex
1✔
37
    return h[:length]
1✔
38

39

40
def _get_recipe_from_query(dataset_query: str) -> DatasetRecipe:
1✔
41
    dataset_query = dataset_query.replace("sys_prompt", "instruction")
1✔
42
    try:
1✔
43
        dataset_stream, _ = fetch_artifact(dataset_query)
1✔
44
    except:
1✔
45
        dataset_stream = get_dataset_artifact(dataset_query)
1✔
46
    return dataset_stream
1✔
47

48

49
def _get_recipe_from_dict(dataset_params: Dict[str, Any]) -> DatasetRecipe:
1✔
50
    recipe_attributes = list(DatasetRecipe.__dict__["__fields__"].keys())
1✔
51
    for param in dataset_params.keys():
1✔
52
        assert param in recipe_attributes, (
1✔
53
            f"The parameter '{param}' is not an attribute of the 'DatasetRecipe' class. "
54
            f"Please check if the name is correct. The available attributes are: '{recipe_attributes}'."
55
        )
56
    return DatasetRecipe(**dataset_params)
1✔
57

58

59
def _verify_dataset_args(dataset_query: Optional[str] = None, dataset_args=None):
1✔
60
    if dataset_query and dataset_args:
1✔
61
        raise ValueError(
×
62
            "Cannot provide 'dataset_query' and key-worded arguments at the same time. "
63
            "If you want to load dataset from a card in local catalog, use query only. "
64
            "Otherwise, use key-worded arguments only to specify properties of dataset."
65
        )
66

67
    if dataset_query:
1✔
68
        if not isinstance(dataset_query, str):
1✔
69
            raise ValueError(
×
70
                f"If specified, 'dataset_query' must be a string, however, "
71
                f"'{dataset_query}' was provided instead, which is of type "
72
                f"'{type(dataset_query)}'."
73
            )
74

75
    if not dataset_query and not dataset_args:
1✔
76
        raise ValueError(
×
77
            "Either 'dataset_query' or key-worded arguments must be provided."
78
        )
79

80

81
def load_recipe(dataset_query: Optional[str] = None, **kwargs) -> DatasetRecipe:
1✔
82
    if isinstance(dataset_query, DatasetRecipe):
1✔
83
        return dataset_query
×
84

85
    _verify_dataset_args(dataset_query, kwargs)
1✔
86

87
    if dataset_query:
1✔
88
        recipe = _get_recipe_from_query(dataset_query)
1✔
89

90
    if kwargs:
1✔
91
        recipe = _get_recipe_from_dict(kwargs)
1✔
92

93
    return recipe
1✔
94

95

96
def create_dataset(
1✔
97
    task: Union[str, Task],
98
    test_set: List[Dict[Any, Any]],
99
    train_set: Optional[List[Dict[Any, Any]]] = None,
100
    validation_set: Optional[List[Dict[Any, Any]]] = None,
101
    split: Optional[str] = None,
102
    **kwargs,
103
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
104
    """Creates dataset from input data based on a specific task.
105

106
    Args:
107
        task:  The name of the task from the Unitxt Catalog (https://www.unitxt.ai/en/latest/catalog/catalog.tasks.__dir__.html)
108
        test_set : required list of instances
109
        train_set : optional train_set
110
        validation_set: optional validation set
111
        split: optional one split to choose
112
        **kwargs: Arguments used to load dataset from provided datasets (see load_dataset())
113

114
    Returns:
115
        DatasetDict
116

117
    Example:
118
        template = Template(...)
119
        dataset = create_dataset(task="tasks.qa.open", template=template, format="formats.chatapi")
120
    """
121
    data = {"test": test_set}
1✔
122
    if train_set is not None:
1✔
123
        data["train"] = train_set
×
124
    if validation_set is not None:
1✔
125
        data["validation"] = validation_set
×
126
    task, _ = fetch_artifact(task)
1✔
127

128
    if "template" not in kwargs and task.default_template is None:
1✔
129
        raise Exception(
×
130
            f"No 'template' was passed to the create_dataset() and the given task ('{task.__id__}') has no 'default_template' field."
131
        )
132

133
    card = TaskCard(loader=LoadFromDictionary(data=data), task=task)
1✔
134
    return load_dataset(card=card, split=split, **kwargs)
1✔
135

136
def object_to_str_without_addresses(obj):
1✔
137
    """Generates a string representation of a Python object while removing memory address references.
138

139
    This function is useful for creating consistent and comparable string representations of objects
140
    that would otherwise include memory addresses (e.g., `<object_name at 0x123abc>`), which can vary
141
    between executions. By stripping the memory address, the function ensures that the representation
142
    is stable and independent of the object's location in memory.
143

144
    Args:
145
        obj: Any Python object to be converted to a string representation.
146

147
    Returns:
148
        str: A string representation of the object with memory addresses removed if present.
149

150
    Example:
151
        ```python
152
        class MyClass:
153
            pass
154

155
        obj = MyClass()
156
        print(str(obj))  # "<__main__.MyClass object at 0x7f8b9d4d6e20>"
157
        print(to_str_without_addresses(obj))  # "<__main__.MyClass object>"
158
        ```
159
    """
160
    obj_str = str(obj)
1✔
161
    if " at 0x" in obj_str:
1✔
162
        obj_str = obj_str.split(" at 0x")[0] + ">"
1✔
163
    return obj_str
1✔
164

165
def _source_to_dataset(
1✔
166
    source: SourceOperator,
167
    split=None,
168
    use_cache=False,
169
    streaming=False,
170
    lock_timeout=60,  # Timeout in seconds for acquiring the lock
171
):
172
    import json
1✔
173
    import os
1✔
174

175
    import filelock
1✔
176

177
    from .dataset import Dataset as UnitxtDataset
1✔
178

179
    # Generate a unique signature for the source
180
    source_signature = json.dumps(to_dict(source, object_to_str_without_addresses), sort_keys=True)
1✔
181
    config_name = "recipe-" + short_hex_hash(source_signature)
1✔
182

183
    stream = source()
1✔
184

185
    try:
1✔
186
        ds_builder = UnitxtDataset(
1✔
187
            dataset_name="unitxt",
188
            config_name=config_name,
189
            version=constants.version,
190
        )
191

192
        if split is not None:
1✔
193
            stream = {split: stream[split]}
1✔
194

195
        ds_builder._generators = stream
1✔
196

197
        # Create a lock file path based on the dataset configuration
198
        lock_file = os.path.join(os.path.expanduser("~"), ".cache", "unitxt", f"{config_name}.lock")
1✔
199
        os.makedirs(os.path.dirname(lock_file), exist_ok=True)
1✔
200

201
        # Create a file lock
202
        lock = filelock.FileLock(lock_file, timeout=lock_timeout)
1✔
203

204
        # Only protect the download_and_prepare operation with the lock
205
        try:
1✔
206
            with lock:
1✔
207
                ds_builder.download_and_prepare(
1✔
208
                    verification_mode="no_checks",
209
                    download_mode=None if use_cache else "force_redownload",
210
                )
211
        except filelock.Timeout:
×
212
            raise TimeoutError(f"Could not acquire lock for {config_name} within {lock_timeout} seconds. Another process may be preparing the same dataset.")
×
213

214
        if streaming:
1✔
215
            return ds_builder.as_streaming_dataset(split=split)
×
216
        return ds_builder.as_dataset(
1✔
217
            split=split, run_post_process=False, verification_mode="no_checks"
218
        )
219
    except DatasetGenerationError as e:
×
220
        raise e.__cause__
×
221

222
def load_dataset(
1✔
223
    dataset_query: Optional[str] = None,
224
    split: Optional[str] = None,
225
    streaming: bool = False,
226
    use_cache: Optional[bool] = False,
227
    **kwargs,
228
) -> Union[DatasetDict, IterableDatasetDict, Dataset, IterableDataset]:
229
    """Loads dataset.
230

231
    If the 'dataset_query' argument is provided, then dataset is loaded from a card
232
    in local catalog based on parameters specified in the query.
233

234
    Alternatively, dataset is loaded from a provided card based on explicitly
235
    given parameters.
236

237
    Args:
238
        dataset_query (str, optional):
239
            A string query which specifies a dataset to load from
240
            local catalog or name of specific recipe or benchmark in the catalog. For
241
            example, ``"card=cards.wnli,template=templates.classification.multi_class.relation.default"``.
242
        streaming (bool, False):
243
            When True yields the data as a stream.
244
            This is useful when loading very large datasets.
245
            Loading datasets as streams avoid loading all the data to memory, but requires the dataset's loader to support streaming.
246
        split (str, optional):
247
            The split of the data to load
248
        use_cache (bool, optional):
249
            If set to True, the returned Huggingface dataset is cached on local disk such that if the same dataset is loaded again, it will be loaded from local disk, resulting in faster runs.
250
            If set to False (default), the returned dataset is not cached.
251
            Note that if caching is enabled and the dataset card definition is changed, the old version in the cache may be returned.
252
            Enable caching only if you are sure you are working with fixed Unitxt datasets and definitions (e.g. running using predefined datasets from the Unitxt catalog).
253
        **kwargs:
254
            Arguments used to load dataset from provided card, which is not present in local catalog.
255

256
    Returns:
257
        DatasetDict
258

259
    :Example:
260

261
        .. code-block:: python
262

263
            dataset = load_dataset(
264
                dataset_query="card=cards.stsb,template=templates.regression.two_texts.simple,max_train_instances=5"
265
            )  # card and template must be present in local catalog
266

267
            # or built programmatically
268
            card = TaskCard(...)
269
            template = Template(...)
270
            loader_limit = 10
271
            dataset = load_dataset(card=card, template=template, loader_limit=loader_limit)
272

273
    """
274
    recipe = load_recipe(dataset_query, **kwargs)
1✔
275

276
    dataset = _source_to_dataset(
1✔
277
        source=recipe, split=split, use_cache=use_cache, streaming=streaming
278
    )
279

280
    frame = inspect.currentframe()
1✔
281
    args, _, _, values = inspect.getargvalues(frame)
1✔
282
    all_kwargs = {key: values[key] for key in args if key != "kwargs"}
1✔
283
    all_kwargs.update(kwargs)
1✔
284
    metadata = fill_metadata(**all_kwargs)
1✔
285
    if isinstance(dataset, dict):
1✔
286
        for ds in dataset.values():
1✔
287
            ds.info.description = metadata.copy()
1✔
288
    else:
289
        dataset.info.description = metadata
1✔
290
    return dataset
1✔
291

292

293
def fill_metadata(**kwargs):
1✔
294
    metadata = kwargs.copy()
1✔
295
    metadata["unitxt_version"] = get_constants().version
1✔
296
    metadata["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
1✔
297
    return metadata
1✔
298

299

300
def evaluate(
1✔
301
    predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
302
) -> EvaluationResults:
303
    if dataset is None and data is None:
1✔
304
        raise UnitxtError(message="Specify 'dataset' in evaluate")
×
305
    if data is not None:
1✔
306
        dataset = data  # for backward compatibility
1✔
307
    evaluation_result = _compute(predictions=predictions, references=dataset)
1✔
308
    if hasattr(dataset, "info") and hasattr(dataset.info, "description"):
1✔
309
        evaluation_result.metadata["dataset"] = dataset.info.description
1✔
310
    if hasattr(predictions, "metadata"):
1✔
311
        evaluation_result.metadata["predictions"] = predictions.metadata
×
312
    evaluation_result.metadata["creation_time"] = datetime.now().strftime(
1✔
313
        "%Y-%m-%d %H:%M:%S.%f"
314
    )[:-3]
315
    return evaluation_result
1✔
316

317

318
def post_process(predictions, data) -> List[Dict[str, Any]]:
1✔
319
    return _inference_post_process(predictions=predictions, references=data)
1✔
320

321

322
@lru_cache
1✔
323
def _get_produce_with_cache(dataset_query: Optional[str] = None, **kwargs):
1✔
324
    return load_recipe(dataset_query, **kwargs).produce
1✔
325

326

327
def produce(
1✔
328
    instance_or_instances, dataset_query: Optional[str] = None, **kwargs
329
) -> Union[Dataset, Dict[str, Any]]:
330
    is_list = isinstance(instance_or_instances, list)
1✔
331
    if not is_list:
1✔
332
        instance_or_instances = [instance_or_instances]
1✔
333
    result = _get_produce_with_cache(dataset_query, **kwargs)(instance_or_instances)
1✔
334
    if not is_list:
1✔
335
        return result[0]
1✔
336
    return Dataset.from_list(result).with_transform(loads_instance)
1✔
337

338

339
def infer(
1✔
340
    instance_or_instances,
341
    engine: InferenceEngine,
342
    dataset_query: Optional[str] = None,
343
    return_data: bool = False,
344
    return_log_probs: bool = False,
345
    return_meta_data: bool = False,
346
    previous_messages: Optional[List[Dict[str, str]]] = None,
347
    **kwargs,
348
):
349
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
1✔
350
    if previous_messages is not None:
1✔
351

352
        def add_previous_messages(example, index):
1✔
353
            example["source"] = previous_messages[index] + example["source"]
1✔
354
            return example
1✔
355

356
        dataset = dataset.map(add_previous_messages, with_indices=True)
1✔
357
    engine, _ = fetch_artifact(engine)
1✔
358
    if return_log_probs:
1✔
359
        if not isinstance(engine, LogProbInferenceEngine):
×
360
            raise NotImplementedError(
×
361
                f"Error in infer: return_log_probs set to True but supplied engine "
362
                f"{engine.__class__.__name__} does not support logprobs."
363
            )
364
        infer_outputs = engine.infer_log_probs(dataset, return_meta_data)
×
365
        raw_predictions = (
×
366
            [output.prediction for output in infer_outputs]
367
            if return_meta_data
368
            else infer_outputs
369
        )
370
        raw_predictions = [
×
371
            json.dumps(raw_prediction) for raw_prediction in raw_predictions
372
        ]
373
    else:
374
        infer_outputs = engine.infer(dataset, return_meta_data)
1✔
375
        raw_predictions = (
1✔
376
            [output.prediction for output in infer_outputs]
377
            if return_meta_data
378
            else infer_outputs
379
        )
380
    predictions = post_process(raw_predictions, dataset)
1✔
381
    if return_data:
1✔
382
        if return_meta_data:
1✔
383
            infer_output_list = [
×
384
                infer_output.__dict__ for infer_output in infer_outputs
385
            ]
386
            for infer_output in infer_output_list:
×
387
                del infer_output["prediction"]
×
388
            dataset = dataset.add_column("infer_meta_data", infer_output_list)
×
389
        dataset = dataset.add_column("prediction", predictions)
1✔
390
        return dataset.add_column("raw_prediction", raw_predictions)
1✔
391
    return predictions
1✔
392

393

394
def select(
1✔
395
    instance_or_instances,
396
    engine: OptionSelectingByLogProbsInferenceEngine,
397
    dataset_query: Optional[str] = None,
398
    return_data: bool = False,
399
    previous_messages: Optional[List[Dict[str, str]]] = None,
400
    **kwargs,
401
):
402
    dataset = produce(instance_or_instances, dataset_query, **kwargs)
×
403
    if previous_messages is not None:
×
404

405
        def add_previous_messages(example, index):
×
406
            example["source"] = previous_messages[index] + example["source"]
×
407
            return example
×
408

409
        dataset = dataset.map(add_previous_messages, with_indices=True)
×
410
    engine, _ = fetch_artifact(engine)
×
411
    predictions = engine.select(dataset)
×
412
    # predictions = post_process(raw_predictions, dataset)
413
    if return_data:
×
414
        return dataset.add_column("prediction", predictions)
×
415
    return predictions
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc