• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 11463116725

22 Oct 2024 03:08PM UTC coverage: 90.471% (+0.003%) from 90.468%
11463116725

push

github

web-flow
fix: Enforce basic Python types restriction on serialized component data (#8473)

7548 of 8343 relevant lines covered (90.47%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.34
haystack/core/serialization.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
from collections.abc import Callable
1✔
7
from dataclasses import dataclass
1✔
8
from importlib import import_module
1✔
9
from typing import Any, Dict, Iterable, Optional, Type
1✔
10

11
from haystack.core.component.component import _hook_component_init, logger
1✔
12
from haystack.core.errors import DeserializationError, SerializationError
1✔
13

14

15
@dataclass(frozen=True)
1✔
16
class DeserializationCallbacks:
1✔
17
    """
18
    Callback functions that are invoked in specific stages of the pipeline deserialization process.
19

20
    :param component_pre_init:
21
        Invoked just before a component instance is
22
        initialized. Receives the following inputs:
23
        `component_name` (`str`), `component_class` (`Type`), `init_params` (`Dict[str, Any]`).
24

25
        The callback is allowed to modify the `init_params`
26
        dictionary, which contains all the parameters that
27
        are passed to the component's constructor.
28
    """
29

30
    component_pre_init: Optional[Callable] = None
1✔
31

32

33
def component_to_dict(obj: Any, name: str) -> Dict[str, Any]:
1✔
34
    """
35
    Converts a component instance into a dictionary.
36

37
    If a `to_dict` method is present in the component instance, that will be used instead of the default method.
38

39
    :param obj:
40
        The component to be serialized.
41
    :param name:
42
        The name of the component.
43
    :returns:
44
        A dictionary representation of the component.
45

46
    :raises SerializationError:
47
        If the component doesn't have a `to_dict` method.
48
        If the values of the init parameters can't be determined.
49
        If a non-basic Python type is used in the serialized data.
50
    """
51
    if hasattr(obj, "to_dict"):
1✔
52
        data = obj.to_dict()
1✔
53
    else:
54
        init_parameters = {}
1✔
55
        for param_name, param in inspect.signature(obj.__init__).parameters.items():
1✔
56
            # Ignore `args` and `kwargs`, used by the default constructor
57
            if param_name in ("args", "kwargs"):
1✔
58
                continue
1✔
59
            try:
1✔
60
                # This only works if the Component constructor assigns the init
61
                # parameter to an instance variable or property with the same name
62
                param_value = getattr(obj, param_name)
1✔
63
            except AttributeError as e:
1✔
64
                # If the parameter doesn't have a default value, raise an error
65
                if param.default is param.empty:
1✔
66
                    raise SerializationError(
×
67
                        f"Cannot determine the value of the init parameter '{param_name}' "
68
                        f"for the class {obj.__class__.__name__}."
69
                        f"You can fix this error by assigning 'self.{param_name} = {param_name}' or adding a "
70
                        f"custom serialization method 'to_dict' to the class."
71
                    ) from e
72
                # In case the init parameter was not assigned, we use the default value
73
                param_value = param.default
1✔
74
            init_parameters[param_name] = param_value
1✔
75

76
        data = default_to_dict(obj, **init_parameters)
1✔
77

78
    _validate_component_to_dict_output(obj, name, data)
1✔
79
    return data
1✔
80

81

82
def _validate_component_to_dict_output(component: Any, name: str, data: Dict[str, Any]) -> None:
1✔
83
    # Ensure that only basic Python types are used in the serde data.
84
    def is_allowed_type(obj: Any) -> bool:
1✔
85
        return isinstance(obj, (str, int, float, bool, list, dict, set, tuple, type(None)))
1✔
86

87
    def check_iterable(l: Iterable[Any]):
1✔
88
        for v in l:
1✔
89
            if not is_allowed_type(v):
1✔
90
                raise SerializationError(
×
91
                    f"Component '{name}' of type '{type(component).__name__}' has an unsupported value "
92
                    f"of type '{type(v).__name__}' in the serialized data."
93
                )
94
            if isinstance(v, (list, set, tuple)):
1✔
95
                check_iterable(v)
1✔
96
            elif isinstance(v, dict):
1✔
97
                check_dict(v)
1✔
98

99
    def check_dict(d: Dict[str, Any]):
1✔
100
        if any(not isinstance(k, str) for k in data.keys()):
1✔
101
            raise SerializationError(
×
102
                f"Component '{name}' of type '{type(component).__name__}' has a non-string key in the serialized data."
103
            )
104

105
        for k, v in d.items():
1✔
106
            if not is_allowed_type(v):
1✔
107
                raise SerializationError(
1✔
108
                    f"Component '{name}' of type '{type(component).__name__}' has an unsupported value "
109
                    f"of type '{type(v).__name__}' in the serialized data under key '{k}'."
110
                )
111
            if isinstance(v, (list, set, tuple)):
1✔
112
                check_iterable(v)
1✔
113
            elif isinstance(v, dict):
1✔
114
                check_dict(v)
1✔
115

116
    check_dict(data)
1✔
117

118

119
def generate_qualified_class_name(cls: Type[object]) -> str:
1✔
120
    """
121
    Generates a qualified class name for a class.
122

123
    :param cls:
124
        The class whose qualified name is to be generated.
125
    :returns:
126
        The qualified name of the class.
127
    """
128
    return f"{cls.__module__}.{cls.__name__}"
1✔
129

130

131
def component_from_dict(
1✔
132
    cls: Type[object], data: Dict[str, Any], name: str, callbacks: Optional[DeserializationCallbacks] = None
133
) -> Any:
134
    """
135
    Creates a component instance from a dictionary.
136

137
    If a `from_dict` method is present in the component class, that will be used instead of the default method.
138

139
    :param cls:
140
        The class to be used for deserialization.
141
    :param data:
142
        The serialized data.
143
    :param name:
144
        The name of the component.
145
    :param callbacks:
146
        Callbacks to invoke during deserialization.
147
    :returns:
148
        The deserialized component.
149
    """
150

151
    def component_pre_init_callback(component_cls, init_params):
1✔
152
        assert callbacks is not None
1✔
153
        assert callbacks.component_pre_init is not None
1✔
154
        callbacks.component_pre_init(name, component_cls, init_params)
1✔
155

156
    def do_from_dict():
1✔
157
        if hasattr(cls, "from_dict"):
1✔
158
            return cls.from_dict(data)
1✔
159

160
        return default_from_dict(cls, data)
1✔
161

162
    if callbacks is None or callbacks.component_pre_init is None:
1✔
163
        return do_from_dict()
1✔
164

165
    with _hook_component_init(component_pre_init_callback):
1✔
166
        return do_from_dict()
1✔
167

168

169
def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]:
1✔
170
    """
171
    Utility function to serialize an object to a dictionary.
172

173
    This is mostly necessary for components but can be used by any object.
174
    `init_parameters` are parameters passed to the object class `__init__`.
175
    They must be defined explicitly as they'll be used when creating a new
176
    instance of `obj` with `from_dict`. Omitting them might cause deserialisation
177
    errors or unexpected behaviours later, when calling `from_dict`.
178

179
    An example usage:
180

181
    ```python
182
    class MyClass:
183
        def __init__(self, my_param: int = 10):
184
            self.my_param = my_param
185

186
        def to_dict(self):
187
            return default_to_dict(self, my_param=self.my_param)
188

189

190
    obj = MyClass(my_param=1000)
191
    data = obj.to_dict()
192
    assert data == {
193
        "type": "MyClass",
194
        "init_parameters": {
195
            "my_param": 1000,
196
        },
197
    }
198
    ```
199

200
    :param obj:
201
        The object to be serialized.
202
    :param init_parameters:
203
        The parameters used to create a new instance of the class.
204
    :returns:
205
        A dictionary representation of the instance.
206
    """
207
    return {"type": generate_qualified_class_name(type(obj)), "init_parameters": init_parameters}
1✔
208

209

210
def default_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
1✔
211
    """
212
    Utility function to deserialize a dictionary to an object.
213

214
    This is mostly necessary for components but can be used by any object.
215

216
    The function will raise a `DeserializationError` if the `type` field in `data` is
217
    missing or it doesn't match the type of `cls`.
218

219
    If `data` contains an `init_parameters` field it will be used as parameters to create
220
    a new instance of `cls`.
221

222
    :param cls:
223
        The class to be used for deserialization.
224
    :param data:
225
        The serialized data.
226
    :returns:
227
        The deserialized object.
228

229
    :raises DeserializationError:
230
        If the `type` field in `data` is missing or it doesn't match the type of `cls`.
231
    """
232
    init_params = data.get("init_parameters", {})
1✔
233
    if "type" not in data:
1✔
234
        raise DeserializationError("Missing 'type' in serialization data")
1✔
235
    if data["type"] != generate_qualified_class_name(cls):
1✔
236
        raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")
1✔
237
    return cls(**init_params)
1✔
238

239

240
def import_class_by_name(fully_qualified_name: str) -> Type[object]:
1✔
241
    """
242
    Utility function to import (load) a class object based on its fully qualified class name.
243

244
    This function dynamically imports a class based on its string name.
245
    It splits the name into module path and class name, imports the module,
246
    and returns the class object.
247

248
    :param fully_qualified_name: the fully qualified class name as a string
249
    :returns: the class object.
250
    :raises ImportError: If the class cannot be imported or found.
251
    """
252
    try:
1✔
253
        module_path, class_name = fully_qualified_name.rsplit(".", 1)
1✔
254
        logger.debug(f"Attempting to import class '{class_name}' from module '{module_path}'")
1✔
255
        module = import_module(module_path)
1✔
256
        return getattr(module, class_name)
1✔
257
    except (ImportError, AttributeError) as error:
1✔
258
        logger.error(f"Failed to import class '{fully_qualified_name}'")
1✔
259
        raise ImportError(f"Could not import class '{fully_qualified_name}'") from error
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc