• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15049844454

15 May 2025 04:07PM UTC coverage: 90.446% (+0.04%) from 90.41%
15049844454

Pull #9345

github

web-flow
Merge 9e4071f83 into 2a64cd4e9
Pull Request #9345: feat: add serialization to `State` / move `State` to utils

10981 of 12141 relevant lines covered (90.45%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.57
haystack/components/agents/agent.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from haystack import component, default_from_dict, default_to_dict, logging, tracing
1✔
9
from haystack.components.generators.chat.types import ChatGenerator
1✔
10
from haystack.components.tools import ToolInvoker
1✔
11
from haystack.core.pipeline.async_pipeline import AsyncPipeline
1✔
12
from haystack.core.pipeline.pipeline import Pipeline
1✔
13
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
1✔
14
from haystack.core.serialization import component_to_dict
1✔
15
from haystack.dataclasses import ChatMessage
1✔
16
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
1✔
17
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
1✔
18
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
19
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
1✔
20
from haystack.utils.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1✔
21
from haystack.utils.state_utils import merge_lists
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25

26
@component
1✔
27
class Agent:
1✔
28
    """
29
    A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
30

31
    The component processes messages and executes tools until a exit_condition condition is met.
32
    The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool.
33

34
    When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits.
35

36
    ### Usage example
37
    ```python
38
    from haystack.components.agents import Agent
39
    from haystack.components.generators.chat import OpenAIChatGenerator
40
    from haystack.dataclasses import ChatMessage
41
    from haystack.tools.tool import Tool
42

43
    tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")]
44

45
    agent = Agent(
46
        chat_generator=OpenAIChatGenerator(),
47
        tools=tools,
48
        exit_condition="search",
49
    )
50

51
    # Run the agent
52
    result = agent.run(
53
        messages=[ChatMessage.from_user("Find information about Haystack")]
54
    )
55

56
    assert "messages" in result  # Contains conversation history
57
    ```
58
    """
59

60
    def __init__(
1✔
61
        self,
62
        *,
63
        chat_generator: ChatGenerator,
64
        tools: Optional[Union[List[Tool], Toolset]] = None,
65
        system_prompt: Optional[str] = None,
66
        exit_conditions: Optional[List[str]] = None,
67
        state_schema: Optional[Dict[str, Any]] = None,
68
        max_agent_steps: int = 100,
69
        raise_on_tool_invocation_failure: bool = False,
70
        streaming_callback: Optional[StreamingCallbackT] = None,
71
    ):
72
        """
73
        Initialize the agent component.
74

75
        :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
76
        :param tools: List of Tool objects or a Toolset that the agent can use.
77
        :param system_prompt: System prompt for the agent.
78
        :param exit_conditions: List of conditions that will cause the agent to return.
79
            Can include "text" if the agent should return when it generates a message without tool calls,
80
            or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
81
        :param state_schema: The schema for the runtime state used by the tools.
82
        :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
83
            If the agent exceeds this number of steps, it will stop and return the current state.
84
        :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
85
            If set to False, the exception will be turned into a chat message and passed to the LLM.
86
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
87
            The same callback can be configured to emit tool results when a tool is called.
88
        :raises TypeError: If the chat_generator does not support tools parameter in its run method.
89
        """
90
        # Check if chat_generator supports tools parameter
91
        chat_generator_run_method = inspect.signature(chat_generator.run)
1✔
92
        if "tools" not in chat_generator_run_method.parameters:
1✔
93
            raise TypeError(
1✔
94
                f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
95
                "The Agent component requires a chat generator that supports tools."
96
            )
97

98
        valid_exits = ["text"] + [tool.name for tool in tools or []]
1✔
99
        if exit_conditions is None:
1✔
100
            exit_conditions = ["text"]
1✔
101
        if not all(condition in valid_exits for condition in exit_conditions):
1✔
102
            raise ValueError(
1✔
103
                f"Invalid exit conditions provided: {exit_conditions}. "
104
                f"Valid exit conditions must be a subset of {valid_exits}. "
105
                "Ensure that each exit condition corresponds to either 'text' or a valid tool name."
106
            )
107

108
        # Validate state schema if provided
109
        if state_schema is not None:
1✔
110
            _validate_schema(state_schema)
1✔
111
        self._state_schema = state_schema or {}
1✔
112

113
        # Initialize state schema
114
        resolved_state_schema = _deepcopy_with_exceptions(self._state_schema)
1✔
115
        if resolved_state_schema.get("messages") is None:
1✔
116
            resolved_state_schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
1✔
117
        self.state_schema = resolved_state_schema
1✔
118

119
        self.chat_generator = chat_generator
1✔
120
        self.tools = tools or []
1✔
121
        self.system_prompt = system_prompt
1✔
122
        self.exit_conditions = exit_conditions
1✔
123
        self.max_agent_steps = max_agent_steps
1✔
124
        self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
1✔
125
        self.streaming_callback = streaming_callback
1✔
126

127
        output_types = {}
1✔
128
        for param, config in self.state_schema.items():
1✔
129
            output_types[param] = config["type"]
1✔
130
            # Skip setting input types for parameters that are already in the run method
131
            if param in ["messages", "streaming_callback"]:
1✔
132
                continue
1✔
133
            component.set_input_type(self, name=param, type=config["type"], default=None)
1✔
134
        component.set_output_types(self, **output_types)
1✔
135

136
        self._tool_invoker = None
1✔
137
        if self.tools:
1✔
138
            self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure)
1✔
139
        else:
140
            logger.warning(
1✔
141
                "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
142
                "responses. To enable tool usage, pass tools directly to the Agent, not to the chat_generator."
143
            )
144

145
        self._is_warmed_up = False
1✔
146

147
    def warm_up(self) -> None:
1✔
148
        """
149
        Warm up the Agent.
150
        """
151
        if not self._is_warmed_up:
1✔
152
            if hasattr(self.chat_generator, "warm_up"):
1✔
153
                self.chat_generator.warm_up()
×
154
            self._is_warmed_up = True
1✔
155

156
    def to_dict(self) -> Dict[str, Any]:
1✔
157
        """
158
        Serialize the component to a dictionary.
159

160
        :return: Dictionary with serialized data
161
        """
162
        if self.streaming_callback is not None:
1✔
163
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
164
        else:
165
            streaming_callback = None
1✔
166

167
        return default_to_dict(
1✔
168
            self,
169
            chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"),
170
            tools=serialize_tools_or_toolset(self.tools),
171
            system_prompt=self.system_prompt,
172
            exit_conditions=self.exit_conditions,
173
            # We serialize the original state schema, not the resolved one to reflect the original user input
174
            state_schema=_schema_to_dict(self._state_schema),
175
            max_agent_steps=self.max_agent_steps,
176
            raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
177
            streaming_callback=streaming_callback,
178
        )
179

180
    @classmethod
1✔
181
    def from_dict(cls, data: Dict[str, Any]) -> "Agent":
1✔
182
        """
183
        Deserialize the agent from a dictionary.
184

185
        :param data: Dictionary to deserialize from
186
        :return: Deserialized agent
187
        """
188
        init_params = data.get("init_parameters", {})
1✔
189

190
        deserialize_chatgenerator_inplace(init_params, key="chat_generator")
1✔
191

192
        if "state_schema" in init_params:
1✔
193
            init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
1✔
194

195
        if init_params.get("streaming_callback") is not None:
1✔
196
            init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
1✔
197

198
        deserialize_tools_or_toolset_inplace(init_params, key="tools")
1✔
199

200
        return default_from_dict(cls, data)
1✔
201

202
    def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]:
1✔
203
        """Prepare inputs for the chat generator."""
204
        generator_inputs: Dict[str, Any] = {"tools": self.tools}
1✔
205
        if streaming_callback is not None:
1✔
206
            generator_inputs["streaming_callback"] = streaming_callback
1✔
207
        return generator_inputs
1✔
208

209
    def _create_agent_span(self) -> Any:
1✔
210
        """Create a span for the agent run."""
211
        return tracing.tracer.trace(
1✔
212
            "haystack.agent.run",
213
            tags={
214
                "haystack.agent.max_steps": self.max_agent_steps,
215
                "haystack.agent.tools": self.tools,
216
                "haystack.agent.exit_conditions": self.exit_conditions,
217
                "haystack.agent.state_schema": _schema_to_dict(self.state_schema),
218
            },
219
        )
220

221
    def run(
1✔
222
        self,
223
        messages: List[ChatMessage],
224
        streaming_callback: Optional[StreamingCallbackT] = None,
225
        **kwargs: Dict[str, Any],
226
    ) -> Dict[str, Any]:
227
        """
228
        Process messages and execute tools until the exit condition is met.
229

230
        :param messages: List of chat messages to process
231
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
232
            The same callback can be configured to emit tool results when a tool is called.
233
        :param kwargs: Additional data to pass to the State schema used by the Agent.
234
            The keys must match the schema defined in the Agent's `state_schema`.
235
        :return: Dictionary containing messages and outputs matching the defined output types
236
        """
237
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
238
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
1✔
239

240
        if self.system_prompt is not None:
1✔
241
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
1✔
242

243
        streaming_callback = select_streaming_callback(
1✔
244
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
245
        )
246

247
        state = State(schema=self.state_schema, data=kwargs)
1✔
248
        state.set("messages", messages)
1✔
249

250
        generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
1✔
251

252
        component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
1✔
253
        with self._create_agent_span() as span:
1✔
254
            span.set_content_tag(
1✔
255
                "haystack.agent.input",
256
                _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}),
257
            )
258
            counter = 0
1✔
259
            while counter < self.max_agent_steps:
1✔
260
                # 1. Call the ChatGenerator
261
                llm_messages = Pipeline._run_component(
1✔
262
                    component_name="chat_generator",
263
                    component={"instance": self.chat_generator},
264
                    inputs={"messages": messages, **generator_inputs},
265
                    component_visits=component_visits,
266
                    parent_span=span,
267
                )["replies"]
268
                state.set("messages", llm_messages)
1✔
269

270
                # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
271
                if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
1✔
272
                    counter += 1
1✔
273
                    break
1✔
274

275
                # 3. Call the ToolInvoker
276
                # We only send the messages from the LLM to the tool invoker
277
                tool_invoker_result = Pipeline._run_component(
1✔
278
                    component_name="tool_invoker",
279
                    component={"instance": self._tool_invoker},
280
                    inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
281
                    component_visits=component_visits,
282
                    parent_span=span,
283
                )
284
                tool_messages = tool_invoker_result["tool_messages"]
1✔
285
                state = tool_invoker_result["state"]
1✔
286
                state.set("messages", tool_messages)
1✔
287

288
                # 4. Check if any LLM message's tool call name matches an exit condition
289
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
1✔
290
                    counter += 1
1✔
291
                    break
1✔
292

293
                # 5. Fetch the combined messages and send them back to the LLM
294
                messages = state.get("messages")
1✔
295
                counter += 1
1✔
296

297
            if counter >= self.max_agent_steps:
1✔
298
                logger.warning(
1✔
299
                    "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
300
                    max_agent_steps=self.max_agent_steps,
301
                )
302
            span.set_content_tag("haystack.agent.output", state.data)
1✔
303
            span.set_tag("haystack.agent.steps_taken", counter)
1✔
304
        return state.data
1✔
305

306
    async def run_async(
1✔
307
        self,
308
        messages: List[ChatMessage],
309
        streaming_callback: Optional[StreamingCallbackT] = None,
310
        **kwargs: Dict[str, Any],
311
    ) -> Dict[str, Any]:
312
        """
313
        Asynchronously process messages and execute tools until the exit condition is met.
314

315
        This is the asynchronous version of the `run` method. It follows the same logic but uses
316
        asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
317
        if available.
318

319
        :param messages: List of chat messages to process
320
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
321
            The same callback can be configured to emit tool results when a tool is called.
322
        :param kwargs: Additional data to pass to the State schema used by the Agent.
323
            The keys must match the schema defined in the Agent's `state_schema`.
324
        :return: Dictionary containing messages and outputs matching the defined output types
325
        """
326
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
327
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
×
328

329
        if self.system_prompt is not None:
1✔
330
            messages = [ChatMessage.from_system(self.system_prompt)] + messages
×
331

332
        streaming_callback = select_streaming_callback(
1✔
333
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
334
        )
335

336
        state = State(schema=self.state_schema, data=kwargs)
1✔
337
        state.set("messages", messages)
1✔
338

339
        generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
1✔
340

341
        component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
1✔
342
        with self._create_agent_span() as span:
1✔
343
            span.set_content_tag(
1✔
344
                "haystack.agent.input",
345
                _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}),
346
            )
347
            counter = 0
1✔
348
            while counter < self.max_agent_steps:
1✔
349
                # 1. Call the ChatGenerator
350
                result = await AsyncPipeline._run_component_async(
1✔
351
                    component_name="chat_generator",
352
                    component={"instance": self.chat_generator},
353
                    component_inputs={"messages": messages, **generator_inputs},
354
                    component_visits=component_visits,
355
                    max_runs_per_component=self.max_agent_steps,
356
                    parent_span=span,
357
                )
358
                llm_messages = result["replies"]
1✔
359
                state.set("messages", llm_messages)
1✔
360

361
                # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
362
                if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
1✔
363
                    counter += 1
1✔
364
                    break
1✔
365

366
                # 3. Call the ToolInvoker
367
                # We only send the messages from the LLM to the tool invoker
368
                # Check if the ToolInvoker supports async execution. Currently, it doesn't.
369
                tool_invoker_result = await AsyncPipeline._run_component_async(
×
370
                    component_name="tool_invoker",
371
                    component={"instance": self._tool_invoker},
372
                    component_inputs={"messages": llm_messages, "state": state},
373
                    component_visits=component_visits,
374
                    max_runs_per_component=self.max_agent_steps,
375
                    parent_span=span,
376
                )
377
                tool_messages = tool_invoker_result["tool_messages"]
×
378
                state = tool_invoker_result["state"]
×
379
                state.set("messages", tool_messages)
×
380

381
                # 4. Check if any LLM message's tool call name matches an exit condition
382
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
×
383
                    counter += 1
×
384
                    break
×
385

386
                # 5. Fetch the combined messages and send them back to the LLM
387
                messages = state.get("messages")
×
388
                counter += 1
×
389

390
            if counter >= self.max_agent_steps:
1✔
391
                logger.warning(
×
392
                    "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
393
                    max_agent_steps=self.max_agent_steps,
394
                )
395
            span.set_content_tag("haystack.agent.output", state.data)
1✔
396
            span.set_tag("haystack.agent.steps_taken", counter)
1✔
397
        return state.data
1✔
398

399
    def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool:
1✔
400
        """
401
        Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
402

403
        :param llm_messages: List of messages from the LLM
404
        :param tool_messages: List of messages from the tool invoker
405
        :return: True if an exit condition is met and there are no errors, False otherwise
406
        """
407
        matched_exit_conditions = set()
1✔
408
        has_errors = False
1✔
409

410
        for msg in llm_messages:
1✔
411
            if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
1✔
412
                matched_exit_conditions.add(msg.tool_call.tool_name)
1✔
413

414
                # Check if any error is specifically from the tool matching the exit condition
415
                tool_errors = [
1✔
416
                    tool_msg.tool_call_result.error
417
                    for tool_msg in tool_messages
418
                    if tool_msg.tool_call_result is not None
419
                    and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
420
                ]
421
                if any(tool_errors):
1✔
422
                    has_errors = True
×
423
                    # No need to check further if we found an error
424
                    break
×
425

426
        # Only return True if at least one exit condition was matched AND none had errors
427
        return bool(matched_exit_conditions) and not has_errors
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc