• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18848249981

27 Oct 2025 04:24PM UTC coverage: 92.214% (-0.005%) from 92.219%
18848249981

Pull #9616

github

web-flow
Merge cee0f6bf7 into 4ce5b683d
Pull Request #9616: feat: Add generation_kwargs to run methods of Agent

13467 of 14604 relevant lines covered (92.21%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.48
haystack/components/agents/agent.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
from dataclasses import dataclass
1✔
7
from typing import Any, Optional, Union, cast
1✔
8

9
from haystack import logging, tracing
1✔
10
from haystack.components.generators.chat.types import ChatGenerator
1✔
11
from haystack.components.tools import ToolInvoker
1✔
12
from haystack.core.component.component import component
1✔
13
from haystack.core.errors import PipelineRuntimeError
1✔
14
from haystack.core.pipeline.async_pipeline import AsyncPipeline
1✔
15
from haystack.core.pipeline.breakpoint import (
1✔
16
    _create_pipeline_snapshot_from_chat_generator,
17
    _create_pipeline_snapshot_from_tool_invoker,
18
    _trigger_chat_generator_breakpoint,
19
    _trigger_tool_invoker_breakpoint,
20
    _validate_tool_breakpoint_is_valid,
21
)
22
from haystack.core.pipeline.pipeline import Pipeline
1✔
23
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
1✔
24
from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict
1✔
25
from haystack.dataclasses import ChatMessage, ChatRole
1✔
26
from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, PipelineSnapshot, ToolBreakpoint
1✔
27
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback
1✔
28
from haystack.tools import (
1✔
29
    Tool,
30
    Toolset,
31
    ToolsType,
32
    deserialize_tools_or_toolset_inplace,
33
    flatten_tools_or_toolsets,
34
    serialize_tools_or_toolset,
35
)
36
from haystack.utils import _deserialize_value_with_schema
1✔
37
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
38
from haystack.utils.deserialization import deserialize_chatgenerator_inplace
1✔
39

40
from .state.state import State, _schema_from_dict, _schema_to_dict, _validate_schema
1✔
41
from .state.state_utils import merge_lists
1✔
42

43
logger = logging.getLogger(__name__)
1✔
44

45

46
@dataclass
1✔
47
class _ExecutionContext:
1✔
48
    """
49
    Context for executing the agent.
50

51
    :param state: The current state of the agent, including messages and any additional data.
52
    :param component_visits: A dictionary tracking how many times each component has been visited.
53
    :param chat_generator_inputs: Runtime inputs to be passed to the chat generator.
54
    :param tool_invoker_inputs: Runtime inputs to be passed to the tool invoker.
55
    :param counter: A counter to track the number of steps taken in the agent's run.
56
    :param skip_chat_generator: A flag to indicate whether to skip the chat generator in the next iteration.
57
        This is useful when resuming from a ToolBreakpoint where the ToolInvoker needs to be called first.
58
    """
59

60
    state: State
1✔
61
    component_visits: dict
1✔
62
    chat_generator_inputs: dict
1✔
63
    tool_invoker_inputs: dict
1✔
64
    counter: int = 0
1✔
65
    skip_chat_generator: bool = False
1✔
66

67

68
@component
1✔
69
class Agent:
1✔
70
    """
71
    A Haystack component that implements a tool-using agent with provider-agnostic chat model support.
72

73
    The component processes messages and executes tools until an exit condition is met.
74
    The exit condition can be triggered either by a direct text response or by invoking a specific designated tool.
75
    Multiple exit conditions can be specified.
76

77
    When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits.
78

79
    ### Usage example
80
    ```python
81
    from haystack.components.agents import Agent
82
    from haystack.components.generators.chat import OpenAIChatGenerator
83
    from haystack.dataclasses import ChatMessage
84
    from haystack.tools.tool import Tool
85

86
    tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")]
87

88
    agent = Agent(
89
        chat_generator=OpenAIChatGenerator(),
90
        tools=tools,
91
        exit_conditions=["search"],
92
    )
93

94
    # Run the agent
95
    result = agent.run(
96
        messages=[ChatMessage.from_user("Find information about Haystack")]
97
    )
98

99
    assert "messages" in result  # Contains conversation history
100
    ```
101
    """
102

103
    def __init__(
1✔
104
        self,
105
        *,
106
        chat_generator: ChatGenerator,
107
        tools: Optional[ToolsType] = None,
108
        system_prompt: Optional[str] = None,
109
        exit_conditions: Optional[list[str]] = None,
110
        state_schema: Optional[dict[str, Any]] = None,
111
        max_agent_steps: int = 100,
112
        streaming_callback: Optional[StreamingCallbackT] = None,
113
        raise_on_tool_invocation_failure: bool = False,
114
        tool_invoker_kwargs: Optional[dict[str, Any]] = None,
115
    ) -> None:
116
        """
117
        Initialize the agent component.
118

119
        :param chat_generator: An instance of the chat generator that your agent should use. It must support tools.
120
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset that the agent can use.
121
        :param system_prompt: System prompt for the agent.
122
        :param exit_conditions: List of conditions that will cause the agent to return.
123
            Can include "text" if the agent should return when it generates a message without tool calls,
124
            or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
125
        :param state_schema: The schema for the runtime state used by the tools.
126
        :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100.
127
            If the agent exceeds this number of steps, it will stop and return the current state.
128
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
129
            The same callback can be configured to emit tool results when a tool is called.
130
        :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
131
            If set to False, the exception will be turned into a chat message and passed to the LLM.
132
        :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
133
        :raises TypeError: If the chat_generator does not support tools parameter in its run method.
134
        :raises ValueError: If the exit_conditions are not valid.
135
        """
136
        # Check if chat_generator supports tools parameter
137
        chat_generator_run_method = inspect.signature(chat_generator.run)
1✔
138
        if "tools" not in chat_generator_run_method.parameters:
1✔
139
            raise TypeError(
1✔
140
                f"{type(chat_generator).__name__} does not accept tools parameter in its run method. "
141
                "The Agent component requires a chat generator that supports tools."
142
            )
143

144
        valid_exits = ["text"] + [tool.name for tool in flatten_tools_or_toolsets(tools)]
1✔
145
        if exit_conditions is None:
1✔
146
            exit_conditions = ["text"]
1✔
147
        if not all(condition in valid_exits for condition in exit_conditions):
1✔
148
            raise ValueError(
1✔
149
                f"Invalid exit conditions provided: {exit_conditions}. "
150
                f"Valid exit conditions must be a subset of {valid_exits}. "
151
                "Ensure that each exit condition corresponds to either 'text' or a valid tool name."
152
            )
153

154
        # Validate state schema if provided
155
        if state_schema is not None:
1✔
156
            _validate_schema(state_schema)
1✔
157
        self._state_schema = state_schema or {}
1✔
158

159
        # Initialize state schema
160
        resolved_state_schema = _deepcopy_with_exceptions(self._state_schema)
1✔
161
        if resolved_state_schema.get("messages") is None:
1✔
162
            resolved_state_schema["messages"] = {"type": list[ChatMessage], "handler": merge_lists}
1✔
163
        self.state_schema = resolved_state_schema
1✔
164

165
        self.chat_generator = chat_generator
1✔
166
        self.tools = tools or []
1✔
167
        self.system_prompt = system_prompt
1✔
168
        self.exit_conditions = exit_conditions
1✔
169
        self.max_agent_steps = max_agent_steps
1✔
170
        self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
1✔
171
        self.streaming_callback = streaming_callback
1✔
172

173
        output_types = {"last_message": ChatMessage}
1✔
174
        for param, config in self.state_schema.items():
1✔
175
            output_types[param] = config["type"]
1✔
176
            # Skip setting input types for parameters that are already in the run method
177
            if param in ["messages", "streaming_callback"]:
1✔
178
                continue
1✔
179
            component.set_input_type(self, name=param, type=config["type"], default=None)
1✔
180
        component.set_output_types(self, **output_types)
1✔
181

182
        self.tool_invoker_kwargs = tool_invoker_kwargs
1✔
183
        self._tool_invoker = None
1✔
184
        if self.tools:
1✔
185
            resolved_tool_invoker_kwargs = {
1✔
186
                "tools": self.tools,
187
                "raise_on_failure": self.raise_on_tool_invocation_failure,
188
                **(tool_invoker_kwargs or {}),
189
            }
190
            self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs)
1✔
191
        else:
192
            logger.warning(
1✔
193
                "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text "
194
                "responses. To enable tool usage, pass tools directly to the Agent, not to the chat_generator."
195
            )
196

197
        self._is_warmed_up = False
1✔
198

199
    def warm_up(self) -> None:
1✔
200
        """
201
        Warm up the Agent.
202
        """
203
        if not self._is_warmed_up:
1✔
204
            if hasattr(self.chat_generator, "warm_up"):
1✔
205
                self.chat_generator.warm_up()
1✔
206
            if hasattr(self._tool_invoker, "warm_up") and self._tool_invoker is not None:
1✔
207
                self._tool_invoker.warm_up()
1✔
208
            self._is_warmed_up = True
1✔
209

210
    def to_dict(self) -> dict[str, Any]:
1✔
211
        """
212
        Serialize the component to a dictionary.
213

214
        :return: Dictionary with serialized data
215
        """
216
        return default_to_dict(
1✔
217
            self,
218
            chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"),
219
            tools=serialize_tools_or_toolset(self.tools),
220
            system_prompt=self.system_prompt,
221
            exit_conditions=self.exit_conditions,
222
            # We serialize the original state schema, not the resolved one to reflect the original user input
223
            state_schema=_schema_to_dict(self._state_schema),
224
            max_agent_steps=self.max_agent_steps,
225
            streaming_callback=serialize_callable(self.streaming_callback) if self.streaming_callback else None,
226
            raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
227
            tool_invoker_kwargs=self.tool_invoker_kwargs,
228
        )
229

230
    @classmethod
1✔
231
    def from_dict(cls, data: dict[str, Any]) -> "Agent":
1✔
232
        """
233
        Deserialize the agent from a dictionary.
234

235
        :param data: Dictionary to deserialize from
236
        :return: Deserialized agent
237
        """
238
        init_params = data.get("init_parameters", {})
1✔
239

240
        deserialize_chatgenerator_inplace(init_params, key="chat_generator")
1✔
241

242
        if "state_schema" in init_params:
1✔
243
            init_params["state_schema"] = _schema_from_dict(init_params["state_schema"])
1✔
244

245
        if init_params.get("streaming_callback") is not None:
1✔
246
            init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"])
1✔
247

248
        deserialize_tools_or_toolset_inplace(init_params, key="tools")
1✔
249

250
        return default_from_dict(cls, data)
1✔
251

252
    def _create_agent_span(self) -> Any:
1✔
253
        """Create a span for the agent run."""
254
        return tracing.tracer.trace(
1✔
255
            "haystack.agent.run",
256
            tags={
257
                "haystack.agent.max_steps": self.max_agent_steps,
258
                "haystack.agent.tools": self.tools,
259
                "haystack.agent.exit_conditions": self.exit_conditions,
260
                "haystack.agent.state_schema": _schema_to_dict(self.state_schema),
261
            },
262
        )
263

264
    def _initialize_fresh_execution(
1✔
265
        self,
266
        messages: list[ChatMessage],
267
        streaming_callback: Optional[StreamingCallbackT],
268
        requires_async: bool,
269
        *,
270
        system_prompt: Optional[str] = None,
271
        generation_kwargs: Optional[dict[str, Any]] = None,
272
        tools: Optional[Union[ToolsType, list[str]]] = None,
273
        **kwargs,
274
    ) -> _ExecutionContext:
275
        """
276
        Initialize execution context for a fresh run of the agent.
277

278
        :param messages: List of ChatMessage objects to start the agent with.
279
        :param streaming_callback: Optional callback for streaming responses.
280
        :param requires_async: Whether the agent run requires asynchronous execution.
281
        :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
282
        :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will
283
            override the parameters passed during component initialization.
284
        :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
285
            When passing tool names, tools are selected from the Agent's originally configured tools.
286
        :param kwargs: Additional data to pass to the State used by the Agent.
287
        """
288
        system_prompt = system_prompt or self.system_prompt
1✔
289
        if system_prompt is not None:
1✔
290
            messages = [ChatMessage.from_system(system_prompt)] + messages
1✔
291

292
        if all(m.is_from(ChatRole.SYSTEM) for m in messages):
1✔
293
            logger.warning("All messages provided to the Agent component are system messages. This is not recommended.")
1✔
294

295
        state = State(schema=self.state_schema, data=kwargs)
1✔
296
        state.set("messages", messages)
1✔
297

298
        streaming_callback = select_streaming_callback(  # type: ignore[call-overload]
1✔
299
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
300
        )
301

302
        selected_tools = self._select_tools(tools)
1✔
303
        tool_invoker_inputs: dict[str, Any] = {"tools": selected_tools}
1✔
304
        generator_inputs: dict[str, Any] = {"tools": selected_tools}
1✔
305
        if streaming_callback is not None:
1✔
306
            tool_invoker_inputs["streaming_callback"] = streaming_callback
1✔
307
            generator_inputs["streaming_callback"] = streaming_callback
1✔
308
        if generation_kwargs is not None:
1✔
309
            generator_inputs["generation_kwargs"] = generation_kwargs
1✔
310

311
        return _ExecutionContext(
1✔
312
            state=state,
313
            component_visits=dict.fromkeys(["chat_generator", "tool_invoker"], 0),
314
            chat_generator_inputs=generator_inputs,
315
            tool_invoker_inputs=tool_invoker_inputs,
316
        )
317

318
    def _select_tools(self, tools: Optional[Union[ToolsType, list[str]]] = None) -> ToolsType:
1✔
319
        """
320
        Select tools for the current run based on the provided tools parameter.
321

322
        :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
323
            When passing tool names, tools are selected from the Agent's originally configured tools.
324
        :returns: Selected tools for the current run.
325
        :raises ValueError: If tool names are provided but no tools were configured at initialization,
326
            or if any provided tool name is not valid.
327
        :raises TypeError: If tools is not a list of Tool objects, a Toolset, or a list of tool names (strings).
328
        """
329
        if tools is None:
1✔
330
            return self.tools
1✔
331

332
        if isinstance(tools, list) and all(isinstance(t, str) for t in tools):
1✔
333
            if not self.tools:
1✔
334
                raise ValueError("No tools were configured for the Agent at initialization.")
1✔
335
            available_tools = flatten_tools_or_toolsets(self.tools)
1✔
336
            selected_tool_names = cast(list[str], tools)  # mypy thinks this could still be list[Tool] or Toolset
1✔
337
            valid_tool_names = {tool.name for tool in available_tools}
1✔
338
            invalid_tool_names = {name for name in selected_tool_names if name not in valid_tool_names}
1✔
339
            if invalid_tool_names:
1✔
340
                raise ValueError(
1✔
341
                    f"The following tool names are not valid: {invalid_tool_names}. "
342
                    f"Valid tool names are: {valid_tool_names}."
343
                )
344
            return [tool for tool in available_tools if tool.name in selected_tool_names]
1✔
345

346
        if isinstance(tools, Toolset):
1✔
347
            return tools
×
348

349
        if isinstance(tools, list):
1✔
350
            return cast(list[Union[Tool, Toolset]], tools)  # mypy can't narrow the Union type from isinstance check
1✔
351

352
        raise TypeError(
1✔
353
            "tools must be a list of Tool and/or Toolset objects, a Toolset, or a list of tool names (strings)."
354
        )
355

356
    def _initialize_from_snapshot(
1✔
357
        self,
358
        snapshot: AgentSnapshot,
359
        streaming_callback: Optional[StreamingCallbackT],
360
        requires_async: bool,
361
        *,
362
        generation_kwargs: Optional[dict[str, Any]] = None,
363
        tools: Optional[Union[ToolsType, list[str]]] = None,
364
    ) -> _ExecutionContext:
365
        """
366
        Initialize execution context from an AgentSnapshot.
367

368
        :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution.
369
        :param streaming_callback: Optional callback for streaming responses.
370
        :param requires_async: Whether the agent run requires asynchronous execution.
371
        :param generation_kwargs: Additional keyword arguments for chat generator. These parameters will
372
            override the parameters passed during component initialization.
373
        :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
374
            When passing tool names, tools are selected from the Agent's originally configured tools.
375
        """
376
        component_visits = snapshot.component_visits
1✔
377
        current_inputs = {
1✔
378
            "chat_generator": _deserialize_value_with_schema(snapshot.component_inputs["chat_generator"]),
379
            "tool_invoker": _deserialize_value_with_schema(snapshot.component_inputs["tool_invoker"]),
380
        }
381

382
        state_data = current_inputs["tool_invoker"]["state"].data
1✔
383
        state = State(schema=self.state_schema, data=state_data)
1✔
384

385
        skip_chat_generator = isinstance(snapshot.break_point.break_point, ToolBreakpoint)
1✔
386
        streaming_callback = current_inputs["chat_generator"].get("streaming_callback", streaming_callback)
1✔
387
        streaming_callback = select_streaming_callback(  # type: ignore[call-overload]
1✔
388
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=requires_async
389
        )
390

391
        selected_tools = self._select_tools(tools)
1✔
392
        tool_invoker_inputs: dict[str, Any] = {"tools": selected_tools}
1✔
393
        generator_inputs: dict[str, Any] = {"tools": selected_tools}
1✔
394
        if streaming_callback is not None:
1✔
395
            tool_invoker_inputs["streaming_callback"] = streaming_callback
×
396
            generator_inputs["streaming_callback"] = streaming_callback
×
397
        if generation_kwargs is not None:
1✔
398
            generator_inputs["generation_kwargs"] = generation_kwargs
×
399

400
        return _ExecutionContext(
1✔
401
            state=state,
402
            component_visits=component_visits,
403
            chat_generator_inputs=generator_inputs,
404
            tool_invoker_inputs=tool_invoker_inputs,
405
            counter=snapshot.break_point.break_point.visit_count,
406
            skip_chat_generator=skip_chat_generator,
407
        )
408

409
    def _runtime_checks(self, break_point: Optional[AgentBreakpoint], snapshot: Optional[AgentSnapshot]) -> None:
1✔
410
        """
411
        Perform runtime checks before running the agent.
412

413
        :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
414
            for "tool_invoker".
415
        :param snapshot: An AgentSnapshot containing the state of a previously saved agent execution.
416
        :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
417
        :raises ValueError: If the break_point is invalid.
418
        """
419
        if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
1✔
420
            raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
1✔
421

422
        if break_point and isinstance(break_point.break_point, ToolBreakpoint):
1✔
423
            _validate_tool_breakpoint_is_valid(agent_breakpoint=break_point, tools=self.tools)
1✔
424

425
    @staticmethod
1✔
426
    def _check_chat_generator_breakpoint(
1✔
427
        execution_context: _ExecutionContext,
428
        break_point: Optional[AgentBreakpoint],
429
        parent_snapshot: Optional[PipelineSnapshot],
430
    ) -> None:
431
        """
432
        Check if the chat generator breakpoint should be triggered.
433

434
        If the breakpoint should be triggered, create an agent snapshot and trigger the chat generator breakpoint.
435

436
        :param execution_context: The current execution context of the agent.
437
        :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
438
            for "tool_invoker".
439
        :param parent_snapshot: An optional parent snapshot for the agent execution.
440
        """
441
        if (
1✔
442
            break_point
443
            and break_point.break_point.component_name == "chat_generator"
444
            and execution_context.component_visits["chat_generator"] == break_point.break_point.visit_count
445
        ):
446
            pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
1✔
447
                execution_context=execution_context, break_point=break_point, parent_snapshot=parent_snapshot
448
            )
449
            _trigger_chat_generator_breakpoint(pipeline_snapshot=pipeline_snapshot)
1✔
450

451
    @staticmethod
1✔
452
    def _check_tool_invoker_breakpoint(
1✔
453
        execution_context: _ExecutionContext,
454
        break_point: Optional[AgentBreakpoint],
455
        parent_snapshot: Optional[PipelineSnapshot],
456
    ) -> None:
457
        """
458
        Check if the tool invoker breakpoint should be triggered.
459

460
        If the breakpoint should be triggered, create an agent snapshot and trigger the tool invoker breakpoint.
461

462
        :param execution_context: The current execution context of the agent.
463
        :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
464
            for "tool_invoker".
465
        :param parent_snapshot: An optional parent snapshot for the agent execution.
466
        """
467
        if (
1✔
468
            break_point
469
            and break_point.break_point.component_name == "tool_invoker"
470
            and break_point.break_point.visit_count == execution_context.component_visits["tool_invoker"]
471
        ):
472
            pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
1✔
473
                execution_context=execution_context, break_point=break_point, parent_snapshot=parent_snapshot
474
            )
475
            _trigger_tool_invoker_breakpoint(
1✔
476
                llm_messages=execution_context.state.data["messages"][-1:], pipeline_snapshot=pipeline_snapshot
477
            )
478

479
    def run(  # noqa: PLR0915
1✔
480
        self,
481
        messages: list[ChatMessage],
482
        streaming_callback: Optional[StreamingCallbackT] = None,
483
        *,
484
        generation_kwargs: Optional[dict[str, Any]] = None,
485
        break_point: Optional[AgentBreakpoint] = None,
486
        snapshot: Optional[AgentSnapshot] = None,
487
        system_prompt: Optional[str] = None,
488
        tools: Optional[Union[ToolsType, list[str]]] = None,
489
        **kwargs: Any,
490
    ) -> dict[str, Any]:
491
        """
492
        Process messages and execute tools until an exit condition is met.
493

494
        :param messages: List of Haystack ChatMessage objects to process.
495
        :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
496
            The same callback can be configured to emit tool results when a tool is called.
497
        :param generation_kwargs: Additional keyword arguments for LLM. These parameters will
498
            override the parameters passed during component initialization.
499
        :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
500
            for "tool_invoker".
501
        :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
502
            the relevant information to restart the Agent execution from where it left off.
503
        :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
504
        :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
505
            When passing tool names, tools are selected from the Agent's originally configured tools.
506
        :param kwargs: Additional data to pass to the State schema used by the Agent.
507
            The keys must match the schema defined in the Agent's `state_schema`.
508
        :returns:
509
            A dictionary with the following keys:
510
            - "messages": List of all messages exchanged during the agent's run.
511
            - "last_message": The last message exchanged during the agent's run.
512
            - Any additional keys defined in the `state_schema`.
513
        :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
514
        :raises BreakpointException: If an agent breakpoint is triggered.
515
        """
516
        # We pop parent_snapshot from kwargs to avoid passing it into State.
517
        parent_snapshot = kwargs.pop("parent_snapshot", None)
1✔
518
        agent_inputs = {
1✔
519
            "messages": messages,
520
            "streaming_callback": streaming_callback,
521
            "break_point": break_point,
522
            "snapshot": snapshot,
523
            **kwargs,
524
        }
525
        self._runtime_checks(break_point=break_point, snapshot=snapshot)
1✔
526

527
        if snapshot:
1✔
528
            exe_context = self._initialize_from_snapshot(
1✔
529
                snapshot=snapshot,
530
                streaming_callback=streaming_callback,
531
                requires_async=False,
532
                tools=tools,
533
                generation_kwargs=generation_kwargs,
534
            )
535
        else:
536
            exe_context = self._initialize_fresh_execution(
1✔
537
                messages=messages,
538
                streaming_callback=streaming_callback,
539
                requires_async=False,
540
                system_prompt=system_prompt,
541
                tools=tools,
542
                generation_kwargs=generation_kwargs,
543
                **kwargs,
544
            )
545

546
        with self._create_agent_span() as span:
1✔
547
            span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
1✔
548

549
            while exe_context.counter < self.max_agent_steps:
1✔
550
                # Handle breakpoint and ChatGenerator call
551
                Agent._check_chat_generator_breakpoint(
1✔
552
                    execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
553
                )
554
                # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
555
                if exe_context.skip_chat_generator:
1✔
556
                    llm_messages = exe_context.state.get("messages", [])[-1:]
1✔
557
                    # Set to False so the next iteration will call the chat generator
558
                    exe_context.skip_chat_generator = False
1✔
559
                else:
560
                    try:
1✔
561
                        result = Pipeline._run_component(
1✔
562
                            component_name="chat_generator",
563
                            component={"instance": self.chat_generator},
564
                            inputs={
565
                                "messages": exe_context.state.data["messages"],
566
                                **exe_context.chat_generator_inputs,
567
                            },
568
                            component_visits=exe_context.component_visits,
569
                            parent_span=span,
570
                        )
571
                    except PipelineRuntimeError as e:
1✔
572
                        pipeline_snapshot = _create_pipeline_snapshot_from_chat_generator(
1✔
573
                            agent_name=getattr(self, "__component_name__", None),
574
                            execution_context=exe_context,
575
                            parent_snapshot=parent_snapshot,
576
                        )
577
                        e.pipeline_snapshot = pipeline_snapshot
1✔
578
                        raise e
1✔
579

580
                    llm_messages = result["replies"]
1✔
581
                    exe_context.state.set("messages", llm_messages)
1✔
582

583
                # Check if any of the LLM responses contain a tool call or if the LLM is not using tools
584
                if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
1✔
585
                    exe_context.counter += 1
1✔
586
                    break
1✔
587

588
                # Handle breakpoint and ToolInvoker call
589
                Agent._check_tool_invoker_breakpoint(
1✔
590
                    execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
591
                )
592
                try:
1✔
593
                    # We only send the messages from the LLM to the tool invoker
594
                    tool_invoker_result = Pipeline._run_component(
1✔
595
                        component_name="tool_invoker",
596
                        component={"instance": self._tool_invoker},
597
                        inputs={
598
                            "messages": llm_messages,
599
                            "state": exe_context.state,
600
                            **exe_context.tool_invoker_inputs,
601
                        },
602
                        component_visits=exe_context.component_visits,
603
                        parent_span=span,
604
                    )
605
                except PipelineRuntimeError as e:
1✔
606
                    # Access the original Tool Invoker exception
607
                    original_error = e.__cause__
1✔
608
                    tool_name = getattr(original_error, "tool_name", None)
1✔
609

610
                    pipeline_snapshot = _create_pipeline_snapshot_from_tool_invoker(
1✔
611
                        tool_name=tool_name,
612
                        agent_name=getattr(self, "__component_name__", None),
613
                        execution_context=exe_context,
614
                        parent_snapshot=parent_snapshot,
615
                    )
616
                    e.pipeline_snapshot = pipeline_snapshot
1✔
617
                    raise e
1✔
618

619
                tool_messages = tool_invoker_result["tool_messages"]
1✔
620
                exe_context.state = tool_invoker_result["state"]
1✔
621
                exe_context.state.set("messages", tool_messages)
1✔
622

623
                # Check if any LLM message's tool call name matches an exit condition
624
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
1✔
625
                    exe_context.counter += 1
1✔
626
                    break
1✔
627

628
                # Increment the step counter
629
                exe_context.counter += 1
1✔
630

631
            if exe_context.counter >= self.max_agent_steps:
1✔
632
                logger.warning(
1✔
633
                    "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
634
                    max_agent_steps=self.max_agent_steps,
635
                )
636
            span.set_content_tag("haystack.agent.output", exe_context.state.data)
1✔
637
            span.set_tag("haystack.agent.steps_taken", exe_context.counter)
1✔
638

639
        result = {**exe_context.state.data}
1✔
640
        if msgs := result.get("messages"):
1✔
641
            result["last_message"] = msgs[-1]
1✔
642
        return result
1✔
643

644
    async def run_async(
1✔
645
        self,
646
        messages: list[ChatMessage],
647
        streaming_callback: Optional[StreamingCallbackT] = None,
648
        *,
649
        generation_kwargs: Optional[dict[str, Any]] = None,
650
        break_point: Optional[AgentBreakpoint] = None,
651
        snapshot: Optional[AgentSnapshot] = None,
652
        system_prompt: Optional[str] = None,
653
        tools: Optional[Union[ToolsType, list[str]]] = None,
654
        **kwargs: Any,
655
    ) -> dict[str, Any]:
656
        """
657
        Asynchronously process messages and execute tools until the exit condition is met.
658

659
        This is the asynchronous version of the `run` method. It follows the same logic but uses
660
        asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator
661
        if available.
662

663
        :param messages: List of Haystack ChatMessage objects to process.
664
        :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the
665
            LLM. The same callback can be configured to emit tool results when a tool is called.
666
        :param generation_kwargs: Additional keyword arguments for LLM. These parameters will
667
            override the parameters passed during component initialization.
668
        :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
669
            for "tool_invoker".
670
        :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
671
            the relevant information to restart the Agent execution from where it left off.
672
        :param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
673
        :param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
674
        :param kwargs: Additional data to pass to the State schema used by the Agent.
675
            The keys must match the schema defined in the Agent's `state_schema`.
676
        :returns:
677
            A dictionary with the following keys:
678
            - "messages": List of all messages exchanged during the agent's run.
679
            - "last_message": The last message exchanged during the agent's run.
680
            - Any additional keys defined in the `state_schema`.
681
        :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`.
682
        :raises BreakpointException: If an agent breakpoint is triggered.
683
        """
684
        # We pop parent_snapshot from kwargs to avoid passing it into State.
685
        parent_snapshot = kwargs.pop("parent_snapshot", None)
1✔
686
        agent_inputs = {
1✔
687
            "messages": messages,
688
            "streaming_callback": streaming_callback,
689
            "break_point": break_point,
690
            "snapshot": snapshot,
691
            **kwargs,
692
        }
693
        self._runtime_checks(break_point=break_point, snapshot=snapshot)
1✔
694

695
        if snapshot:
1✔
696
            exe_context = self._initialize_from_snapshot(
1✔
697
                snapshot=snapshot,
698
                streaming_callback=streaming_callback,
699
                requires_async=True,
700
                tools=tools,
701
                generation_kwargs=generation_kwargs,
702
            )
703
        else:
704
            exe_context = self._initialize_fresh_execution(
1✔
705
                messages=messages,
706
                streaming_callback=streaming_callback,
707
                requires_async=True,
708
                system_prompt=system_prompt,
709
                tools=tools,
710
                generation_kwargs=generation_kwargs,
711
                **kwargs,
712
            )
713

714
        with self._create_agent_span() as span:
1✔
715
            span.set_content_tag("haystack.agent.input", _deepcopy_with_exceptions(agent_inputs))
1✔
716

717
            while exe_context.counter < self.max_agent_steps:
1✔
718
                # Handle breakpoint and ChatGenerator call
719
                self._check_chat_generator_breakpoint(
1✔
720
                    execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
721
                )
722
                # We skip the chat generator when restarting from a snapshot from a ToolBreakpoint
723
                if exe_context.skip_chat_generator:
1✔
724
                    llm_messages = exe_context.state.get("messages", [])[-1:]
1✔
725
                    # Set to False so the next iteration will call the chat generator
726
                    exe_context.skip_chat_generator = False
1✔
727
                else:
728
                    result = await AsyncPipeline._run_component_async(
1✔
729
                        component_name="chat_generator",
730
                        component={"instance": self.chat_generator},
731
                        component_inputs={
732
                            "messages": exe_context.state.data["messages"],
733
                            **exe_context.chat_generator_inputs,
734
                        },
735
                        component_visits=exe_context.component_visits,
736
                        parent_span=span,
737
                    )
738
                    llm_messages = result["replies"]
1✔
739
                    exe_context.state.set("messages", llm_messages)
1✔
740

741
                # Check if any of the LLM responses contain a tool call or if the LLM is not using tools
742
                if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
1✔
743
                    exe_context.counter += 1
1✔
744
                    break
1✔
745

746
                # Handle breakpoint and ToolInvoker call
747
                self._check_tool_invoker_breakpoint(
1✔
748
                    execution_context=exe_context, break_point=break_point, parent_snapshot=parent_snapshot
749
                )
750
                # We only send the messages from the LLM to the tool invoker
751
                tool_invoker_result = await AsyncPipeline._run_component_async(
1✔
752
                    component_name="tool_invoker",
753
                    component={"instance": self._tool_invoker},
754
                    component_inputs={
755
                        "messages": llm_messages,
756
                        "state": exe_context.state,
757
                        **exe_context.tool_invoker_inputs,
758
                    },
759
                    component_visits=exe_context.component_visits,
760
                    parent_span=span,
761
                )
762
                tool_messages = tool_invoker_result["tool_messages"]
1✔
763
                exe_context.state = tool_invoker_result["state"]
1✔
764
                exe_context.state.set("messages", tool_messages)
1✔
765

766
                # Check if any LLM message's tool call name matches an exit condition
767
                if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
1✔
768
                    exe_context.counter += 1
×
769
                    break
×
770

771
                # Increment the step counter
772
                exe_context.counter += 1
1✔
773

774
            if exe_context.counter >= self.max_agent_steps:
1✔
775
                logger.warning(
×
776
                    "Agent reached maximum agent steps of {max_agent_steps}, stopping.",
777
                    max_agent_steps=self.max_agent_steps,
778
                )
779
            span.set_content_tag("haystack.agent.output", exe_context.state.data)
1✔
780
            span.set_tag("haystack.agent.steps_taken", exe_context.counter)
1✔
781

782
        result = {**exe_context.state.data}
1✔
783
        if msgs := result.get("messages"):
1✔
784
            result["last_message"] = msgs[-1]
1✔
785
        return result
1✔
786

787
    def _check_exit_conditions(self, llm_messages: list[ChatMessage], tool_messages: list[ChatMessage]) -> bool:
1✔
788
        """
789
        Check if any of the LLM messages' tool calls match an exit condition and if there are no errors.
790

791
        :param llm_messages: List of messages from the LLM
792
        :param tool_messages: List of messages from the tool invoker
793
        :return: True if an exit condition is met and there are no errors, False otherwise
794
        """
795
        matched_exit_conditions = set()
1✔
796
        has_errors = False
1✔
797

798
        for msg in llm_messages:
1✔
799
            if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions:
1✔
800
                matched_exit_conditions.add(msg.tool_call.tool_name)
1✔
801

802
                # Check if any error is specifically from the tool matching the exit condition
803
                tool_errors = [
1✔
804
                    tool_msg.tool_call_result.error
805
                    for tool_msg in tool_messages
806
                    if tool_msg.tool_call_result is not None
807
                    and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name
808
                ]
809
                if any(tool_errors):
1✔
810
                    has_errors = True
×
811
                    # No need to check further if we found an error
812
                    break
×
813

814
        # Only return True if at least one exit condition was matched AND none had errors
815
        return bool(matched_exit_conditions) and not has_errors
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc