• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14468821137

15 Apr 2025 11:58AM UTC coverage: 90.343% (+0.07%) from 90.278%
14468821137

Pull #9240

github

web-flow
Merge 7e893bbc8 into 656fe6dc6
Pull Request #9240: feat: Agent tracing

10656 of 11795 relevant lines covered (90.34%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.73
haystack/components/agents/agent.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
from copy import deepcopy
1✔
7
from typing import Any, Dict, Iterator, List, Optional
1✔
8

9
from haystack import component, default_from_dict, default_to_dict, logging, tracing
1✔
10
from haystack.components.generators.chat.types import ChatGenerator
1✔
11
from haystack.components.tools import ToolInvoker
1✔
12
from haystack.core.pipeline.async_pipeline import AsyncPipeline
1✔
13
from haystack.core.pipeline.pipeline import Pipeline
1✔
14
from haystack.core.serialization import component_to_dict
1✔
15
from haystack.dataclasses import ChatMessage
1✔
16
from haystack.dataclasses.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1✔
17
from haystack.dataclasses.state_utils import merge_lists
1✔
18
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1✔
19
from haystack.tools import Tool, deserialize_tools_or_toolset_inplace
1✔
20
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
21
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25

26
@component
1✔
27
class Agent:
1✔
28
    """
29
    A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
30

31
    The component processes messages and executes tools until a exit_condition condition is met.
32
    The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool.
33

34
    ### Usage example
35
    ```python
36
    from haystack.components.agents import Agent
37
    from haystack.components.generators.chat import OpenAIChatGenerator
38
    from haystack.dataclasses import ChatMessage
39
    from haystack.tools.tool import Tool
40

41
    tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")]
42

43
    agent = Agent(
44
        chat_generator=OpenAIChatGenerator(),
45
        tools=tools,
46
        exit_condition="search",
47
    )
48

49
    # Run the agent
50
    result = agent.run(
51
        messages=[ChatMessage.from_user("Find information about Haystack")]
52
    )
53

54
    assert "messages" in result  # Contains conversation history
55
    ```
56
    """
57

58
    def __init__(
1✔
59
        self,
60
        *,
61
        chat_generator: ChatGenerator,
62
        tools: Optional[List[Tool]] = None,
63
        system_prompt: Optional[str] = None,
64
        exit_conditions: Optional[List[str]] = None,
65
        state_schema: Optional[Dict[str, Any]] = None,
66
        max_agent_steps: int = 100,
67
        raise_on_tool_invocation_failure: bool = False,
68
        streaming_callback: Optional[StreamingCallbackT] = None,
69
    ):
70
        """
71
        Initialize the agent component.
72

73
        :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
74
        :param tools: List of Tool objects available to the agent
75
        :param system_prompt: System prompt for the agent.
76
        :param exit_conditions: List of conditions that will cause the agent to return.
77
            Can include "text" if the agent should return when it generates a message without tool calls,
78
            or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
79
        :param state_schema: The schema for the runtime state used by the tools.
80
        :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
81
            If the agent exceeds this number of steps, it will stop and return the current state.
82
        :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
83
            If set to False, the exception will be turned into a chat message and passed to the LLM.
84
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
85
        :raises TypeError: If the chat_generator does not support tools parameter in its run method.
86
        """
87
        # Check if chat_generator supports tools parameter
88
        chat_generator_run_method = inspect.signature(chat_generator.run)
1✔
89
        if "tools" not in chat_generator_run_method.parameters:
1✔
90
            raise TypeError(
1✔
91
                f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
92
                "The Agent component requires a chat generator that supports tools."
93
            )
94

95
        valid_exits = ["text"] + [tool.name for tool in tools or []]
1✔
96
        if exit_conditions is None:
1✔
97
            exit_conditions = ["text"]
1✔
98
        if not all(condition in valid_exits for condition in exit_conditions):
1✔
99
            raise ValueError(
1✔
100
                f"Invalid exit conditions provided: {exit_conditions}. "
101
                f"Valid exit conditions must be a subset of {valid_exits}. "
102
                "Ensure that each exit condition corresponds to either 'text' or a valid tool name."
103
            )
104

105
        # Validate state schema if provided
106
        if state_schema is not None:
1✔
107
            _validate_schema(state_schema)
1✔
108
        self._state_schema = state_schema or {}
1✔
109

110
        # Initialize state schema
111
        resolved_state_schema = deepcopy(self._state_schema)
1✔
112
        if resolved_state_schema.get("messages") is None:
1✔
113
            resolved_state_schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
1✔
114
        self.state_schema = resolved_state_schema
1✔
115

116
        self.chat_generator = chat_generator
1✔
117
        self.tools = tools or []
1✔
118
        self.system_prompt = system_prompt
1✔
119
        self.exit_conditions = exit_conditions
1✔
120
        self.max_agent_steps = max_agent_steps
1✔
121
        self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
1✔
122
        self.streaming_callback = streaming_callback
1✔
123

124
        output_types = {}
1✔
125
        for param, config in self.state_schema.items():
1✔
126
            output_types[param] = config["type"]
1✔
127
            # Skip setting input types for parameters that are already in the run method
128
            if param in ["messages", "streaming_callback"]:
1✔
129
                continue
1✔
130
            component.set_input_type(self, name=param, type=config["type"], default=None)
1✔
131
        component.set_output_types(self, **output_types)
1✔
132

133
        self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
1✔
134
        self._is_warmed_up = False
1✔
135

136
    def warm_up(self) -> None:
1✔
137
        """
138
        Warm up the Agent.
139
        """
140
        if not self._is_warmed_up:
1✔
141
            if hasattr(self.chat_generator, "warm_up"):
1✔
142
                self.chat_generator.warm_up()
×
143
            self._is_warmed_up = True
1✔
144

145
    def to_dict(self) -> Dict[str, Any]:
1✔
146
        """
147
        Serialize the component to a dictionary.
148

149
        :return: Dictionary with serialized data
150
        """
151
        if self.streaming_callback is not None:
1✔
152
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
153
        else:
154
            streaming_callback = None
1✔
155

156
        return default_to_dict(
1✔
157
            self,
158
            chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"),
159
            tools=[t.to_dict() for t in self.tools],
160
            system_prompt=self.system_prompt,
161
            exit_conditions=self.exit_conditions,
162
            # We serialize the original state schema, not the resolved one to reflect the original user input
163
            state_schema=_schema_to_dict(self._state_schema),
164
            max_agent_steps=self.max_agent_steps,
165
            raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
166
            streaming_callback=streaming_callback,
167
        )
168

169
    @classmethod
1✔
170
    def from_dict(cls, data: Dict[str, Any]) -> "Agent":
1✔
171
        """
172
        Deserialize the agent from a dictionary.
173

174
        :param data: Dictionary to deserialize from
175
        :return: Deserialized agent
176
        """
177
        init_params = data.get("init_parameters", {})
1✔
178

179
        deserialize_chatgenerator_inplace(init_params, key="chat_generator")
1✔
180

181
        if "state_schema" in init_params:
1✔
182
            init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
1✔
183

184
        if init_params.get("streaming_callback") is not None:
1✔
185
            init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
1✔
186

187
        deserialize_tools_or_toolset_inplace(init_params, key="tools")
1✔
188

189
        return default_from_dict(cls, data)
1✔
190

191
    def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
1✔
192
        """Prepare inputs for the chat generator."""
193
        generator_inputs = {"tools": self.tools}
1✔
194
        selected_callback = streaming_callback or self.streaming_callback
1✔
195
        if selected_callback is not None:
1✔
196
            generator_inputs["streaming_callback"] = selected_callback
1✔
197
        return generator_inputs
1✔
198

199
    def _create_agent_span(self, input_data: Dict[str, Any]) -> Iterator[tracing.Span]:
1✔
200
        """Create a span for the agent run."""
201
        return tracing.tracer.trace(
1✔
202
            "haystack.agent.run",
203
            tags={
204
                "haystack.agent.input_data": input_data,
205
                "haystack.agent.max_steps": self.max_agent_steps,
206
                "haystack.agent.tools": self.tools,
207
                "haystack.agent.exit_conditions": self.exit_conditions,
208
                "haystack.agent.state_schema": self.state_schema,
209
            },
210
        )
211

212
    def run(
1✔
213
        self,
214
        messages: List[ChatMessage],
215
        streaming_callback: Optional[StreamingCallbackT] = None,
216
        **kwargs: Dict[str, Any],
217
    ) -> Dict[str, Any]:
218
        """
219
        Process messages and execute tools until the exit condition is met.
220

221
        :param messages: List of chat messages to process
222
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
223
        :param kwargs: Additional data to pass to the State schema used by the Agent.
224
            The keys must match the schema defined in the Agent's `state_schema`.
225
        :return: Dictionary containing messages and outputs matching the defined output types
226
        """
227
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
228
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
1✔
229

230
        if self.system_prompt is not None:
1✔
231
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
1✔
232

233
        input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
1✔
234

235
        state = State(schema=self.state_schema, data=kwargs)
1✔
236
        state.set("messages", messages)
1✔
237

238
        generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
1✔
239

240
        component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
1✔
241
        with self._create_agent_span(input_data) as span:
1✔
242
            counter = 0
1✔
243
            while counter < self.max_agent_steps:
1✔
244
                # 1. Call the ChatGenerator
245
                llm_messages = Pipeline._run_component(
1✔
246
                    component_name="chat_generator",
247
                    component={"instance": self.chat_generator},
248
                    inputs={"messages": messages, **generator_inputs},
249
                    component_visits=component_visits,
250
                    parent_span=span,
251
                )["replies"]
252
                state.set("messages", llm_messages)
1✔
253

254
                # 2. Check if any of the LLM responses contain a tool call
255
                if not any(msg.tool_call for msg in llm_messages):
1✔
256
                    break
1✔
257

258
                # 3. Call the ToolInvoker
259
                # We only send the messages from the LLM to the tool invoker
260
                tool_invoker_result = Pipeline._run_component(
1✔
261
                    component_name="tool_invoker",
262
                    component={"instance": self._tool_invoker},
263
                    inputs={"messages": llm_messages, "state": state},
264
                    component_visits=component_visits,
265
                    parent_span=span,
266
                )
267
                tool_messages = tool_invoker_result["tool_messages"]
1✔
268
                state = tool_invoker_result["state"]
1✔
269
                state.set("messages", tool_messages)
1✔
270

271
                # 4. Check if any LLM message's tool call name matches an exit condition
272
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
1✔
273
                    break
1✔
274

275
                # 5. Fetch the combined messages and send them back to the LLM
276
                messages = state.get("messages")
1✔
277
                counter += 1
1✔
278

279
        if counter >= self.max_agent_steps:
1✔
280
            logger.warning(
1✔
281
                "Agent exceeded maximum agent steps of {max_agent_steps}, stopping.",
282
                max_agent_steps=self.max_agent_steps,
283
            )
284
        span.set_content_tag("haystack.agent.output_data", state.data)
1✔
285
        span.set_tag("haystack.agent.steps_taken", counter)
1✔
286
        return state.data
1✔
287

288
    async def run_async(
1✔
289
        self,
290
        messages: List[ChatMessage],
291
        streaming_callback: Optional[StreamingCallbackT] = None,
292
        **kwargs: Dict[str, Any],
293
    ) -> Dict[str, Any]:
294
        """
295
        Asynchronously process messages and execute tools until the exit condition is met.
296

297
        This is the asynchronous version of the `run` method. It follows the same logic but uses
298
        asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
299
        if available.
300

301
        :param messages: List of chat messages to process
302
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
303
        :param kwargs: Additional data to pass to the State schema used by the Agent.
304
            The keys must match the schema defined in the Agent's `state_schema`.
305
        :return: Dictionary containing messages and outputs matching the defined output types
306
        """
307
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
308
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
×
309

310
        if self.system_prompt is not None:
1✔
311
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
×
312

313
        input_data = deepcopy({"messages": messages, "streaming_callback": streaming_callback, **kwargs})
1✔
314

315
        state = State(schema=self.state_schema, data=kwargs)
1✔
316
        state.set("messages", messages)
1✔
317

318
        generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
1✔
319

320
        component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
1✔
321
        with self._create_agent_span(input_data) as span:
1✔
322
            counter = 0
1✔
323
            while counter < self.max_agent_steps:
1✔
324
                # 1. Call the ChatGenerator
325
                result = await AsyncPipeline._run_component_async(
1✔
326
                    component_name="chat_generator",
327
                    component={"instance": self.chat_generator},
328
                    component_inputs={"messages": messages, **generator_inputs},
329
                    component_visits=component_visits,
330
                    max_runs_per_component=self.max_agent_steps,
331
                    parent_span=span,
332
                )
333
                llm_messages = result["replies"]
1✔
334
                state.set("messages", llm_messages)
1✔
335

336
                # 2. Check if any of the LLM responses contain a tool call
337
                if not any(msg.tool_call for msg in llm_messages):
1✔
338
                    break
1✔
339

340
                # 3. Call the ToolInvoker
341
                # We only send the messages from the LLM to the tool invoker
342
                # Check if the ToolInvoker supports async execution. Currently, it doesn't.
343
                tool_invoker_result = await AsyncPipeline._run_component_async(
×
344
                    component_name="tool_invoker",
345
                    component={"instance": self._tool_invoker},
346
                    component_inputs={"messages": llm_messages, "state": state},
347
                    component_visits=component_visits,
348
                    max_runs_per_component=self.max_agent_steps,
349
                    parent_span=span,
350
                )
351
                tool_messages = tool_invoker_result["tool_messages"]
×
352
                state = tool_invoker_result["state"]
×
353
                state.set("messages", tool_messages)
×
354

355
                # 4. Check if any LLM message's tool call name matches an exit condition
356
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
×
357
                    break
×
358

359
                # 5. Fetch the combined messages and send them back to the LLM
360
                messages = state.get("messages")
×
361
                counter += 1
×
362

363
        if counter >= self.max_agent_steps:
1✔
364
            logger.warning(
×
365
                "Agent exceeded maximum agent steps of {max_agent_steps}, stopping.",
366
                max_agent_steps=self.max_agent_steps,
367
            )
368
        span.set_content_tag("haystack.agent.output_data", state.data)
1✔
369
        span.set_tag("haystack.agent.steps_taken", counter)
1✔
370
        return state.data
1✔
371

372
    def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool:
1✔
373
        """
374
        Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
375

376
        :param llm_messages: List of messages from the LLM
377
        :param tool_messages: List of messages from the tool invoker
378
        :return: True if an exit condition is met and there are no errors, False otherwise
379
        """
380
        matched_exit_conditions = set()
1✔
381
        has_errors = False
1✔
382

383
        for msg in llm_messages:
1✔
384
            if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
1✔
385
                matched_exit_conditions.add(msg.tool_call.tool_name)
1✔
386

387
                # Check if any error is specifically from the tool matching the exit condition
388
                tool_errors = [
1✔
389
                    tool_msg.tool_call_result.error
390
                    for tool_msg in tool_messages
391
                    if tool_msg.tool_call_result is not None
392
                    and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
393
                ]
394
                if any(tool_errors):
1✔
395
                    has_errors = True
×
396
                    # No need to check further if we found an error
397
                    break
×
398

399
        # Only return True if at least one exit condition was matched AND none had errors
400
        return bool(matched_exit_conditions) and not has_errors
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc