• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12373296669

17 Dec 2024 12:45PM UTC coverage: 90.485% (+0.005%) from 90.48%
12373296669

Pull #8651

github

web-flow
Merge 297c12748 into a5b57f4b1
Pull Request #8651: fix: fix deserialization issues in multi-threading environments

8112 of 8965 relevant lines covered (90.49%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.04
haystack/utils/type_serialization.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import importlib
1✔
6
import inspect
1✔
7
import sys
1✔
8
import typing
1✔
9
from threading import Lock
1✔
10
from types import ModuleType
1✔
11
from typing import Any, get_args, get_origin
1✔
12

13
from haystack import DeserializationError
1✔
14

15
_import_lock = Lock()
1✔
16

17

18
def serialize_type(target: Any) -> str:
1✔
19
    """
20
    Serializes a type or an instance to its string representation, including the module name.
21

22
    This function handles types, instances of types, and special typing objects.
23
    It assumes that non-typing objects will have a '__name__' attribute and raises
24
    an error if a type cannot be serialized.
25

26
    :param target:
27
        The object to serialize, can be an instance or a type.
28
    :return:
29
        The string representation of the type.
30
    :raises ValueError:
31
        If the type cannot be serialized.
32
    """
33
    # If the target is a string and contains a dot, treat it as an already serialized type
34
    if isinstance(target, str) and "." in target:
1✔
35
        return target
1✔
36

37
    # Determine if the target is a type or an instance of a typing object
38
    is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
1✔
39
    type_obj = target if is_type_or_typing else type(target)
1✔
40
    type_obj_repr = repr(type_obj)
1✔
41

42
    if type_obj_repr.startswith("typing."):
1✔
43
        # e.g., typing.List[int] -> List[int], we'll add the module below
44
        type_name = type_obj_repr.split(".", 1)[1]
1✔
45
    elif origin := get_origin(type_obj):  # get the origin (base type of the parameterized generic type)
1✔
46
        # get the arguments of the generic type
47
        args = get_args(type_obj)
1✔
48
        args_repr = ", ".join(serialize_type(arg) for arg in args)
1✔
49
        type_name = f"{origin.__name__}[{args_repr}]"
1✔
50
    elif hasattr(type_obj, "__name__"):
1✔
51
        type_name = type_obj.__name__
1✔
52
    else:
53
        # If type cannot be serialized, raise an error
54
        raise ValueError(f"Could not serialize type: {type_obj_repr}")
×
55

56
    module = inspect.getmodule(type_obj)
1✔
57
    if module and hasattr(module, "__name__"):
1✔
58
        if module.__name__ == "builtins":
1✔
59
            # omit the module name for builtins, it just clutters the output
60
            # e.g. instead of 'builtins.str', we'll just return 'str'
61
            full_path = type_name
1✔
62
        else:
63
            full_path = f"{module.__name__}.{type_name}"
1✔
64
    else:
65
        full_path = type_name
×
66

67
    return full_path
1✔
68

69

70
def deserialize_type(type_str: str) -> Any:
1✔
71
    """
72
    Deserializes a type given its full import path as a string, including nested generic types.
73

74
    This function will dynamically import the module if it's not already imported
75
    and then retrieve the type object from it. It also handles nested generic types like
76
    `typing.List[typing.Dict[int, str]]`.
77

78
    :param type_str:
79
        The string representation of the type's full import path.
80
    :returns:
81
        The deserialized type object.
82
    :raises DeserializationError:
83
        If the type cannot be deserialized due to missing module or type.
84
    """
85

86
    type_mapping = {
1✔
87
        list: typing.List,
88
        dict: typing.Dict,
89
        set: typing.Set,
90
        tuple: typing.Tuple,
91
        frozenset: typing.FrozenSet,
92
    }
93

94
    def parse_generic_args(args_str):
1✔
95
        args = []
1✔
96
        bracket_count = 0
1✔
97
        current_arg = ""
1✔
98

99
        for char in args_str:
1✔
100
            if char == "[":
1✔
101
                bracket_count += 1
1✔
102
            elif char == "]":
1✔
103
                bracket_count -= 1
1✔
104

105
            if char == "," and bracket_count == 0:
1✔
106
                args.append(current_arg.strip())
1✔
107
                current_arg = ""
1✔
108
            else:
109
                current_arg += char
1✔
110

111
        if current_arg:
1✔
112
            args.append(current_arg.strip())
1✔
113

114
        return args
1✔
115

116
    if "[" in type_str and type_str.endswith("]"):
1✔
117
        # Handle generics
118
        main_type_str, generics_str = type_str.split("[", 1)
1✔
119
        generics_str = generics_str[:-1]
1✔
120

121
        main_type = deserialize_type(main_type_str)
1✔
122
        generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
1✔
123

124
        # Reconstruct
125
        if sys.version_info >= (3, 9) or repr(main_type).startswith("typing."):
1✔
126
            return main_type[generic_args]
1✔
127
        else:
128
            return type_mapping[main_type][generic_args]  # type: ignore
×
129

130
    else:
131
        # Handle non-generics
132
        parts = type_str.split(".")
1✔
133
        module_name = ".".join(parts[:-1]) or "builtins"
1✔
134
        type_name = parts[-1]
1✔
135

136
        module = sys.modules.get(module_name)
1✔
137
        if not module:
1✔
138
            try:
×
139
                module = thread_safe_import(module_name)
×
140
            except ImportError as e:
×
141
                raise DeserializationError(f"Could not import the module: {module_name}") from e
×
142

143
        deserialized_type = getattr(module, type_name, None)
1✔
144
        if not deserialized_type:
1✔
145
            raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
×
146

147
        return deserialized_type
1✔
148

149

150
def thread_safe_import(module_name: str) -> ModuleType:
1✔
151
    """
152
    Import a module in a thread-safe manner.
153

154
    Importing modules in a multi-threaded environment can lead to race conditions.
155
    This function ensures that the module is imported in a thread-safe manner without having impact
156
    on the performance of the import for single-threaded environments.
157

158
    :param module_name: the module to import
159
    """
160
    with _import_lock:
1✔
161
        return importlib.import_module(module_name)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc