• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14373748478

10 Apr 2025 06:24AM UTC coverage: 90.378% (+0.03%) from 90.353%
14373748478

Pull #9195

github

web-flow
Merge b41725749 into 65a4b7406
Pull Request #9195: fix: Set `messages` in state_schema at init time of Agent

10670 of 11806 relevant lines covered (90.38%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.25
haystack/dataclasses/state.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from copy import deepcopy
1✔
6
from typing import Any, Callable, Dict, List, Optional
1✔
7

8
from haystack.dataclasses import ChatMessage
1✔
9
from haystack.dataclasses.state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values
1✔
10
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
11
from haystack.utils.type_serialization import deserialize_type, serialize_type
1✔
12

13

14
def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
1✔
15
    """
16
    Convert a schema dictionary to a serializable format.
17

18
    Converts each parameter's type and optional handler function into a serializable
19
    format using type and callable serialization utilities.
20

21
    :param schema: Dictionary mapping parameter names to their type and handler configs
22
    :returns: Dictionary with serialized type and handler information
23
    """
24
    serialized_schema = {}
1✔
25
    for param, config in schema.items():
1✔
26
        serialized_schema[param] = {"type": serialize_type(config["type"])}
1✔
27
        if config.get("handler"):
1✔
28
            serialized_schema[param]["handler"] = serialize_callable(config["handler"])
1✔
29

30
    return serialized_schema
1✔
31

32

33
def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
1✔
34
    """
35
    Convert a serialized schema dictionary back to its original format.
36

37
    Deserializes the type and optional handler function for each parameter from their
38
    serialized format back into Python types and callables.
39

40
    :param schema: Dictionary containing serialized schema information
41
    :returns: Dictionary with deserialized type and handler configurations
42
    """
43
    deserialized_schema = {}
1✔
44
    for param, config in schema.items():
1✔
45
        deserialized_schema[param] = {"type": deserialize_type(config["type"])}
1✔
46

47
        if config.get("handler"):
1✔
48
            deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])
1✔
49

50
    return deserialized_schema
1✔
51

52

53
def _validate_schema(schema: Dict[str, Any]) -> None:
1✔
54
    """
55
    Validate that a schema dictionary meets all required constraints.
56

57
    Checks that each parameter definition has a valid type field and that any handler
58
    specified is a callable function.
59

60
    :param schema: Dictionary mapping parameter names to their type and handler configs
61
    :raises ValueError: If schema validation fails due to missing or invalid fields
62
    """
63
    for param, definition in schema.items():
1✔
64
        if "type" not in definition:
1✔
65
            raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
1✔
66
        if not _is_valid_type(definition["type"]):
1✔
67
            raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
1✔
68
        if definition.get("handler") is not None and not callable(definition["handler"]):
1✔
69
            raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
1✔
70
        if param == "messages" and definition["type"] is not List[ChatMessage]:
1✔
71
            raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")
×
72

73

74
class State:
1✔
75
    """
76
    A dataclass that wraps a StateSchema and maintains an internal _data dictionary.
77

78
    Each schema entry has:
79
      "parameter_name": {
80
        "type": SomeType,
81
        "handler": Optional[Callable[[Any, Any], Any]]
82
      }
83
    """
84

85
    def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
1✔
86
        """
87
        Initialize a State object with a schema and optional data.
88

89
        :param schema: Dictionary mapping parameter names to their type and handler configs.
90
            Type must be a valid Python type, and handler must be a callable function or None.
91
            If handler is None, the default handler for the type will be used. The default handlers are:
92
                - For list types: `haystack.dataclasses.state_utils.merge_lists`
93
                - For all other types: `haystack.dataclasses.state_utils.replace_values`
94
        :param data: Optional dictionary of initial data to populate the state
95
        """
96
        _validate_schema(schema)
1✔
97
        self.schema = deepcopy(schema)
1✔
98
        if self.schema.get("messages") is None:
1✔
99
            self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
1✔
100
        self._data = data or {}
1✔
101

102
        # Set default handlers if not provided in schema
103
        for definition in self.schema.values():
1✔
104
            # Skip if handler is already defined and not None
105
            if definition.get("handler") is not None:
1✔
106
                continue
1✔
107
            # Set default handler based on type
108
            if _is_list_type(definition["type"]):
1✔
109
                definition["handler"] = merge_lists
1✔
110
            else:
111
                definition["handler"] = replace_values
1✔
112

113
    def get(self, key: str, default: Any = None) -> Any:
1✔
114
        """
115
        Retrieve a value from the state by key.
116

117
        :param key: Key to look up in the state
118
        :param default: Value to return if key is not found
119
        :returns: Value associated with key or default if not found
120
        """
121
        return deepcopy(self._data.get(key, default))
1✔
122

123
    def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
1✔
124
        """
125
        Set or merge a value in the state according to schema rules.
126

127
        Value is merged or overwritten according to these rules:
128
          - if handler_override is given, use that
129
          - else use the handler defined in the schema for 'key'
130

131
        :param key: Key to store the value under
132
        :param value: Value to store or merge
133
        :param handler_override: Optional function to override the default merge behavior
134
        """
135
        # If key not in schema, we throw an error
136
        definition = self.schema.get(key, None)
1✔
137
        if definition is None:
1✔
138
            raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")
1✔
139

140
        # Get current value from state and apply handler
141
        current_value = self._data.get(key, None)
1✔
142
        handler = handler_override or definition["handler"]
1✔
143
        self._data[key] = handler(current_value, value)
1✔
144

145
    @property
1✔
146
    def data(self):
1✔
147
        """
148
        All current data of the state.
149
        """
150
        return self._data
1✔
151

152
    def has(self, key: str) -> bool:
1✔
153
        """
154
        Check if a key exists in the state.
155

156
        :param key: Key to check for existence
157
        :returns: True if key exists in state, False otherwise
158
        """
159
        return key in self._data
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc