• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16704320175

03 Aug 2025 11:05AM UTC coverage: 80.829% (-0.4%) from 81.213%
16704320175

Pull #1845

github

web-flow
Merge 59428aa88 into 5372aa6df
Pull Request #1845: Allow using python functions instead of operators (e.g in pre-processing pipeline)

1576 of 1970 branches covered (80.0%)

Branch coverage included in aggregate %.

10685 of 13199 relevant lines covered (80.95%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.84
src/unitxt/utils.py
1
import copy
1✔
2
import functools
1✔
3
import importlib.util
1✔
4
import inspect
1✔
5
import json
1✔
6
import os
1✔
7
import random
1✔
8
import re
1✔
9
import time
1✔
10
import types
1✔
11
from collections import OrderedDict
1✔
12
from contextvars import ContextVar
1✔
13
from functools import wraps
1✔
14
from typing import Any, Dict, Optional
1✔
15
from urllib.error import HTTPError as UrllibHTTPError
1✔
16

17
from requests.exceptions import ConnectionError, HTTPError
1✔
18
from requests.exceptions import Timeout as TimeoutError
1✔
19

20
from .logging_utils import get_logger
1✔
21
from .settings_utils import get_settings
1✔
22
from .text_utils import is_made_of_sub_strings
1✔
23

24
logger = get_logger()
25
settings = get_settings()
1✔
26

27

28
def retry_connection_with_exponential_backoff(
1✔
29
    max_retries=None,
30
    retry_exceptions=(
31
        ConnectionError,
32
        TimeoutError,
33
        HTTPError,
34
        FileNotFoundError,
35
        UrllibHTTPError,
36
    ),
37
    backoff_factor=1,
38
):
39
    """Decorator that implements retry with exponential backoff for network operations.
40

41
    Also handles errors that were triggered by the specified retry exceptions,
42
    whether they're direct causes or part of the exception context.
43

44
    Args:
45
        max_retries: Maximum number of retry attempts (falls back to settings if None)
46
        retry_exceptions: Tuple of exceptions that should trigger a retry
47
        backoff_factor: Base delay factor in seconds for backoff calculation
48

49
    Returns:
50
        The decorated function with retry logic
51
    """
52

53
    def decorator(func):
1✔
54
        @functools.wraps(func)
1✔
55
        def wrapper(*args, **kwargs):
1✔
56
            # Get max_retries from settings if not provided
57
            retries = (
1✔
58
                max_retries
59
                if max_retries is not None
60
                else settings.max_connection_retries
61
            )
62

63
            for attempt in range(retries):
1✔
64
                try:
1✔
65
                    return func(*args, **kwargs)
1✔
66
                except Exception as e:
67
                    # Check if this exception or any of its causes match the retry exceptions
68
                    should_retry = False
69
                    current_exc = e
70

71
                    # Check the exception chain for both __cause__ (explicit) and __context__ (implicit)
72
                    visited_exceptions = (
73
                        set()
74
                    )  # To prevent infinite loops in rare cyclic exception references
75

76
                    while (
77
                        current_exc is not None
78
                        and id(current_exc) not in visited_exceptions
79
                    ):
80
                        visited_exceptions.add(id(current_exc))
81

82
                        if isinstance(current_exc, retry_exceptions):
83
                            should_retry = True
84
                            break
85

86
                        # First check __cause__ (from "raise X from Y")
87
                        if current_exc.__cause__ is not None:
88
                            current_exc = current_exc.__cause__
89
                        # Then check __context__ (from "try: ... except: raise X")
90
                        elif current_exc.__context__ is not None:
91
                            current_exc = current_exc.__context__
92
                        else:
93
                            # No more causes in the chain
94
                            break
95

96
                    if not should_retry:
97
                        # Not a retry exception or caused by a retry exception, so re-raise
98
                        raise
99

100
                    if attempt >= retries - 1:  # Last attempt
101
                        raise  # Re-raise the last exception
102

103
                    # Calculate exponential backoff with jitter
104
                    wait_time = backoff_factor * (2**attempt) + random.uniform(0, 1)
105
                    logger.warning(
106
                        f"{func.__name__} failed (attempt {attempt+1}/{retries}). "
107
                        f"Retrying in {wait_time:.2f}s. Error: {e!s}"
108
                    )
109
                    time.sleep(wait_time)
110

111
            raise ValueError("there was a problem") from None
112

113
        return wrapper
1✔
114

115
    return decorator
1✔
116

117

118
class Singleton(type):
1✔
119
    _instances = {}
1✔
120

121
    def __call__(cls, *args, **kwargs):
1✔
122
        if cls not in cls._instances:
1✔
123
            cls._instances[cls] = super().__call__(*args, **kwargs)
1✔
124
        return cls._instances[cls]
1✔
125

126

127
class LRUCache:
1✔
128
    def __init__(self, max_size: Optional[int] = 10):
1✔
129
        self._max_size = max_size
1✔
130
        self._context_cache = ContextVar("context_lru_cache", default=None)
1✔
131

132
    def _get_cache(self):
1✔
133
        cache = self._context_cache.get()
1✔
134
        if cache is None:
1✔
135
            cache = OrderedDict()
1✔
136
            self._context_cache.set(cache)
1✔
137
        return cache
1✔
138

139
    def __setitem__(self, key, value):
1✔
140
        cache = self._get_cache()
1✔
141
        if key in cache:
1✔
142
            cache.pop(key)
×
143
        cache[key] = value
1✔
144
        if self._max_size is not None:
1✔
145
            while len(cache) > self._max_size:
1✔
146
                cache.popitem(last=False)
1✔
147

148
    def __getitem__(self, key):
1✔
149
        cache = self._get_cache()
1✔
150
        if key in cache:
1✔
151
            value = cache.pop(key)
1✔
152
            cache[key] = value
1✔
153
            return value
1✔
154
        raise KeyError(f"{key} not found in cache")
×
155

156
    def get(self, key, default=None):
1✔
157
        cache = self._get_cache()
1✔
158
        if key in cache:
1✔
159
            value = cache.pop(key)
1✔
160
            cache[key] = value
1✔
161
            return value
1✔
162
        return default
1✔
163

164
    def clear(self):
1✔
165
        """Clear all items from the cache."""
166
        cache = self._get_cache()
1✔
167
        cache.clear()
1✔
168

169
    def __contains__(self, key):
1✔
170
        return key in self._get_cache()
1✔
171

172
    def __len__(self):
1✔
173
        return len(self._get_cache())
1✔
174

175
    def __repr__(self):
1✔
176
        return f"LRUCache(max_size={self._max_size}, items={list(self._get_cache().items())})"
×
177

178

179
def lru_cache_decorator(max_size=128):
1✔
180
    def decorator(func):
1✔
181
        cache = LRUCache(max_size=max_size)
1✔
182

183
        @wraps(func)
1✔
184
        def wrapper(*args, **kwargs):
1✔
185
            key = args
1✔
186
            if kwargs:
1✔
187
                key += tuple(sorted(kwargs.items()))
1✔
188
            if key in cache:
1✔
189
                return cache[key]
1✔
190
            result = func(*args, **kwargs)
1✔
191
            cache[key] = result
1✔
192
            return result
1✔
193

194
        wrapper.cache_clear = cache.clear
1✔
195
        return wrapper
1✔
196

197
    return decorator
1✔
198

199

200
@lru_cache_decorator(max_size=None)
1✔
201
def artifacts_json_cache(artifact_path):
1✔
202
    return load_json(artifact_path)
1✔
203

204

205
def flatten_dict(
1✔
206
    d: Dict[str, Any], parent_key: str = "", sep: str = "_"
207
) -> Dict[str, Any]:
208
    items = []
1✔
209
    for k, v in d.items():
1✔
210
        new_key = parent_key + sep + k if parent_key else k
1✔
211
        if isinstance(v, dict):
1✔
212
            items.extend(flatten_dict(v, new_key, sep=sep).items())
1✔
213
        else:
214
            items.append((new_key, v))
1✔
215

216
    return dict(items)
1✔
217

218

219
def load_json(path):
1✔
220
    with open(path) as f:
1✔
221
        try:
1✔
222
            return json.load(f, object_hook=decode_function)
1✔
223
        except json.decoder.JSONDecodeError as e:
×
224
            with open(path) as f:
×
225
                file_content = "\n".join(f.readlines())
×
226
            raise RuntimeError(
×
227
                f"Failed to decode json file at '{path}' with file content:\n{file_content}"
228
            ) from e
229

230

231
def save_to_file(path, data):
1✔
232
    with open(path, "w") as f:
1✔
233
        f.write(data)
1✔
234
        f.write("\n")
1✔
235

236

237
def encode_function(obj):
1✔
238
    # Allow only plain (module-level) functions
239
    if isinstance(obj, types.FunctionType):
1✔
240
        try:
1✔
241
            return {"__function__": obj.__name__, "source": get_function_source(obj)}
1✔
242
        except Exception as e:
243
            raise TypeError(f"Failed to serialize function {obj.__name__}") from e
244
    elif isinstance(obj, types.MethodType):
×
245
        raise TypeError(
×
246
            f"Method {obj.__func__.__name__} of class {obj.__self__.__class__.__name__} is not JSON serializable"
247
        )
248
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
×
249

250

251
def json_dump(data, sort_keys=False):
1✔
252
    return json.dumps(
1✔
253
        data, indent=4, default=encode_function, ensure_ascii=False, sort_keys=sort_keys
254
    )
255

256

257
def get_function_source(func):
1✔
258
    if hasattr(func, "__exec_source__"):
1✔
259
        return func.__exec_source__
×
260
    return inspect.getsource(func)
1✔
261

262

263
def decode_function(obj):
1✔
264
    # Detect our special function marker
265
    if "__function__" in obj and "source" in obj:
1✔
266
        namespace = {}
1✔
267
        func_name = obj["__function__"]
1✔
268
        try:
1✔
269
            exec(obj["source"], namespace)
1✔
270
            func = namespace.get(func_name)
1✔
271
            func.__exec_source__ = obj["source"]
1✔
272
            if not callable(func):
1✔
273
                raise ValueError(
274
                    f"Source did not define a callable named {func_name!r}"
275
                )
276
            return func
1✔
277
        except Exception as e:
278
            raise ValueError(
279
                f"Failed to load function {func_name!r} from source:\n{obj['source']}"
280
            ) from e
281

282
    return obj
1✔
283

284

285
def json_load(s):
1✔
286
    return json.loads(s, object_hook=decode_function)
1✔
287

288

289
def is_package_installed(package_name):
1✔
290
    """Check if a package is installed.
291

292
    Parameters:
293
    - package_name (str): The name of the package to check.
294

295
    Returns:
296
    - bool: True if the package is installed, False otherwise.
297
    """
298
    unitxt_pkg = importlib.util.find_spec(package_name)
1✔
299
    return unitxt_pkg is not None
1✔
300

301

302
def is_module_available(module_name):
1✔
303
    """Check if a module is available in the current Python environment.
304

305
    Parameters:
306
    - module_name (str): The name of the module to check.
307

308
    Returns:
309
    - bool: True if the module is available, False otherwise.
310
    """
311
    try:
1✔
312
        __import__(module_name)
1✔
313
        return True
1✔
314
    except ImportError:
1✔
315
        return False
1✔
316

317

318
def remove_numerics_and_quoted_texts(input_str):
1✔
319
    # Remove floats first to avoid leaving stray periods
320
    input_str = re.sub(r"\d+\.\d+", "", input_str)
1✔
321

322
    # Remove integers
323
    input_str = re.sub(r"\d+", "", input_str)
1✔
324

325
    # Remove strings in single quotes
326
    input_str = re.sub(r"'.*?'", "", input_str)
1✔
327

328
    # Remove strings in double quotes
329
    input_str = re.sub(r'".*?"', "", input_str)
1✔
330

331
    # Remove strings in triple quotes
332
    return re.sub(r'""".*?"""', "", input_str, flags=re.DOTALL)
1✔
333

334

335
def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
1✔
336
    """Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
337

338
    Args:
339
        expression (str): The expression to evaluate.
340
        context (dict): A dictionary mapping variable names to their values, which
341
                        can be used in the expression.
342
        allowed_tokens (list): A list of strings representing allowed tokens (such as
343
                               operators, function names, etc.) that can be used in the expression.
344

345
    Returns:
346
        any: The result of evaluating the expression.
347

348
    Raises:
349
        ValueError: If the expression contains tokens not in the allowed list or context keys.
350

351
    Note:
352
        This function should be used carefully, as it employs `eval`, which can
353
        execute arbitrary code. The function attempts to mitigate security risks
354
        by restricting the available tokens and not exposing built-in functions.
355
    """
356
    allowed_sub_strings = list(context.keys()) + allowed_tokens
1✔
357
    if is_made_of_sub_strings(
1✔
358
        remove_numerics_and_quoted_texts(expression), allowed_sub_strings
359
    ):
360
        return eval(expression, {"__builtins__": {}}, context)
1✔
361
    raise ValueError(
362
        f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
363
    )
364

365

366
def import_module_from_file(file_path):
1✔
367
    # Get the module name (file name without extension)
368
    module_name = os.path.splitext(os.path.basename(file_path))[0]
×
369
    # Create a module specification
370
    spec = importlib.util.spec_from_file_location(module_name, file_path)
×
371
    # Create a new module based on the specification
372
    module = importlib.util.module_from_spec(spec)
×
373
    # Load the module
374
    spec.loader.exec_module(module)
×
375
    return module
×
376

377

378
def deep_copy(obj):
1✔
379
    """Creates a deep copy of the given object.
380

381
    Args:
382
        obj: The object to be deep copied.
383

384
    Returns:
385
        A deep copy of the original object.
386
    """
387
    return copy.deepcopy(obj)
1✔
388

389

390
def shallow_copy(obj):
1✔
391
    """Creates a shallow copy of the given object.
392

393
    Args:
394
        obj: The object to be shallow copied.
395

396
    Returns:
397
        A shallow copy of the original object.
398
    """
399
    return copy.copy(obj)
1✔
400

401

402
def recursive_copy(obj, internal_copy=None):
1✔
403
    """Recursively copies an object with a selective copy method.
404

405
    For `list`, `dict`, and `tuple` types, it recursively copies their contents.
406
    For other types, it uses the provided `internal_copy` function if available.
407
    Objects without a `copy` method are returned as is.
408

409
    Args:
410
        obj: The object to be copied.
411
        internal_copy (callable, optional): The copy function to use for non-container objects.
412
            If `None`, objects without a `copy` method are returned as is.
413

414
    Returns:
415
        The recursively copied object.
416
    """
417
    # Handle dictionaries
418
    if isinstance(obj, dict):
1✔
419
        return type(obj)(
1✔
420
            {key: recursive_copy(value, internal_copy) for key, value in obj.items()}
421
        )
422

423
    # Handle named tuples
424
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):
1✔
425
        return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
1✔
426

427
    # Handle tuples and lists
428
    if isinstance(obj, (tuple, list)):
1✔
429
        return type(obj)(recursive_copy(item, internal_copy) for item in obj)
1✔
430

431
    if internal_copy is None:
1✔
432
        return obj
1✔
433

434
    return internal_copy(obj)
1✔
435

436

437
def recursive_deep_copy(obj):
1✔
438
    """Performs a recursive deep copy of the given object.
439

440
    This function uses `deep_copy` as the internal copy method for non-container objects.
441

442
    Args:
443
        obj: The object to be deep copied.
444

445
    Returns:
446
        A recursively deep-copied version of the original object.
447
    """
448
    return recursive_copy(obj, deep_copy)
1✔
449

450

451
def recursive_shallow_copy(obj):
1✔
452
    """Performs a recursive shallow copy of the given object.
453

454
    This function uses `shallow_copy` as the internal copy method for non-container objects.
455

456
    Args:
457
        obj: The object to be shallow copied.
458

459
    Returns:
460
        A recursively shallow-copied version of the original object.
461
    """
462
    return recursive_copy(obj, shallow_copy)
1✔
463

464

465
class LongString(str):
1✔
466
    def __new__(cls, value, *, repr_str=None):
1✔
467
        obj = super().__new__(cls, value)
×
468
        obj._repr_str = repr_str
×
469
        return obj
×
470

471
    def __repr__(self):
1✔
472
        if self._repr_str is not None:
×
473
            return self._repr_str
×
474
        return super().__repr__()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc