• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5834361039

11 Aug 2023 03:31PM UTC coverage: 93.466% (-0.06%) from 93.524%
5834361039

Pull #82

github

web-flow
Merge 5f0b60f21 into 19f2e8fac
Pull Request #82: feat: remove `init_parameters` decorator

178 of 183 branches covered (97.27%)

Branch coverage included in aggregate %.

666 of 720 relevant lines covered (92.5%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.67
canals/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
"""
1✔
5
    Attributes:
6

7
        component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
8

9
    All components must follow the contract below. This docstring is the source of truth for components contract.
10

11
    <hr>
12

13
    `@component` decorator
14

15
    All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
16

17
    <hr>
18

19
    `__init__(self, **kwargs)`
20

21
    Optional method.
22

23
    Components may have an `__init__` method where they define:
24

25
    - `self.init_parameters = {same parameters that the __init__ method received}`:
26
        In this dictionary you can store any state the components wish to be persisted when they are saved.
27
        These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
28
        Note that by default the `@component` decorator saves the arguments automatically.
29
        However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
30
        Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
31

32
    Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
33
    dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
34
    time. If there's the need for such values, consider serializing them to a string.
35

36
    _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
37

38
    The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and
39
    validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
40
    the `warm_up()` method.
41

42
    <hr>
43

44
    `warm_up(self)`
45

46
    Optional method.
47

48
    This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
49
    because Pipeline will not keep track of which components it called `warm_up()` on.
50

51
    <hr>
52

53
    `run(self, data)`
54

55
    Mandatory method.
56

57
    This is the method where the main functionality of the component should be carried out. It's called by
58
    `Pipeline.run()`.
59

60
    When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
61
    method decorated with `@component.input`. This dataclass contains:
62

63
    - all the input values coming from other components connected to it,
64
    - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
65

66
    `run()` must return a single instance of the dataclass declared through the method decorated with
67
    `@component.output`.
68

69
"""
70

71
import logging
1✔
72
import inspect
1✔
73
from typing import Protocol, Union, Dict, Type, Any, get_origin, get_args
1✔
74
from functools import wraps
1✔
75

76
from canals.errors import ComponentError, ComponentDeserializationError
1✔
77

78

79
logger = logging.getLogger(__name__)
1✔
80

81

82
# We ignore too-few-public-methods Pylint error as this is only meant to be
83
# the definition of the Component interface.
84
class Component(Protocol):  # pylint: disable=too-few-public-methods
1✔
85
    """
86
    Abstract interface of a Component.
87
    This is only used by type checking tools.
88
    If you want to create a new Component use the @component decorator.
89
    """
90

91
    def run(self, **kwargs) -> Dict[str, Any]:
1✔
92
        """
93
        Takes the Component input and returns its output.
94
        Inputs are defined explicitly by the run method's signature or with `component.set_input_types()` if dynamic.
95
        Outputs are defined by decorating the run method with `@component.output_types()`
96
        or with `component.set_output_types()` if dynamic.
97
        """
98

99
    def to_dict(self) -> Dict[str, Any]:
1✔
100
        """
101
        Serializes the component to a dictionary.
102
        """
103

104
    @classmethod
1✔
105
    def from_dict(cls, data: Dict[str, Any]) -> "Component":
1✔
106
        """
107
        Deserializes the component from a dictionary.
108
        """
109

110

111
class _Component:
1✔
112
    """
113
    See module's docstring.
114

115
    Args:
116
        class_: the class that Canals should use as a component.
117
        serializable: whether to check, at init time, if the component can be saved with
118
        `save_pipelines()`.
119

120
    Returns:
121
        A class that can be recognized as a component.
122

123
    Raises:
124
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
125
    """
126

127
    def __init__(self):
1✔
128
        self.registry = {}
1✔
129

130
    def set_input_types(self, instance, **types):
1✔
131
        """
132
        Method that validates the input kwargs of the run method.
133

134
        Use as:
135

136
        ```python
137
        @component
138
        class MyComponent:
139

140
            def __init__(self, value: int):
141
                component.set_input_types(value_1=str, value_2=str)
142
                ...
143

144
            @component.output_types(output_1=int, output_2=str)
145
            def run(self, **kwargs):
146
                return {"output_1": kwargs["value_1"], "output_2": ""}
147
        ```
148
        """
149
        run_method = instance.run
1✔
150

151
        def wrapper(**kwargs):
1✔
152
            return run_method(**kwargs)
1✔
153

154
        # Store the input types in the run method
155
        wrapper.__canals_input__ = {
1✔
156
            name: {"name": name, "type": type_, "is_optional": _is_optional(type_)} for name, type_ in types.items()
157
        }
158
        wrapper.__canals_output__ = getattr(run_method, "__canals_output__", {})
1✔
159

160
        # Assigns the wrapped method to the instance's run()
161
        instance.run = wrapper
1✔
162

163
    def set_output_types(self, instance, **types):
1✔
164
        """
165
        Method that validates the output dictionary of the run method.
166

167
        Use as:
168

169
        ```python
170
        @component
171
        class MyComponent:
172

173
            def __init__(self, value: int):
174
                component.set_output_types(output_1=int, output_2=str)
175
                ...
176

177
            def run(self, value: int):
178
                return {"output_1": 1, "output_2": "2"}
179
        ```
180
        """
181
        if not types:
1✔
182
            return
×
183

184
        run_method = instance.run
1✔
185

186
        def wrapper(*args, **kwargs):
1✔
187
            return run_method(*args, **kwargs)
1✔
188

189
        # Store the output types in the run method
190
        wrapper.__canals_input__ = getattr(run_method, "__canals_input__", {})
1✔
191
        wrapper.__canals_output__ = {name: {"name": name, "type": type_} for name, type_ in types.items()}
1✔
192

193
        # Assigns the wrapped method to the instance's run()
194
        instance.run = wrapper
1✔
195

196
    def output_types(self, **types):
1✔
197
        """
198
        Decorator factory that validates the output dictionary of the run method.
199

200
        Use as:
201

202
        ```python
203
        @component
204
        class MyComponent:
205
            @component.output_types(output_1=int, output_2=str)
206
            def run(self, value: int):
207
                return {"output_1": 1, "output_2": "2"}
208
        ```
209
        """
210

211
        def output_types_decorator(run_method):
1✔
212
            """
213
            Decorator that validates the output dictionary of the run method.
214
            """
215
            # Store the output types in the run method - used by the pipeline to build the sockets.
216

217
            @wraps(run_method)
1✔
218
            def wrapper(self, *args, **kwargs):
1✔
219
                return run_method(self, *args, **kwargs)
1✔
220

221
            wrapper.__canals_input__ = getattr(run_method, "__canals_input__", {})
1✔
222
            wrapper.__canals_output__ = {name: {"name": name, "type": type_} for name, type_ in types.items()}
1✔
223

224
            return wrapper
1✔
225

226
        return output_types_decorator
1✔
227

228
    def _component(self, class_):
1✔
229
        """
230
        Decorator validating the structure of the component and registering it in the components registry.
231
        """
232
        logger.debug("Registering %s as a component", class_)
1✔
233

234
        # Check for run()
235
        if not hasattr(class_, "run"):
1✔
236
            raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
1✔
237
        run_signature = inspect.signature(class_.run)
1✔
238

239
        # Create the input sockets
240
        class_.run.__canals_input__ = {
1✔
241
            param: {
242
                "name": param,
243
                "type": run_signature.parameters[param].annotation,
244
                "is_optional": _is_optional(run_signature.parameters[param].annotation),
245
            }
246
            for param in list(run_signature.parameters)[1:]  # First is 'self' and it doesn't matter.
247
        }
248

249
        # Save the component in the class registry (for deserialization)
250
        if class_.__name__ in self.registry:
1✔
251
            logger.error(
1✔
252
                "Component %s is already registered. Previous imported from '%s', new imported from '%s'",
253
                class_.__name__,
254
                self.registry[class_.__name__],
255
                class_,
256
            )
257
        self.registry[class_.__name__] = class_
1✔
258
        logger.debug("Registered Component %s", class_)
1✔
259

260
        setattr(class_, "__canals_component__", True)
1✔
261

262
        if not hasattr(class_, "to_dict"):
1✔
263
            class_.to_dict = _default_component_to_dict
1✔
264

265
        if not hasattr(class_, "from_dict"):
1✔
266
            class_.from_dict = classmethod(_default_component_from_dict)
1✔
267

268
        return class_
1✔
269

270
    def __call__(self, class_=None):
1✔
271
        """Allows us to use this decorator with parenthesis and without."""
272
        if class_:
1✔
273
            return self._component(class_)
1✔
274

275
        return self._component
×
276

277

278
component = _Component()
1✔
279

280

281
def _is_optional(type_: type) -> bool:
1✔
282
    """
283
    Utility method that returns whether a type is Optional.
284
    """
285
    return get_origin(type_) is Union and type(None) in get_args(type_)
1✔
286

287

288
def _default_component_to_dict(comp: Component) -> Dict[str, Any]:
1✔
289
    """
290
    Default component serializer.
291
    Serializes a component to a dictionary.
292
    """
293
    return {
1✔
294
        "hash": id(comp),
295
        "type": comp.__class__.__name__,
296
        "init_parameters": getattr(comp, "init_parameters", {}),
297
    }
298

299

300
def _default_component_from_dict(cls: Type[Component], data: Dict[str, Any]) -> Component:
1✔
301
    """
302
    Default component deserializer.
303
    The "type" field in `data` must match the class that is being deserialized into.
304
    """
305
    init_params = data.get("init_parameters", {})
1✔
306
    if "type" not in data:
1✔
307
        raise ComponentDeserializationError("Missing 'type' in component serialization data")
1✔
308
    if data["type"] != cls.__name__:
1✔
309
        raise ComponentDeserializationError(f"Component '{data['type']}' can't be deserialized as '{cls.__name__}'")
1✔
310
    return cls(**init_params)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc