• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 10112264105

26 Jul 2024 01:40PM UTC coverage: 90.045% (-0.001%) from 90.046%
10112264105

Pull #8095

github

web-flow
Merge e16cefc3a into 47f4db869
Pull Request #8095: fix: Fix issue that could lead to RCE if using unsecure Jinja templates

6793 of 7544 relevant lines covered (90.05%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.53
haystack/components/routers/conditional_router.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import ast
1✔
6
import contextlib
1✔
7
from typing import Any, Callable, Dict, List, Optional, Set
1✔
8

9
from jinja2 import Environment, TemplateSyntaxError, meta
1✔
10
from jinja2.nativetypes import NativeEnvironment
1✔
11
from jinja2.sandbox import SandboxedEnvironment
1✔
12

13
from haystack import component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
class NoRouteSelectedException(Exception):
1✔
20
    """Exception raised when no route is selected in ConditionalRouter."""
21

22

23
class RouteConditionException(Exception):
1✔
24
    """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
25

26

27
@component
1✔
28
class ConditionalRouter:
1✔
29
    """
30
    `ConditionalRouter` allows data routing based on specific conditions.
31

32
    This is achieved by defining a list named `routes`. Each element in this list is a dictionary representing a
33
    single route.
34
    A route dictionary comprises four key elements:
35
    - `condition`: A Jinja2 string expression that determines if the route is selected.
36
    - `output`: A Jinja2 expression defining the route's output value.
37
    - `output_type`: The type of the output data (e.g., `str`, `List[int]`).
38
    - `output_name`: The name under which the `output` value of the route is published. This name is used to connect
39
    the router to other components in the pipeline.
40

41
    Usage example:
42
    ```python
43
    from typing import List
44
    from haystack.components.routers import ConditionalRouter
45

46
    routes = [
47
        {
48
            "condition": "{{streams|length > 2}}",
49
            "output": "{{streams}}",
50
            "output_name": "enough_streams",
51
            "output_type": List[int],
52
        },
53
        {
54
            "condition": "{{streams|length <= 2}}",
55
            "output": "{{streams}}",
56
            "output_name": "insufficient_streams",
57
            "output_type": List[int],
58
        },
59
    ]
60
    router = ConditionalRouter(routes)
61
    # When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
62
    kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
63
    result = router.run(**kwargs)
64
    assert result == {"enough_streams": [1, 2, 3]}
65
    ```
66

67
    In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the
68
    stream count exceeds two. Conversely, the second route directs 'streams' to 'insufficient_streams' when there
69
    are two or fewer streams.
70

71
    In the pipeline setup, the router is connected to other components using the output names. For example, the
72
    'enough_streams' output might be connected to another component that processes the streams, while the
73
    'insufficient_streams' output might be connected to a component that fetches more streams, and so on.
74

75

76
    Here is a pseudocode example of a pipeline that uses the `ConditionalRouter` and routes fetched `ByteStreams` to
77
    different components depending on the number of streams fetched:
78
    ```python
79
    from typing import List
80
    from haystack import Pipeline
81
    from haystack.dataclasses import ByteStream
82
    from haystack.components.routers import ConditionalRouter
83

84
    routes = [
85
        {
86
            "condition": "{{streams|length > 2}}",
87
            "output": "{{streams}}",
88
            "output_name": "enough_streams",
89
            "output_type": List[ByteStream],
90
        },
91
        {
92
            "condition": "{{streams|length <= 2}}",
93
            "output": "{{streams}}",
94
            "output_name": "insufficient_streams",
95
            "output_type": List[ByteStream],
96
        },
97
    ]
98

99
    pipe = Pipeline()
100
    pipe.add_component("router", router)
101
    ...
102
    pipe.connect("router.enough_streams", "some_component_a.streams")
103
    pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
104
    ...
105
    ```
106
    """
107

108
    def __init__(self, routes: List[Dict], custom_filters: Optional[Dict[str, Callable]] = None):
1✔
109
        """
110
        Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
111

112
        :param routes: A list of dictionaries, each defining a route.
113
            A route dictionary comprises four key elements:
114
            - `condition`: A Jinja2 string expression that determines if the route is selected.
115
            - `output`: A Jinja2 expression defining the route's output value.
116
            - `output_type`: The type of the output data (e.g., str, List[int]).
117
            - `output_name`: The name under which the `output` value of the route is published. This name is used to
118
                connect the router to other components in the pipeline.
119
        :param custom_filters: A dictionary of custom Jinja2 filters to be used in the condition expressions.
120
            For example, passing `{"my_filter": my_filter_fcn}` where:
121
            - `my_filter` is the name of the custom filter.
122
            - `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
123
              `{{ my_var|my_filter }}` can then be used inside a route condition expression like so:
124
                `"condition": "{{ my_var|my_filter == 'foo' }}"`.
125
        """
126
        self._validate_routes(routes)
1✔
127
        self.routes: List[dict] = routes
1✔
128
        self.custom_filters = custom_filters or {}
1✔
129

130
        # Create a Jinja native environment to inspect variables in the condition templates
131
        self._env = SandboxedEnvironment()
1✔
132
        self._env.filters.update(self.custom_filters)
1✔
133

134
        # Inspect the routes to determine input and output types.
135
        input_types: Set[str] = set()  # let's just store the name, type will always be Any
1✔
136
        output_types: Dict[str, str] = {}
1✔
137

138
        for route in routes:
1✔
139
            # extract inputs
140
            route_input_names = self._extract_variables(self._env, [route["output"], route["condition"]])
1✔
141
            input_types.update(route_input_names)
1✔
142

143
            # extract outputs
144
            output_types.update({route["output_name"]: route["output_type"]})
1✔
145

146
        component.set_input_types(self, **{var: Any for var in input_types})
1✔
147
        component.set_output_types(self, **output_types)
1✔
148

149
    def to_dict(self) -> Dict[str, Any]:
1✔
150
        """
151
        Serializes the component to a dictionary.
152

153
        :returns:
154
            Dictionary with serialized data.
155
        """
156
        for route in self.routes:
1✔
157
            # output_type needs to be serialized to a string
158
            route["output_type"] = serialize_type(route["output_type"])
1✔
159
        se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
1✔
160
        return default_to_dict(self, routes=self.routes, custom_filters=se_filters)
1✔
161

162
    @classmethod
1✔
163
    def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
1✔
164
        """
165
        Deserializes the component from a dictionary.
166

167
        :param data:
168
            The dictionary to deserialize from.
169
        :returns:
170
            The deserialized component.
171
        """
172
        init_params = data.get("init_parameters", {})
1✔
173
        routes = init_params.get("routes")
1✔
174
        for route in routes:
1✔
175
            # output_type needs to be deserialized from a string to a type
176
            route["output_type"] = deserialize_type(route["output_type"])
1✔
177
        for name, filter_func in init_params.get("custom_filters", {}).items():
1✔
178
            init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
1✔
179
        return default_from_dict(cls, data)
1✔
180

181
    def run(self, **kwargs):
1✔
182
        """
183
        Executes the routing logic.
184

185
        Executes the routing logic by evaluating the specified boolean condition expressions for each route in the
186
        order they are listed. The method directs the flow of data to the output specified in the first route whose
187
        `condition` is True.
188

189
        :param kwargs: All variables used in the `condition` expressed in the routes. When the component is used in a
190
            pipeline, these variables are passed from the previous component's output.
191

192
        :returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output`
193
            of the selected route.
194

195
        :raises NoRouteSelectedException: If no `condition' in the routes is `True`.
196
        :raises RouteConditionException: If there is an error parsing or evaluating the `condition` expression in the
197
            routes.
198
        """
199
        # Create a Jinja native environment to evaluate the condition templates as Python expressions
200
        for route in self.routes:
1✔
201
            try:
1✔
202
                t = self._env.from_string(route["condition"])
1✔
203
                rendered = t.render(**kwargs)
1✔
204
                if ast.literal_eval(rendered):
1✔
205
                    # We now evaluate the `output` expression to determine the route output
206
                    t_output = self._env.from_string(route["output"])
1✔
207
                    output = t_output.render(**kwargs)
1✔
208
                    # We suppress the exception in case the output is already a string, otherwise
209
                    # we try to evaluate it and would fail.
210
                    # This must be done cause the output could be different literal structures.
211
                    # This doesn't support any user types.
212
                    with contextlib.suppress(Exception):
1✔
213
                        output = ast.literal_eval(output)
1✔
214
                    # and return the output as a dictionary under the output_name key
215
                    return {route["output_name"]: output}
1✔
216
            except Exception as e:
×
217
                raise RouteConditionException(f"Error evaluating condition for route '{route}': {e}") from e
×
218

219
        raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
1✔
220

221
    def _validate_routes(self, routes: List[Dict]):
1✔
222
        """
223
        Validates a list of routes.
224

225
        :param routes: A list of routes.
226
        """
227
        env = NativeEnvironment()
1✔
228
        for route in routes:
1✔
229
            try:
1✔
230
                keys = set(route.keys())
1✔
231
            except AttributeError:
1✔
232
                raise ValueError(f"Route must be a dictionary, got: {route}")
1✔
233

234
            mandatory_fields = {"condition", "output", "output_type", "output_name"}
1✔
235
            has_all_mandatory_fields = mandatory_fields.issubset(keys)
1✔
236
            if not has_all_mandatory_fields:
1✔
237
                raise ValueError(
1✔
238
                    f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
239
                )
240
            for field in ["condition", "output"]:
1✔
241
                if not self._validate_template(env, route[field]):
1✔
242
                    raise ValueError(f"Invalid template for field '{field}': {route[field]}")
1✔
243

244
    def _extract_variables(self, env: SandboxedEnvironment, templates: List[str]) -> Set[str]:
1✔
245
        """
246
        Extracts all variables from a list of Jinja template strings.
247

248
        :param env: A Jinja environment.
249
        :param templates: A list of Jinja template strings.
250
        :returns: A set of variable names.
251
        """
252
        variables = set()
1✔
253
        for template in templates:
1✔
254
            ast = env.parse(template)
1✔
255
            variables.update(meta.find_undeclared_variables(ast))
1✔
256
        return variables
1✔
257

258
    def _validate_template(self, env: Environment, template_text: str):
1✔
259
        """
260
        Validates a template string by parsing it with Jinja.
261

262
        :param env: A Jinja environment.
263
        :param template_text: A Jinja template string.
264
        :returns: `True` if the template is valid, `False` otherwise.
265
        """
266
        try:
1✔
267
            env.parse(template_text)
1✔
268
            return True
1✔
269
        except TemplateSyntaxError:
1✔
270
            return False
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc