• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14706437845

28 Apr 2025 11:16AM UTC coverage: 80.149% (+0.1%) from 80.035%
14706437845

Pull #1764

github

web-flow
Merge 1ce583a5e into 29ef085a0
Pull Request #1764: Add tool calling support + Berekley Tool Calling Benchmark (simple-v3)

1643 of 2034 branches covered (80.78%)

Branch coverage included in aggregate %.

10268 of 12827 relevant lines covered (80.05%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.13
src/unitxt/schema.py
1
import json
1✔
2
from typing import Any, Dict, List, Optional
1✔
3

4
from datasets import Audio, Features, Sequence, Value
1✔
5
from datasets import Image as DatasetImage
1✔
6

7
from .artifact import Artifact
1✔
8
from .dict_utils import dict_get
1✔
9
from .image_operators import ImageDataString
1✔
10
from .operator import InstanceOperatorValidator
1✔
11
from .settings_utils import get_constants, get_settings
1✔
12
from .type_utils import isoftype
1✔
13
from .types import Image
1✔
14

15
constants = get_constants()
1✔
16
settings = get_settings()
1✔
17

18
UNITXT_DATASET_SCHEMA = Features(
1✔
19
    {
20
        "source": Value("string"),
21
        "target": Value("string"),
22
        "references": Sequence(Value("string")),
23
        "metrics": Sequence(Value("string")),
24
        "groups": Sequence(Value("string")),
25
        "subset": Sequence(Value("string")),
26
        "media": {
27
            "images": Sequence(DatasetImage()),
28
            "audios": Sequence(Audio()),
29
        },
30
        "postprocessors": Sequence(Value("string")),
31
        "task_data": Value(dtype="string"),
32
        "data_classification_policy": Sequence(Value("string")),
33
    }
34
)
35

36
UNITXT_INFERENCE_SCHEMA = Features(
1✔
37
    {
38
        "source": Value("string"),
39
        "metrics": Sequence(Value("string")),
40
        "groups": Sequence(Value("string")),
41
        "subset": Sequence(Value("string")),
42
        "postprocessors": Sequence(Value("string")),
43
        "task_data": Value(dtype="string"),
44
        "data_classification_policy": Sequence(Value("string")),
45
        "media": {
46
            "images": Sequence(Image()),
47
            "audios": Sequence(Audio()),
48
        },
49
    }
50
)
51

52

53
def get_schema(stream_name):
1✔
54
    if stream_name == constants.inference_stream:
1✔
55
        return UNITXT_INFERENCE_SCHEMA
1✔
56
    return UNITXT_DATASET_SCHEMA
1✔
57

58

59
def load_chat_source(chat_str):
1✔
60
    chat = json.loads(chat_str)
1✔
61
    for turn in chat:
1✔
62
        if isinstance(turn["content"], list):
1✔
63
            for content in turn["content"]:
×
64
                if content["type"] == "image_url":
×
65
                    content["image_url"]["url"] = ImageDataString(
×
66
                        content["image_url"]["url"]
67
                    )
68
    return chat
1✔
69

70
def loads_batch(batch):
1✔
71
    if (
1✔
72
        "source" in batch
73
        and isinstance(batch["source"][0], str)
74
        and (
75
            batch["source"][0].startswith('[{"role":')
76
            or batch["source"][0].startswith('[{"content":')
77
        )
78
    ):
79
        batch["source"] = [load_chat_source(d) for d in batch["source"]]
1✔
80
    if (
1✔
81
        not settings.task_data_as_text
82
        and "task_data" in batch
83
        and isinstance(batch["task_data"][0], str)
84
    ):
85
        batch["task_data"] = [json.loads(d) for d in batch["task_data"]]
×
86
    return batch
1✔
87

88
def loads_instance(instance):
1✔
89
    if (
×
90
        "source" in instance
91
        and isinstance(instance["source"], str)
92
        and (
93
            instance["source"].startswith('[{"role":')
94
            or instance["source"].startswith('[{"content":')
95
        )
96
    ):
97
        instance["source"] = load_chat_source(instance["source"])
×
98
    if (
×
99
        not settings.task_data_as_text
100
        and "task_data" in instance
101
        and isinstance(instance["task_data"], str)
102
    ):
103
        instance["task_data"] = json.loads(instance["task_data"])
×
104
    return instance
×
105

106

107
class FinalizeDataset(InstanceOperatorValidator):
1✔
108
    group_by: List[List[str]]
1✔
109
    remove_unnecessary_fields: bool = True
1✔
110

111
    @staticmethod
1✔
112
    def artifact_to_jsonable(artifact):
1✔
113
        if artifact.__id__ is None:
1✔
114
            return artifact.to_dict()
1✔
115
        return artifact.__id__
1✔
116

117
    def _prepare_media(self, instance):
1✔
118
        if "media" not in instance:
1✔
119
            instance["media"] = {}
×
120

121
        if "images" not in instance["media"]:
1✔
122
            instance["media"]["images"] = []
1✔
123

124
        if "audios" not in instance["media"]:
1✔
125
            instance["media"]["audios"] = []
1✔
126

127
        for i in range(len(instance["media"]["images"])):
1✔
128
            if isoftype(instance["media"]["images"][i], Image):
×
129
                instance["media"]["images"][i] = instance["media"]["images"][i]["image"]
×
130

131
        return instance
1✔
132

133
    def _get_instance_task_data(
1✔
134
        self, instance: Dict[str, Any], use_reference_fields=True
135
    ) -> Dict[str, Any]:
136
        task_data = {
1✔
137
            **instance["input_fields"],
138
            "metadata": {
139
                "data_classification_policy": instance["data_classification_policy"],
140
            },
141
        }
142
        if use_reference_fields:
1✔
143
            task_data = {**task_data, **instance["reference_fields"]}
1✔
144

145
        if "__tools__" in instance:
1✔
146
            task_data["__tools__"] = instance["__tools__"]
×
147
        return task_data
1✔
148

149
    def serialize_instance_fields(self, instance, task_data):
1✔
150
        if settings.task_data_as_text:
1✔
151
            instance["task_data"] = json.dumps(task_data)
1✔
152

153
        if not isinstance(instance["source"], str):
1✔
154
            instance["source"] = json.dumps(instance["source"])
1✔
155
        return instance
1✔
156

157
    def process(
1✔
158
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
159
    ) -> Dict[str, Any]:
160
        task_data = self._get_instance_task_data(
1✔
161
            instance,
162
            use_reference_fields=stream_name != constants.inference_stream,
163
        )
164

165
        task_data["metadata"]["num_demos"] = instance["recipe_metadata"]["num_demos"]
1✔
166
        task_data["metadata"]["demos_pool_size"] = instance["recipe_metadata"][
1✔
167
            "demos_pool_size"
168
        ]
169
        task_data["metadata"]["template"] = self.artifact_to_jsonable(
1✔
170
            instance["recipe_metadata"]["template"]
171
        )
172
        if "criteria" in task_data and isinstance(task_data["criteria"], Artifact):
1✔
173
            task_data["criteria"] = self.artifact_to_jsonable(task_data["criteria"])
×
174
        if constants.demos_field in instance:
1✔
175
            task_data[constants.demos_field] = [
1✔
176
                self._get_instance_task_data(instance)
177
                for instance in instance.pop(constants.demos_field)
178
            ]
179

180
        instance = self.serialize_instance_fields(instance, task_data)
1✔
181

182
        if self.remove_unnecessary_fields:
1✔
183
            keys_to_delete = []
1✔
184

185
            for key in instance.keys():
1✔
186
                if key not in get_schema(stream_name):
1✔
187
                    keys_to_delete.append(key)
1✔
188

189
            for key in keys_to_delete:
1✔
190
                del instance[key]
1✔
191

192
        data = {**task_data, **task_data["metadata"]}
1✔
193
        groups = []
1✔
194
        for group_attributes in self.group_by:
1✔
195
            group = {}
1✔
196
            if isinstance(group_attributes, str):
1✔
197
                group_attributes = [group_attributes]
1✔
198
            for attribute in group_attributes:
1✔
199
                group[attribute] = dict_get(data, attribute)
1✔
200
            groups.append(json.dumps(group))
1✔
201

202
        instance["groups"] = groups
1✔
203
        instance["subset"] = []
1✔
204

205
        instance = self._prepare_media(instance)
1✔
206

207
        instance["metrics"] = [
1✔
208
            metric.to_json() if isinstance(metric, Artifact) else metric
209
            for metric in instance["metrics"]
210
        ]
211
        instance["postprocessors"] = [
1✔
212
            processor.to_json() if isinstance(processor, Artifact) else processor
213
            for processor in instance["postprocessors"]
214
        ]
215

216
        return instance
1✔
217

218
    def validate(self, instance: Dict[str, Any], stream_name: Optional[str] = None):
1✔
219
        # verify the instance has the required schema
220
        assert instance is not None, "Instance is None"
1✔
221
        assert isinstance(
1✔
222
            instance, dict
223
        ), f"Instance should be a dict, got {type(instance)}"
224
        schema = get_schema(stream_name)
1✔
225
        assert all(
1✔
226
            key in instance for key in schema
227
        ), f"Instance should have the following keys: {schema}. Instance is: {instance}"
228
        schema.encode_example(instance)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc