• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 17578922709

09 Sep 2025 09:57AM UTC coverage: 92.063% (-0.003%) from 92.066%
17578922709

Pull #9754

github

web-flow
Merge 868faf198 into 34f1a0412
Pull Request #9754: feat: support structured outputs in `OpenAIChatGenerator`

12992 of 14112 relevant lines covered (92.06%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.67
haystack/core/pipeline/breakpoint.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
from copy import deepcopy
1✔
7
from dataclasses import replace
1✔
8
from datetime import datetime
1✔
9
from pathlib import Path
1✔
10
from typing import TYPE_CHECKING, Any, Optional, Union
1✔
11

12
from networkx import MultiDiGraph
1✔
13

14
from haystack import logging
1✔
15
from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError
1✔
16
from haystack.dataclasses import ChatMessage
1✔
17
from haystack.dataclasses.breakpoints import (
1✔
18
    AgentBreakpoint,
19
    AgentSnapshot,
20
    Breakpoint,
21
    PipelineSnapshot,
22
    PipelineState,
23
    ToolBreakpoint,
24
)
25
from haystack.utils.base_serialization import _serialize_value_with_schema
1✔
26

27
if TYPE_CHECKING:
28
    from haystack.tools.tool import Tool
29
    from haystack.tools.toolset import Toolset
30

31
logger = logging.getLogger(__name__)
1✔
32

33

34
def _validate_break_point_against_pipeline(
1✔
35
    break_point: Union[Breakpoint, AgentBreakpoint], graph: MultiDiGraph
36
) -> None:
37
    """
38
    Validates the breakpoints passed to the pipeline.
39

40
    Makes sure the breakpoint contains a valid components registered in the pipeline.
41

42
    :param break_point: a breakpoint to validate, can be Breakpoint or AgentBreakpoint
43
    """
44

45
    # all Breakpoints must refer to a valid component in the pipeline
46
    if isinstance(break_point, Breakpoint) and break_point.component_name not in graph.nodes:
1✔
47
        raise ValueError(f"break_point {break_point} is not a registered component in the pipeline")
×
48

49
    if isinstance(break_point, AgentBreakpoint):
1✔
50
        breakpoint_agent_component = graph.nodes.get(break_point.agent_name)
1✔
51
        if not breakpoint_agent_component:
1✔
52
            raise ValueError(f"break_point {break_point} is not a registered Agent component in the pipeline")
×
53

54
        if isinstance(break_point.break_point, ToolBreakpoint):
1✔
55
            instance = breakpoint_agent_component["instance"]
1✔
56
            for tool in instance.tools:
1✔
57
                if break_point.break_point.tool_name == tool.name:
1✔
58
                    break
1✔
59
            else:
60
                raise ValueError(
×
61
                    f"break_point {break_point.break_point} is not a registered tool in the Agent component"
62
                )
63

64

65
def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnapshot, graph: MultiDiGraph) -> None:
1✔
66
    """
67
    Validates that the pipeline_snapshot contains valid configuration for the current pipeline.
68

69
    Raises a PipelineInvalidPipelineSnapshotError if any component in pipeline_snapshot is not part of the
70
    target pipeline.
71

72
    :param pipeline_snapshot: The saved state to validate.
73
    """
74

75
    pipeline_state = pipeline_snapshot.pipeline_state
1✔
76
    valid_components = set(graph.nodes.keys())
1✔
77

78
    # Check if the ordered_component_names are valid components in the pipeline
79
    invalid_ordered_components = set(pipeline_snapshot.ordered_component_names) - valid_components
1✔
80
    if invalid_ordered_components:
1✔
81
        raise PipelineInvalidPipelineSnapshotError(
×
82
            f"Invalid pipeline snapshot: components {invalid_ordered_components} in 'ordered_component_names' "
83
            f"are not part of the current pipeline."
84
        )
85

86
    # Check if the original_input_data is valid components in the pipeline
87
    serialized_input_data = pipeline_snapshot.original_input_data["serialized_data"]
1✔
88
    invalid_input_data = set(serialized_input_data.keys()) - valid_components
1✔
89
    if invalid_input_data:
1✔
90
        raise PipelineInvalidPipelineSnapshotError(
×
91
            f"Invalid pipeline snapshot: components {invalid_input_data} in 'input_data' "
92
            f"are not part of the current pipeline."
93
        )
94

95
    # Validate 'component_visits'
96
    invalid_component_visits = set(pipeline_state.component_visits.keys()) - valid_components
1✔
97
    if invalid_component_visits:
1✔
98
        raise PipelineInvalidPipelineSnapshotError(
×
99
            f"Invalid pipeline snapshot: components {invalid_component_visits} in 'component_visits' "
100
            f"are not part of the current pipeline."
101
        )
102

103
    if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
1✔
104
        component_name = pipeline_snapshot.break_point.agent_name
1✔
105
    else:
106
        component_name = pipeline_snapshot.break_point.component_name
×
107

108
    visit_count = pipeline_snapshot.pipeline_state.component_visits[component_name]
1✔
109

110
    logger.info(
1✔
111
        "Resuming pipeline from {component} with visit count {visits}", component=component_name, visits=visit_count
112
    )
113

114

115
def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot:
1✔
116
    """
117
    Load a saved pipeline snapshot.
118

119
    :param file_path: Path to the pipeline_snapshot file.
120
    :returns:
121
        Dict containing the loaded pipeline_snapshot.
122
    """
123

124
    file_path = Path(file_path)
1✔
125

126
    try:
1✔
127
        with open(file_path, "r", encoding="utf-8") as f:
1✔
128
            pipeline_snapshot_dict = json.load(f)
1✔
129
    except FileNotFoundError:
×
130
        raise FileNotFoundError(f"File not found: {file_path}")
×
131
    except json.JSONDecodeError as e:
×
132
        raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos)
×
133
    except IOError as e:
×
134
        raise IOError(f"Error reading {file_path}: {str(e)}")
×
135

136
    try:
1✔
137
        pipeline_snapshot = PipelineSnapshot.from_dict(pipeline_snapshot_dict)
1✔
138
    except ValueError as e:
1✔
139
        raise ValueError(f"Invalid pipeline snapshot from {file_path}: {str(e)}")
1✔
140

141
    logger.info(f"Successfully loaded the pipeline snapshot from: {file_path}")
1✔
142
    return pipeline_snapshot
1✔
143

144

145
def _save_pipeline_snapshot_to_file(
1✔
146
    *, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime
147
) -> None:
148
    """
149
    Save the pipeline snapshot dictionary to a JSON file.
150

151
    :param pipeline_snapshot: The pipeline snapshot to save.
152
    :param snapshot_file_path: The path where to save the file.
153
    :param dt: The datetime object for timestamping.
154
    :raises:
155
        ValueError: If the snapshot_file_path is not a string or a Path object.
156
        Exception: If saving the JSON snapshot fails.
157
    """
158
    snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path
1✔
159
    if not isinstance(snapshot_file_path, Path):
1✔
160
        raise ValueError("Debug path must be a string or a Path object.")
×
161

162
    snapshot_file_path.mkdir(exist_ok=True)
1✔
163

164
    # Generate filename
165
    # We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
166
    if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
1✔
167
        agent_name = pipeline_snapshot.break_point.agent_name
1✔
168
        component_name = pipeline_snapshot.break_point.break_point.component_name
1✔
169
        visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
1✔
170
        file_name = f"{agent_name}_{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
1✔
171
    else:
172
        component_name = pipeline_snapshot.break_point.component_name
1✔
173
        visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
1✔
174
        file_name = f"{component_name}_{visit_nr}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json"
1✔
175

176
    try:
1✔
177
        with open(snapshot_file_path / file_name, "w") as f_out:
1✔
178
            json.dump(pipeline_snapshot.to_dict(), f_out, indent=2)
1✔
179
        logger.info(f"Pipeline snapshot saved at: {file_name}")
1✔
180
    except Exception as e:
×
181
        logger.error(f"Failed to save pipeline snapshot: {str(e)}")
×
182
        raise
×
183

184

185
def _create_pipeline_snapshot(
1✔
186
    *,
187
    inputs: dict[str, Any],
188
    break_point: Union[AgentBreakpoint, Breakpoint],
189
    component_visits: dict[str, int],
190
    original_input_data: Optional[dict[str, Any]] = None,
191
    ordered_component_names: Optional[list[str]] = None,
192
    include_outputs_from: Optional[set[str]] = None,
193
    pipeline_outputs: Optional[dict[str, Any]] = None,
194
) -> PipelineSnapshot:
195
    """
196
    Create a snapshot of the pipeline at the point where the breakpoint was triggered.
197

198
    :param inputs: The current pipeline snapshot inputs.
199
    :param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint.
200
    :param component_visits: The visit count of the component that triggered the breakpoint.
201
    :param original_input_data: The original input data.
202
    :param ordered_component_names: The ordered component names.
203
    :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results.
204
    """
205
    dt = datetime.now()
1✔
206

207
    transformed_original_input_data = _transform_json_structure(original_input_data)
1✔
208
    transformed_inputs = _transform_json_structure(inputs)
1✔
209

210
    pipeline_snapshot = PipelineSnapshot(
1✔
211
        pipeline_state=PipelineState(
212
            inputs=_serialize_value_with_schema(transformed_inputs),  # current pipeline inputs
213
            component_visits=component_visits,
214
            pipeline_outputs=pipeline_outputs or {},
215
        ),
216
        timestamp=dt,
217
        break_point=break_point,
218
        original_input_data=_serialize_value_with_schema(transformed_original_input_data),
219
        ordered_component_names=ordered_component_names or [],
220
        include_outputs_from=include_outputs_from or set(),
221
    )
222
    return pipeline_snapshot
1✔
223

224

225
def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot:
1✔
226
    """
227
    Save the pipeline snapshot to a file.
228

229
    :param pipeline_snapshot: The pipeline snapshot to save.
230

231
    :returns:
232
        The dictionary containing the snapshot of the pipeline containing the following keys:
233
        - input_data: The original input data passed to the pipeline.
234
        - timestamp: The timestamp of the breakpoint.
235
        - pipeline_breakpoint: The component name and visit count that triggered the breakpoint.
236
        - pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys:
237
            - inputs: The current state of inputs for pipeline components.
238
            - component_visits: The visit count of the components when the breakpoint was triggered.
239
            - ordered_component_names: The order of components in the pipeline.
240
    """
241
    break_point = pipeline_snapshot.break_point
1✔
242
    if isinstance(break_point, AgentBreakpoint):
1✔
243
        snapshot_file_path = break_point.break_point.snapshot_file_path
1✔
244
    else:
245
        snapshot_file_path = break_point.snapshot_file_path
1✔
246

247
    if snapshot_file_path is not None:
1✔
248
        dt = pipeline_snapshot.timestamp or datetime.now()
1✔
249
        _save_pipeline_snapshot_to_file(
1✔
250
            pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt
251
        )
252

253
    return pipeline_snapshot
1✔
254

255

256
def _transform_json_structure(data: Union[dict[str, Any], list[Any], Any]) -> Any:
1✔
257
    """
258
    Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level.
259

260
    For example:
261
    "key": [{"sender": null, "value": "some value"}] -> "key": "some value"
262

263
    :param data: The JSON structure to transform.
264
    :returns: The transformed structure.
265
    """
266
    if isinstance(data, dict):
1✔
267
        # If this dict has both 'sender' and 'value', return just the value
268
        if "value" in data and "sender" in data:
1✔
269
            return data["value"]
1✔
270
        # Otherwise, recursively process each key-value pair
271
        return {k: _transform_json_structure(v) for k, v in data.items()}
1✔
272

273
    if isinstance(data, list):
1✔
274
        # First, transform each item in the list.
275
        transformed = [_transform_json_structure(item) for item in data]
1✔
276
        # If the original list has exactly one element and that element was a dict
277
        # with 'sender' and 'value', then unwrap the list.
278
        if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]:
1✔
279
            return transformed[0]
1✔
280
        return transformed
1✔
281

282
    # For other data types, just return the value as is.
283
    return data
1✔
284

285

286
def _trigger_break_point(*, pipeline_snapshot: PipelineSnapshot, pipeline_outputs: dict[str, Any]) -> None:
1✔
287
    """
288
    Trigger a breakpoint by saving a snapshot and raising exception.
289

290
    :param pipeline_snapshot: The current pipeline snapshot containing the state and break point
291
    :param pipeline_outputs: Current pipeline outputs
292
    :raises PipelineBreakpointException: When breakpoint is triggered
293
    """
294
    _save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot)
1✔
295

296
    if isinstance(pipeline_snapshot.break_point, Breakpoint):
1✔
297
        component_name = pipeline_snapshot.break_point.component_name
1✔
298
    else:
299
        component_name = pipeline_snapshot.break_point.agent_name
×
300

301
    component_visits = pipeline_snapshot.pipeline_state.component_visits
1✔
302
    msg = f"Breaking at component {component_name} at visit count {component_visits[component_name]}"
1✔
303
    raise BreakpointException(
1✔
304
        message=msg, component=component_name, inputs=pipeline_snapshot.pipeline_state.inputs, results=pipeline_outputs
305
    )
306

307

308
def _create_agent_snapshot(
1✔
309
    *, component_visits: dict[str, int], agent_breakpoint: AgentBreakpoint, component_inputs: dict[str, Any]
310
) -> AgentSnapshot:
311
    """
312
    Create a snapshot of the agent's state.
313

314
    :param component_visits: The visit counts for the agent's components.
315
    :param agent_breakpoint: AgentBreakpoint object containing breakpoints
316
    :return: An AgentSnapshot containing the agent's state and component visits.
317
    """
318
    return AgentSnapshot(
1✔
319
        component_inputs={
320
            "chat_generator": _serialize_value_with_schema(deepcopy(component_inputs["chat_generator"])),
321
            "tool_invoker": _serialize_value_with_schema(deepcopy(component_inputs["tool_invoker"])),
322
        },
323
        component_visits=component_visits,
324
        break_point=agent_breakpoint,
325
        timestamp=datetime.now(),
326
    )
327

328

329
def _validate_tool_breakpoint_is_valid(
1✔
330
    agent_breakpoint: AgentBreakpoint, tools: Union[list["Tool"], "Toolset"]
331
) -> None:
332
    """
333
    Validates the AgentBreakpoint passed to the agent.
334

335
    Validates that the tool name in ToolBreakpoints correspond to a tool available in the agent.
336

337
    :param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components.
338
    :param tools: List of Tool objects or a Toolset that the agent can use.
339
    :raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools.
340
    """
341

342
    available_tool_names = {tool.name for tool in tools}
1✔
343
    tool_breakpoint = agent_breakpoint.break_point
1✔
344
    # Assert added for mypy to pass, but this is already checked before this function is called
345
    assert isinstance(tool_breakpoint, ToolBreakpoint)
1✔
346
    if tool_breakpoint.tool_name and tool_breakpoint.tool_name not in available_tool_names:
1✔
347
        raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools")
1✔
348

349

350
def _trigger_chat_generator_breakpoint(
1✔
351
    *, agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot]
352
) -> None:
353
    """
354
    Trigger a breakpoint before ChatGenerator execution in Agent.
355

356
    :param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints
357
    :param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent.
358
    :raises BreakpointException: Always raised when this function is called, indicating a breakpoint has been triggered.
359
    """
360

361
    break_point = agent_snapshot.break_point.break_point
1✔
362

363
    if parent_snapshot is None:
1✔
364
        # Create an empty pipeline snapshot if no parent snapshot is provided
365
        final_snapshot = PipelineSnapshot(
1✔
366
            pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}),
367
            timestamp=agent_snapshot.timestamp,
368
            break_point=agent_snapshot.break_point,
369
            agent_snapshot=agent_snapshot,
370
            original_input_data={},
371
            ordered_component_names=[],
372
            include_outputs_from=set(),
373
        )
374
    else:
375
        final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot)
1✔
376
    _save_pipeline_snapshot(pipeline_snapshot=final_snapshot)
1✔
377

378
    msg = (
1✔
379
        f"Breaking at {break_point.component_name} visit count "
380
        f"{agent_snapshot.component_visits[break_point.component_name]}"
381
    )
382
    logger.info(msg)
1✔
383
    raise BreakpointException(
1✔
384
        message=msg,
385
        component=break_point.component_name,
386
        inputs=agent_snapshot.component_inputs,
387
        results=agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"],
388
    )
389

390

391
def _handle_tool_invoker_breakpoint(
1✔
392
    *, llm_messages: list[ChatMessage], agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot]
393
) -> None:
394
    """
395
    Check if a tool call breakpoint should be triggered before executing the tool invoker.
396

397
    :param llm_messages: List of ChatMessage objects containing potential tool calls.
398
    :param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints.
399
    :param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent.
400
    :raises BreakpointException: If the breakpoint is triggered, indicating a breakpoint has been reached for a tool
401
        call.
402
    """
403
    if not isinstance(agent_snapshot.break_point.break_point, ToolBreakpoint):
1✔
404
        return
×
405

406
    tool_breakpoint = agent_snapshot.break_point.break_point
1✔
407

408
    # Check if we should break for this specific tool or all tools
409
    if tool_breakpoint.tool_name is None:
1✔
410
        # Break for any tool call
411
        should_break = any(msg.tool_call for msg in llm_messages)
1✔
412
    else:
413
        # Break only for the specific tool
414
        should_break = any(
1✔
415
            msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages
416
        )
417

418
    if not should_break:
1✔
419
        return  # No breakpoint triggered
×
420

421
    if parent_snapshot is None:
1✔
422
        # Create an empty pipeline snapshot if no parent snapshot is provided
423
        final_snapshot = PipelineSnapshot(
1✔
424
            pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}),
425
            timestamp=agent_snapshot.timestamp,
426
            break_point=agent_snapshot.break_point,
427
            agent_snapshot=agent_snapshot,
428
            original_input_data={},
429
            ordered_component_names=[],
430
            include_outputs_from=set(),
431
        )
432
    else:
433
        final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot)
1✔
434
    _save_pipeline_snapshot(pipeline_snapshot=final_snapshot)
1✔
435

436
    msg = (
1✔
437
        f"Breaking at {tool_breakpoint.component_name} visit count "
438
        f"{agent_snapshot.component_visits[tool_breakpoint.component_name]}"
439
    )
440
    if tool_breakpoint.tool_name:
1✔
441
        msg += f" for tool {tool_breakpoint.tool_name}"
1✔
442
    logger.info(msg)
1✔
443

444
    raise BreakpointException(
1✔
445
        message=msg,
446
        component=tool_breakpoint.component_name,
447
        inputs=agent_snapshot.component_inputs,
448
        results=agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"],
449
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc