• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5953104377

23 Aug 2023 03:05PM UTC coverage: 92.958% (+0.1%) from 92.818%
5953104377

push

github

web-flow
Check Component I/O socket names are valid (#100)

185 of 190 branches covered (97.37%)

Branch coverage included in aggregate %.

673 of 733 relevant lines covered (91.81%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.7
canals/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
"""
1✔
5
    Attributes:
6

7
        component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
8

9
    All components must follow the contract below. This docstring is the source of truth for components contract.
10

11
    <hr>
12

13
    `@component` decorator
14

15
    All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
16

17
    <hr>
18

19
    `__init__(self, **kwargs)`
20

21
    Optional method.
22

23
    Components may have an `__init__` method where they define:
24

25
    - `self.init_parameters = {same parameters that the __init__ method received}`:
26
        In this dictionary you can store any state the components wish to be persisted when they are saved.
27
        These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
28
        Note that by default the `@component` decorator saves the arguments automatically.
29
        However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
30
        Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
31

32
    Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
33
    dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
34
    time. If there's the need for such values, consider serializing them to a string.
35

36
    _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
37

38
    The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and
39
    validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
40
    the `warm_up()` method.
41

42
    <hr>
43

44
    `warm_up(self)`
45

46
    Optional method.
47

48
    This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
49
    because Pipeline will not keep track of which components it called `warm_up()` on.
50

51
    <hr>
52

53
    `run(self, data)`
54

55
    Mandatory method.
56

57
    This is the method where the main functionality of the component should be carried out. It's called by
58
    `Pipeline.run()`.
59

60
    When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
61
    method decorated with `@component.input`. This dataclass contains:
62

63
    - all the input values coming from other components connected to it,
64
    - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
65

66
    `run()` must return a single instance of the dataclass declared through the method decorated with
67
    `@component.output`.
68

69
"""
70

71
import logging
1✔
72
import inspect
1✔
73
from typing import Protocol, Union, Dict, Any, get_origin, get_args
1✔
74
from keyword import iskeyword
1✔
75
from functools import wraps
1✔
76

77
from canals.errors import ComponentError
1✔
78

79

80
logger = logging.getLogger(__name__)
1✔
81

82

83
class Component(Protocol):
1✔
84
    """
85
    Abstract interface of a Component.
86
    This is only used by type checking tools.
87
    If you want to create a new Component use the @component decorator.
88
    """
89

90
    def run(self, **kwargs) -> Dict[str, Any]:
1✔
91
        """
92
        Takes the Component input and returns its output.
93
        Inputs are defined explicitly by the run method's signature or with `component.set_input_types()` if dynamic.
94
        Outputs are defined by decorating the run method with `@component.output_types()`
95
        or with `component.set_output_types()` if dynamic.
96
        """
97

98
    def to_dict(self) -> Dict[str, Any]:
1✔
99
        """
100
        Serializes the component to a dictionary.
101
        """
102

103
    @classmethod
1✔
104
    def from_dict(cls, data: Dict[str, Any]) -> "Component":
1✔
105
        """
106
        Deserializes the component from a dictionary.
107
        """
108

109

110
class _Component:
1✔
111
    """
112
    See module's docstring.
113

114
    Args:
115
        class_: the class that Canals should use as a component.
116
        serializable: whether to check, at init time, if the component can be saved with
117
        `save_pipelines()`.
118

119
    Returns:
120
        A class that can be recognized as a component.
121

122
    Raises:
123
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
124
    """
125

126
    def __init__(self):
1✔
127
        self.registry = {}
1✔
128

129
    def set_input_types(self, instance, **types):
1✔
130
        """
131
        Method that validates the input kwargs of the run method.
132
        `types` names must be valid Python identifiers and must not clash with any keyword.
133

134
        Use as:
135

136
        ```python
137
        @component
138
        class MyComponent:
139

140
            def __init__(self, value: int):
141
                component.set_input_types(value_1=str, value_2=str)
142
                ...
143

144
            @component.output_types(output_1=int, output_2=str)
145
            def run(self, **kwargs):
146
                return {"output_1": kwargs["value_1"], "output_2": ""}
147
        ```
148
        """
149
        for name in types:
1✔
150
            if not _is_valid_socket_name(name):
1✔
151
                raise ComponentError(
1✔
152
                    f"Invalid socket name '{name}'. Socket names must be valid Python identifiers and must not clash with any keyword."
153
                )
154

155
        run_method = instance.run
1✔
156

157
        def wrapper(**kwargs):
1✔
158
            return run_method(**kwargs)
1✔
159

160
        # Store the input types in the run method
161
        wrapper.__canals_input__ = {
1✔
162
            name: {"name": name, "type": type_, "is_optional": _is_optional(type_)} for name, type_ in types.items()
163
        }
164
        wrapper.__canals_output__ = getattr(run_method, "__canals_output__", {})
1✔
165

166
        # Assigns the wrapped method to the instance's run()
167
        instance.run = wrapper
1✔
168

169
    def set_output_types(self, instance, **types):
1✔
170
        """
171
        Method that validates the output dictionary of the run method.
172
        `types` names must be valid Python identifiers and must not clash with any keyword.
173

174
        Use as:
175

176
        ```python
177
        @component
178
        class MyComponent:
179

180
            def __init__(self, value: int):
181
                component.set_output_types(output_1=int, output_2=str)
182
                ...
183

184
            def run(self, value: int):
185
                return {"output_1": 1, "output_2": "2"}
186
        ```
187
        """
188
        if not types:
1✔
189
            return
×
190

191
        for name in types:
1✔
192
            if not _is_valid_socket_name(name):
1✔
193
                raise ComponentError(
1✔
194
                    f"Invalid socket name '{name}'. Socket names must be valid Python identifiers and must not clash with any keyword."
195
                )
196

197
        run_method = instance.run
1✔
198

199
        def wrapper(*args, **kwargs):
1✔
200
            return run_method(*args, **kwargs)
1✔
201

202
        # Store the output types in the run method
203
        wrapper.__canals_input__ = getattr(run_method, "__canals_input__", {})
1✔
204
        wrapper.__canals_output__ = {name: {"name": name, "type": type_} for name, type_ in types.items()}
1✔
205

206
        # Assigns the wrapped method to the instance's run()
207
        instance.run = wrapper
1✔
208

209
    def output_types(self, **types):
1✔
210
        """
211
        Decorator factory that validates the output dictionary of the run method.
212
        `types` names must be valid Python identifiers and must not clash with any keyword.
213

214
        Use as:
215

216
        ```python
217
        @component
218
        class MyComponent:
219
            @component.output_types(output_1=int, output_2=str)
220
            def run(self, value: int):
221
                return {"output_1": 1, "output_2": "2"}
222
        ```
223
        """
224
        for name in types:
1✔
225
            if not _is_valid_socket_name(name):
1✔
226
                raise ComponentError(
1✔
227
                    f"Invalid socket name '{name}'. Socket names must be valid Python identifiers and must not clash with any keyword."
228
                )
229

230
        def output_types_decorator(run_method):
1✔
231
            """
232
            Decorator that validates the output dictionary of the run method.
233
            """
234
            # Store the output types in the run method - used by the pipeline to build the sockets.
235

236
            @wraps(run_method)
1✔
237
            def wrapper(self, *args, **kwargs):
1✔
238
                return run_method(self, *args, **kwargs)
1✔
239

240
            wrapper.__canals_input__ = getattr(run_method, "__canals_input__", {})
1✔
241
            wrapper.__canals_output__ = {name: {"name": name, "type": type_} for name, type_ in types.items()}
1✔
242

243
            return wrapper
1✔
244

245
        return output_types_decorator
1✔
246

247
    def _component(self, class_):
1✔
248
        """
249
        Decorator validating the structure of the component and registering it in the components registry.
250
        """
251
        logger.debug("Registering %s as a component", class_)
1✔
252

253
        # Check for required methods
254
        if not hasattr(class_, "run"):
1✔
255
            raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
1✔
256
        run_signature = inspect.signature(class_.run)
1✔
257

258
        if not hasattr(class_, "to_dict"):
1✔
259
            raise ComponentError(
×
260
                f"{class_.__name__} must have a 'to_dict()' method. See the docs for more information."
261
            )
262

263
        if not hasattr(class_, "from_dict"):
1✔
264
            raise ComponentError(
×
265
                f"{class_.__name__} must have a 'from_dict()' method. See the docs for more information."
266
            )
267

268
        # Create the input sockets
269
        class_.run.__canals_input__ = {
1✔
270
            param: {
271
                "name": param,
272
                "type": run_signature.parameters[param].annotation,
273
                "is_optional": _is_optional(run_signature.parameters[param].annotation),
274
            }
275
            for param in list(run_signature.parameters)[1:]  # First is 'self' and it doesn't matter.
276
        }
277

278
        # Save the component in the class registry (for deserialization)
279
        if class_.__name__ in self.registry:
1✔
280
            logger.error(
1✔
281
                "Component %s is already registered. Previous imported from '%s', new imported from '%s'",
282
                class_.__name__,
283
                self.registry[class_.__name__],
284
                class_,
285
            )
286
        self.registry[class_.__name__] = class_
1✔
287
        logger.debug("Registered Component %s", class_)
1✔
288

289
        setattr(class_, "__canals_component__", True)
1✔
290

291
        return class_
1✔
292

293
    def __call__(self, class_=None):
1✔
294
        """Allows us to use this decorator with parenthesis and without."""
295
        if class_:
1✔
296
            return self._component(class_)
1✔
297

298
        return self._component
×
299

300

301
component = _Component()
1✔
302

303

304
def _is_optional(type_: type) -> bool:
1✔
305
    """
306
    Utility method that returns whether a type is Optional.
307
    """
308
    return get_origin(type_) is Union and type(None) in get_args(type_)
1✔
309

310

311
def _is_valid_socket_name(name: str) -> bool:
1✔
312
    """
313
    Utility method that checks if a string a valid socket name.
314
    Socket names must be valid Python identifiers and must clash with any keyword.
315
    """
316
    return name.isidentifier() and not iskeyword(name)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc