• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12883566116

21 Jan 2025 09:07AM UTC coverage: 91.3% (-0.006%) from 91.306%
12883566116

push

github

web-flow
build: add `jsonschema` library to core dependencies (#8753)

* add jsonschema to core dependencies

* release note

8847 of 9690 relevant lines covered (91.3%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.73
haystack/components/validators/json_schema.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
from typing import Any, Dict, List, Optional
1✔
7

8
from jsonschema import ValidationError, validate
1✔
9

10
from haystack import component
1✔
11
from haystack.dataclasses import ChatMessage
1✔
12

13

14
def is_valid_json(s: str) -> bool:
1✔
15
    """
16
    Check if the provided string is a valid JSON.
17

18
    :param s: The string to be checked.
19
    :returns: `True` if the string is a valid JSON; otherwise, `False`.
20
    """
21
    try:
1✔
22
        json.loads(s)
1✔
23
    except ValueError:
×
24
        return False
×
25
    return True
1✔
26

27

28
@component
1✔
29
class JsonSchemaValidator:
1✔
30
    """
31
    Validates JSON content of `ChatMessage` against a specified [JSON Schema](https://json-schema.org/).
32

33
    If JSON content of a message conforms to the provided schema, the message is passed along the "validated" output.
34
    If the JSON content does not conform to the schema, the message is passed along the "validation_error" output.
35
    In the latter case, the error message is constructed using the provided `error_template` or a default template.
36
    These error ChatMessages can be used by LLMs in Haystack 2.x recovery loops.
37

38
    Usage example:
39

40
    ```python
41
    from typing import List
42

43
    from haystack import Pipeline
44
    from haystack.components.generators.chat import OpenAIChatGenerator
45
    from haystack.components.joiners import BranchJoiner
46
    from haystack.components.validators import JsonSchemaValidator
47
    from haystack import component
48
    from haystack.dataclasses import ChatMessage
49

50

51
    @component
52
    class MessageProducer:
53

54
        @component.output_types(messages=List[ChatMessage])
55
        def run(self, messages: List[ChatMessage]) -> dict:
56
            return {"messages": messages}
57

58

59
    p = Pipeline()
60
    p.add_component("llm", OpenAIChatGenerator(model="gpt-4-1106-preview",
61
                                               generation_kwargs={"response_format": {"type": "json_object"}}))
62
    p.add_component("schema_validator", JsonSchemaValidator())
63
    p.add_component("joiner_for_llm", BranchJoiner(List[ChatMessage]))
64
    p.add_component("message_producer", MessageProducer())
65

66
    p.connect("message_producer.messages", "joiner_for_llm")
67
    p.connect("joiner_for_llm", "llm")
68
    p.connect("llm.replies", "schema_validator.messages")
69
    p.connect("schema_validator.validation_error", "joiner_for_llm")
70

71
    result = p.run(data={
72
        "message_producer": {
73
            "messages":[ChatMessage.from_user("Generate JSON for person with name 'John' and age 30")]},
74
            "schema_validator": {
75
                "json_schema": {
76
                    "type": "object",
77
                    "properties": {"name": {"type": "string"},
78
                    "age": {"type": "integer"}
79
                }
80
            }
81
        }
82
    })
83
    print(result)
84
    >> {'schema_validator': {'validated': [ChatMessage(content='\\n{\\n  "name": "John",\\n  "age": 30\\n}',
85
    role=<ChatRole.ASSISTANT: 'assistant'>, name=None, meta={'model': 'gpt-4-1106-preview', 'index': 0,
86
    'finish_reason': 'stop', 'usage': {'completion_tokens': 17, 'prompt_tokens': 20, 'total_tokens': 37}})]}}
87
    ```
88
    """
89

90
    # Default error description template
91
    default_error_template = (
1✔
92
        "The following generated JSON does not conform to the provided schema.\n"
93
        "Generated JSON: {failing_json}\n"
94
        "Error details:\n- Message: {error_message}\n"
95
        "- Error Path in JSON: {error_path}\n"
96
        "- Schema Path: {error_schema_path}\n"
97
        "Please match the following schema:\n"
98
        "{json_schema}\n"
99
        "and provide the corrected JSON content ONLY. Please do not output anything else than the raw corrected "
100
        "JSON string, this is the most important part of the task. Don't use any markdown and don't add any comment."
101
    )
102

103
    def __init__(self, json_schema: Optional[Dict[str, Any]] = None, error_template: Optional[str] = None):
1✔
104
        """
105
        Initialize the JsonSchemaValidator component.
106

107
        :param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/) against which
108
            the messages' content is validated.
109
        :param error_template: A custom template string for formatting the error message in case of validation failure.
110
        """
111
        self.json_schema = json_schema
1✔
112
        self.error_template = error_template
1✔
113

114
    @component.output_types(validated=List[ChatMessage], validation_error=List[ChatMessage])
1✔
115
    def run(
1✔
116
        self,
117
        messages: List[ChatMessage],
118
        json_schema: Optional[Dict[str, Any]] = None,
119
        error_template: Optional[str] = None,
120
    ) -> Dict[str, List[ChatMessage]]:
121
        """
122
        Validates the last of the provided messages against the specified json schema.
123

124
        If it does, the message is passed along the "validated" output. If it does not, the message is passed along
125
        the "validation_error" output.
126

127
        :param messages: A list of ChatMessage instances to be validated. The last message in this list is the one
128
            that is validated.
129
        :param json_schema: A dictionary representing the [JSON schema](https://json-schema.org/)
130
            against which the messages' content is validated. If not provided, the schema from the component init
131
            is used.
132
        :param error_template: A custom template string for formatting the error message in case of validation. If not
133
            provided, the `error_template` from the component init is used.
134
        :return:  A dictionary with the following keys:
135
            - "validated": A list of messages if the last message is valid.
136
            - "validation_error": A list of messages if the last message is invalid.
137
        :raises ValueError: If no JSON schema is provided or if the message content is not a dictionary or a list of
138
            dictionaries.
139
        """
140
        last_message = messages[-1]
1✔
141
        if last_message.text is None:
1✔
142
            raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {last_message}")
×
143
        if not is_valid_json(last_message.text):
1✔
144
            return {
×
145
                "validation_error": [
146
                    ChatMessage.from_user(
147
                        f"The message '{last_message.text}' is not a valid JSON object. "
148
                        f"Please provide only a valid JSON object in string format."
149
                        f"Don't use any markdown and don't add any comment."
150
                    )
151
                ]
152
            }
153

154
        last_message_content = json.loads(last_message.text)
1✔
155
        json_schema = json_schema or self.json_schema
1✔
156
        error_template = error_template or self.error_template or self.default_error_template
1✔
157

158
        if not json_schema:
1✔
159
            raise ValueError("Provide a JSON schema for validation either in the run method or in the component init.")
×
160
        # fc payload is json object but subtree `parameters` is string - we need to convert to json object
161
        # we need complete json to validate it against schema
162
        last_message_json = self._recursive_json_to_object(last_message_content)
1✔
163
        using_openai_schema: bool = self._is_openai_function_calling_schema(json_schema)
1✔
164
        if using_openai_schema:
1✔
165
            validation_schema = json_schema["parameters"]
1✔
166
        else:
167
            validation_schema = json_schema
1✔
168
        try:
1✔
169
            last_message_json = [last_message_json] if not isinstance(last_message_json, list) else last_message_json
1✔
170
            for content in last_message_json:
1✔
171
                if using_openai_schema:
1✔
172
                    validate(instance=content["function"]["arguments"], schema=validation_schema)
1✔
173
                else:
174
                    validate(instance=content, schema=validation_schema)
1✔
175

176
            return {"validated": [last_message]}
1✔
177
        except ValidationError as e:
1✔
178
            error_path = " -> ".join(map(str, e.absolute_path)) if e.absolute_path else "N/A"
1✔
179
            error_schema_path = " -> ".join(map(str, e.absolute_schema_path)) if e.absolute_schema_path else "N/A"
1✔
180

181
            error_template = error_template or self.default_error_template
1✔
182

183
            recovery_prompt = self._construct_error_recovery_message(
1✔
184
                error_template, str(e), error_path, error_schema_path, validation_schema, failing_json=last_message.text
185
            )
186
            return {"validation_error": [ChatMessage.from_user(recovery_prompt)]}
1✔
187

188
    def _construct_error_recovery_message(  # pylint: disable=too-many-positional-arguments
1✔
189
        self,
190
        error_template: str,
191
        error_message: str,
192
        error_path: str,
193
        error_schema_path: str,
194
        json_schema: Dict[str, Any],
195
        failing_json: str,
196
    ) -> str:
197
        """
198
        Constructs an error recovery message using a specified template or the default one if none is provided.
199

200
        :param error_template: A custom template string for formatting the error message in case of validation failure.
201
        :param error_message: The error message returned by the JSON schema validator.
202
        :param error_path: The path in the JSON content where the error occurred.
203
        :param error_schema_path: The path in the JSON schema where the error occurred.
204
        :param json_schema: The JSON schema against which the content is validated.
205
        :param failing_json: The generated invalid JSON string.
206
        """
207
        error_template = error_template or self.default_error_template
1✔
208

209
        return error_template.format(
1✔
210
            error_message=error_message,
211
            error_path=error_path,
212
            error_schema_path=error_schema_path,
213
            json_schema=json_schema,
214
            failing_json=failing_json,
215
        )
216

217
    def _is_openai_function_calling_schema(self, json_schema: Dict[str, Any]) -> bool:
1✔
218
        """
219
        Checks if the provided schema is a valid OpenAI function calling schema.
220

221
        :param json_schema: The JSON schema to check
222
        :return: `True` if the schema is a valid OpenAI function calling schema; otherwise, `False`.
223
        """
224
        return all(key in json_schema for key in ["name", "description", "parameters"])
1✔
225

226
    def _recursive_json_to_object(self, data: Any) -> Any:
1✔
227
        """
228
        Convert any string values that are valid JSON objects into dictionary objects.
229

230
        Returns a new data structure.
231

232
        :param data: The data structure to be traversed.
233
        :return: A new data structure with JSON strings converted to dictionary objects.
234
        """
235
        if isinstance(data, list):
1✔
236
            return [self._recursive_json_to_object(item) for item in data]
1✔
237

238
        if isinstance(data, dict):
1✔
239
            new_dict = {}
1✔
240
            for key, value in data.items():
1✔
241
                if isinstance(value, str):
1✔
242
                    try:
1✔
243
                        json_value = json.loads(value)
1✔
244
                        if isinstance(json_value, (dict, list)):
1✔
245
                            new_dict[key] = self._recursive_json_to_object(json_value)
1✔
246
                        else:
247
                            new_dict[key] = value  # Preserve the original string value
×
248
                    except json.JSONDecodeError:
1✔
249
                        new_dict[key] = value
1✔
250
                elif isinstance(value, dict):
1✔
251
                    new_dict[key] = self._recursive_json_to_object(value)
1✔
252
                else:
253
                    new_dict[key] = value
×
254
            return new_dict
1✔
255

256
        # If it's neither a list nor a dictionary, return the value directly
257
        raise ValueError("Input must be a dictionary or a list of dictionaries.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc