• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16806334066

07 Aug 2025 01:46PM UTC coverage: 80.704% (-0.4%) from 81.079%
16806334066

Pull #1845

github

web-flow
Merge 42b8ebc45 into 8381fb80e
Pull Request #1845: Allow using python functions instead of operators (e.g in pre-processing pipeline)

1610 of 2013 branches covered (79.98%)

Branch coverage included in aggregate %.

10862 of 13441 relevant lines covered (80.81%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.44
src/unitxt/utils.py
1
import copy
1✔
2
import functools
1✔
3
import importlib.util
1✔
4
import inspect
1✔
5
import json
1✔
6
import os
1✔
7
import random
1✔
8
import re
1✔
9
import time
1✔
10
import types
1✔
11
from collections import OrderedDict
1✔
12
from contextvars import ContextVar
1✔
13
from functools import wraps
1✔
14
from importlib.metadata import PackageNotFoundError
1✔
15
from importlib.metadata import version as get_installed_version
1✔
16
from typing import Any, Dict, Optional
1✔
17
from urllib.error import HTTPError as UrllibHTTPError
1✔
18

19
from packaging.requirements import Requirement
1✔
20
from packaging.version import Version
1✔
21
from requests.exceptions import ConnectionError, HTTPError
1✔
22
from requests.exceptions import Timeout as TimeoutError
1✔
23

24
from .logging_utils import get_logger
1✔
25
from .settings_utils import get_settings
1✔
26
from .text_utils import is_made_of_sub_strings
1✔
27

28
logger = get_logger()
29
settings = get_settings()
1✔
30

31

32
def retry_connection_with_exponential_backoff(
1✔
33
    max_retries=None,
34
    retry_exceptions=(
35
        ConnectionError,
36
        TimeoutError,
37
        HTTPError,
38
        FileNotFoundError,
39
        UrllibHTTPError,
40
    ),
41
    backoff_factor=1,
42
):
43
    """Decorator that implements retry with exponential backoff for network operations.
44

45
    Also handles errors that were triggered by the specified retry exceptions,
46
    whether they're direct causes or part of the exception context.
47

48
    Args:
49
        max_retries: Maximum number of retry attempts (falls back to settings if None)
50
        retry_exceptions: Tuple of exceptions that should trigger a retry
51
        backoff_factor: Base delay factor in seconds for backoff calculation
52

53
    Returns:
54
        The decorated function with retry logic
55
    """
56

57
    def decorator(func):
1✔
58
        @functools.wraps(func)
1✔
59
        def wrapper(*args, **kwargs):
1✔
60
            # Get max_retries from settings if not provided
61
            retries = (
1✔
62
                max_retries
63
                if max_retries is not None
64
                else settings.max_connection_retries
65
            )
66

67
            for attempt in range(retries):
1✔
68
                try:
1✔
69
                    return func(*args, **kwargs)
1✔
70
                except Exception as e:
71
                    # Check if this exception or any of its causes match the retry exceptions
72
                    should_retry = False
73
                    current_exc = e
74

75
                    # Check the exception chain for both __cause__ (explicit) and __context__ (implicit)
76
                    visited_exceptions = (
77
                        set()
78
                    )  # To prevent infinite loops in rare cyclic exception references
79

80
                    while (
81
                        current_exc is not None
82
                        and id(current_exc) not in visited_exceptions
83
                    ):
84
                        visited_exceptions.add(id(current_exc))
85

86
                        if isinstance(current_exc, retry_exceptions):
87
                            should_retry = True
88
                            break
89

90
                        # First check __cause__ (from "raise X from Y")
91
                        if current_exc.__cause__ is not None:
92
                            current_exc = current_exc.__cause__
93
                        # Then check __context__ (from "try: ... except: raise X")
94
                        elif current_exc.__context__ is not None:
95
                            current_exc = current_exc.__context__
96
                        else:
97
                            # No more causes in the chain
98
                            break
99

100
                    if not should_retry:
101
                        # Not a retry exception or caused by a retry exception, so re-raise
102
                        raise
103

104
                    if attempt >= retries - 1:  # Last attempt
105
                        raise  # Re-raise the last exception
106

107
                    # Calculate exponential backoff with jitter
108
                    wait_time = backoff_factor * (2**attempt) + random.uniform(0, 1)
109
                    logger.warning(
110
                        f"{func.__name__} failed (attempt {attempt+1}/{retries}). "
111
                        f"Retrying in {wait_time:.2f}s. Error: {e!s}"
112
                    )
113
                    time.sleep(wait_time)
114

115
            raise ValueError("there was a problem") from None
116

117
        return wrapper
1✔
118

119
    return decorator
1✔
120

121

122
class Singleton(type):
1✔
123
    _instances = {}
1✔
124

125
    def __call__(cls, *args, **kwargs):
1✔
126
        if cls not in cls._instances:
1✔
127
            cls._instances[cls] = super().__call__(*args, **kwargs)
1✔
128
        return cls._instances[cls]
1✔
129

130

131
class LRUCache:
1✔
132
    def __init__(self, max_size: Optional[int] = 10):
1✔
133
        self._max_size = max_size
1✔
134
        self._context_cache = ContextVar("context_lru_cache", default=None)
1✔
135

136
    def _get_cache(self):
1✔
137
        cache = self._context_cache.get()
1✔
138
        if cache is None:
1✔
139
            cache = OrderedDict()
1✔
140
            self._context_cache.set(cache)
1✔
141
        return cache
1✔
142

143
    def __setitem__(self, key, value):
1✔
144
        cache = self._get_cache()
1✔
145
        if key in cache:
1✔
146
            cache.pop(key)
×
147
        cache[key] = value
1✔
148
        if self._max_size is not None:
1✔
149
            while len(cache) > self._max_size:
1✔
150
                cache.popitem(last=False)
1✔
151

152
    def __getitem__(self, key):
1✔
153
        cache = self._get_cache()
1✔
154
        if key in cache:
1✔
155
            value = cache.pop(key)
1✔
156
            cache[key] = value
1✔
157
            return value
1✔
158
        raise KeyError(f"{key} not found in cache")
×
159

160
    def get(self, key, default=None):
1✔
161
        cache = self._get_cache()
1✔
162
        if key in cache:
1✔
163
            value = cache.pop(key)
1✔
164
            cache[key] = value
1✔
165
            return value
1✔
166
        return default
1✔
167

168
    def clear(self):
1✔
169
        """Clear all items from the cache."""
170
        cache = self._get_cache()
1✔
171
        cache.clear()
1✔
172

173
    def __contains__(self, key):
1✔
174
        return key in self._get_cache()
1✔
175

176
    def __len__(self):
1✔
177
        return len(self._get_cache())
1✔
178

179
    def __repr__(self):
1✔
180
        return f"LRUCache(max_size={self._max_size}, items={list(self._get_cache().items())})"
×
181

182

183
def lru_cache_decorator(max_size=128):
1✔
184
    def decorator(func):
1✔
185
        cache = LRUCache(max_size=max_size)
1✔
186

187
        @wraps(func)
1✔
188
        def wrapper(*args, **kwargs):
1✔
189
            key = args
1✔
190
            if kwargs:
1✔
191
                key += tuple(sorted(kwargs.items()))
1✔
192
            if key in cache:
1✔
193
                return cache[key]
1✔
194
            result = func(*args, **kwargs)
1✔
195
            cache[key] = result
1✔
196
            return result
1✔
197

198
        wrapper.cache_clear = cache.clear
1✔
199
        return wrapper
1✔
200

201
    return decorator
1✔
202

203

204
@lru_cache_decorator(max_size=None)
1✔
205
def artifacts_json_cache(artifact_path):
1✔
206
    return load_json(artifact_path)
1✔
207

208

209
def flatten_dict(
1✔
210
    d: Dict[str, Any], parent_key: str = "", sep: str = "_"
211
) -> Dict[str, Any]:
212
    items = []
1✔
213
    for k, v in d.items():
1✔
214
        new_key = parent_key + sep + k if parent_key else k
1✔
215
        if isinstance(v, dict):
1✔
216
            items.extend(flatten_dict(v, new_key, sep=sep).items())
1✔
217
        else:
218
            items.append((new_key, v))
1✔
219

220
    return dict(items)
1✔
221

222

223
def load_json(path):
1✔
224
    with open(path) as f:
1✔
225
        try:
1✔
226
            return json.load(f, object_hook=decode_function)
1✔
227
        except json.decoder.JSONDecodeError as e:
×
228
            with open(path) as f:
×
229
                file_content = "\n".join(f.readlines())
×
230
            raise RuntimeError(
×
231
                f"Failed to decode json file at '{path}' with file content:\n{file_content}"
232
            ) from e
233

234

235
def save_to_file(path, data):
1✔
236
    with open(path, "w") as f:
1✔
237
        f.write(data)
1✔
238
        f.write("\n")
1✔
239

240

241
def encode_function(obj):
1✔
242
    # Allow only plain (module-level) functions
243
    if isinstance(obj, types.FunctionType):
1✔
244
        try:
1✔
245
            return {"__function__": obj.__name__, "source": get_function_source(obj)}
1✔
246
        except Exception as e:
247
            raise TypeError(f"Failed to serialize function {obj.__name__}") from e
248
    elif isinstance(obj, types.MethodType):
×
249
        raise TypeError(
×
250
            f"Method {obj.__func__.__name__} of class {obj.__self__.__class__.__name__} is not JSON serializable"
251
        )
252
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
×
253

254

255
def json_dump(data, sort_keys=False):
1✔
256
    return json.dumps(
1✔
257
        data, indent=4, default=encode_function, ensure_ascii=False, sort_keys=sort_keys
258
    )
259

260

261
def get_function_source(func):
1✔
262
    if hasattr(func, "__exec_source__"):
1✔
263
        return func.__exec_source__
×
264
    return inspect.getsource(func)
1✔
265

266

267
def decode_function(obj):
1✔
268
    # Detect our special function marker
269
    if "__function__" in obj and "source" in obj:
1✔
270
        namespace = {}
1✔
271
        func_name = obj["__function__"]
1✔
272
        try:
1✔
273
            exec(obj["source"], namespace)
1✔
274
            func = namespace.get(func_name)
1✔
275
            func.__exec_source__ = obj["source"]
1✔
276
            if not callable(func):
1✔
277
                raise ValueError(
278
                    f"Source did not define a callable named {func_name!r}"
279
                )
280
            return func
1✔
281
        except Exception as e:
282
            raise ValueError(
283
                f"Failed to load function {func_name!r} from source:\n{obj['source']}"
284
            ) from e
285

286
    return obj
1✔
287

288

289
def json_load(s):
1✔
290
    return json.loads(s, object_hook=decode_function)
1✔
291

292

293
def is_package_installed(package_name):
1✔
294
    """Check if a package is installed.
295

296
    Parameters:
297
    - package_name (str): The name of the package to check.
298

299
    Returns:
300
    - bool: True if the package is installed, False otherwise.
301
    """
302
    unitxt_pkg = importlib.util.find_spec(package_name)
1✔
303
    return unitxt_pkg is not None
1✔
304

305

306
def is_module_available(module_name):
1✔
307
    """Check if a module is available in the current Python environment.
308

309
    Parameters:
310
    - module_name (str): The name of the module to check.
311

312
    Returns:
313
    - bool: True if the module is available, False otherwise.
314
    """
315
    try:
1✔
316
        __import__(module_name)
1✔
317
        return True
1✔
318
    except ImportError:
1✔
319
        return False
1✔
320

321

322
def remove_numerics_and_quoted_texts(input_str):
1✔
323
    # Remove floats first to avoid leaving stray periods
324
    input_str = re.sub(r"\d+\.\d+", "", input_str)
1✔
325

326
    # Remove integers
327
    input_str = re.sub(r"\d+", "", input_str)
1✔
328

329
    # Remove strings in single quotes
330
    input_str = re.sub(r"'.*?'", "", input_str)
1✔
331

332
    # Remove strings in double quotes
333
    input_str = re.sub(r'".*?"', "", input_str)
1✔
334

335
    # Remove strings in triple quotes
336
    return re.sub(r'""".*?"""', "", input_str, flags=re.DOTALL)
1✔
337

338

339
def safe_eval(expression: str, context: dict, allowed_tokens: list) -> any:
1✔
340
    """Evaluates a given expression in a restricted environment, allowing only specified tokens and context variables.
341

342
    Args:
343
        expression (str): The expression to evaluate.
344
        context (dict): A dictionary mapping variable names to their values, which
345
                        can be used in the expression.
346
        allowed_tokens (list): A list of strings representing allowed tokens (such as
347
                               operators, function names, etc.) that can be used in the expression.
348

349
    Returns:
350
        any: The result of evaluating the expression.
351

352
    Raises:
353
        ValueError: If the expression contains tokens not in the allowed list or context keys.
354

355
    Note:
356
        This function should be used carefully, as it employs `eval`, which can
357
        execute arbitrary code. The function attempts to mitigate security risks
358
        by restricting the available tokens and not exposing built-in functions.
359
    """
360
    allowed_sub_strings = list(context.keys()) + allowed_tokens
1✔
361
    if is_made_of_sub_strings(
1✔
362
        remove_numerics_and_quoted_texts(expression), allowed_sub_strings
363
    ):
364
        return eval(expression, {"__builtins__": {}}, context)
1✔
365
    raise ValueError(
366
        f"The expression '{expression}' can not be evaluated because it contains tokens outside the allowed list of {allowed_sub_strings}."
367
    )
368

369

370
def import_module_from_file(file_path):
1✔
371
    # Get the module name (file name without extension)
372
    module_name = os.path.splitext(os.path.basename(file_path))[0]
×
373
    # Create a module specification
374
    spec = importlib.util.spec_from_file_location(module_name, file_path)
×
375
    # Create a new module based on the specification
376
    module = importlib.util.module_from_spec(spec)
×
377
    # Load the module
378
    spec.loader.exec_module(module)
×
379
    return module
×
380

381

382
def deep_copy(obj):
1✔
383
    """Creates a deep copy of the given object.
384

385
    Args:
386
        obj: The object to be deep copied.
387

388
    Returns:
389
        A deep copy of the original object.
390
    """
391
    return copy.deepcopy(obj)
1✔
392

393

394
def shallow_copy(obj):
1✔
395
    """Creates a shallow copy of the given object.
396

397
    Args:
398
        obj: The object to be shallow copied.
399

400
    Returns:
401
        A shallow copy of the original object.
402
    """
403
    return copy.copy(obj)
1✔
404

405

406
def recursive_copy(obj, internal_copy=None):
1✔
407
    """Recursively copies an object with a selective copy method.
408

409
    For `list`, `dict`, and `tuple` types, it recursively copies their contents.
410
    For other types, it uses the provided `internal_copy` function if available.
411
    Objects without a `copy` method are returned as is.
412

413
    Args:
414
        obj: The object to be copied.
415
        internal_copy (callable, optional): The copy function to use for non-container objects.
416
            If `None`, objects without a `copy` method are returned as is.
417

418
    Returns:
419
        The recursively copied object.
420
    """
421
    # Handle dictionaries
422
    if isinstance(obj, dict):
1✔
423
        return type(obj)(
1✔
424
            {key: recursive_copy(value, internal_copy) for key, value in obj.items()}
425
        )
426

427
    # Handle named tuples
428
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):
1✔
429
        return type(obj)(*(recursive_copy(item, internal_copy) for item in obj))
1✔
430

431
    # Handle tuples and lists
432
    if isinstance(obj, (tuple, list)):
1✔
433
        return type(obj)(recursive_copy(item, internal_copy) for item in obj)
1✔
434

435
    if internal_copy is None:
1✔
436
        return obj
1✔
437

438
    return internal_copy(obj)
1✔
439

440

441
def recursive_deep_copy(obj):
1✔
442
    """Performs a recursive deep copy of the given object.
443

444
    This function uses `deep_copy` as the internal copy method for non-container objects.
445

446
    Args:
447
        obj: The object to be deep copied.
448

449
    Returns:
450
        A recursively deep-copied version of the original object.
451
    """
452
    return recursive_copy(obj, deep_copy)
1✔
453

454

455
def recursive_shallow_copy(obj):
1✔
456
    """Performs a recursive shallow copy of the given object.
457

458
    This function uses `shallow_copy` as the internal copy method for non-container objects.
459

460
    Args:
461
        obj: The object to be shallow copied.
462

463
    Returns:
464
        A recursively shallow-copied version of the original object.
465
    """
466
    return recursive_copy(obj, shallow_copy)
1✔
467

468

469
class LongString(str):
1✔
470
    def __new__(cls, value, *, repr_str=None):
1✔
471
        obj = super().__new__(cls, value)
×
472
        obj._repr_str = repr_str
×
473
        return obj
×
474

475
    def __repr__(self):
1✔
476
        if self._repr_str is not None:
×
477
            return self._repr_str
×
478
        return super().__repr__()
×
479

480

481
class DistributionNotFound(Exception):
1✔
482
    def __init__(self, requirement):
1✔
483
        self.requirement = requirement
1✔
484
        super().__init__(f"Distribution not found for requirement: {requirement}")
1✔
485

486

487
class VersionConflict(Exception):
1✔
488
    def __init__(self, dist, req):
1✔
489
        self.dist = dist  # Distribution object, just emulate enough for your needs
1✔
490
        self.req = req
1✔
491
        super().__init__(f"Version conflict: {dist} does not satisfy {req}")
1✔
492

493

494
class DistStub:
1✔
495
    # Minimal stub to mimic pkg_resources.Distribution
496
    def __init__(self, project_name, version):
1✔
497
        self.project_name = project_name
1✔
498
        self.version = version
1✔
499

500

501
def require(requirements):
1✔
502
    """Minimal drop-in replacement for pkg_resources.require.
503

504
    Accepts a single requirement string or a list of them.
505
    Raises DistributionNotFound or VersionConflict.
506
    Returns nothing (side-effect only).
507
    """
508
    if isinstance(requirements, str):
1✔
509
        requirements = [requirements]
1✔
510
    for req_str in requirements:
1✔
511
        req = Requirement(req_str)
1✔
512
        if req.marker and not req.marker.evaluate():
1✔
513
            continue  # skip not needed for this environment
1✔
514
        name = req.name
1✔
515
        try:
1✔
516
            ver = get_installed_version(name)
1✔
517
        except PackageNotFoundError as e:
1✔
518
            raise DistributionNotFound(req_str) from e
1✔
519
        if req.specifier and not req.specifier.contains(Version(ver), prereleases=True):
1✔
520
            dist = DistStub(name, ver)
1✔
521
            raise VersionConflict(dist, req_str)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc