• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16470706949

23 Jul 2025 12:31PM UTC coverage: 81.122% (-0.1%) from 81.222%
16470706949

Pull #1861

github

web-flow
Merge c48d10af5 into 83063f920
Pull Request #1861: Fix compatibility with datasets 4.0

1585 of 1965 branches covered (80.66%)

Branch coverage included in aggregate %.

10735 of 13222 relevant lines covered (81.19%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.4
src/unitxt/utils.py
1
import copy
1✔
2
import functools
1✔
3
import importlib.util
1✔
4
import json
1✔
5
import os
1✔
6
import random
1✔
7
import re
1✔
8
import time
1✔
9
from collections import OrderedDict
1✔
10
from contextvars import ContextVar
1✔
11
from functools import wraps
1✔
12
from importlib.metadata import PackageNotFoundError
1✔
13
from importlib.metadata import version as get_installed_version
1✔
14
from typing import Any, Dict, Optional
1✔
15
from urllib.error import HTTPError as UrllibHTTPError
1✔
16

17
from packaging.requirements import Requirement
1✔
18
from packaging.version import Version
1✔
19
from requests.exceptions import ConnectionError, HTTPError
1✔
20
from requests.exceptions import Timeout as TimeoutError
1✔
21

22
from .logging_utils import get_logger
1✔
23
from .settings_utils import get_settings
1✔
24
from .text_utils import is_made_of_sub_strings
1✔
25

26
logger = get_logger()
27
settings = get_settings()
1✔
28

29

30
def retry_connection_with_exponential_backoff(
1✔
31
    max_retries=None,
32
    retry_exceptions=(
33
        ConnectionError,
34
        TimeoutError,
35
        HTTPError,
36
        FileNotFoundError,
37
        UrllibHTTPError,
38
    ),
39
    backoff_factor=1,
40
):
41
    """Decorator that implements retry with exponential backoff for network operations.
42

43
    Also handles errors that were triggered by the specified retry exceptions,
44
    whether they're direct causes or part of the exception context.
45

46
    Args:
47
        max_retries: Maximum number of retry attempts (falls back to settings if None)
48
        retry_exceptions: Tuple of exceptions that should trigger a retry
49
        backoff_factor: Base delay factor in seconds for backoff calculation
50

51
    Returns:
52
        The decorated function with retry logic
53
    """
54

55
    def decorator(func):
1✔
56
        @functools.wraps(func)
1✔
57
        def wrapper(*args, **kwargs):
1✔
58
            # Get max_retries from settings if not provided
59
            retries = (
1✔
60
                max_retries
61
                if max_retries is not None
62
                else settings.max_connection_retries
63
            )
64

65
            for attempt in range(retries):
1✔
66
                try:
1✔
67
                    return func(*args, **kwargs)
1✔
68
                except Exception as e:
69
                    # Check if this exception or any of its causes match the retry exceptions
70
                    should_retry = False
71
                    current_exc = e
72

73
                    # Check the exception chain for both __cause__ (explicit) and __context__ (implicit)
74
                    visited_exceptions = (
75
                        set()
76
                    )  # To prevent infinite loops in rare cyclic exception references
77

78
                    while (
79
                        current_exc is not None
80
                        and id(current_exc) not in visited_exceptions
81
                    ):
82
                        visited_exceptions.add(id(current_exc))
83

84
                        if isinstance(current_exc, retry_exceptions):
85
                            should_retry = True
86
                            break
87

88
                        # First check __cause__ (from "raise X from Y")
89
                        if current_exc.__cause__ is not None:
90
                            current_exc = current_exc.__cause__
91
                        # Then check __context__ (from "try: ... except: raise X")
92
                        elif current_exc.__context__ is not None:
93
                            current_exc = current_exc.__context__
94
                        else:
95
                            # No more causes in the chain
96
                            break
97

98
                    if not should_retry:
99
                        # Not a retry exception or caused by a retry exception, so re-raise
100
                        raise
101

102
                    if attempt >= retries - 1:  # Last attempt
103
                        raise  # Re-raise the last exception
104

105
                    # Calculate exponential backoff with jitter
106
                    wait_time = backoff_factor * (2**attempt) + random.uniform(0, 1)
107
                    logger.warning(
108
                        f"{func.__name__} failed (attempt {attempt+1}/{retries}). "
109
                        f"Retrying in {wait_time:.2f}s. Error: {e!s}"
110
                    )
111
                    time.sleep(wait_time)
112

113
            raise ValueError("there was a problem") from None
114

115
        return wrapper
1✔
116

117
    return decorator
1✔
118

119

120
class Singleton(type):
1✔
121
    _instances = {}
1✔
122

123
    def __call__(cls, *args, **kwargs):
1✔
124
        if cls not in cls._instances:
1✔
125
            cls._instances[cls] = super().__call__(*args, **kwargs)
1✔
126
        return cls._instances[cls]
1✔
127

128

129
class LRUCache:
1✔
130
    def __init__(self, max_size: Optional[int] = 10):
1✔
131
        self._max_size = max_size
1✔
132
        self._context_cache = ContextVar("context_lru_cache", default=None)
1✔
133

134
    def _get_cache(self):
1✔
135
        cache = self._context_cache.get()
1✔
136
        if cache is None:
1✔
137
            cache = OrderedDict()
1✔
138
            self._context_cache.set(cache)
1✔
139
        return cache
1✔
140

141
    def __setitem__(self, key, value):
1✔
142
        cache = self._get_cache()
1✔
143
        if key in cache:
1✔
144
            cache.pop(key)
×
145
        cache[key] = value
1✔
146
        if self._max_size is not None:
1✔
147
            while len(cache) > self._max_size:
1✔
148
                cache.popitem(last=False)
1✔
149

150
    def __getitem__(self, key):
1✔
151
        cache = self._get_cache()
1✔
152
        if key in cache:
1✔
153
            value = cache.pop(key)
1✔
154
            cache[key] = value
1✔
155
            return value
1✔
156
        raise KeyError(f"{key} not found in cache")
×
157

158
    def get(self, key, default=None):
1✔
159
        cache = self._get_cache()
1✔
160
        if key in cache:
1✔
161
            value = cache.pop(key)
1✔
162
            cache[key] = value
1✔
163
            return value
1✔
164
        return default
1✔
165

166
    def clear(self):
1✔
167
        """Clear all items from the cache."""
168
        cache = self._get_cache()
1✔
169
        cache.clear()
1✔
170

171
    def __contains__(self, key):
1✔
172
        return key in self._get_cache()
1✔
173

174
    def __len__(self):
1✔
175
        return len(self._get_cache())
1✔
176

177
    def __repr__(self):
1✔
178
        return f"LRUCache(max_size={self._max_size}, items={list(self._get_cache().items())})"
×
179

180

181
def lru_cache_decorator(max_size=128):
1✔
182
    def decorator(func):
1✔
183
        cache = LRUCache(max_size=max_size)
1✔
184

185
        @wraps(func)
1✔
186
        def wrapper(*args, **kwargs):
1✔
187
            key = args
1✔
188
            if kwargs:
1✔
189
                key += tuple(sorted(kwargs.items()))
1✔
190
            if key in cache:
1✔
191
                return cache[key]
1✔
192
            result = func(*args, **kwargs)
1✔
193
            cache[key] = result
1✔
194
            return result
1✔
195

196
        wrapper.cache_clear = cache.clear
1✔
197
        return wrapper
1✔
198

199
    return decorator
1✔
200

201

202
@lru_cache_decorator(max_size=None)
1✔
203
def artifacts_json_cache(artifact_path):
1✔
204
    return load_json(artifact_path)
1✔
205

206

207
def flatten_dict(
1✔
208
    d: Dict[str, Any], parent_key: str = "", sep: str = "_"
209
) -> Dict[str, Any]:
210
    items = []
1✔
211
    for k, v in d.items():
1✔
212
        new_key = parent_key + sep + k if parent_key else k
1✔
213
        if isinstance(v, dict):
1✔
214
            items.extend(flatten_dict(v, new_key, sep=sep).items())
1✔
215
        else:
216
            items.append((new_key, v))
1✔
217

218
    return dict(items)
1✔
219

220

221
def load_json(path):
1✔
222
    with open(path) as f:
1✔
223
        try:
1✔
224
            return json.load(f)
1✔
225
        except json.decoder.JSONDecodeError as e:
×
226
            with open(path) as f:
×
227
                file_content = "\n".join(f.readlines())
×
228
            raise RuntimeError(
×
229
                f"Failed to decode json file at '{path}' with file content:\n{file_content}"
230
            ) from e
231

232

233
def save_to_file(path, data):
1✔
234
    with open(path, "w") as f:
1✔
235
        f.write(data)
1✔
236
        f.write("\n")
1✔
237

238

239
def json_dump(data):
1✔
240
    return json.dumps(data, indent=4, ensure_ascii=False)
1✔
241

242

243
def is_package_installed(package_name):
1✔
244
    """Check if a package is installed.
245

246
    Parameters:
247
    - package_name (str): The name of the package to check.
248

249
    Returns:
250
    - bool: True if the package is installed, False otherwise.
251
    """
252
    unitxt_pkg = importlib.util.find_spec(package_name)
1✔
253
    return unitxt_pkg is not None
1✔
254

255

256
def is_module_available(module_name):
1✔
257
    """Check if a module is available in the current Python environment.
258

259
    Parameters:
260
    - module_name (str): The name of the module to check.
261

262
    Returns:
263
    - bool: True if the module is available, False otherwise.
264
    """
265
    try:
1✔
266
        __import__(module_name)
1✔
267
        return True
1✔
268
    except ImportError:
1✔
269
        return False
1✔
270

271

272
def remove_numerics_and_quoted_texts(input_str):
1✔
273
    # Remove floats first to avoid leaving stray periods
274
    input_str = re.sub(r"\d+\.\d+", "", input_str)
1✔
275

276
    # Remove integers
277
    input_str = re.sub(r"\d+", "", input_str)
1✔
278

279
    # Remove strings in single quotes
280
    input_str = re.sub(r"'.*?'", "", input_str)
1✔
281

282
    # Remove strings in double quotes
283
    input_str = re.sub(r'".*?"', "", input_str)
1✔
284

285
    # Remove strings in triple quotes
286
    return re.sub(r'""".*?"""', "", input_str, flags=re.DOTALL)
1✔
287

288

289
def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
1✔
290
    """Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
291

292
    Args:
293
        expression (str): The expression to evaluate.
294
        context (dict): A dictionary mapping variable names to their values, which
295
                        can be used in the expression.
296
        allowed_tokens (list): A list of strings representing allowed tokens (such as
297
                               operators, function names, etc.) that can be used in the expression.
298

299
    Returns:
300
        any: The result of evaluating the expression.
301

302
    Raises:
303
        ValueError: If the expression contains tokens not in the allowed list or context keys.
304

305
    Note:
306
        This function should be used carefully, as it employs `eval`, which can
307
        execute arbitrary code. The function attempts to mitigate security risks
308
        by restricting the available tokens and not exposing built-in functions.
309
    """
310
    allowed_sub_strings = list(context.keys()) + allowed_tokens
1✔
311
    if is_made_of_sub_strings(
1✔
312
        remove_numerics_and_quoted_texts(expression), allowed_sub_strings
313
    ):
314
        return eval(expression, {"__builtins__": {}}, context)
1✔
315
    raise ValueError(
316
        f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
317
    )
318

319

320
def import_module_from_file(file_path):
1✔
321
    # Get the module name (file name without extension)
322
    module_name = os.path.splitext(os.path.basename(file_path))[0]
×
323
    # Create a module specification
324
    spec = importlib.util.spec_from_file_location(module_name, file_path)
×
325
    # Create a new module based on the specification
326
    module = importlib.util.module_from_spec(spec)
×
327
    # Load the module
328
    spec.loader.exec_module(module)
×
329
    return module
×
330

331

332
def deep_copy(obj):
1✔
333
    """Creates a deep copy of the given object.
334

335
    Args:
336
        obj: The object to be deep copied.
337

338
    Returns:
339
        A deep copy of the original object.
340
    """
341
    return copy.deepcopy(obj)
1✔
342

343

344
def shallow_copy(obj):
1✔
345
    """Creates a shallow copy of the given object.
346

347
    Args:
348
        obj: The object to be shallow copied.
349

350
    Returns:
351
        A shallow copy of the original object.
352
    """
353
    return copy.copy(obj)
1✔
354

355

356
def recursive_copy(obj, internal_copy=None):
1✔
357
    """Recursively copies an object with a selective copy method.
358

359
    For `list`, `dict`, and `tuple` types, it recursively copies their contents.
360
    For other types, it uses the provided `internal_copy` function if available.
361
    Objects without a `copy` method are returned as is.
362

363
    Args:
364
        obj: The object to be copied.
365
        internal_copy (callable, optional): The copy function to use for non-container objects.
366
            If `None`, objects without a `copy` method are returned as is.
367

368
    Returns:
369
        The recursively copied object.
370
    """
371
    # Handle dictionaries
372
    if isinstance(obj, dict):
1✔
373
        return type(obj)(
1✔
374
            {key: recursive_copy(value, internal_copy) for key, value in obj.items()}
375
        )
376

377
    # Handle named tuples
378
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):
1✔
379
        return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
1✔
380

381
    # Handle tuples and lists
382
    if isinstance(obj, (tuple, list)):
1✔
383
        return type(obj)(recursive_copy(item, internal_copy) for item in obj)
1✔
384

385
    if internal_copy is None:
1✔
386
        return obj
1✔
387

388
    return internal_copy(obj)
1✔
389

390

391
def recursive_deep_copy(obj):
1✔
392
    """Performs a recursive deep copy of the given object.
393

394
    This function uses `deep_copy` as the internal copy method for non-container objects.
395

396
    Args:
397
        obj: The object to be deep copied.
398

399
    Returns:
400
        A recursively deep-copied version of the original object.
401
    """
402
    return recursive_copy(obj, deep_copy)
1✔
403

404

405
def recursive_shallow_copy(obj):
1✔
406
    """Performs a recursive shallow copy of the given object.
407

408
    This function uses `shallow_copy` as the internal copy method for non-container objects.
409

410
    Args:
411
        obj: The object to be shallow copied.
412

413
    Returns:
414
        A recursively shallow-copied version of the original object.
415
    """
416
    return recursive_copy(obj, shallow_copy)
1✔
417

418

419
class LongString(str):
1✔
420
    def __new__(cls, value, *, repr_str=None):
1✔
421
        obj = super().__new__(cls, value)
×
422
        obj._repr_str = repr_str
×
423
        return obj
×
424

425
    def __repr__(self):
1✔
426
        if self._repr_str is not None:
×
427
            return self._repr_str
×
428
        return super().__repr__()
×
429

430

431
class DistributionNotFound(Exception):
1✔
432
    def __init__(self, requirement):
1✔
433
        self.requirement = requirement
1✔
434
        super().__init__(f"Distribution not found for requirement: {requirement}")
1✔
435

436

437
class VersionConflict(Exception):
1✔
438
    def __init__(self, dist, req):
1✔
439
        self.dist = dist  # Distribution object, just emulate enough for your needs
1✔
440
        self.req = req
1✔
441
        super().__init__(f"Version conflict: {dist} does not satisfy {req}")
1✔
442

443

444
class DistStub:
1✔
445
    # Minimal stub to mimic pkg_resources.Distribution
446
    def __init__(self, project_name, version):
1✔
447
        self.project_name = project_name
1✔
448
        self.version = version
1✔
449

450

451
def require(requirements):
1✔
452
    """Minimal drop-in replacement for pkg_resources.require.
453

454
    Accepts a single requirement string or a list of them.
455
    Raises DistributionNotFound or VersionConflict.
456
    Returns nothing (side-effect only).
457
    """
458
    if isinstance(requirements, str):
1✔
459
        requirements = [requirements]
1✔
460
    for req_str in requirements:
1✔
461
        req = Requirement(req_str)
1✔
462
        if req.marker and not req.marker.evaluate():
1✔
463
            continue  # skip not needed for this environment
1✔
464
        name = req.name
1✔
465
        try:
1✔
466
            ver = get_installed_version(name)
1✔
467
        except PackageNotFoundError as e:
1✔
468
            raise DistributionNotFound(req_str) from e
1✔
469
        if req.specifier and not req.specifier.contains(Version(ver), prereleases=True):
1✔
470
            dist = DistStub(name, ver)
1✔
471
            raise VersionConflict(dist, req_str)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc