• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12066860315

28 Nov 2024 10:16AM UTC coverage: 90.307% (-0.05%) from 90.359%
12066860315

push

github

web-flow
refactor: update components to access `ChatMessage.text` instead of `content` (#8589)

* introduce text property and deprecate content

* release note

* use chatmessage.text

* release note

* linting

8031 of 8893 relevant lines covered (90.31%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.19
haystack/components/validators/json_schema.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
from typing import Any, Dict, List, Optional
1✔
7

8
from haystack import component
1✔
9
from haystack.dataclasses import ChatMessage
1✔
10
from haystack.lazy_imports import LazyImport
1✔
11

12
with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
1✔
13
    from jsonschema import ValidationError, validate
1✔
14

15

16
def is_valid_json(s: str) -> bool:
1✔
17
    """
18
    Check if the provided string is a valid JSON.
19

20
    :param s: The string to be checked.
21
    :returns: `True` if the string is a valid JSON; otherwise, `False`.
22
    """
23
    try:
1✔
24
        json.loads(s)
1✔
25
    except ValueError:
×
26
        return False
×
27
    return True
1✔
28

29

30
@component
1✔
31
class JsonSchemaValidator:
1✔
32
    """
33
    Validates JSON content of `ChatMessage` against a specified [JSON Schema](https://json-schema.org/).
34

35
    If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output.
36
    If the JSON content does not conform to the schema, the message is passed along the "validation_error" output.
37
    In the latter case, the error message is constructed using the provided `error_template` or a default template.
38
    These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops.
39

40
    Usage example:
41

42
    ```python
43
    from typing import List
44

45
    from haystack import Pipeline
46
    from haystack.components.generators.chat import OpenAIChatGenerator
47
    from haystack.components.joiners import BranchJoiner
48
    from haystack.components.validators import JsonSchemaValidator
49
    from haystack import component
50
    from haystack.dataclasses import ChatMessage
51

52

53
    @component
54
    class MessageProducer:
55

56
        @component.output_types(messages=List[ChatMessage])
57
        def run(self, messages: List[ChatMessage]) -> dict:
58
            return {"messages": messages}
59

60

61
    p = Pipeline()
62
    p.add_component("llm", OpenAIChatGenerator(model="gpt-4-1106-preview",
63
                                               generation_kwargs={"response_format": {"type": "json_object"}}))
64
    p.add_component("schema_validator", JsonSchemaValidator())
65
    p.add_component("joiner_for_llm", BranchJoiner(List[ChatMessage]))
66
    p.add_component("message_producer", MessageProducer())
67

68
    p.connect("message_producer.messages", "joiner_for_llm")
69
    p.connect("joiner_for_llm", "llm")
70
    p.connect("llm.replies", "schema_validator.messages")
71
    p.connect("schema_validator.validation_error", "joiner_for_llm")
72

73
    result = p.run(data={
74
        "message_producer": {
75
            "messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]},
76
            "schema_validator": {
77
                "json_schema": {
78
                    "type": "object",
79
                    "properties": {"name": {"type": "string"},
80
                    "age": {"type": "integer"}
81
                }
82
            }
83
        }
84
    })
85
    print(result)
86
    >> {'schema_validator': {'validated': [ChatMessage(content='\\n{\\n  "name": "John",\\n  "age": 30\\n}',
87
    role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0,
88
    'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}}
89
    ```
90
    """
91

92
    # Default error description template
93
    default_error_template = (
1✔
94
        "The following generated JSON does not conform to the provided schema.\n"
95
        "Generated JSON: {failing_json}\n"
96
        "Error details:\n- Message: {error_message}\n"
97
        "- Error Path in JSON: {error_path}\n"
98
        "- Schema Path: {error_schema_path}\n"
99
        "Please match the following schema:\n"
100
        "{json_schema}\n"
101
        "and provide the corrected JSON content ONLY. Please do not output anything else than the raw corrected "
102
        "JSON string, this is the most important part of the task. Don't use any markdown and don't add any comment."
103
    )
104

105
    def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
1✔
106
        """
107
        Initialize the JsonSchemaValidator component.
108

109
        :param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/) against which
110
            the messages' content is validated.
111
        :param error_template: A custom template string for formatting the error message in case of validation failure.
112
        """
113
        jsonschema_import.check()
1✔
114
        self.json_schema = json_schema
1✔
115
        self.error_template = error_template
1✔
116

117
    @component.output_types(validated=List[ChatMessage], validation_error=List[ChatMessage])
1✔
118
    def run(
1✔
119
        self,
120
        messages: List[ChatMessage],
121
        json_schema: Optional[Dict[str, Any]] = None,
122
        error_template: Optional[str] = None,
123
    ) -> Dict[str, List[ChatMessage]]:
124
        """
125
        Validates the last of the provided messages against the specified json schema.
126

127
        If it does, the message is passed along the "validated" output. If it does not, the message is passed along
128
        the "validation_error" output.
129

130
        :param messages: A list of ChatMessage instances to be validated. The last message in this list is the one
131
            that is validated.
132
        :param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/)
133
            against which the messages' content is validated. If not provided, the schema from the component init
134
            is used.
135
        :param error_template: A custom template string for formatting the error message in case of validation. If not
136
            provided, the `error_template` from the component init is used.
137
        :return:  A dictionary with the following keys:
138
            - "validated": A list of messages if the last message is valid.
139
            - "validation_error": A list of messages if the last message is invalid.
140
        :raises ValueError: If no JSON schema is provided or if the message content is not a dictionary or a list of
141
            dictionaries.
142
        """
143
        last_message = messages[-1]
1✔
144
        if last_message.text is None:
1✔
145
            raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {last_message}")
×
146
        if not is_valid_json(last_message.text):
1✔
147
            return {
×
148
                "validation_error": [
149
                    ChatMessage.from_user(
150
                        f"The message '{last_message.text}' is not a valid JSON object. "
151
                        f"Please provide only a valid JSON object in string format."
152
                        f"Don't use any markdown and don't add any comment."
153
                    )
154
                ]
155
            }
156

157
        last_message_content = json.loads(last_message.text)
1✔
158
        json_schema = json_schema or self.json_schema
1✔
159
        error_template = error_template or self.error_template or self.default_error_template
1✔
160

161
        if not json_schema:
1✔
162
            raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.")
×
163
        # fc payload is json object but subtree `parameters` is string - we need to convert to json object
164
        # we need complete json to validate it against schema
165
        last_message_json = self._recursive_json_to_object(last_message_content)
1✔
166
        using_openai_schema: bool = self._is_openai_function_calling_schema(json_schema)
1✔
167
        if using_openai_schema:
1✔
168
            validation_schema = json_schema["parameters"]
1✔
169
        else:
170
            validation_schema = json_schema
1✔
171
        try:
1✔
172
            last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
1✔
173
            for content in last_message_json:
1✔
174
                if using_openai_schema:
1✔
175
                    validate(instance=content["function"]["arguments"], schema=validation_schema)
1✔
176
                else:
177
                    validate(instance=content, schema=validation_schema)
1✔
178

179
            return {"validated": [last_message]}
1✔
180
        except ValidationError as e:
1✔
181
            error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A"
1✔
182
            error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A"
1✔
183

184
            error_template = error_template or self.default_error_template
1✔
185

186
            recovery_prompt = self._construct_error_recovery_message(
1✔
187
                error_template, str(e), error_path, error_schema_path, validation_schema, failing_json=last_message.text
188
            )
189
            return {"validation_error": [ChatMessage.from_user(recovery_prompt)]}
1✔
190

191
    def _construct_error_recovery_message(  # pylint: disable=too-many-positional-arguments
1✔
192
        self,
193
        error_template: str,
194
        error_message: str,
195
        error_path: str,
196
        error_schema_path: str,
197
        json_schema: Dict[str, Any],
198
        failing_json: str,
199
    ) -> str:
200
        """
201
        Constructs an error recovery message using a specified template or the default one if none is provided.
202

203
        :param error_template: A custom template string for formatting the error message in case of validation failure.
204
        :param error_message: The error message returned by the JSON schema validator.
205
        :param error_path: The path in the JSON content where the error occurred.
206
        :param error_schema_path: The path in the JSON schema where the error occurred.
207
        :param json_schema: The JSON schema against which the content is validated.
208
        :param failing_json: The generated invalid JSON string.
209
        """
210
        error_template = error_template or self.default_error_template
1✔
211

212
        return error_template.format(
1✔
213
            error_message=error_message,
214
            error_path=error_path,
215
            error_schema_path=error_schema_path,
216
            json_schema=json_schema,
217
            failing_json=failing_json,
218
        )
219

220
    def _is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
1✔
221
        """
222
        Checks if the provided schema is a valid OpenAI function calling schema.
223

224
        :param json_schema: The JSON schema to check
225
        :return: `True` if the schema is a valid OpenAI function calling schema; otherwise, `False`.
226
        """
227
        return all(key in json_schema for key in ["name", "description", "parameters"])
1✔
228

229
    def _recursive_json_to_object(self, data: Any) -> Any:
1✔
230
        """
231
        Convert any string values that are valid JSON objects into dictionary objects.
232

233
        Returns a new data structure.
234

235
        :param data: The data structure to be traversed.
236
        :return: A new data structure with JSON strings converted to dictionary objects.
237
        """
238
        if isinstance(data, list):
1✔
239
            return [self._recursive_json_to_object(item) for item in data]
1✔
240

241
        if isinstance(data, dict):
1✔
242
            new_dict = {}
1✔
243
            for key, value in data.items():
1✔
244
                if isinstance(value, str):
1✔
245
                    try:
1✔
246
                        json_value = json.loads(value)
1✔
247
                        if isinstance(json_value, (dict, list)):
1✔
248
                            new_dict[key] = self._recursive_json_to_object(json_value)
1✔
249
                        else:
250
                            new_dict[key] = value  # Preserve the original string value
×
251
                    except json.JSONDecodeError:
1✔
252
                        new_dict[key] = value
1✔
253
                elif isinstance(value, dict):
1✔
254
                    new_dict[key] = self._recursive_json_to_object(value)
1✔
255
                else:
256
                    new_dict[key] = value
×
257
            return new_dict
1✔
258

259
        # If it's neither a list nor a dictionary, return the value directly
260
        raise ValueError("Input must be a dictionary or a list of dictionaries.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc