• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5331939869

21 Jun 2023 08:39AM UTC coverage: 93.779% (-0.2%) from 93.998%
5331939869

Pull #23

github

web-flow
Merge ca1995464 into 4e413c7ba
Pull Request #23: Rework how component I/O is defined

171 of 176 branches covered (97.16%)

Branch coverage included in aggregate %.

643 of 692 relevant lines covered (92.92%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.83
canals/component/input_output.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
import logging
1✔
5
from enum import Enum
1✔
6
from dataclasses import fields, is_dataclass, dataclass, asdict, MISSING
1✔
7

8
from canals.errors import ComponentError
1✔
9

10
logger = logging.getLogger(__name__)
1✔
11

12

13
def _make_fields_optional(class_: type):
1✔
14
    """
15
    Takes a dataclass definition and modifies its __init__ so that all fields have
16
    a default value set.
17
    If a field has a default factory use it to set the default value.
18
    If a field has neither a default factory or value default to None.
19
    """
20
    defaults = []
1✔
21
    for field in fields(class_):
1✔
22
        default = field.default
1✔
23
        if field.default is MISSING and field.default_factory is MISSING:
1✔
24
            default = None
1✔
25
        elif field.default is MISSING and field.default_factory is not MISSING:
1✔
26
            default = field.default_factory()
×
27
        defaults.append(default)
1✔
28
    # mypy complains we're accessing __init__ on an instance but it's not in reality.
29
    # class_ is a class definition and not an instance of it, so we're good.
30
    # Also only I/O dataclasses are meant to be passed to this function making it a bit safer.
31
    class_.__init__.__defaults__ = tuple(defaults)  # type: ignore
1✔
32

33

34
def _make_comparable(class_: type):
1✔
35
    """
36
    Overwrites the existing __eq__ method of class_ with a custom one.
37
    This is meant to be used only in I/O dataclasses, it takes into account
38
    whether the fields are marked as comparable or not.
39

40
    This is necessary since the automatically created __eq__ method in dataclasses
41
    also verifies the type of the class. That causes it to fail if the I/O dataclass
42
    is returned by a function.
43

44
    In here we don't compare the types of self and other but only their fields.
45
    """
46

47
    def comparator(self, other) -> bool:
1✔
48
        if not is_dataclass(other):
1✔
49
            return False
×
50

51
        fields_ = [f.name for f in fields(self) if f.compare]
1✔
52
        other_fields = [f.name for f in fields(other) if f.compare]
1✔
53
        if not len(fields_) == len(other_fields):
1✔
54
            return False
×
55

56
        self_dict, other_dict = asdict(self), asdict(other)
1✔
57
        for field in fields_:
1✔
58
            if not self_dict[field] == other_dict[field]:
1✔
59
                return False
×
60

61
        return True
1✔
62

63
    setattr(class_, "__eq__", comparator)
1✔
64

65

66
class Connection(Enum):
1✔
67
    INPUT = 1
1✔
68
    OUTPUT = 2
1✔
69
    INPUT_VARIADIC = 3
1✔
70

71

72
def _input(input_function=None, variadic: bool = False):
1✔
73
    """
74
    Decorator to mark a method that returns a dataclass defining a Component's input.
75

76
    The decorated function becomes a property.
77

78
    :param variadic: Set it to true to mark the dataclass returned by input_function as variadic,
79
        additional checks are done in this case, defaults to False
80
    """
81

82
    def decorator(function):
1✔
83
        def wrapper(self):
1✔
84
            class_ = function(self)
1✔
85
            # If the user didn't explicitly declare the returned class
86
            # as dataclass we do it out of convenience
87
            if not is_dataclass(class_):
1✔
88
                class_ = dataclass(class_)
1✔
89

90
            _make_comparable(class_)
1✔
91
            _make_fields_optional(class_)
1✔
92

93
            if variadic and len(fields(class_)) > 1:
1✔
94
                raise ComponentError(f"Variadic input dataclass {class_.__name__} must have only one field")
1✔
95

96
            if variadic:
1✔
97
                # Ugly hack to make variadic input work
98
                init = class_.__init__
1✔
99
                class_.__init__ = lambda self, *args: init(self, list(args))
1✔
100

101
            return class_
1✔
102

103
        # Magic field to ease some further checks, we set it in the wrapper
104
        # function so we access it like this <class>.<function>.fget.__canals_connection__
105
        wrapper.__canals_connection__ = Connection.INPUT_VARIADIC if variadic else Connection.INPUT
1✔
106

107
        # If we don't set the documentation explicitly the user wouldn't be able to access
108
        # since we make wrapper a property and not the original function.
109
        # This is not essential but a really nice to have.
110
        return property(fget=wrapper, doc=function.__doc__)
1✔
111

112
    # Check if we're called as @_input or @_input()
113
    if input_function:
1✔
114
        # Called with parens
115
        return decorator(input_function)
1✔
116

117
    # Called without parens
118
    return decorator
1✔
119

120

121
def _output(output_function=None):
1✔
122
    """
123
    Decorator to mark a method that returns a dataclass defining a Component's output.
124

125
    The decorated function becomes a property.
126
    """
127

128
    def decorator(function):
1✔
129
        def wrapper(self):
1✔
130
            class_ = function(self)
1✔
131
            if not is_dataclass(class_):
1✔
132
                class_ = dataclass(class_)
1✔
133
            _make_comparable(class_)
1✔
134
            return class_
1✔
135

136
        # Magic field to ease some further checks, we set it in the wrapper
137
        # function so we access it like this <class>.<function>.fget.__canals_connection__
138
        wrapper.__canals_connection__ = Connection.OUTPUT
1✔
139

140
        # If we don't set the documentation explicitly the user wouldn't be able to access
141
        # since we make wrapper a property and not the original function.
142
        # This is not essential but a really nice to have.
143
        return property(fget=wrapper, doc=function.__doc__)
1✔
144

145
    # Check if we're called as @_output or @_output()
146
    if output_function:
1✔
147
        # Called with parens
148
        return decorator(output_function)
1✔
149

150
    # Called without parens
151
    return decorator
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc