• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5834429870

11 Aug 2023 03:38PM UTC coverage: 95.179% (+1.7%) from 93.466%
5834429870

Pull #55

github

web-flow
Merge 063ea707d into 8e973ea23
Pull Request #55: experiment: FSM implementation

176 of 179 branches covered (98.32%)

Branch coverage included in aggregate %.

673 of 713 relevant lines covered (94.39%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.67
canals/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
"""
1✔
5
    Attributes:
6

7
        component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
8

9
    All components must follow the contract below. This docstring is the source of truth for components contract.
10

11
    <hr>
12

13
    `@component` decorator
14

15
    All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
16

17
    <hr>
18

19
    `__init__(self, **kwargs)`
20

21
    Optional method.
22

23
    Components may have an `__init__` method where they define:
24

25
    - `self.init_parameters = {same parameters that the __init__ method received}`:
26
        In this dictionary you can store any state the components wish to be persisted when they are saved.
27
        These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
28
        Note that by default the `@component` decorator saves the arguments automatically.
29
        However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
30
        Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
31

32
    Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
33
    dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
34
    time. If there's the need for such values, consider serializing them to a string.
35

36
    _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
37

38
    The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and
39
    validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
40
    the `warm_up()` method.
41

42
    <hr>
43

44
    `warm_up(self)`
45

46
    Optional method.
47

48
    This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
49
    because Pipeline will not keep track of which components it called `warm_up()` on.
50

51
    <hr>
52

53
    `run(self, data)`
54

55
    Mandatory method.
56

57
    This is the method where the main functionality of the component should be carried out. It's called by
58
    `Pipeline.run()`.
59

60
    When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
61
    method decorated with `@component.input`. This dataclass contains:
62

63
    - all the input values coming from other components connected to it,
64
    - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
65

66
    `run()` must return a single instance of the dataclass declared through the method decorated with
67
    `@component.output`.
68

69
"""
70

71
import logging
1✔
72
import inspect
1✔
73
from typing import Protocol, Union, Dict, Type, Any, get_origin, get_args
1✔
74
from functools import wraps
1✔
75

76
from canals.errors import ComponentError, ComponentDeserializationError
1✔
77

78

79
logger = logging.getLogger(__name__)
1✔
80

81

82
# We ignore too-few-public-methods Pylint error as this is only meant to be
83
# the definition of the Component interface.
84
class Component(Protocol):  # pylint: disable=too-few-public-methods
1✔
85
    """
86
    Abstract interface of a Component.
87
    This is only used by type checking tools.
88
    If you want to create a new Component use the @component decorator.
89
    """
90

91
    def run(self, **kwargs) -> Dict[str, Any]:
1✔
92
        """
93
        Takes the Component input and returns its output.
94
        Inputs are defined explicitly by the run method's signature or with `component.set_input_types()` if dynamic.
95
        Outputs are defined by decorating the run method with `@component.output_types()`
96
        or with `component.set_output_types()` if dynamic.
97
        """
98

99
    def to_dict(self) -> Dict[str, Any]:
1✔
100
        """
101
        Serializes the component to a dictionary.
102
        """
103

104
    @classmethod
1✔
105
    def from_dict(cls, data: Dict[str, Any]) -> "Component":
1✔
106
        """
107
        Deserializes the component from a dictionary.
108
        """
109

110

111
class _Component:
1✔
112
    """
113
    See module's docstring.
114

115
    Args:
116
        class_: the class that Canals should use as a component.
117
        serializable: whether to check, at init time, if the component can be saved with
118
        `save_pipelines()`.
119

120
    Returns:
121
        A class that can be recognized as a component.
122

123
    Raises:
124
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
125
    """
126

127
    def __init__(self):
1✔
128
        self.registry = {}
1✔
129

130
    def set_input_types(self, instance, **types):
1✔
131
        """
132
        Method that validates the input kwargs of the run method.
133

134
        Use as:
135

136
        ```python
137
        @component
138
        class MyComponent:
139

140
            def __init__(self, value: int):
141
                component.set_input_types(value_1=str, value_2=str)
142
                ...
143

144
            @component.output_types(output_1=int, output_2=str)
145
            def run(self, **kwargs):
146
                return {"output_1": kwargs["value_1"], "output_2": ""}
147
        ```
148
        """
149
        run_method = instance.run
1✔
150

151
        def wrapper(**kwargs):
1✔
152
            return run_method(**kwargs)
1✔
153

154
        # Store the input types in the run method
155
        wrapper.__canals_input__ = {
1✔
156
            name: {"name": name, "type": type_, "is_optional": _is_optional(type_), "has_default": False}
157
            for name, type_ in types.items()
158
        }
159
        wrapper.__canals_output__ = getattr(run_method, "__canals_output__", {})
1✔
160

161
        # Assigns the wrapped method to the instance's run()
162
        instance.run = wrapper
1✔
163

164
    def set_output_types(self, instance, **types):
1✔
165
        """
166
        Method that validates the output dictionary of the run method.
167

168
        Use as:
169

170
        ```python
171
        @component
172
        class MyComponent:
173

174
            def __init__(self, value: int):
175
                component.set_output_types(output_1=int, output_2=str)
176
                ...
177

178
            def run(self, value: int):
179
                return {"output_1": 1, "output_2": "2"}
180
        ```
181
        """
182
        if not types:
1✔
183
            return
×
184

185
        run_method = instance.run
1✔
186

187
        def wrapper(*args, **kwargs):
1✔
188
            return run_method(*args, **kwargs)
1✔
189

190
        # Store the output types in the run method
191
        wrapper.__canals_input__ = getattr(run_method, "__canals_input__", {})
1✔
192
        wrapper.__canals_output__ = {name: {"name": name, "type": type_} for name, type_ in types.items()}
1✔
193

194
        # Assigns the wrapped method to the instance's run()
195
        instance.run = wrapper
1✔
196

197
    def output_types(self, **types):
1✔
198
        """
199
        Decorator factory that validates the output dictionary of the run method.
200

201
        Use as:
202

203
        ```python
204
        @component
205
        class MyComponent:
206
            @component.output_types(output_1=int, output_2=str)
207
            def run(self, value: int):
208
                return {"output_1": 1, "output_2": "2"}
209
        ```
210
        """
211

212
        def output_types_decorator(run_method):
1✔
213
            """
214
            Decorator that validates the output dictionary of the run method.
215
            """
216
            # Store the output types in the run method - used by the pipeline to build the sockets.
217

218
            @wraps(run_method)
1✔
219
            def wrapper(self, *args, **kwargs):
1✔
220
                return run_method(self, *args, **kwargs)
1✔
221

222
            wrapper.__canals_input__ = getattr(run_method, "__canals_input__", {})
1✔
223
            wrapper.__canals_output__ = {name: {"name": name, "type": type_} for name, type_ in types.items()}
1✔
224

225
            return wrapper
1✔
226

227
        return output_types_decorator
1✔
228

229
    def _component(self, class_):
1✔
230
        """
231
        Decorator validating the structure of the component and registering it in the components registry.
232
        """
233
        logger.debug("Registering %s as a component", class_)
1✔
234

235
        # Check for run()
236
        if not hasattr(class_, "run"):
1✔
237
            raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
1✔
238
        run_signature = inspect.signature(class_.run)
1✔
239

240
        # Create the input sockets
241
        class_.run.__canals_input__ = {
1✔
242
            param: {
243
                "name": param,
244
                "type": run_signature.parameters[param].annotation,
245
                "is_optional": _is_optional(run_signature.parameters[param].annotation),
246
                "has_default": run_signature.parameters[param].default != inspect.Parameter.empty,
247
            }
248
            for param in list(run_signature.parameters)[1:]  # First is 'self' and it doesn't matter.
249
        }
250

251
        # Save the component in the class registry (for deserialization)
252
        if class_.__name__ in self.registry:
1✔
253
            logger.error(
1✔
254
                "Component %s is already registered. Previous imported from '%s', new imported from '%s'",
255
                class_.__name__,
256
                self.registry[class_.__name__],
257
                class_,
258
            )
259
        self.registry[class_.__name__] = class_
1✔
260
        logger.debug("Registered Component %s", class_)
1✔
261

262
        setattr(class_, "__canals_component__", True)
1✔
263

264
        if not hasattr(class_, "to_dict"):
1✔
265
            class_.to_dict = _default_component_to_dict
1✔
266

267
        if not hasattr(class_, "from_dict"):
1✔
268
            class_.from_dict = classmethod(_default_component_from_dict)
1✔
269

270
        return class_
1✔
271

272
    def __call__(self, class_=None):
1✔
273
        """Allows us to use this decorator with parenthesis and without."""
274
        if class_:
1✔
275
            return self._component(class_)
1✔
276

277
        return self._component
×
278

279

280
component = _Component()
1✔
281

282

283
def _is_optional(type_: type) -> bool:
1✔
284
    """
285
    Utility method that returns whether a type is Optional.
286
    """
287
    return get_origin(type_) is Union and type(None) in get_args(type_)
1✔
288

289

290
def _default_component_to_dict(comp: Component) -> Dict[str, Any]:
1✔
291
    """
292
    Default component serializer.
293
    Serializes a component to a dictionary.
294
    """
295
    return {
1✔
296
        "hash": id(comp),
297
        "type": comp.__class__.__name__,
298
        "init_parameters": getattr(comp, "init_parameters", {}),
299
    }
300

301

302
def _default_component_from_dict(cls: Type[Component], data: Dict[str, Any]) -> Component:
1✔
303
    """
304
    Default component deserializer.
305
    The "type" field in `data` must match the class that is being deserialized into.
306
    """
307
    init_params = data.get("init_parameters", {})
1✔
308
    if "type" not in data:
1✔
309
        raise ComponentDeserializationError("Missing 'type' in component serialization data")
1✔
310
    if data["type"] != cls.__name__:
1✔
311
        raise ComponentDeserializationError(f"Component '{data['type']}' can't be deserialized as '{cls.__name__}'")
1✔
312
    return cls(**init_params)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc