• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13718432299

07 Mar 2025 10:10AM UTC coverage: 89.936% (-0.01%) from 89.949%
13718432299

push

github

web-flow
fix: Fix type serialization and deserialization (#8993)

* Expand tests

* New version of type serialization

* Adding more tests

* More tests

* Fix type serialization when using python 3.9

* Deserialization works with Optional now and we don't require 'typing.' to be present anymore

* Don't worry about Literal

* Add reno

* Fix mypy

* Pylint

* Add additional test

* Simplify

* Add back comment

* Fix types

* Fix

9687 of 10771 relevant lines covered (89.94%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.08
haystack/utils/type_serialization.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import builtins
1✔
6
import importlib
1✔
7
import inspect
1✔
8
import sys
1✔
9
import typing
1✔
10
from threading import Lock
1✔
11
from types import ModuleType
1✔
12
from typing import Any, get_args
1✔
13

14
from haystack import DeserializationError
1✔
15

16
_import_lock = Lock()
1✔
17

18

19
def serialize_type(target: Any) -> str:
1✔
20
    """
21
    Serializes a type or an instance to its string representation, including the module name.
22

23
    This function handles types, instances of types, and special typing objects.
24
    It assumes that non-typing objects will have a '__name__' attribute.
25

26
    :param target:
27
        The object to serialize, can be an instance or a type.
28
    :return:
29
        The string representation of the type.
30
    """
31
    name = getattr(target, "__name__", str(target))
1✔
32

33
    # Remove the 'typing.' prefix when using python <3.9
34
    if name.startswith("typing."):
1✔
35
        name = name[7:]
1✔
36
    # Remove the arguments from the name when using python <3.9
37
    if "[" in name:
1✔
38
        name = name.split("[")[0]
1✔
39

40
    # Get module name
41
    module = inspect.getmodule(target)
1✔
42
    module_name = ""
1✔
43
    # We omit the module name for builtins to not clutter the output
44
    if module and hasattr(module, "__name__") and module.__name__ != "builtins":
1✔
45
        module_name = f"{module.__name__}"
1✔
46

47
    args = get_args(target)
1✔
48
    if args:
1✔
49
        args_str = ", ".join([serialize_type(a) for a in args if a is not type(None)])
1✔
50
        return f"{module_name}.{name}[{args_str}]" if module_name else f"{name}[{args_str}]"
1✔
51

52
    return f"{module_name}.{name}" if module_name else f"{name}"
1✔
53

54

55
def _parse_generic_args(args_str):
1✔
56
    args = []
1✔
57
    bracket_count = 0
1✔
58
    current_arg = ""
1✔
59

60
    for char in args_str:
1✔
61
        if char == "[":
1✔
62
            bracket_count += 1
1✔
63
        elif char == "]":
1✔
64
            bracket_count -= 1
1✔
65

66
        if char == "," and bracket_count == 0:
1✔
67
            args.append(current_arg.strip())
1✔
68
            current_arg = ""
1✔
69
        else:
70
            current_arg += char
1✔
71

72
    if current_arg:
1✔
73
        args.append(current_arg.strip())
1✔
74

75
    return args
1✔
76

77

78
def deserialize_type(type_str: str) -> Any:  # pylint: disable=too-many-return-statements
1✔
79
    """
80
    Deserializes a type given its full import path as a string, including nested generic types.
81

82
    This function will dynamically import the module if it's not already imported
83
    and then retrieve the type object from it. It also handles nested generic types like
84
    `typing.List[typing.Dict[int, str]]`.
85

86
    :param type_str:
87
        The string representation of the type's full import path.
88
    :returns:
89
        The deserialized type object.
90
    :raises DeserializationError:
91
        If the type cannot be deserialized due to missing module or type.
92
    """
93

94
    type_mapping = {
1✔
95
        list: typing.List,
96
        dict: typing.Dict,
97
        set: typing.Set,
98
        tuple: typing.Tuple,
99
        frozenset: typing.FrozenSet,
100
    }
101

102
    # Handle generics
103
    if "[" in type_str and type_str.endswith("]"):
1✔
104
        main_type_str, generics_str = type_str.split("[", 1)
1✔
105
        generics_str = generics_str[:-1]
1✔
106

107
        main_type = deserialize_type(main_type_str)
1✔
108
        generic_args = [deserialize_type(arg) for arg in _parse_generic_args(generics_str)]
1✔
109

110
        # Reconstruct
111
        try:
1✔
112
            if sys.version_info >= (3, 9) or repr(main_type).startswith("typing."):
1✔
113
                return main_type[tuple(generic_args) if len(generic_args) > 1 else generic_args[0]]
1✔
114
            else:
115
                return type_mapping[main_type][tuple(generic_args) if len(generic_args) > 1 else generic_args[0]]
×
116
        except (TypeError, AttributeError) as e:
×
117
            raise DeserializationError(f"Could not apply arguments {generic_args} to type {main_type}") from e
×
118

119
    # Handle non-generic types
120
    # First, check if there's a module prefix
121
    if "." in type_str:
1✔
122
        parts = type_str.split(".")
1✔
123
        module_name = ".".join(parts[:-1])
1✔
124
        type_name = parts[-1]
1✔
125

126
        module = sys.modules.get(module_name)
1✔
127
        if module is None:
1✔
128
            try:
×
129
                module = thread_safe_import(module_name)
×
130
            except ImportError as e:
×
131
                raise DeserializationError(f"Could not import the module: {module_name}") from e
×
132

133
        # Get the class from the module
134
        if hasattr(module, type_name):
1✔
135
            return getattr(module, type_name)
1✔
136

137
        raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")
×
138

139
    # No module prefix, check builtins and typing
140
    # First check builtins
141
    if hasattr(builtins, type_str):
1✔
142
        return getattr(builtins, type_str)
1✔
143

144
    # Then check typing
145
    if hasattr(typing, type_str):
1✔
146
        return getattr(typing, type_str)
1✔
147

148
    # Special case for NoneType
149
    if type_str == "NoneType":
1✔
150
        return type(None)
1✔
151

152
    # Special case for None
153
    if type_str == "None":
×
154
        return None
×
155

156
    raise DeserializationError(f"Could not deserialize type: {type_str}")
×
157

158

159
def thread_safe_import(module_name: str) -> ModuleType:
1✔
160
    """
161
    Import a module in a thread-safe manner.
162

163
    Importing modules in a multi-threaded environment can lead to race conditions.
164
    This function ensures that the module is imported in a thread-safe manner without having impact
165
    on the performance of the import for single-threaded environments.
166

167
    :param module_name: the module to import
168
    """
169
    with _import_lock:
1✔
170
        return importlib.import_module(module_name)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc