• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15165238383

21 May 2025 02:42PM UTC coverage: 90.404% (-0.04%) from 90.443%
15165238383

Pull #9275

github

web-flow
Merge 82e69fe2c into 17432f710
Pull Request #9275: feat: return common type in SuperComponent type compatibility check

11135 of 12317 relevant lines covered (90.4%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.07
haystack/components/agents/agent.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict, logging, tracing
1✔
9
from haystack.components.generators.chat.types import ChatGenerator
1✔
10
from haystack.components.tools import ToolInvoker
1✔
11
from haystack.core.pipeline.async_pipeline import AsyncPipeline
1✔
12
from haystack.core.pipeline.pipeline import Pipeline
1✔
13
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
1✔
14
from haystack.core.serialization import component_to_dict
1✔
15
from haystack.dataclasses import ChatMessage
1✔
16
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1✔
17
from haystack.dataclasses.state_utils import merge_lists
1✔
18
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
1✔
19
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
1✔
20
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
21
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25

26
@component
1✔
27
class Agent:
1✔
28
    """
29
    A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
30

31
    The component processes messages and executes tools until a exit_condition condition is met.
32
    The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool.
33

34
    When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits.
35

36
    ### Usage example
37
    ```python
38
    from haystack.components.agents import Agent
39
    from haystack.components.generators.chat import OpenAIChatGenerator
40
    from haystack.dataclasses import ChatMessage
41
    from haystack.tools.tool import Tool
42

43
    tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")]
44

45
    agent = Agent(
46
        chat_generator=OpenAIChatGenerator(),
47
        tools=tools,
48
        exit_condition="search",
49
    )
50

51
    # Run the agent
52
    result = agent.run(
53
        messages=[ChatMessage.from_user("Find information about Haystack")]
54
    )
55

56
    assert "messages" in result  # Contains conversation history
57
    ```
58
    """
59

60
    def __init__(
1✔
61
        self,
62
        *,
63
        chat_generator: ChatGenerator,
64
        tools: Optional[Union[List[Tool], Toolset]] = None,
65
        system_prompt: Optional[str] = None,
66
        exit_conditions: Optional[List[str]] = None,
67
        state_schema: Optional[Dict[str, Any]] = None,
68
        max_agent_steps: int = 100,
69
        raise_on_tool_invocation_failure: bool = False,
70
        streaming_callback: Optional[StreamingCallbackT] = None,
71
    ):
72
        """
73
        Initialize the agent component.
74

75
        :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
76
        :param tools: List of Tool objects or a Toolset that the agent can use.
77
        :param system_prompt: System prompt for the agent.
78
        :param exit_conditions: List of conditions that will cause the agent to return.
79
            Can include "text" if the agent should return when it generates a message without tool calls,
80
            or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
81
        :param state_schema: The schema for the runtime state used by the tools.
82
        :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
83
            If the agent exceeds this number of steps, it will stop and return the current state.
84
        :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
85
            If set to False, the exception will be turned into a chat message and passed to the LLM.
86
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
87
            The same callback can be configured to emit tool results when a tool is called.
88
        :raises TypeError: If the chat_generator does not support tools parameter in its run method.
89
        :raises ValueError: If the exit_conditions are not valid.
90
        """
91
        # Check if chat_generator supports tools parameter
92
        chat_generator_run_method = inspect.signature(chat_generator.run)
1✔
93
        if "tools" not in chat_generator_run_method.parameters:
1✔
94
            raise TypeError(
1✔
95
                f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
96
                "The Agent component requires a chat generator that supports tools."
97
            )
98

99
        valid_exits = ["text"] + [tool.name for tool in tools or []]
1✔
100
        if exit_conditions is None:
1✔
101
            exit_conditions = ["text"]
1✔
102
        if not all(condition in valid_exits for condition in exit_conditions):
1✔
103
            raise ValueError(
1✔
104
                f"Invalid exit conditions provided: {exit_conditions}. "
105
                f"Valid exit conditions must be a subset of {valid_exits}. "
106
                "Ensure that each exit condition corresponds to either 'text' or a valid tool name."
107
            )
108

109
        # Validate state schema if provided
110
        if state_schema is not None:
1✔
111
            _validate_schema(state_schema)
1✔
112
        self._state_schema = state_schema or {}
1✔
113

114
        # Initialize state schema
115
        resolved_state_schema = _deepcopy_with_exceptions(self._state_schema)
1✔
116
        if resolved_state_schema.get("messages") is None:
1✔
117
            resolved_state_schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
1✔
118
        self.state_schema = resolved_state_schema
1✔
119

120
        self.chat_generator = chat_generator
1✔
121
        self.tools = tools or []
1✔
122
        self.system_prompt = system_prompt
1✔
123
        self.exit_conditions = exit_conditions
1✔
124
        self.max_agent_steps = max_agent_steps
1✔
125
        self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
1✔
126
        self.streaming_callback = streaming_callback
1✔
127

128
        output_types = {"last_message": ChatMessage}
1✔
129
        for param, config in self.state_schema.items():
1✔
130
            output_types[param] = config["type"]
1✔
131
            # Skip setting input types for parameters that are already in the run method
132
            if param in ["messages", "streaming_callback"]:
1✔
133
                continue
1✔
134
            component.set_input_type(self, name=param, type=config["type"], default=None)
1✔
135
        component.set_output_types(self, **output_types)
1✔
136

137
        self._tool_invoker = None
1✔
138
        if self.tools:
1✔
139
            self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
1✔
140
        else:
141
            logger.warning(
1✔
142
                "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
143
                "responses. To enable tool usage, pass tools directly to the Agent, not to the chat_generator."
144
            )
145

146
        self._is_warmed_up = False
1✔
147

148
    def warm_up(self) -> None:
1✔
149
        """
150
        Warm up the Agent.
151
        """
152
        if not self._is_warmed_up:
1✔
153
            if hasattr(self.chat_generator, "warm_up"):
1✔
154
                self.chat_generator.warm_up()
×
155
            self._is_warmed_up = True
1✔
156

157
    def to_dict(self) -> Dict[str, Any]:
1✔
158
        """
159
        Serialize the component to a dictionary.
160

161
        :return: Dictionary with serialized data
162
        """
163
        if self.streaming_callback is not None:
1✔
164
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
165
        else:
166
            streaming_callback = None
1✔
167

168
        return default_to_dict(
1✔
169
            self,
170
            chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"),
171
            tools=serialize_tools_or_toolset(self.tools),
172
            system_prompt=self.system_prompt,
173
            exit_conditions=self.exit_conditions,
174
            # We serialize the original state schema, not the resolved one to reflect the original user input
175
            state_schema=_schema_to_dict(self._state_schema),
176
            max_agent_steps=self.max_agent_steps,
177
            raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
178
            streaming_callback=streaming_callback,
179
        )
180

181
    @classmethod
1✔
182
    def from_dict(cls, data: Dict[str, Any]) -> "Agent":
1✔
183
        """
184
        Deserialize the agent from a dictionary.
185

186
        :param data: Dictionary to deserialize from
187
        :return: Deserialized agent
188
        """
189
        init_params = data.get("init_parameters", {})
1✔
190

191
        deserialize_chatgenerator_inplace(init_params, key="chat_generator")
1✔
192

193
        if "state_schema" in init_params:
1✔
194
            init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
1✔
195

196
        if init_params.get("streaming_callback") is not None:
1✔
197
            init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
1✔
198

199
        deserialize_tools_or_toolset_inplace(init_params, key="tools")
1✔
200

201
        return default_from_dict(cls, data)
1✔
202

203
    def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
1✔
204
        """Prepare inputs for the chat generator."""
205
        generator_inputs: Dict[str, Any] = {"tools": self.tools}
1✔
206
        if streaming_callback is not None:
1✔
207
            generator_inputs["streaming_callback"] = streaming_callback
1✔
208
        return generator_inputs
1✔
209

210
    def _create_agent_span(self) -> Any:
1✔
211
        """Create a span for the agent run."""
212
        return tracing.tracer.trace(
1✔
213
            "haystack.agent.run",
214
            tags={
215
                "haystack.agent.max_steps": self.max_agent_steps,
216
                "haystack.agent.tools": self.tools,
217
                "haystack.agent.exit_conditions": self.exit_conditions,
218
                "haystack.agent.state_schema": _schema_to_dict(self.state_schema),
219
            },
220
        )
221

222
    def run(
1✔
223
        self,
224
        messages: List[ChatMessage],
225
        streaming_callback: Optional[StreamingCallbackT] = None,
226
        **kwargs: Dict[str, Any],
227
    ) -> Dict[str, Any]:
228
        """
229
        Process messages and execute tools until an exit condition is met.
230

231
        :param messages: List of Haystack ChatMessage objects to process.
232
            If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object.
233
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
234
            The same callback can be configured to emit tool results when a tool is called.
235
        :param kwargs: Additional data to pass to the State schema used by the Agent.
236
            The keys must match the schema defined in the Agent's `state_schema`.
237
        :returns:
238
            A dictionary with the following keys:
239
            - "messages": List of all messages exchanged during the agent's run.
240
            - "last_message": The last message exchanged during the agent's run.
241
            - Any additional keys defined in the `state_schema`.
242
        """
243
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
244
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
1✔
245

246
        if self.system_prompt is not None:
1✔
247
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
1✔
248

249
        state = State(schema=self.state_schema, data=kwargs)
1✔
250
        state.set("messages", messages)
1✔
251
        component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
1✔
252

253
        streaming_callback = select_streaming_callback(
1✔
254
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
255
        )
256
        generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
1✔
257
        with self._create_agent_span() as span:
1✔
258
            span.set_content_tag(
1✔
259
                "haystack.agent.input",
260
                _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}),
261
            )
262
            counter = 0
1✔
263
            while counter < self.max_agent_steps:
1✔
264
                # 1. Call the ChatGenerator
265
                result = Pipeline._run_component(
1✔
266
                    component_name="chat_generator",
267
                    component={"instance": self.chat_generator},
268
                    inputs={"messages": messages, **generator_inputs},
269
                    component_visits=component_visits,
270
                    parent_span=span,
271
                )
272
                llm_messages = result["replies"]
1✔
273
                state.set("messages", llm_messages)
1✔
274

275
                # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
276
                if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
1✔
277
                    counter += 1
1✔
278
                    break
1✔
279

280
                # 3. Call the ToolInvoker
281
                # We only send the messages from the LLM to the tool invoker
282
                tool_invoker_result = Pipeline._run_component(
1✔
283
                    component_name="tool_invoker",
284
                    component={"instance": self._tool_invoker},
285
                    inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
286
                    component_visits=component_visits,
287
                    parent_span=span,
288
                )
289
                tool_messages = tool_invoker_result["tool_messages"]
1✔
290
                state = tool_invoker_result["state"]
1✔
291
                state.set("messages", tool_messages)
1✔
292

293
                # 4. Check if any LLM message's tool call name matches an exit condition
294
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
1✔
295
                    counter += 1
1✔
296
                    break
1✔
297

298
                # 5. Fetch the combined messages and send them back to the LLM
299
                messages = state.get("messages")
1✔
300
                counter += 1
1✔
301

302
            if counter >= self.max_agent_steps:
1✔
303
                logger.warning(
1✔
304
                    "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
305
                    max_agent_steps=self.max_agent_steps,
306
                )
307
            span.set_content_tag("haystack.agent.output", state.data)
1✔
308
            span.set_tag("haystack.agent.steps_taken", counter)
1✔
309

310
        result = {**state.data}
1✔
311
        all_messages = state.get("messages")
1✔
312
        if all_messages:
1✔
313
            result.update({"last_message": all_messages[-1]})
1✔
314
        return result
1✔
315

316
    async def run_async(
1✔
317
        self,
318
        messages: List[ChatMessage],
319
        streaming_callback: Optional[StreamingCallbackT] = None,
320
        **kwargs: Dict[str, Any],
321
    ) -> Dict[str, Any]:
322
        """
323
        Asynchronously process messages and execute tools until the exit condition is met.
324

325
        This is the asynchronous version of the `run` method. It follows the same logic but uses
326
        asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
327
        if available.
328

329
        :param messages: List of chat messages to process
330
        :param streaming_callback: An asynchronous callback that will be invoked when a response
331
        is streamed from the LLM. The same callback can be configured to emit tool results
332
        when a tool is called.
333
        :param kwargs: Additional data to pass to the State schema used by the Agent.
334
            The keys must match the schema defined in the Agent's `state_schema`.
335
        :returns:
336
            A dictionary with the following keys:
337
            - "messages": List of all messages exchanged during the agent's run.
338
            - "last_message": The last message exchanged during the agent's run.
339
            - Any additional keys defined in the `state_schema`.
340
        """
341
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
342
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
×
343

344
        if self.system_prompt is not None:
1✔
345
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
×
346

347
        state = State(schema=self.state_schema, data=kwargs)
1✔
348
        state.set("messages", messages)
1✔
349
        component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
1✔
350

351
        streaming_callback = select_streaming_callback(
1✔
352
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
353
        )
354
        generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
1✔
355
        with self._create_agent_span() as span:
1✔
356
            span.set_content_tag(
1✔
357
                "haystack.agent.input",
358
                _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}),
359
            )
360
            counter = 0
1✔
361
            while counter < self.max_agent_steps:
1✔
362
                # 1. Call the ChatGenerator
363
                result = await AsyncPipeline._run_component_async(
1✔
364
                    component_name="chat_generator",
365
                    component={"instance": self.chat_generator},
366
                    component_inputs={"messages": messages, **generator_inputs},
367
                    component_visits=component_visits,
368
                    max_runs_per_component=self.max_agent_steps,
369
                    parent_span=span,
370
                )
371
                llm_messages = result["replies"]
1✔
372
                state.set("messages", llm_messages)
1✔
373

374
                # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
375
                if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
1✔
376
                    counter += 1
1✔
377
                    break
1✔
378

379
                # 3. Call the ToolInvoker
380
                # We only send the messages from the LLM to the tool invoker
381
                # Check if the ToolInvoker supports async execution. Currently, it doesn't.
382
                tool_invoker_result = await AsyncPipeline._run_component_async(
×
383
                    component_name="tool_invoker",
384
                    component={"instance": self._tool_invoker},
385
                    component_inputs={
386
                        "messages": llm_messages,
387
                        "state": state,
388
                        "streaming_callback": streaming_callback,
389
                    },
390
                    component_visits=component_visits,
391
                    max_runs_per_component=self.max_agent_steps,
392
                    parent_span=span,
393
                )
394
                tool_messages = tool_invoker_result["tool_messages"]
×
395
                state = tool_invoker_result["state"]
×
396
                state.set("messages", tool_messages)
×
397

398
                # 4. Check if any LLM message's tool call name matches an exit condition
399
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
×
400
                    counter += 1
×
401
                    break
×
402

403
                # 5. Fetch the combined messages and send them back to the LLM
404
                messages = state.get("messages")
×
405
                counter += 1
×
406

407
            if counter >= self.max_agent_steps:
1✔
408
                logger.warning(
×
409
                    "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
410
                    max_agent_steps=self.max_agent_steps,
411
                )
412
            span.set_content_tag("haystack.agent.output", state.data)
1✔
413
            span.set_tag("haystack.agent.steps_taken", counter)
1✔
414

415
        result = {**state.data}
1✔
416
        all_messages = state.get("messages")
1✔
417
        if all_messages:
1✔
418
            result.update({"last_message": all_messages[-1]})
1✔
419
        return result
1✔
420

421
    def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool:
1✔
422
        """
423
        Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
424

425
        :param llm_messages: List of messages from the LLM
426
        :param tool_messages: List of messages from the tool invoker
427
        :return: True if an exit condition is met and there are no errors, False otherwise
428
        """
429
        matched_exit_conditions = set()
1✔
430
        has_errors = False
1✔
431

432
        for msg in llm_messages:
1✔
433
            if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
1✔
434
                matched_exit_conditions.add(msg.tool_call.tool_name)
1✔
435

436
                # Check if any error is specifically from the tool matching the exit condition
437
                tool_errors = [
1✔
438
                    tool_msg.tool_call_result.error
439
                    for tool_msg in tool_messages
440
                    if tool_msg.tool_call_result is not None
441
                    and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
442
                ]
443
                if any(tool_errors):
1✔
444
                    has_errors = True
×
445
                    # No need to check further if we found an error
446
                    break
×
447

448
        # Only return True if at least one exit condition was matched AND none had errors
449
        return bool(matched_exit_conditions) and not has_errors
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc