• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 17578922709

09 Sep 2025 09:57AM UTC coverage: 92.063% (-0.003%) from 92.066%
17578922709

Pull #9754

github

web-flow
Merge 868faf198 into 34f1a0412
Pull Request #9754: feat: support structured outputs in `OpenAIChatGenerator`

12992 of 14112 relevant lines covered (92.06%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.92
haystack/core/super_component/super_component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import functools
1✔
6
from pathlib import Path
1✔
7
from types import new_class
1✔
8
from typing import Any, Optional, TypeVar, Union
1✔
9

10
from haystack import logging
1✔
11
from haystack.core.component.component import component
1✔
12
from haystack.core.pipeline.async_pipeline import AsyncPipeline
1✔
13
from haystack.core.pipeline.pipeline import Pipeline
1✔
14
from haystack.core.pipeline.utils import parse_connect_string
1✔
15
from haystack.core.serialization import default_from_dict, default_to_dict, generate_qualified_class_name
1✔
16
from haystack.core.super_component.utils import _delegate_default, _is_compatible
1✔
17

18
logger = logging.getLogger(__name__)
1✔
19

20
T = TypeVar("T")
1✔
21

22

23
class InvalidMappingTypeError(Exception):
1✔
24
    """Raised when input or output mappings have invalid types or type conflicts."""
25

26
    pass
1✔
27

28

29
class InvalidMappingValueError(Exception):
1✔
30
    """Raised when input or output mappings have invalid values or missing components/sockets."""
31

32
    pass
1✔
33

34

35
@component
1✔
36
class _SuperComponent:
1✔
37
    def __init__(
1✔
38
        self,
39
        pipeline: Union[Pipeline, AsyncPipeline],
40
        input_mapping: Optional[dict[str, list[str]]] = None,
41
        output_mapping: Optional[dict[str, str]] = None,
42
    ) -> None:
43
        """
44
        Creates a SuperComponent with optional input and output mappings.
45

46
        :param pipeline: The pipeline instance or async pipeline instance to be wrapped
47
        :param input_mapping: A dictionary mapping component input names to pipeline input socket paths.
48
            If not provided, a default input mapping will be created based on all pipeline inputs.
49
            Example:
50
            ```python
51
            input_mapping={
52
                "query": ["retriever.query", "prompt_builder.query"],
53
            }
54
            ```
55
        :param output_mapping: A dictionary mapping pipeline output socket paths to component output names.
56
            If not provided, a default output mapping will be created based on all pipeline outputs.
57
            Example:
58
            ```python
59
            output_mapping={
60
                "retriever.documents": "documents",
61
                "generator.replies": "replies",
62
            }
63
            ```
64
        :raises InvalidMappingError: Raised if any mapping is invalid or type conflicts occur
65
        :raises ValueError: Raised if no pipeline is provided
66
        """
67
        if pipeline is None:
1✔
68
            raise ValueError("Pipeline must be provided to SuperComponent.")
×
69

70
        self.pipeline: Union[Pipeline, AsyncPipeline] = pipeline
1✔
71
        self._warmed_up = False
1✔
72

73
        # Determine input types based on pipeline and mapping
74
        pipeline_inputs = self.pipeline.inputs()
1✔
75
        resolved_input_mapping = (
1✔
76
            input_mapping if input_mapping is not None else self._create_input_mapping(pipeline_inputs)
77
        )
78
        self._validate_input_mapping(pipeline_inputs, resolved_input_mapping)
1✔
79
        input_types = self._resolve_input_types_from_mapping(pipeline_inputs, resolved_input_mapping)
1✔
80
        # Set input types on the component
81
        for input_name, info in input_types.items():
1✔
82
            component.set_input_type(self, name=input_name, **info)
1✔
83

84
        self.input_mapping: dict[str, list[str]] = resolved_input_mapping
1✔
85
        self._original_input_mapping = input_mapping
1✔
86

87
        # Set output types based on pipeline and mapping
88
        leaf_pipeline_outputs = self.pipeline.outputs()
1✔
89
        all_possible_pipeline_outputs = self.pipeline.outputs(include_components_with_connected_outputs=True)
1✔
90

91
        resolved_output_mapping = (
1✔
92
            output_mapping if output_mapping is not None else self._create_output_mapping(leaf_pipeline_outputs)
93
        )
94
        self._validate_output_mapping(all_possible_pipeline_outputs, resolved_output_mapping)
1✔
95
        output_types = self._resolve_output_types_from_mapping(all_possible_pipeline_outputs, resolved_output_mapping)
1✔
96
        # Set output types on the component
97
        component.set_output_types(self, **output_types)
1✔
98
        self.output_mapping: dict[str, str] = resolved_output_mapping
1✔
99
        self._original_output_mapping = output_mapping
1✔
100

101
    def warm_up(self) -> None:
1✔
102
        """
103
        Warms up the SuperComponent by warming up the wrapped pipeline.
104
        """
105
        if not self._warmed_up:
1✔
106
            self.pipeline.warm_up()
1✔
107
            self._warmed_up = True
1✔
108

109
    def run(self, **kwargs: Any) -> dict[str, Any]:
1✔
110
        """
111
        Runs the wrapped pipeline with the provided inputs.
112

113
        Steps:
114
        1. Maps the inputs from kwargs to pipeline component inputs
115
        2. Runs the pipeline
116
        3. Maps the pipeline outputs to the SuperComponent's outputs
117

118
        :param kwargs: Keyword arguments matching the SuperComponent's input names
119
        :returns:
120
            Dictionary containing the SuperComponent's output values
121
        """
122
        filtered_inputs = {param: value for param, value in kwargs.items() if value != _delegate_default}
1✔
123
        pipeline_inputs = self._map_explicit_inputs(input_mapping=self.input_mapping, inputs=filtered_inputs)
1✔
124
        include_outputs_from = self._get_include_outputs_from()
1✔
125
        pipeline_outputs = self.pipeline.run(data=pipeline_inputs, include_outputs_from=include_outputs_from)
1✔
126
        return self._map_explicit_outputs(pipeline_outputs, self.output_mapping)
1✔
127

128
    def _get_include_outputs_from(self) -> set[str]:
1✔
129
        # Collecting the component names from output_mapping
130
        return {self._split_component_path(path)[0] for path in self.output_mapping.keys()}
1✔
131

132
    async def run_async(self, **kwargs: Any) -> dict[str, Any]:
1✔
133
        """
134
        Runs the wrapped pipeline with the provided inputs async.
135

136
        Steps:
137
        1. Maps the inputs from kwargs to pipeline component inputs
138
        2. Runs the pipeline async
139
        3. Maps the pipeline outputs to the SuperComponent's outputs
140

141
        :param kwargs: Keyword arguments matching the SuperComponent's input names
142
        :returns:
143
            Dictionary containing the SuperComponent's output values
144
        :raises TypeError:
145
            If the pipeline is not an AsyncPipeline
146
        """
147
        if not isinstance(self.pipeline, AsyncPipeline):
1✔
148
            raise TypeError("Pipeline is not an AsyncPipeline. run_async is not supported.")
×
149

150
        filtered_inputs = {param: value for param, value in kwargs.items() if value != _delegate_default}
1✔
151
        pipeline_inputs = self._map_explicit_inputs(input_mapping=self.input_mapping, inputs=filtered_inputs)
1✔
152
        pipeline_outputs = await self.pipeline.run_async(data=pipeline_inputs)
1✔
153
        return self._map_explicit_outputs(pipeline_outputs, self.output_mapping)
1✔
154

155
    @staticmethod
1✔
156
    def _split_component_path(path: str) -> tuple[str, str]:
1✔
157
        """
158
        Splits a component path into a component name and a socket name.
159

160
        :param path: A string in the format "component_name.socket_name".
161
        :returns:
162
            A tuple containing (component_name, socket_name).
163
        :raises InvalidMappingValueError:
164
            If the path format is incorrect.
165
        """
166
        comp_name, socket_name = parse_connect_string(path)
1✔
167
        if socket_name is None:
1✔
168
            raise InvalidMappingValueError(f"Invalid path format: '{path}'. Expected 'component_name.socket_name'.")
1✔
169
        return comp_name, socket_name
1✔
170

171
    def _validate_input_mapping(
1✔
172
        self, pipeline_inputs: dict[str, dict[str, Any]], input_mapping: dict[str, list[str]]
173
    ) -> None:
174
        """
175
        Validates the input mapping to ensure that specified components and sockets exist in the pipeline.
176

177
        :param pipeline_inputs: A dictionary containing pipeline input specifications.
178
        :param input_mapping: A dictionary mapping wrapper input names to pipeline socket paths.
179
        :raises InvalidMappingTypeError:
180
            If the input mapping is of invalid type or contains invalid types.
181
        :raises InvalidMappingValueError:
182
            If the input mapping contains nonexistent components or sockets.
183
        """
184
        if not isinstance(input_mapping, dict):
1✔
185
            raise InvalidMappingTypeError("input_mapping must be a dictionary")
×
186

187
        for wrapper_input_name, pipeline_input_paths in input_mapping.items():
1✔
188
            if not isinstance(pipeline_input_paths, list):
1✔
189
                raise InvalidMappingTypeError(f"Input paths for '{wrapper_input_name}' must be a list of strings.")
1✔
190
            for path in pipeline_input_paths:
1✔
191
                comp_name, socket_name = self._split_component_path(path)
1✔
192
                if comp_name not in pipeline_inputs:
1✔
193
                    raise InvalidMappingValueError(
1✔
194
                        f"Component '{comp_name}' not found in pipeline inputs.\n"
195
                        f"Available components: {list(pipeline_inputs.keys())}"
196
                    )
197
                if socket_name not in pipeline_inputs[comp_name]:
1✔
198
                    raise InvalidMappingValueError(
×
199
                        f"Input socket '{socket_name}' not found in component '{comp_name}'.\n"
200
                        f"Available inputs for '{comp_name}': {list(pipeline_inputs[comp_name].keys())}"
201
                    )
202

203
    def _resolve_input_types_from_mapping(
1✔
204
        self, pipeline_inputs: dict[str, dict[str, Any]], input_mapping: dict[str, list[str]]
205
    ) -> dict[str, dict[str, Any]]:
206
        """
207
        Resolves and validates input types based on the provided input mapping.
208

209
        This function ensures that all mapped pipeline inputs are compatible, consolidating types
210
        when multiple mappings exist. It also determines whether an input is mandatory or has a default value.
211

212
        :param pipeline_inputs: A dictionary containing pipeline input specifications.
213
        :param input_mapping: A dictionary mapping SuperComponent inputs to pipeline socket paths.
214
        :returns:
215
            A dictionary specifying the resolved input types and their properties.
216
        :raises InvalidMappingTypeError:
217
            If the input mapping contains incompatible types.
218
        """
219
        aggregated_inputs: dict[str, dict[str, Any]] = {}
1✔
220
        for wrapper_input_name, pipeline_input_paths in input_mapping.items():
1✔
221
            for path in pipeline_input_paths:
1✔
222
                comp_name, socket_name = self._split_component_path(path)
1✔
223
                socket_info = pipeline_inputs[comp_name][socket_name]
1✔
224

225
                # Add to aggregated inputs
226
                existing_socket_info = aggregated_inputs.get(wrapper_input_name)
1✔
227
                if existing_socket_info is None:
1✔
228
                    aggregated_inputs[wrapper_input_name] = {"type": socket_info["type"]}
1✔
229
                    if not socket_info["is_mandatory"]:
1✔
230
                        aggregated_inputs[wrapper_input_name]["default"] = _delegate_default
1✔
231
                    continue
1✔
232

233
                is_compatible, common_type = _is_compatible(existing_socket_info["type"], socket_info["type"])
1✔
234

235
                if not is_compatible:
1✔
236
                    raise InvalidMappingTypeError(
×
237
                        f"Type conflict for input '{socket_name}' from component '{comp_name}'. "
238
                        f"Existing type: {existing_socket_info['type']}, new type: {socket_info['type']}."
239
                    )
240

241
                # Use the common type for the aggregated input
242
                aggregated_inputs[wrapper_input_name]["type"] = common_type
1✔
243

244
                # If any socket requires mandatory inputs then the aggregated input is also considered mandatory.
245
                # So we use the type of the mandatory input and remove the default value if it exists.
246
                if socket_info["is_mandatory"]:
1✔
247
                    aggregated_inputs[wrapper_input_name].pop("default", None)
1✔
248

249
        return aggregated_inputs
1✔
250

251
    @staticmethod
1✔
252
    def _create_input_mapping(pipeline_inputs: dict[str, dict[str, Any]]) -> dict[str, list[str]]:
1✔
253
        """
254
        Create an input mapping from pipeline inputs.
255

256
        :param pipeline_inputs: Dictionary of pipeline input specifications
257
        :returns:
258
            Dictionary mapping SuperComponent input names to pipeline socket paths
259
        """
260
        input_mapping: dict[str, list[str]] = {}
1✔
261
        for comp_name, inputs_dict in pipeline_inputs.items():
1✔
262
            for socket_name in inputs_dict.keys():
1✔
263
                existing_socket_info = input_mapping.get(socket_name)
1✔
264
                if existing_socket_info is None:
1✔
265
                    input_mapping[socket_name] = [f"{comp_name}.{socket_name}"]
1✔
266
                    continue
1✔
267
                input_mapping[socket_name].append(f"{comp_name}.{socket_name}")
1✔
268
        return input_mapping
1✔
269

270
    def _validate_output_mapping(
1✔
271
        self, pipeline_outputs: dict[str, dict[str, Any]], output_mapping: dict[str, str]
272
    ) -> None:
273
        """
274
        Validates the output mapping to ensure that specified components and sockets exist in the pipeline.
275

276
        :param pipeline_outputs: A dictionary containing pipeline output specifications.
277
        :param output_mapping: A dictionary mapping pipeline socket paths to wrapper output names.
278
        :raises InvalidMappingTypeError:
279
            If the output mapping is of invalid type or contains invalid types.
280
        :raises InvalidMappingValueError:
281
            If the output mapping contains nonexistent components or sockets.
282
        """
283
        for pipeline_output_path, wrapper_output_name in output_mapping.items():
1✔
284
            if not isinstance(wrapper_output_name, str):
1✔
285
                raise InvalidMappingTypeError("Output names in output_mapping must be strings.")
1✔
286
            comp_name, socket_name = self._split_component_path(pipeline_output_path)
1✔
287
            if comp_name not in pipeline_outputs:
1✔
288
                raise InvalidMappingValueError(f"Component '{comp_name}' not found among pipeline outputs.")
1✔
289
            if socket_name not in pipeline_outputs[comp_name]:
1✔
290
                raise InvalidMappingValueError(f"Output socket '{socket_name}' not found in component '{comp_name}'.")
×
291

292
    def _resolve_output_types_from_mapping(
1✔
293
        self, pipeline_outputs: dict[str, dict[str, Any]], output_mapping: dict[str, str]
294
    ) -> dict[str, Any]:
295
        """
296
        Resolves and validates output types based on the provided output mapping.
297

298
        This function ensures that all mapped pipeline outputs are correctly assigned to
299
        the corresponding SuperComponent outputs while preventing duplicate output names.
300

301
        :param pipeline_outputs: A dictionary containing pipeline output specifications.
302
        :param output_mapping: A dictionary mapping pipeline output socket paths to SuperComponent output names.
303
        :returns:
304
            A dictionary mapping SuperComponent output names to their resolved types.
305
        :raises InvalidMappingValueError:
306
            If the output mapping contains duplicate output names.
307
        """
308
        resolved_outputs = {}
1✔
309
        for pipeline_output_path, wrapper_output_name in output_mapping.items():
1✔
310
            comp_name, socket_name = self._split_component_path(pipeline_output_path)
1✔
311
            if wrapper_output_name in resolved_outputs:
1✔
312
                raise InvalidMappingValueError(f"Duplicate output name '{wrapper_output_name}' in output_mapping.")
1✔
313
            resolved_outputs[wrapper_output_name] = pipeline_outputs[comp_name][socket_name]["type"]
1✔
314
        return resolved_outputs
1✔
315

316
    @staticmethod
1✔
317
    def _create_output_mapping(pipeline_outputs: dict[str, dict[str, Any]]) -> dict[str, str]:
1✔
318
        """
319
        Create an output mapping from pipeline outputs.
320

321
        :param pipeline_outputs: Dictionary of pipeline output specifications
322
        :returns:
323
            Dictionary mapping pipeline socket paths to SuperComponent output names
324
        :raises InvalidMappingValueError:
325
            If there are output name conflicts between components
326
        """
327
        output_mapping = {}
1✔
328
        used_output_names: set[str] = set()
1✔
329
        for comp_name, outputs_dict in pipeline_outputs.items():
1✔
330
            for socket_name in outputs_dict.keys():
1✔
331
                if socket_name in used_output_names:
1✔
332
                    raise InvalidMappingValueError(
×
333
                        f"Output name conflict: '{socket_name}' is produced by multiple components. "
334
                        "Please provide an output_mapping to resolve this conflict."
335
                    )
336
                used_output_names.add(socket_name)
1✔
337
                output_mapping[f"{comp_name}.{socket_name}"] = socket_name
1✔
338
        return output_mapping
1✔
339

340
    def _map_explicit_inputs(
1✔
341
        self, input_mapping: dict[str, list[str]], inputs: dict[str, Any]
342
    ) -> dict[str, dict[str, Any]]:
343
        """
344
        Map inputs according to explicit input mapping.
345

346
        :param input_mapping: Mapping configuration for inputs
347
        :param inputs: Input arguments provided to wrapper
348
        :return: Dictionary of mapped pipeline inputs
349
        """
350
        pipeline_inputs: dict[str, dict[str, Any]] = {}
1✔
351
        for wrapper_input_name, pipeline_input_paths in input_mapping.items():
1✔
352
            if wrapper_input_name not in inputs:
1✔
353
                continue
1✔
354

355
            for socket_path in pipeline_input_paths:
1✔
356
                comp_name, input_name = self._split_component_path(socket_path)
1✔
357
                if comp_name not in pipeline_inputs:
1✔
358
                    pipeline_inputs[comp_name] = {}
1✔
359
                pipeline_inputs[comp_name][input_name] = inputs[wrapper_input_name]
1✔
360

361
        return pipeline_inputs
1✔
362

363
    def _map_explicit_outputs(
1✔
364
        self, pipeline_outputs: dict[str, dict[str, Any]], output_mapping: dict[str, str]
365
    ) -> dict[str, Any]:
366
        """
367
        Map outputs according to explicit output mapping.
368

369
        :param pipeline_outputs: Raw outputs from pipeline execution
370
        :param output_mapping: Output mapping configuration
371
        :return: Dictionary of mapped outputs
372
        """
373
        outputs: dict[str, Any] = {}
1✔
374
        for pipeline_output_path, wrapper_output_name in output_mapping.items():
1✔
375
            comp_name, socket_name = self._split_component_path(pipeline_output_path)
1✔
376
            if comp_name in pipeline_outputs and socket_name in pipeline_outputs[comp_name]:
1✔
377
                outputs[wrapper_output_name] = pipeline_outputs[comp_name][socket_name]
1✔
378
        return outputs
1✔
379

380
    def _to_super_component_dict(self) -> dict[str, Any]:
1✔
381
        """
382
        Convert to a SuperComponent dictionary representation.
383

384
        :return: Dictionary containing serialized SuperComponent data
385
        """
386
        serialized_pipeline = self.pipeline.to_dict()
1✔
387
        is_pipeline_async = isinstance(self.pipeline, AsyncPipeline)
1✔
388
        serialized = default_to_dict(
1✔
389
            self,
390
            pipeline=serialized_pipeline,
391
            input_mapping=self._original_input_mapping,
392
            output_mapping=self._original_output_mapping,
393
            is_pipeline_async=is_pipeline_async,
394
        )
395
        serialized["type"] = generate_qualified_class_name(SuperComponent)
1✔
396
        return serialized
1✔
397

398

399
@component
1✔
400
class SuperComponent(_SuperComponent):
1✔
401
    """
402
    A class for creating super components that wrap around a Pipeline.
403

404
    This component allows for remapping of input and output socket names between the wrapped pipeline and the
405
    SuperComponent's input and output names. This is useful for creating higher-level components that abstract
406
    away the details of the wrapped pipeline.
407

408
    ### Usage example
409

410
    ```python
411
    from haystack import Pipeline, SuperComponent
412
    from haystack.components.generators.chat import OpenAIChatGenerator
413
    from haystack.components.builders import ChatPromptBuilder
414
    from haystack.components.retrievers import InMemoryBM25Retriever
415
    from haystack.dataclasses.chat_message import ChatMessage
416
    from haystack.document_stores.in_memory import InMemoryDocumentStore
417
    from haystack.dataclasses import Document
418

419
    document_store = InMemoryDocumentStore()
420
    documents = [
421
        Document(content="Paris is the capital of France."),
422
        Document(content="London is the capital of England."),
423
    ]
424
    document_store.write_documents(documents)
425

426
    prompt_template = [
427
        ChatMessage.from_user(
428
        '''
429
        According to the following documents:
430
        {% for document in documents %}
431
        {{document.content}}
432
        {% endfor %}
433
        Answer the given question: {{query}}
434
        Answer:
435
        '''
436
        )
437
    ]
438

439
    prompt_builder = ChatPromptBuilder(template=prompt_template, required_variables="*")
440

441
    pipeline = Pipeline()
442
    pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store))
443
    pipeline.add_component("prompt_builder", prompt_builder)
444
    pipeline.add_component("llm", OpenAIChatGenerator())
445
    pipeline.connect("retriever.documents", "prompt_builder.documents")
446
    pipeline.connect("prompt_builder.prompt", "llm.messages")
447

448
    # Create a super component with simplified input/output mapping
449
    wrapper = SuperComponent(
450
        pipeline=pipeline,
451
        input_mapping={
452
            "query": ["retriever.query", "prompt_builder.query"],
453
        },
454
        output_mapping={"llm.replies": "replies"}
455
    )
456

457
    # Run the pipeline with simplified interface
458
    result = wrapper.run(query="What is the capital of France?")
459
    print(result)
460
    {'replies': [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>,
461
     _content=[TextContent(text='The capital of France is Paris.')],...)
462
    ```
463

464
    """
465

466
    def to_dict(self) -> dict[str, Any]:
1✔
467
        """
468
        Serializes the SuperComponent into a dictionary.
469

470
        :returns:
471
            Dictionary with serialized data.
472
        """
473
        return self._to_super_component_dict()
1✔
474

475
    @classmethod
1✔
476
    def from_dict(cls, data: dict[str, Any]) -> "SuperComponent":
1✔
477
        """
478
        Deserializes the SuperComponent from a dictionary.
479

480
        :param data: The dictionary to deserialize from.
481
        :returns:
482
            The deserialized SuperComponent.
483
        """
484
        is_pipeline_async = data["init_parameters"].pop("is_pipeline_async", False)
1✔
485
        pipeline_class = AsyncPipeline if is_pipeline_async else Pipeline
1✔
486
        pipeline = pipeline_class.from_dict(data["init_parameters"]["pipeline"])
1✔
487
        data["init_parameters"]["pipeline"] = pipeline
1✔
488
        return default_from_dict(cls, data)
1✔
489

490
    def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30) -> None:
1✔
491
        """
492
        Display an image representing this SuperComponent's underlying pipeline in a Jupyter notebook.
493

494
        This function generates a diagram of the Pipeline using a Mermaid server and displays it directly in
495
        the notebook.
496

497
        :param server_url:
498
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
499
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
500
            info on how to set up your own Mermaid server.
501

502
        :param params:
503
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
504
            Supported keys:
505
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
506
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
507
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
508
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
509
                - width: Width of the output image (integer).
510
                - height: Height of the output image (integer).
511
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
512
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
513
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
514
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
515

516
        :param timeout:
517
            Timeout in seconds for the request to the Mermaid server.
518

519
        :raises PipelineDrawingError:
520
            If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
521
        """
522
        self.pipeline.show(server_url=server_url, params=params, timeout=timeout)
1✔
523

524
    def draw(
1✔
525
        self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30
526
    ) -> None:
527
        """
528
        Save an image representing this SuperComponent's underlying pipeline to the specified file path.
529

530
        This function generates a diagram of the Pipeline using the Mermaid server and saves it to the provided path.
531

532
        :param path:
533
            The file path where the generated image will be saved.
534
        :param server_url:
535
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
536
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
537
            info on how to set up your own Mermaid server.
538
        :param params:
539
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
540
            Supported keys:
541
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
542
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
543
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
544
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
545
                - width: Width of the output image (integer).
546
                - height: Height of the output image (integer).
547
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
548
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
549
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
550
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
551

552
        :param timeout:
553
            Timeout in seconds for the request to the Mermaid server.
554

555
        :raises PipelineDrawingError:
556
            If there is an issue with rendering or saving the image.
557
        """
558
        self.pipeline.draw(path=path, server_url=server_url, params=params, timeout=timeout)
1✔
559

560

561
def super_component(cls: type[T]) -> type[T]:
1✔
562
    """
563
    Decorator that converts a class into a SuperComponent.
564

565
    This decorator:
566
    1. Creates a new class that inherits from SuperComponent
567
    2. Copies all methods and attributes from the original class
568
    3. Adds initialization logic to properly set up the SuperComponent
569

570
    The decorated class should define:
571
    - pipeline: A Pipeline or AsyncPipeline instance in the __init__ method
572
    - input_mapping: Dictionary mapping component inputs to pipeline inputs (optional)
573
    - output_mapping: Dictionary mapping pipeline outputs to component outputs (optional)
574
    """
575
    logger.debug("Registering {cls} as a super_component", cls=cls)
1✔
576

577
    # Store the original __init__ method
578
    original_init = cls.__init__
1✔
579

580
    # Create a new __init__ method that will initialize both the original class and SuperComponent
581
    def init_wrapper(self, *args, **kwargs):
1✔
582
        # Call the original __init__ to set up pipeline and mappings
583
        original_init(self, *args, **kwargs)
1✔
584

585
        # Verify required attributes
586
        if not hasattr(self, "pipeline"):
1✔
587
            raise ValueError(f"Class {cls.__name__} decorated with @super_component must define a 'pipeline' attribute")
×
588

589
        # Initialize SuperComponent
590
        _SuperComponent.__init__(
1✔
591
            self,
592
            pipeline=self.pipeline,
593
            input_mapping=getattr(self, "input_mapping", None),
594
            output_mapping=getattr(self, "output_mapping", None),
595
        )
596

597
    # Preserve original init's signature for IDEs/docs/tools
598
    init_wrapper = functools.wraps(original_init)(init_wrapper)
1✔
599

600
    # Function to copy namespace from the original class
601
    def copy_class_namespace(namespace):
1✔
602
        """Copy all attributes from the original class except special ones."""
603
        for key, val in dict(cls.__dict__).items():
1✔
604
            # Skip special attributes that should be recreated
605
            if key in ("__dict__", "__weakref__"):
1✔
606
                continue
1✔
607

608
            # Override __init__ with our wrapper
609
            if key == "__init__":
1✔
610
                namespace["__init__"] = init_wrapper
1✔
611
                continue
1✔
612

613
            namespace[key] = val
1✔
614

615
    # Create a new class inheriting from SuperComponent with the original methods
616
    # We use (SuperComponent,) + cls.__bases__ to make the new class inherit from
617
    # SuperComponent and all the original class's bases
618
    new_cls = new_class(cls.__name__, (_SuperComponent,) + cls.__bases__, {}, copy_class_namespace)
1✔
619

620
    # Copy other class attributes
621
    new_cls.__module__ = cls.__module__
1✔
622
    new_cls.__qualname__ = cls.__qualname__
1✔
623
    new_cls.__doc__ = cls.__doc__
1✔
624

625
    # Apply the component decorator to the new class
626
    return component(new_cls)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc