• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14384570060

10 Apr 2025 03:37PM UTC coverage: 90.388% (+0.02%) from 90.373%
14384570060

Pull #9213

github

web-flow
Merge e953558d2 into 45aa9608b
Pull Request #9213: feat: Minimal chat generator protocol

10683 of 11819 relevant lines covered (90.39%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.28
haystack/components/agents/agent.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
from copy import deepcopy
1✔
7
from typing import Any, Dict, List, Optional
1✔
8

9
from haystack import component, default_from_dict, default_to_dict, logging
1✔
10
from haystack.components.generators.chat.types import ChatGenerator
1✔
11
from haystack.components.tools import ToolInvoker
1✔
12
from haystack.core.serialization import component_to_dict
1✔
13
from haystack.dataclasses import ChatMessage
1✔
14
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1✔
15
from haystack.dataclasses.state_utils import merge_lists
1✔
16
from haystack.dataclasses.streaming_chunk import SyncStreamingCallbackT
1✔
17
from haystack.tools import Tool, deserialize_tools_or_toolset_inplace
1✔
18
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
19
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
1✔
20

21
logger = logging.getLogger(__name__)
1✔
22

23

24
@component
1✔
25
class Agent:
1✔
26
    """
27
    A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
28

29
    The component processes messages and executes tools until a exit_condition condition is met.
30
    The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool.
31

32
    ### Usage example
33
    ```python
34
    from haystack.components.agents import Agent
35
    from haystack.components.generators.chat import OpenAIChatGenerator
36
    from haystack.dataclasses import ChatMessage
37
    from haystack.tools.tool import Tool
38

39
    tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")]
40

41
    agent = Agent(
42
        chat_generator=OpenAIChatGenerator(),
43
        tools=tools,
44
        exit_condition="search",
45
    )
46

47
    # Run the agent
48
    result = agent.run(
49
        messages=[ChatMessage.from_user("Find information about Haystack")]
50
    )
51

52
    assert "messages" in result  # Contains conversation history
53
    ```
54
    """
55

56
    def __init__(
1✔
57
        self,
58
        *,
59
        chat_generator: ChatGenerator,
60
        tools: Optional[List[Tool]] = None,
61
        system_prompt: Optional[str] = None,
62
        exit_conditions: Optional[List[str]] = None,
63
        state_schema: Optional[Dict[str, Any]] = None,
64
        max_agent_steps: int = 100,
65
        raise_on_tool_invocation_failure: bool = False,
66
        streaming_callback: Optional[SyncStreamingCallbackT] = None,
67
    ):
68
        """
69
        Initialize the agent component.
70

71
        :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
72
        :param tools: List of Tool objects available to the agent
73
        :param system_prompt: System prompt for the agent.
74
        :param exit_conditions: List of conditions that will cause the agent to return.
75
            Can include "text" if the agent should return when it generates a message without tool calls,
76
            or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
77
        :param state_schema: The schema for the runtime state used by the tools.
78
        :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
79
            If the agent exceeds this number of steps, it will stop and return the current state.
80
        :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
81
            If set to False, the exception will be turned into a chat message and passed to the LLM.
82
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
83
        :raises TypeError: If the chat_generator does not support tools parameter in its run method.
84
        """
85
        # Check if chat_generator supports tools parameter
86
        chat_generator_run_method = inspect.signature(chat_generator.run)
1✔
87
        if "tools" not in chat_generator_run_method.parameters:
1✔
88
            raise TypeError(
1✔
89
                f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
90
                "The Agent component requires a chat generator that supports tools."
91
            )
92

93
        valid_exits = ["text"] + [tool.name for tool in tools or []]
1✔
94
        if exit_conditions is None:
1✔
95
            exit_conditions = ["text"]
1✔
96
        if not all(condition in valid_exits for condition in exit_conditions):
1✔
97
            raise ValueError(
1✔
98
                f"Invalid exit conditions provided: {exit_conditions}. "
99
                f"Valid exit conditions must be a subset of {valid_exits}. "
100
                "Ensure that each exit condition corresponds to either 'text' or a valid tool name."
101
            )
102

103
        # Validate state schema if provided
104
        if state_schema is not None:
1✔
105
            _validate_schema(state_schema)
1✔
106
        self._state_schema = state_schema or {}
1✔
107

108
        # Initialize state schema
109
        resolved_state_schema = deepcopy(self._state_schema)
1✔
110
        if resolved_state_schema.get("messages") is None:
1✔
111
            resolved_state_schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
1✔
112
        self.state_schema = resolved_state_schema
1✔
113

114
        self.chat_generator = chat_generator
1✔
115
        self.tools = tools or []
1✔
116
        self.system_prompt = system_prompt
1✔
117
        self.exit_conditions = exit_conditions
1✔
118
        self.max_agent_steps = max_agent_steps
1✔
119
        self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
1✔
120
        self.streaming_callback = streaming_callback
1✔
121

122
        output_types = {}
1✔
123
        for param, config in self.state_schema.items():
1✔
124
            output_types[param] = config["type"]
1✔
125
            # Skip setting input types for parameters that are already in the run method
126
            if param in ["messages", "streaming_callback"]:
1✔
127
                continue
1✔
128
            component.set_input_type(self, name=param, type=config["type"], default=None)
1✔
129
        component.set_output_types(self, **output_types)
1✔
130

131
        self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
1✔
132
        self._is_warmed_up = False
1✔
133

134
    def warm_up(self) -> None:
1✔
135
        """
136
        Warm up the Agent.
137
        """
138
        if not self._is_warmed_up:
1✔
139
            if hasattr(self.chat_generator, "warm_up"):
1✔
140
                self.chat_generator.warm_up()
×
141
            self._is_warmed_up = True
1✔
142

143
    def to_dict(self) -> Dict[str, Any]:
1✔
144
        """
145
        Serialize the component to a dictionary.
146

147
        :return: Dictionary with serialized data
148
        """
149
        if self.streaming_callback is not None:
1✔
150
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
151
        else:
152
            streaming_callback = None
1✔
153

154
        return default_to_dict(
1✔
155
            self,
156
            chat_generator=component_to_dict(self.chat_generator, "chat_generator"),
157
            tools=[t.to_dict() for t in self.tools],
158
            system_prompt=self.system_prompt,
159
            exit_conditions=self.exit_conditions,
160
            # We serialize the original state schema, not the resolved one to reflect the original user input
161
            state_schema=_schema_to_dict(self._state_schema),
162
            max_agent_steps=self.max_agent_steps,
163
            raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
164
            streaming_callback=streaming_callback,
165
        )
166

167
    @classmethod
1✔
168
    def from_dict(cls, data: Dict[str, Any]) -> "Agent":
1✔
169
        """
170
        Deserialize the agent from a dictionary.
171

172
        :param data: Dictionary to deserialize from
173
        :return: Deserialized agent
174
        """
175
        init_params = data.get("init_parameters", {})
1✔
176

177
        deserialize_chatgenerator_inplace(init_params, key="chat_generator")
1✔
178

179
        if "state_schema" in init_params:
1✔
180
            init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
1✔
181

182
        if init_params.get("streaming_callback") is not None:
1✔
183
            init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
1✔
184

185
        deserialize_tools_or_toolset_inplace(init_params, key="tools")
1✔
186

187
        return default_from_dict(cls, data)
1✔
188

189
    def run(
1✔
190
        self,
191
        messages: List[ChatMessage],
192
        streaming_callback: Optional[SyncStreamingCallbackT] = None,
193
        **kwargs: Dict[str, Any],
194
    ) -> Dict[str, Any]:
195
        """
196
        Process messages and execute tools until the exit condition is met.
197

198
        :param messages: List of chat messages to process
199
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
200
        :param kwargs: Additional data to pass to the State schema used by the Agent.
201
            The keys must match the schema defined in the Agent's `state_schema`.
202
        :return: Dictionary containing messages and outputs matching the defined output types
203
        """
204
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
205
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
×
206

207
        state = State(schema=self.state_schema, data=kwargs)
1✔
208

209
        if self.system_prompt is not None:
1✔
210
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
×
211
        state.set("messages", messages)
1✔
212

213
        generator_inputs: Dict[str, Any] = {"tools": self.tools}
1✔
214

215
        selected_callback = streaming_callback or self.streaming_callback
1✔
216
        if selected_callback is not None:
1✔
217
            generator_inputs["streaming_callback"] = selected_callback
1✔
218

219
        # Repeat until the exit condition is met
220
        counter = 0
1✔
221
        while counter < self.max_agent_steps:
1✔
222
            # 1. Call the ChatGenerator
223
            llm_messages = self.chat_generator.run(messages=messages, **generator_inputs)["replies"]
1✔
224
            state.set("messages", llm_messages)
1✔
225

226
            # 2. Check if any of the LLM responses contain a tool call
227
            if not any(msg.tool_call for msg in llm_messages):
1✔
228
                return {**state.data}
1✔
229

230
            # 3. Call the ToolInvoker
231
            # We only send the messages from the LLM to the tool invoker
232
            tool_invoker_result = self._tool_invoker.run(messages=llm_messages, state=state)
1✔
233
            tool_messages = tool_invoker_result["tool_messages"]
1✔
234
            state = tool_invoker_result["state"]
1✔
235
            state.set("messages", tool_messages)
1✔
236

237
            # 4. Check if any LLM message's tool call name matches an exit condition
238
            if self.exit_conditions != ["text"]:
1✔
239
                matched_exit_conditions = set()
1✔
240
                has_errors = False
1✔
241

242
                for msg in llm_messages:
1✔
243
                    if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
1✔
244
                        matched_exit_conditions.add(msg.tool_call.tool_name)
1✔
245

246
                        # Check if any error is specifically from the tool matching the exit condition
247
                        tool_errors = [
1✔
248
                            tool_msg.tool_call_result.error
249
                            for tool_msg in tool_messages
250
                            if tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
251
                        ]
252
                        if any(tool_errors):
1✔
253
                            has_errors = True
×
254
                            # No need to check further if we found an error
255
                            break
×
256

257
                # Only return if at least one exit condition was matched AND none had errors
258
                if matched_exit_conditions and not has_errors:
1✔
259
                    return {**state.data}
1✔
260

261
            # 5. Fetch the combined messages and send them back to the LLM
262
            messages = state.get("messages")
1✔
263
            counter += 1
1✔
264

265
        logger.warning(
1✔
266
            "Agent exceeded maximum agent steps of {max_agent_steps}, stopping.", max_agent_steps=self.max_agent_steps
267
        )
268
        return {**state.data}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc