• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 10074443031

24 Jul 2024 09:51AM UTC coverage: 90.084% (-0.04%) from 90.122%
10074443031

Pull #7943

github

web-flow
Merge 3c2a91368 into 0c9dc008f
Pull Request #7943: feat: Multimodal ChatMessage

6995 of 7765 relevant lines covered (90.08%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.07
haystack/components/builders/chat_prompt_builder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from copy import deepcopy
1✔
6
from typing import Any, Dict, List, Optional, Set, Union
1✔
7

8
from jinja2 import Template, meta
1✔
9

10
from haystack import component, default_from_dict, default_to_dict, logging
1✔
11
from haystack.dataclasses.byte_stream import ByteStream
1✔
12
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
1✔
13

14
logger = logging.getLogger(__name__)
1✔
15

16

17
@component
1✔
18
class ChatPromptBuilder:
1✔
19
    """
20
    ChatPromptBuilder is a component that renders a chat prompt from a template string using Jinja2 templates.
21

22
    It is designed to construct prompts for the pipeline using static or dynamic templates: Users can change
23
    the prompt template at runtime by providing a new template for each pipeline run invocation if needed.
24

25
    The template variables found in the init template string are used as input types for the component and are all
26
    optional, unless explicitly specified. If an optional template variable is not provided as an input, it will be
27
    replaced with an empty string in the rendered prompt. Use `variable` and `required_variables` to specify the input
28
    types and required variables.
29

30
    Usage example with static prompt template:
31
    ```python
32
    template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
33
    builder = ChatPromptBuilder(template=template)
34
    builder.run(target_language="spanish", snippet="I can't speak spanish.")
35
    ```
36

37
    Usage example of overriding the static template at runtime:
38
    ```python
39
    template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
40
    builder = ChatPromptBuilder(template=template)
41
    builder.run(target_language="spanish", snippet="I can't speak spanish.")
42

43
    msg = "Translate to {{ target_language }} and summarize. Context: {{ snippet }}; Summary:"
44
    summary_template = [ChatMessage.from_user(msg)]
45
    builder.run(target_language="spanish", snippet="I can't speak spanish.", template=summary_template)
46
    ```
47

48
    Usage example with dynamic prompt template:
49
    ```python
50
    from haystack.components.builders import ChatPromptBuilder
51
    from haystack.components.generators.chat import OpenAIChatGenerator
52
    from haystack.dataclasses import ChatMessage
53
    from haystack import Pipeline
54
    from haystack.utils import Secret
55

56
    # no parameter init, we don't use any runtime template variables
57
    prompt_builder = ChatPromptBuilder()
58
    llm = OpenAIChatGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
59

60
    pipe = Pipeline()
61
    pipe.add_component("prompt_builder", prompt_builder)
62
    pipe.add_component("llm", llm)
63
    pipe.connect("prompt_builder.prompt", "llm.messages")
64

65
    location = "Berlin"
66
    language = "English"
67
    system_message = ChatMessage.from_system("You are an assistant giving information to tourists in {{language}}")
68
    messages = [system_message, ChatMessage.from_user("Tell me about {{location}}")]
69

70
    res = pipe.run(data={"prompt_builder": {"template_variables": {"location": location, "language": language},
71
                                        "template": messages}})
72
    print(res)
73

74
    >> {'llm': {'replies': [ChatMessage(content="Berlin is the capital city of Germany and one of the most vibrant
75
    and diverse cities in Europe. Here are some key things to know...Enjoy your time exploring the vibrant and dynamic
76
    capital of Germany!", role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-3.5-turbo-0613',
77
    'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 27, 'completion_tokens': 681, 'total_tokens':
78
    708}})]}}
79

80

81
    messages = [system_message, ChatMessage.from_user("What's the weather forecast for {{location}} in the next
82
    {{day_count}} days?")]
83

84
    res = pipe.run(data={"prompt_builder": {"template_variables": {"location": location, "day_count": "5"},
85
                                        "template": messages}})
86

87
    print(res)
88
    >> {'llm': {'replies': [ChatMessage(content="Here is the weather forecast for Berlin in the next 5
89
    days:\n\nDay 1: Mostly cloudy with a high of 22°C (72°F) and...so it's always a good idea to check for updates
90
    closer to your visit.", role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-3.5-turbo-0613',
91
    'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 37, 'completion_tokens': 201,
92
    'total_tokens': 238}})]}}
93
    ```
94

95
    Note how in the example above, we can dynamically change the prompt template by providing a new template to the
96
    run method of the pipeline.
97

98
    """
99

100
    def __init__(
1✔
101
        self,
102
        template: Optional[List[ChatMessage]] = None,
103
        required_variables: Optional[List[str]] = None,
104
        variables: Optional[List[str]] = None,
105
    ):
106
        """
107
        Constructs a ChatPromptBuilder component.
108

109
        :param template:
110
            A list of `ChatMessage` instances. All user and system messages are treated as potentially having jinja2
111
            templates and are rendered with the provided template variables. If not provided, the template
112
            must be provided at runtime using the `template` parameter of the `run` method.
113
        :param required_variables: An optional list of input variables that must be provided at all times.
114
            If not provided, an exception will be raised.
115
        :param variables:
116
            A list of template variable names you can use in prompt construction. For example,
117
            if `variables` contains the string `documents`, the component will create an input called
118
            `documents` of type `Any`. These variable names are used to resolve variables and their values during
119
            pipeline execution. The values associated with variables from the pipeline runtime are then injected into
120
            template placeholders of a prompt text template that is provided to the `run` method.
121
            If not provided, variables are inferred from `template`.
122
        """
123
        self._variables = variables
1✔
124
        self._required_variables = required_variables
1✔
125
        self.required_variables = required_variables or []
1✔
126
        self.template = template
1✔
127
        variables = variables or []
1✔
128
        if template and not variables:
1✔
129
            for message in template:
1✔
130
                if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
1✔
131
                    # infere variables from template
132
                    if isinstance(message.content, str):
1✔
133
                        msg_template = Template(message.content)
1✔
134
                        ast = msg_template.environment.parse(message.content)
1✔
135
                        template_variables = meta.find_undeclared_variables(ast)
1✔
136
                        variables += list(template_variables)
1✔
137
                    elif isinstance(message.content, ByteStream):
1✔
138
                        msg_template = Template(message.content.to_string())
×
139
                        ast = msg_template.environment.parse(message.content.to_string())
×
140
                        template_variables = meta.find_undeclared_variables(ast)
×
141
                        variables += list(template_variables)
×
142
                    elif isinstance(message.content, list):
1✔
143
                        for content in message.content:
1✔
144
                            if isinstance(content, str):
1✔
145
                                msg_template = Template(content)
1✔
146
                                ast = msg_template.environment.parse(content)
1✔
147
                                template_variables = meta.find_undeclared_variables(ast)
1✔
148
                                variables += list(template_variables)
1✔
149
                            elif isinstance(content, ByteStream):
1✔
150
                                msg_template = Template(content.to_string())
1✔
151
                                ast = msg_template.environment.parse(content.to_string())
1✔
152
                                template_variables = meta.find_undeclared_variables(ast)
1✔
153
                                variables += list(template_variables)
1✔
154
                            else:
155
                                raise ValueError(
×
156
                                    f"One of the elements of the content of a ChatMessage is of"
157
                                    f" an invalid type: {type(content)}\nValid Types: str and ByteStream."
158
                                    f"Content: {content}"
159
                                )
160
                    else:
161
                        raise ValueError(
×
162
                            f"The content of a ChatMessage is of an invalid type: {type(message.content)}"
163
                            f"Valid Types: str, ByteStream and List[str, ByteSteam]"
164
                            f"Content: {message.content}"
165
                        )
166

167
        # setup inputs
168
        static_input_slots = {"template": Optional[str], "template_variables": Optional[Dict[str, Any]]}
1✔
169
        component.set_input_types(self, **static_input_slots)
1✔
170
        for var in variables:
1✔
171
            if var in self.required_variables:
1✔
172
                component.set_input_type(self, var, Any)
1✔
173
            else:
174
                component.set_input_type(self, var, Any, "")
1✔
175

176
    @component.output_types(prompt=List[ChatMessage])
1✔
177
    def run(
1✔
178
        self,
179
        template: Optional[List[ChatMessage]] = None,
180
        template_variables: Optional[Dict[str, Any]] = None,
181
        **kwargs,
182
    ):
183
        """
184
        Executes the prompt building process.
185

186
        It applies the template variables to render the final prompt. You can provide variables either via pipeline
187
        (set through `variables` or inferred from `template` at initialization) or via additional template variables
188
        set directly to this method. On collision, the variables provided directly to this method take precedence.
189

190
        :param template:
191
            An optional list of ChatMessages to overwrite ChatPromptBuilder's default template. If None, the default
192
            template provided at initialization is used.
193
        :param template_variables:
194
            An optional dictionary of template variables. These are additional variables users can provide directly
195
            to this method in contrast to pipeline variables.
196
        :param kwargs:
197
            Pipeline variables (typically resolved from a pipeline) which are merged with the provided template
198
            variables.
199

200
        :returns: A dictionary with the following keys:
201
            - `prompt`: The updated list of `ChatMessage` instances after rendering the found templates.
202
        :raises ValueError:
203
            If `chat_messages` is empty or contains elements that are not instances of `ChatMessage`.
204
        """
205
        kwargs = kwargs or {}
1✔
206
        template_variables = template_variables or {}
1✔
207
        template_variables_combined = {**kwargs, **template_variables}
1✔
208

209
        if template is None:
1✔
210
            template = self.template
1✔
211

212
        if not template:
1✔
213
            raise ValueError(
1✔
214
                f"The {self.__class__.__name__} requires a non-empty list of ChatMessage instances. "
215
                f"Please provide a valid list of ChatMessage instances to render the prompt."
216
            )
217

218
        if not all(isinstance(message, ChatMessage) for message in template):
1✔
219
            raise ValueError(
1✔
220
                f"The {self.__class__.__name__} expects a list containing only ChatMessage instances. "
221
                f"The provided list contains other types. Please ensure that all elements in the list "
222
                f"are ChatMessage instances."
223
            )
224

225
        processed_messages = []
1✔
226
        for message in template:
1✔
227
            if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
1✔
228
                self._validate_variables(set(template_variables_combined.keys()))
1✔
229
                if isinstance(message.content, str):
1✔
230
                    compiled_template = Template(message.content)
1✔
231
                    rendered_content = compiled_template.render(template_variables_combined)
1✔
232
                    rendered_message = deepcopy(message)
1✔
233
                    rendered_message.content = rendered_content
1✔
234
                elif isinstance(message.content, ByteStream):
1✔
235
                    compiled_template = Template(message.content.to_string())
×
236
                    rendered_content = compiled_template.render(template_variables_combined)
×
237
                    rendered_message = deepcopy(message)
×
238
                    rendered_message.content.data = rendered_content.encode()
×
239
                elif isinstance(message.content, list):
1✔
240
                    rendered_content: List[Union[str, ByteStream]] = []
1✔
241
                    for part in message.content:
1✔
242
                        if isinstance(part, str):
1✔
243
                            compiled_template = Template(part)
1✔
244
                            rendered_part: str = compiled_template.render(template_variables_combined)
1✔
245
                            rendered_content.append(rendered_part)
1✔
246
                        elif isinstance(part, ByteStream):
1✔
247
                            compiled_template = Template(part.to_string())
1✔
248
                            rendered_part_content = compiled_template.render(template_variables_combined)
1✔
249
                            rendered_part: ByteStream = deepcopy(part)
1✔
250
                            rendered_part.data = rendered_part_content.encode()
1✔
251
                            rendered_content.append(rendered_part)
1✔
252
                        else:
253
                            raise ValueError(
×
254
                                f"One of the elements of the content of a ChatMessage is of"
255
                                f" an invalid type: {type(part)}\nValid Types: str and ByteStream."
256
                                f"Content: {part}"
257
                            )
258

259
                    rendered_message = deepcopy(message)
1✔
260
                    rendered_message.content = rendered_content
1✔
261
                else:
262
                    raise ValueError(
×
263
                        f"The content of a ChatMessage is of an invalid type: {type(message.content)}"
264
                        f"Valid Types: str, ByteStream and List[str, ByteSteam]"
265
                        f"Content: {message.content}"
266
                    )
267
                processed_messages.append(rendered_message)
1✔
268
            else:
269
                processed_messages.append(message)
1✔
270

271
        return {"prompt": processed_messages}
1✔
272

273
    def _validate_variables(self, provided_variables: Set[str]):
1✔
274
        """
275
        Checks if all the required template variables are provided.
276

277
        :param provided_variables:
278
            A set of provided template variables.
279
        :raises ValueError:
280
            If no template is provided or if all the required template variables are not provided.
281
        """
282
        missing_variables = [var for var in self.required_variables if var not in provided_variables]
1✔
283
        if missing_variables:
1✔
284
            missing_vars_str = ", ".join(missing_variables)
1✔
285
            raise ValueError(
1✔
286
                f"Missing required input variables in ChatPromptBuilder: {missing_vars_str}. "
287
                f"Required variables: {self.required_variables}. Provided variables: {provided_variables}."
288
            )
289

290
    def to_dict(self) -> Dict[str, Any]:
1✔
291
        """
292
        Returns a dictionary representation of the component.
293

294
        :returns:
295
            Serialized dictionary representation of the component.
296
        """
297
        if self.template is not None:
1✔
298
            template = [m.to_dict() for m in self.template]
1✔
299
        else:
300
            template = None
×
301

302
        return default_to_dict(
1✔
303
            self, template=template, variables=self._variables, required_variables=self._required_variables
304
        )
305

306
    @classmethod
1✔
307
    def from_dict(cls, data: Dict[str, Any]) -> "ChatPromptBuilder":
1✔
308
        """
309
        Deserialize this component from a dictionary.
310

311
        :param data:
312
            The dictionary to deserialize and create the component.
313

314
        :returns:
315
            The deserialized component.
316
        """
317
        init_parameters = data["init_parameters"]
1✔
318
        template = init_parameters.get("template", [])
1✔
319
        init_parameters["template"] = [ChatMessage.from_dict(d) for d in template]
1✔
320

321
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc