• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14645042953

24 Apr 2025 03:06PM UTC coverage: 90.447% (-0.04%) from 90.482%
14645042953

Pull #9303

github

web-flow
Merge fdc9cc510 into f97472329
Pull Request #9303: fix: make `HuggingFaceAPIChatGenerator` convert Tool Call `arguments` from string

10860 of 12007 relevant lines covered (90.45%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.5
haystack/components/routers/conditional_router.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import ast
1✔
6
import contextlib
1✔
7
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Union, get_args, get_origin
1✔
8

9
from jinja2 import Environment, TemplateSyntaxError, meta
1✔
10
from jinja2.nativetypes import NativeEnvironment
1✔
11
from jinja2.sandbox import SandboxedEnvironment
1✔
12

13
from haystack import component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
class NoRouteSelectedException(Exception):
1✔
20
    """Exception raised when no route is selected in ConditionalRouter."""
21

22

23
class RouteConditionException(Exception):
1✔
24
    """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""
25

26

27
@component
1✔
28
class ConditionalRouter:
1✔
29
    """
30
    Routes data based on specific conditions.
31

32
    You define these conditions in a list of dictionaries called `routes`.
33
    Each dictionary in this list represents a single route. Each route has these four elements:
34
    - `condition`: A Jinja2 string expression that determines if the route is selected.
35
    - `output`: A Jinja2 expression defining the route's output value.
36
    - `output_type`: The type of the output data (for example, `str`, `List[int]`).
37
    - `output_name`: The name you want to use to publish `output`. This name is used to connect
38
    the router to other components in the pipeline.
39

40
    ### Usage example
41

42
    ```python
43
    from typing import List
44
    from haystack.components.routers import ConditionalRouter
45

46
    routes = [
47
        {
48
            "condition": "{{streams|length > 2}}",
49
            "output": "{{streams}}",
50
            "output_name": "enough_streams",
51
            "output_type": List[int],
52
        },
53
        {
54
            "condition": "{{streams|length <= 2}}",
55
            "output": "{{streams}}",
56
            "output_name": "insufficient_streams",
57
            "output_type": List[int],
58
        },
59
    ]
60
    router = ConditionalRouter(routes)
61
    # When 'streams' has more than 2 items, 'enough_streams' output will activate, emitting the list [1, 2, 3]
62
    kwargs = {"streams": [1, 2, 3], "query": "Haystack"}
63
    result = router.run(**kwargs)
64
    assert result == {"enough_streams": [1, 2, 3]}
65
    ```
66

67
    In this example, we configure two routes. The first route sends the 'streams' value to 'enough_streams' if the
68
    stream count exceeds two. The second route directs 'streams' to 'insufficient_streams' if there
69
    are two or fewer streams.
70

71
    In the pipeline setup, the Router connects to other components using the output names. For example,
72
    'enough_streams' might connect to a component that processes streams, while
73
    'insufficient_streams' might connect to a component that fetches more streams.
74

75

76
    Here is a pipeline that uses `ConditionalRouter` and routes the fetched `ByteStreams` to
77
    different components depending on the number of streams fetched:
78

79
    ```python
80
    from typing import List
81
    from haystack import Pipeline
82
    from haystack.dataclasses import ByteStream
83
    from haystack.components.routers import ConditionalRouter
84

85
    routes = [
86
        {
87
            "condition": "{{streams|length > 2}}",
88
            "output": "{{streams}}",
89
            "output_name": "enough_streams",
90
            "output_type": List[ByteStream],
91
        },
92
        {
93
            "condition": "{{streams|length <= 2}}",
94
            "output": "{{streams}}",
95
            "output_name": "insufficient_streams",
96
            "output_type": List[ByteStream],
97
        },
98
    ]
99

100
    pipe = Pipeline()
101
    pipe.add_component("router", router)
102
    ...
103
    pipe.connect("router.enough_streams", "some_component_a.streams")
104
    pipe.connect("router.insufficient_streams", "some_component_b.streams_or_some_other_input")
105
    ...
106
    ```
107
    """
108

109
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
110
        self,
111
        routes: List[Dict],
112
        custom_filters: Optional[Dict[str, Callable]] = None,
113
        unsafe: bool = False,
114
        validate_output_type: bool = False,
115
        optional_variables: Optional[List[str]] = None,
116
    ):
117
        """
118
        Initializes the `ConditionalRouter` with a list of routes detailing the conditions for routing.
119

120
        :param routes: A list of dictionaries, each defining a route.
121
            Each route has these four elements:
122
            - `condition`: A Jinja2 string expression that determines if the route is selected.
123
            - `output`: A Jinja2 expression defining the route's output value.
124
            - `output_type`: The type of the output data (for example, `str`, `List[int]`).
125
            - `output_name`: The name you want to use to publish `output`. This name is used to connect
126
            the router to other components in the pipeline.
127
        :param custom_filters: A dictionary of custom Jinja2 filters used in the condition expressions.
128
            For example, passing `{"my_filter": my_filter_fcn}` where:
129
            - `my_filter` is the name of the custom filter.
130
            - `my_filter_fcn` is a callable that takes `my_var:str` and returns `my_var[:3]`.
131
              `{{ my_var|my_filter }}` can then be used inside a route condition expression:
132
                `"condition": "{{ my_var|my_filter == 'foo' }}"`.
133
        :param unsafe:
134
            Enable execution of arbitrary code in the Jinja template.
135
            This should only be used if you trust the source of the template as it can be lead to remote code execution.
136
        :param validate_output_type:
137
            Enable validation of routes' output.
138
            If a route output doesn't match the declared type a ValueError is raised running.
139
        :param optional_variables:
140
            A list of variable names that are optional in your route conditions and outputs.
141
            If these variables are not provided at runtime, they will be set to `None`.
142
            This allows you to write routes that can handle missing inputs gracefully without raising errors.
143

144
            Example usage with a default fallback route in a Pipeline:
145
            ```python
146
            from haystack import Pipeline
147
            from haystack.components.routers import ConditionalRouter
148

149
            routes = [
150
                {
151
                    "condition": '{{ path == "rag" }}',
152
                    "output": "{{ question }}",
153
                    "output_name": "rag_route",
154
                    "output_type": str
155
                },
156
                {
157
                    "condition": "{{ True }}",  # fallback route
158
                    "output": "{{ question }}",
159
                    "output_name": "default_route",
160
                    "output_type": str
161
                }
162
            ]
163

164
            router = ConditionalRouter(routes, optional_variables=["path"])
165
            pipe = Pipeline()
166
            pipe.add_component("router", router)
167

168
            # When 'path' is provided in the pipeline:
169
            result = pipe.run(data={"router": {"question": "What?", "path": "rag"}})
170
            assert result["router"] == {"rag_route": "What?"}
171

172
            # When 'path' is not provided, fallback route is taken:
173
            result = pipe.run(data={"router": {"question": "What?"}})
174
            assert result["router"] == {"default_route": "What?"}
175
            ```
176

177
            This pattern is particularly useful when:
178
            - You want to provide default/fallback behavior when certain inputs are missing
179
            - Some variables are only needed for specific routing conditions
180
            - You're building flexible pipelines where not all inputs are guaranteed to be present
181
        """
182
        self.routes: List[dict] = routes
1✔
183
        self.custom_filters = custom_filters or {}
1✔
184
        self._unsafe = unsafe
1✔
185
        self._validate_output_type = validate_output_type
1✔
186
        self.optional_variables = optional_variables or []
1✔
187

188
        # Create a Jinja environment to inspect variables in the condition templates
189
        if self._unsafe:
1✔
190
            msg = (
1✔
191
                "Unsafe mode is enabled. This allows execution of arbitrary code in the Jinja template. "
192
                "Use this only if you trust the source of the template."
193
            )
194
            logger.warning(msg)
1✔
195

196
        self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment()
1✔
197
        self._env.filters.update(self.custom_filters)
1✔
198

199
        self._validate_routes(routes)
1✔
200
        # Inspect the routes to determine input and output types.
201
        input_types: Set[str] = set()  # let's just store the name, type will always be Any
1✔
202
        output_types: Dict[str, str] = {}
1✔
203

204
        for route in routes:
1✔
205
            # extract inputs
206
            route_input_names = self._extract_variables(
1✔
207
                self._env,
208
                [route["condition"]] + (route["output"] if isinstance(route["output"], list) else [route["output"]]),
209
            )
210
            input_types.update(route_input_names)
1✔
211

212
            # extract outputs
213
            output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
1✔
214
            output_types_list = (
1✔
215
                route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
216
            )
217

218
            output_types.update(dict(zip(output_names, output_types_list)))
1✔
219

220
        # remove optional variables from mandatory input types
221
        mandatory_input_types = input_types - set(self.optional_variables)
1✔
222

223
        # warn about unused optional variables
224
        unused_optional_vars = set(self.optional_variables) - input_types if self.optional_variables else None
1✔
225
        if unused_optional_vars:
1✔
226
            logger.warning(
1✔
227
                "The following optional variables are specified but not used in any route: {unused_optional_vars}. "
228
                "Check if there's a typo in variable names.",
229
                unused_optional_vars=unused_optional_vars,
230
            )
231

232
        # add mandatory input types
233
        component.set_input_types(self, **dict.fromkeys(mandatory_input_types, Any))
1✔
234

235
        # now add optional input types
236
        for optional_var_name in self.optional_variables:
1✔
237
            component.set_input_type(self, name=optional_var_name, type=Any, default=None)
1✔
238

239
        # set output types
240
        component.set_output_types(self, **output_types)
1✔
241

242
    def to_dict(self) -> Dict[str, Any]:
1✔
243
        """
244
        Serializes the component to a dictionary.
245

246
        :returns:
247
            Dictionary with serialized data.
248
        """
249
        serialized_routes = []
1✔
250
        for route in self.routes:
1✔
251
            # output_type needs to be serialized to a string
252
            serialized_routes.append({**route, "output_type": serialize_type(route["output_type"])})
1✔
253
        se_filters = {name: serialize_callable(filter_func) for name, filter_func in self.custom_filters.items()}
1✔
254
        return default_to_dict(
1✔
255
            self,
256
            routes=serialized_routes,
257
            custom_filters=se_filters,
258
            unsafe=self._unsafe,
259
            validate_output_type=self._validate_output_type,
260
            optional_variables=self.optional_variables,
261
        )
262

263
    @classmethod
1✔
264
    def from_dict(cls, data: Dict[str, Any]) -> "ConditionalRouter":
1✔
265
        """
266
        Deserializes the component from a dictionary.
267

268
        :param data:
269
            The dictionary to deserialize from.
270
        :returns:
271
            The deserialized component.
272
        """
273
        init_params = data.get("init_parameters", {})
1✔
274
        routes = init_params.get("routes")
1✔
275
        for route in routes:
1✔
276
            # output_type needs to be deserialized from a string to a type
277
            route["output_type"] = deserialize_type(route["output_type"])
1✔
278

279
        # Since the custom_filters are typed as optional in the init signature, we catch the
280
        # case where they are not present in the serialized data and set them to an empty dict.
281
        custom_filters = init_params.get("custom_filters", {})
1✔
282
        if custom_filters is not None:
1✔
283
            for name, filter_func in custom_filters.items():
1✔
284
                init_params["custom_filters"][name] = deserialize_callable(filter_func) if filter_func else None
1✔
285
        return default_from_dict(cls, data)
1✔
286

287
    def run(self, **kwargs):
1✔
288
        """
289
        Executes the routing logic.
290

291
        Executes the routing logic by evaluating the specified boolean condition expressions for each route in the
292
        order they are listed. The method directs the flow of data to the output specified in the first route whose
293
        `condition` is True.
294

295
        :param kwargs: All variables used in the `condition` expressed in the routes. When the component is used in a
296
            pipeline, these variables are passed from the previous component's output.
297

298
        :returns: A dictionary where the key is the `output_name` of the selected route and the value is the `output`
299
            of the selected route.
300

301
        :raises NoRouteSelectedException:
302
            If no `condition' in the routes is `True`.
303
        :raises RouteConditionException:
304
            If there is an error parsing or evaluating the `condition` expression in the routes.
305
        :raises ValueError:
306
            If type validation is enabled and route type doesn't match actual value type.
307
        """
308
        # Create a Jinja native environment to evaluate the condition templates as Python expressions
309
        for route in self.routes:
1✔
310
            try:
1✔
311
                t = self._env.from_string(route["condition"])
1✔
312
                rendered = t.render(**kwargs)
1✔
313
                if not self._unsafe:
1✔
314
                    rendered = ast.literal_eval(rendered)
1✔
315
                if not rendered:
1✔
316
                    continue
1✔
317

318
                # Handle multiple outputs
319
                outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
1✔
320
                output_types = (
1✔
321
                    route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
322
                )
323
                output_names = (
1✔
324
                    route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
325
                )
326

327
                result = {}
1✔
328
                for output, output_type, output_name in zip(outputs, output_types, output_names):
1✔
329
                    # Evaluate output template
330
                    t_output = self._env.from_string(output)
1✔
331
                    output_value = t_output.render(**kwargs)
1✔
332

333
                    # We suppress the exception in case the output is already a string, otherwise
334
                    # we try to evaluate it and would fail.
335
                    # This must be done cause the output could be different literal structures.
336
                    # This doesn't support any user types.
337
                    with contextlib.suppress(Exception):
1✔
338
                        if not self._unsafe:
1✔
339
                            output_value = ast.literal_eval(output_value)
1✔
340

341
                    # Validate output type if needed
342
                    if self._validate_output_type and not self._output_matches_type(output_value, output_type):
1✔
343
                        raise ValueError(f"Route '{output_name}' type doesn't match expected type")
1✔
344

345
                    result[output_name] = output_value
1✔
346

347
                return result
1✔
348

349
            except Exception as e:
1✔
350
                # If this was a type‐validation failure, let it propagate as a ValueError
351
                if isinstance(e, ValueError):
1✔
352
                    raise
1✔
353
                msg = f"Error evaluating condition for route '{route}': {e}"
×
354
                raise RouteConditionException(msg) from e
×
355

356
        raise NoRouteSelectedException(f"No route fired. Routes: {self.routes}")
1✔
357

358
    def _validate_routes(self, routes: List[Dict]):
1✔
359
        """
360
        Validates a list of routes.
361

362
        :param routes: A list of routes.
363
        """
364
        for route in routes:
1✔
365
            try:
1✔
366
                keys = set(route.keys())
1✔
367
            except AttributeError:
1✔
368
                raise ValueError(f"Route must be a dictionary, got: {route}")
1✔
369

370
            mandatory_fields = {"condition", "output", "output_type", "output_name"}
1✔
371
            has_all_mandatory_fields = mandatory_fields.issubset(keys)
1✔
372
            if not has_all_mandatory_fields:
1✔
373
                raise ValueError(
1✔
374
                    f"Route must contain 'condition', 'output', 'output_type' and 'output_name' fields: {route}"
375
                )
376

377
            # Validate outputs are consistent
378
            outputs = route["output"] if isinstance(route["output"], list) else [route["output"]]
1✔
379
            output_types = route["output_type"] if isinstance(route["output_type"], list) else [route["output_type"]]
1✔
380
            output_names = route["output_name"] if isinstance(route["output_name"], list) else [route["output_name"]]
1✔
381

382
            # Check lengths match
383
            if not len(outputs) == len(output_types) == len(output_names):
1✔
384
                raise ValueError(f"Route output, output_type and output_name must have same length: {route}")
1✔
385

386
            # Validate templates
387
            if not self._validate_template(self._env, route["condition"]):
1✔
388
                raise ValueError(f"Invalid template for condition: {route['condition']}")
1✔
389

390
            for output in outputs:
1✔
391
                if not self._validate_template(self._env, output):
1✔
392
                    raise ValueError(f"Invalid template for output: {output}")
×
393

394
    def _extract_variables(self, env: Environment, templates: List[str]) -> Set[str]:
1✔
395
        """
396
        Extracts all variables from a list of Jinja template strings.
397

398
        :param env: A Jinja environment.
399
        :param templates: A list of Jinja template strings.
400
        :returns: A set of variable names.
401
        """
402
        variables = set()
1✔
403
        for template in templates:
1✔
404
            ast = env.parse(template)
1✔
405
            variables.update(meta.find_undeclared_variables(ast))
1✔
406
        return variables
1✔
407

408
    def _validate_template(self, env: Environment, template_text: str):
1✔
409
        """
410
        Validates a template string by parsing it with Jinja.
411

412
        :param env: A Jinja environment.
413
        :param template_text: A Jinja template string.
414
        :returns: `True` if the template is valid, `False` otherwise.
415
        """
416
        try:
1✔
417
            env.parse(template_text)
1✔
418
            return True
1✔
419
        except TemplateSyntaxError:
1✔
420
            return False
1✔
421

422
    def _output_matches_type(self, value: Any, expected_type: type):  # noqa: PLR0911 # pylint: disable=too-many-return-statements
1✔
423
        """
424
        Checks whether `value` type matches the `expected_type`.
425
        """
426
        # Handle Any type
427
        if expected_type is Any:
1✔
428
            return True
×
429

430
        # Get the origin type (List, Dict, etc) and type arguments
431
        origin = get_origin(expected_type)
1✔
432
        args = get_args(expected_type)
1✔
433

434
        # Handle basic types (int, str, etc)
435
        if origin is None:
1✔
436
            return isinstance(value, expected_type)
1✔
437

438
        # Handle Sequence types (List, Tuple, etc)
439
        if isinstance(origin, type) and issubclass(origin, Sequence):
1✔
440
            if not isinstance(value, Sequence):
1✔
441
                return False
×
442
            # Empty sequence is valid
443
            if not value:
1✔
444
                return True
×
445
            # Check each element against the sequence's type parameter
446
            return all(self._output_matches_type(item, args[0]) for item in value)
1✔
447

448
        # Handle basic types (int, str, etc)
449
        if origin is None:
×
450
            return isinstance(value, expected_type)
×
451

452
        # Handle Mapping types (Dict, etc)
453
        if isinstance(origin, type) and issubclass(origin, Mapping):
×
454
            if not isinstance(value, Mapping):
×
455
                return False
×
456
            # Empty mapping is valid
457
            if not value:
×
458
                return True
×
459
            key_type, value_type = args
×
460
            # Check all keys and values match their respective types
461
            return all(
×
462
                self._output_matches_type(k, key_type) and self._output_matches_type(v, value_type)
463
                for k, v in value.items()
464
            )
465

466
        # Handle Union types (including Optional)
467
        if origin is Union:
×
468
            return any(self._output_matches_type(value, arg) for arg in args)
×
469

470
        return False
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc