• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 19899324781

03 Dec 2025 03:30PM UTC coverage: 92.196% (-0.003%) from 92.199%
19899324781

Pull #10189

github

web-flow
Merge 0dade7e2d into bad2937ae
Pull Request #10189: fix: Improve error messages for non-string templates in ConditionalRouter

14082 of 15274 relevant lines covered (92.2%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.97
haystack/components/routers/conditional_router.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import ast
1✔
6
import contextlib
1✔
7
from typing import Any, Callable, Mapping, Optional, Sequence, TypedDict, Union, get_args, get_origin
1✔
8

9
from jinja2 import Environment, TemplateSyntaxError, meta
1✔
10
from jinja2.nativetypes import NativeEnvironment
1✔
11
from jinja2.sandbox import SandboxedEnvironment
1✔
12

13
from haystack import component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
class NoRouteSelectedException(Exception):
1✔
20
    """Exception raised when no route is selected in ConditionalRouter."""
21

22

23
class RouteConditionException(Exception):
1✔
24
    """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
25

26

27
class Route(TypedDict):
1✔
28
    condition: str
1✔
29
    output: Union[str, list[str]]
1✔
30
    output_name: Union[str, list[str]]
1✔
31
    output_type: Union[type, list[type]]
1✔
32

33

34
@component
1✔
35
class ConditionalRouter:
1✔
36
    """
37
    Routes data based on specific conditions.
38

39
    You define these conditions in a list of dictionaries called `routes`.
40
    Each dictionary in this list represents a single route. Each route has these four elements:
41
    - `condition`: A Jinja2 string expression that determines if the route is selected.
42
    - `output`: A Jinja2 expression defining the route's output value.
43
    - `output_type`: The type of the output data (for example, `str`, `list[int]`).
44
    - `output_name`: The name you want to use to publish `output`. This name is used to connect
45
    the router to other components in the pipeline.
46

47
    ### Usage example
48

49
    ```python
50
    from haystack.components.routers import ConditionalRouter
51

52
    routes = [
53
        {
54
            "condition": "{{streams|length > 2}}",
55
            "output": "{{streams}}",
56
            "output_name": "enough_streams",
57
            "output_type": list[int],
58
        },
59
        {
60
            "condition": "{{streams|length <= 2}}",
61
            "output": "{{streams}}",
62
            "output_name": "insufficient_streams",
63
            "output_type": list[int],
64
        },
65
    ]
66
    router = ConditionalRouter(routes)
67
    # When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
68
    kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
69
    result = router.run(**kwargs)
70
    assert result == {"enough_streams": [1, 2, 3]}
71
    ```
72

73
    In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the
74
    stream count exceeds two. The second route directs 'streams' to 'insufficient_streams' if there
75
    are two or fewer streams.
76

77
    In the pipeline setup, the Router connects to other components using the output names. For example,
78
    'enough_streams' might connect to a component that processes streams, while
79
    'insufficient_streams' might connect to a component that fetches more streams.
80

81

82
    Here is a pipeline that uses `ConditionalRouter` and routes the fetched `ByteStreams` to
83
    different components depending on the number of streams fetched:
84

85
    ```python
86
    from haystack import Pipeline
87
    from haystack.dataclasses import ByteStream
88
    from haystack.components.routers import ConditionalRouter
89

90
    routes = [
91
        {
92
            "condition": "{{streams|length > 2}}",
93
            "output": "{{streams}}",
94
            "output_name": "enough_streams",
95
            "output_type": list[ByteStream],
96
        },
97
        {
98
            "condition": "{{streams|length <= 2}}",
99
            "output": "{{streams}}",
100
            "output_name": "insufficient_streams",
101
            "output_type": list[ByteStream],
102
        },
103
    ]
104

105
    pipe = Pipeline()
106
    pipe.add_component("router", router)
107
    ...
108
    pipe.connect("router.enough_streams", "some_component_a.streams")
109
    pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
110
    ...
111
    ```
112
    """
113

114
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
115
        self,
116
        routes: list[Route],
117
        custom_filters: Optional[dict[str, Callable]] = None,
118
        unsafe: bool = False,
119
        validate_output_type: bool = False,
120
        optional_variables: Optional[list[str]] = None,
121
    ):
122
        """
123
        Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
124

125
        :param routes: A list of dictionaries, each defining a route.
126
            Each route has these four elements:
127
            - `condition`: A Jinja2 string expression that determines if the route is selected.
128
            - `output`: A Jinja2 expression defining the route's output value.
129
            - `output_type`: The type of the output data (for example, `str`, `list[int]`).
130
            - `output_name`: The name you want to use to publish `output`. This name is used to connect
131
            the router to other components in the pipeline.
132
        :param custom_filters: A dictionary of custom Jinja2 filters used in the condition expressions.
133
            For example, passing `{"my_filter": my_filter_fcn}` where:
134
            - `my_filter` is the name of the custom filter.
135
            - `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
136
              `{{ my_var|my_filter }}` can then be used inside a route condition expression:
137
                `"condition": "{{ my_var|my_filter == 'foo' }}"`.
138
        :param unsafe:
139
            Enable execution of arbitrary code in the Jinja template.
140
            This should only be used if you trust the source of the template as it can be lead to remote code execution.
141
        :param validate_output_type:
142
            Enable validation of routes' output.
143
            If a route output doesn't match the declared type a ValueError is raised running.
144
        :param optional_variables:
145
            A list of variable names that are optional in your route conditions and outputs.
146
            If these variables are not provided at runtime, they will be set to `None`.
147
            This allows you to write routes that can handle missing inputs gracefully without raising errors.
148

149
            Example usage with a default fallback route in a Pipeline:
150
            ```python
151
            from haystack import Pipeline
152
            from haystack.components.routers import ConditionalRouter
153

154
            routes = [
155
                {
156
                    "condition": '{{ path == "rag" }}',
157
                    "output": "{{ question }}",
158
                    "output_name": "rag_route",
159
                    "output_type": str
160
                },
161
                {
162
                    "condition": "{{ True }}",  # fallback route
163
                    "output": "{{ question }}",
164
                    "output_name": "default_route",
165
                    "output_type": str
166
                }
167
            ]
168

169
            router = ConditionalRouter(routes, optional_variables=["path"])
170
            pipe = Pipeline()
171
            pipe.add_component("router", router)
172

173
            # When 'path' is provided in the pipeline:
174
            result = pipe.run(data={"router": {"question": "What?", "path": "rag"}})
175
            assert result["router"] == {"rag_route": "What?"}
176

177
            # When 'path' is not provided, fallback route is taken:
178
            result = pipe.run(data={"router": {"question": "What?"}})
179
            assert result["router"] == {"default_route": "What?"}
180
            ```
181

182
            This pattern is particularly useful when:
183
            - You want to provide default/fallback behavior when certain inputs are missing
184
            - Some variables are only needed for specific routing conditions
185
            - You're building flexible pipelines where not all inputs are guaranteed to be present
186
        """
187
        self.routes: list[Route] = routes
1✔
188
        self.custom_filters = custom_filters or {}
1✔
189
        self._unsafe = unsafe
1✔
190
        self._validate_output_type = validate_output_type
1✔
191
        self.optional_variables = optional_variables or []
1✔
192

193
        # Create a Jinja environment to inspect variables in the condition templates
194
        if self._unsafe:
1✔
195
            msg = (
1✔
196
                "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
197
                "Use this only if you trust the source of the template."
198
            )
199
            logger.warning(msg)
1✔
200

201
        self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment()
1✔
202
        self._env.filters.update(self.custom_filters)
1✔
203

204
        self._validate_routes(routes)
1✔
205
        # Inspect the routes to determine input and output types.
206
        input_types: set[str] = set()  # let's just store the name, type will always be Any
1✔
207
        output_types: dict[str, Union[type, list[type]]] = {}
1✔
208

209
        for route in routes:
1✔
210
            # extract inputs
211
            route_input_names = self._extract_variables(
1✔
212
                self._env,
213
                [route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
214
            )
215
            input_types.update(route_input_names)
1✔
216

217
            # extract outputs
218
            output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
1✔
219
            output_types_list = (
1✔
220
                route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
221
            )
222

223
            output_types.update(dict(zip(output_names, output_types_list)))
1✔
224

225
        # remove optional variables from mandatory input types
226
        mandatory_input_types = input_types - set(self.optional_variables)
1✔
227

228
        # warn about unused optional variables
229
        unused_optional_vars = set(self.optional_variables) - input_types if self.optional_variables else None
1✔
230
        if unused_optional_vars:
1✔
231
            logger.warning(
1✔
232
                "The following optional variables are specified but not used in any route: {unused_optional_vars}. "
233
                "Check if there's a typo in variable names.",
234
                unused_optional_vars=unused_optional_vars,
235
            )
236

237
        # add mandatory input types
238
        component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any))
1✔
239

240
        # now add optional input types
241
        for optional_var_name in self.optional_variables:
1✔
242
            component.set_input_type(self, name=optional_var_name, type=Any, default=None)
1✔
243

244
        # set output types
245
        component.set_output_types(self, **output_types)
1✔
246

247
    def to_dict(self) -> dict[str, Any]:
1✔
248
        """
249
        Serializes the component to a dictionary.
250

251
        :returns:
252
            Dictionary with serialized data.
253
        """
254
        serialized_routes = []
1✔
255
        for route in self.routes:
1✔
256
            serialized_output_type = (
1✔
257
                [serialize_type(t) for t in route["output_type"]]
258
                if isinstance(route["output_type"], list)
259
                else serialize_type(route["output_type"])
260
            )
261
            serialized_routes.append({**route, "output_type": serialized_output_type})
1✔
262
        se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
1✔
263
        return default_to_dict(
1✔
264
            self,
265
            routes=serialized_routes,
266
            custom_filters=se_filters,
267
            unsafe=self._unsafe,
268
            validate_output_type=self._validate_output_type,
269
            optional_variables=self.optional_variables,
270
        )
271

272
    @classmethod
1✔
273
    def from_dict(cls, data: dict[str, Any]) -> "ConditionalRouter":
1✔
274
        """
275
        Deserializes the component from a dictionary.
276

277
        :param data:
278
            The dictionary to deserialize from.
279
        :returns:
280
            The deserialized component.
281
        """
282
        init_params = data.get("init_parameters", {})
1✔
283
        routes = init_params.get("routes")
1✔
284
        for route in routes:
1✔
285
            # output_type needs to be deserialized from a string to a type
286
            if isinstance(route["output_type"], list):
1✔
287
                route["output_type"] = [deserialize_type(t) for t in route["output_type"]]
1✔
288
            else:
289
                route["output_type"] = deserialize_type(route["output_type"])
1✔
290

291
        # Since the custom_filters are typed as optional in the init signature, we catch the
292
        # case where they are not present in the serialized data and set them to an empty dict.
293
        custom_filters = init_params.get("custom_filters", {})
1✔
294
        if custom_filters is not None:
1✔
295
            for name, filter_func in custom_filters.items():
1✔
296
                init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
1✔
297
        return default_from_dict(cls, data)
1✔
298

299
    def run(self, **kwargs):
1✔
300
        """
301
        Executes the routing logic.
302

303
        Executes the routing logic by evaluating the specified boolean condition expressions for each route in the
304
        order they are listed. The method directs the flow of data to the output specified in the first route whose
305
        `condition` is True.
306

307
        :param kwargs: All variables used in the `condition` expressed in the routes. When the component is used in a
308
            pipeline, these variables are passed from the previous component's output.
309

310
        :returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output`
311
            of the selected route.
312

313
        :raises NoRouteSelectedException:
314
            If no `condition' in the routes is `True`.
315
        :raises RouteConditionException:
316
            If there is an error parsing or evaluating the `condition` expression in the routes.
317
        :raises ValueError:
318
            If type validation is enabled and route type doesn't match actual value type.
319
        """
320
        # Create a Jinja native environment to evaluate the condition templates as Python expressions
321
        for route in self.routes:
1✔
322
            try:
1✔
323
                t = self._env.from_string(route["condition"])
1✔
324
                rendered = t.render(**kwargs)
1✔
325
                if not self._unsafe:
1✔
326
                    rendered = ast.literal_eval(rendered)
1✔
327
                if not rendered:
1✔
328
                    continue
1✔
329

330
                # Handle multiple outputs
331
                outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
1✔
332
                output_types = (
1✔
333
                    route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
334
                )
335
                output_names = (
1✔
336
                    route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
337
                )
338

339
                result = {}
1✔
340
                for output, output_type, output_name in zip(outputs, output_types, output_names):
1✔
341
                    # Evaluate output template
342
                    t_output = self._env.from_string(output)
1✔
343
                    output_value = t_output.render(**kwargs)
1✔
344

345
                    # We suppress the exception in case the output is already a string, otherwise
346
                    # we try to evaluate it and would fail.
347
                    # This must be done cause the output could be different literal structures.
348
                    # This doesn't support any user types.
349
                    with contextlib.suppress(Exception):
1✔
350
                        if not self._unsafe:
1✔
351
                            output_value = ast.literal_eval(output_value)
1✔
352

353
                    # Validate output type if needed
354
                    if self._validate_output_type and not self._output_matches_type(output_value, output_type):
1✔
355
                        raise ValueError(f"Route '{output_name}' type doesn't match expected type")
1✔
356

357
                    result[output_name] = output_value
1✔
358

359
                return result
1✔
360

361
            except Exception as e:
1✔
362
                # If this was a type‐validation failure, let it propagate as a ValueError
363
                if isinstance(e, ValueError):
1✔
364
                    raise
1✔
365
                msg = f"Error evaluating condition for route '{route}': {e}"
×
366
                raise RouteConditionException(msg) from e
×
367

368
        raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
1✔
369

370
    def _validate_routes(self, routes: list[Route]):
1✔
371
        """
372
        Validates a list of routes.
373

374
        :param routes: A list of routes.
375
        """
376
        for route in routes:
1✔
377
            try:
1✔
378
                keys = set(route.keys())
1✔
379
            except AttributeError:
1✔
380
                raise ValueError(f"Route must be a dictionary, got: {route}")
1✔
381

382
            mandatory_fields = {"condition", "output", "output_type", "output_name"}
1✔
383
            has_all_mandatory_fields = mandatory_fields.issubset(keys)
1✔
384
            if not has_all_mandatory_fields:
1✔
385
                raise ValueError(
1✔
386
                    f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
387
                )
388

389
            # Validate outputs are consistent
390
            outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
1✔
391
            output_types = route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
1✔
392
            output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
1✔
393

394
            # Check lengths match
395
            if not len(outputs) == len(output_types) == len(output_names):
1✔
396
                raise ValueError(f"Route output, output_type and output_name must have same length: {route}")
1✔
397

398
            # Validate templates
399
            if not self._validate_template(self._env, route["condition"]):
1✔
400
                condition_value = route["condition"]
1✔
401
                if not isinstance(condition_value, str):
1✔
402
                    raise ValueError(
×
403
                        f"Invalid template for condition: {condition_value!r} (type: {type(condition_value).__name__})."
404
                        f"Condition must be a string representing a valid Jinja2 template. "
405
                        f"For example, use {str(condition_value)!r} instead of {condition_value!r}."
406
                    )
407
                raise ValueError(f"Invalid template for condition: {condition_value}")
1✔
408

409
            for output in outputs:
1✔
410
                if not self._validate_template(self._env, output):
1✔
411
                    if not isinstance(output, str):
1✔
412
                        raise ValueError(
1✔
413
                            f"Invalid template for output: {output!r} (type: {type(output).__name__}). "
414
                            f"Output must be a string representing a valid Jinja2 template. "
415
                            f"For example, use {str(output)!r} instead of {output!r}."
416
                        )
417
                    raise ValueError(f"Invalid template for output: {output}")
×
418

419
    def _extract_variables(self, env: Environment, templates: list[str]) -> set[str]:
1✔
420
        """
421
        Extracts all variables from a list of Jinja template strings.
422

423
        :param env: A Jinja environment.
424
        :param templates: A list of Jinja template strings.
425
        :returns: A set of variable names.
426
        """
427
        variables = set()
1✔
428
        for template in templates:
1✔
429
            variables.update(meta.find_undeclared_variables(env.parse(template)))
1✔
430
        return variables
1✔
431

432
    def _validate_template(self, env: Environment, template_text: str):
1✔
433
        """
434
        Validates a template string by parsing it with Jinja.
435

436
        :param env: A Jinja environment.
437
        :param template_text: A Jinja template string.
438
        :returns: `True` if the template is valid, `False` otherwise.
439
        """
440
        # Check if template_text is a string before attempting to parse
441
        if not isinstance(template_text, str):
1✔
442
            return False
1✔
443
        try:
1✔
444
            env.parse(template_text)
1✔
445
            return True
1✔
446
        except TemplateSyntaxError:
1✔
447
            return False
1✔
448

449
    def _output_matches_type(self, value: Any, expected_type: type):  # noqa: PLR0911 # pylint: disable=too-many-return-statements
1✔
450
        """
451
        Checks whether `value` type matches the `expected_type`.
452
        """
453
        # Handle Any type
454
        if expected_type is Any:
1✔
455
            return True
×
456

457
        # Get the origin type (List, Dict, etc) and type arguments
458
        origin = get_origin(expected_type)
1✔
459
        args = get_args(expected_type)
1✔
460

461
        # Handle basic types (int, str, etc)
462
        if origin is None:
1✔
463
            return isinstance(value, expected_type)
1✔
464

465
        # Handle Sequence types (List, Tuple, etc)
466
        if isinstance(origin, type) and issubclass(origin, Sequence):
1✔
467
            if not isinstance(value, Sequence):
1✔
468
                return False
×
469
            # Empty sequence is valid
470
            if not value:
1✔
471
                return True
×
472
            # Check each element against the sequence's type parameter
473
            return all(self._output_matches_type(item, args[0]) for item in value)
1✔
474

475
        # Handle basic types (int, str, etc)
476
        if origin is None:
×
477
            return isinstance(value, expected_type)
×
478

479
        # Handle Mapping types (Dict, etc)
480
        if isinstance(origin, type) and issubclass(origin, Mapping):
×
481
            if not isinstance(value, Mapping):
×
482
                return False
×
483
            # Empty mapping is valid
484
            if not value:
×
485
                return True
×
486
            key_type, value_type = args
×
487
            # Check all keys and values match their respective types
488
            return all(
×
489
                self._output_matches_type(k, key_type) and self._output_matches_type(v, value_type)
490
                for k, v in value.items()
491
            )
492

493
        # Handle Union types (including Optional)
494
        if origin is Union:
×
495
            return any(self._output_matches_type(value, arg) for arg in args)
×
496

497
        return False
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc