• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15880205261

25 Jun 2025 03:12PM UTC coverage: 79.77% (+0.06%) from 79.708%
15880205261

push

github

web-flow
Improved error messages (#1838)

* initial

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Improve error messages

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fif error

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix ruff

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Add more contextual error information

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix all tests to pass

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix some more tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix another test

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Fix some error and add contexts

Signed-off-by: elronbandel <elronbandel@gmail.com>

* FIx some tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

* Update inference tests

Signed-off-by: elronbandel <elronbandel@gmail.com>

---------

Signed-off-by: elronbandel <elronbandel@gmail.com>
Co-authored-by: Yoav Katz <68273864+yoavkatz@users.noreply.github.com>

1722 of 2141 branches covered (80.43%)

Branch coverage included in aggregate %.

10699 of 13430 relevant lines covered (79.66%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.86
src/unitxt/task.py
1
import warnings
1✔
2
from typing import Any, Dict, List, Optional, Union
1✔
3

4
from .artifact import fetch_artifact
1✔
5
from .deprecation_utils import deprecation
1✔
6
from .error_utils import Documentation, UnitxtError, UnitxtWarning, error_context
1✔
7
from .logging_utils import get_logger
1✔
8
from .metrics import MetricsList
1✔
9
from .operator import InstanceOperator
1✔
10
from .operators import ArtifactFetcherMixin
1✔
11
from .settings_utils import get_constants, get_settings
1✔
12
from .templates import Template
1✔
13
from .type_utils import (
1✔
14
    Type,
15
    get_args,
16
    get_origin,
17
    is_type_dict,
18
    isoftype,
19
    parse_type_dict,
20
    parse_type_string,
21
    to_type_dict,
22
    to_type_string,
23
    verify_required_schema,
24
)
25

26
constants = get_constants()
1✔
27
logger = get_logger()
1✔
28
settings = get_settings()
1✔
29

30

31
@deprecation(
1✔
32
    version="2.0.0",
33
    msg="use python type instead of type strings (e.g Dict[str] instead of 'Dict[str]')",
34
)
35
def parse_string_types_instead_of_actual_objects(obj):
1✔
36
    if isinstance(obj, dict):
1✔
37
        return parse_type_dict(obj)
1✔
38
    return parse_type_string(obj)
1✔
39

40

41
class Task(InstanceOperator, ArtifactFetcherMixin):
1✔
42
    """Task packs the different instance fields into dictionaries by their roles in the task.
43

44
    Args:
45
        input_fields (Union[Dict[str, str], List[str]]):
46
            Dictionary with string names of instance input fields and types of respective values.
47
            In case a list is passed, each type will be assumed to be Any.
48
        reference_fields (Union[Dict[str, str], List[str]]):
49
            Dictionary with string names of instance output fields and types of respective values.
50
            In case a list is passed, each type will be assumed to be Any.
51
        metrics (List[str]):
52
            List of names of metrics to be used in the task.
53
        prediction_type (Optional[str]):
54
            Need to be consistent with all used metrics. Defaults to None, which means that it will
55
            be set to Any.
56
        defaults (Optional[Dict[str, Any]]):
57
            An optional dictionary with default values for chosen input/output keys. Needs to be
58
            consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
59
            Will not overwrite values if already provided in a given instance.
60

61
    The output instance contains three fields:
62
        1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
63
        2. "reference_fields" -- for the fields listed in Arg "reference_fields".
64
        3. "metrics" -- to contain the value of Arg 'metrics'
65
    """
66

67
    input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
68
    reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
69
    inputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
70
    outputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
71
    metrics: List[str]
1✔
72
    prediction_type: Optional[Union[Type, str]] = None
1✔
73
    augmentable_inputs: List[str] = []
1✔
74
    defaults: Optional[Dict[str, Any]] = None
1✔
75
    default_template: Template = None
1✔
76

77
    def prepare_args(self):
1✔
78
        super().prepare_args()
1✔
79
        if isinstance(self.metrics, str):
1✔
80
            self.metrics = [self.metrics]
1✔
81

82
        if self.input_fields is not None and self.inputs is not None:
1✔
83
            raise UnitxtError(
1✔
84
                "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'",
85
                Documentation.ADDING_TASK,
86
            )
87
        if self.reference_fields is not None and self.outputs is not None:
1✔
88
            raise UnitxtError(
1✔
89
                "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'",
90
                Documentation.ADDING_TASK,
91
            )
92

93
        if self.default_template is not None and not isoftype(
1✔
94
            self.default_template, Template
95
        ):
96
            raise UnitxtError(
×
97
                f"The task's 'default_template' attribute is not of type Template. The 'default_template' attribute is of type {type(self.default_template)}: {self.default_template}",
98
                Documentation.ADDING_TASK,
99
            )
100

101
        self.input_fields = (
1✔
102
            self.input_fields if self.input_fields is not None else self.inputs
103
        )
104
        self.reference_fields = (
1✔
105
            self.reference_fields if self.reference_fields is not None else self.outputs
106
        )
107

108
        if isoftype(self.input_fields, Dict[str, str]):
1✔
109
            self.input_fields = parse_string_types_instead_of_actual_objects(
1✔
110
                self.input_fields
111
            )
112
        if isoftype(self.reference_fields, Dict[str, str]):
1✔
113
            self.reference_fields = parse_string_types_instead_of_actual_objects(
1✔
114
                self.reference_fields
115
            )
116

117
        if isinstance(self.prediction_type, str):
1✔
118
            self.prediction_type = parse_string_types_instead_of_actual_objects(
1✔
119
                self.prediction_type
120
            )
121

122
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
123
            self.inputs = self.input_fields
1✔
124

125
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
126
            self.outputs = self.reference_fields
1✔
127

128
    def task_deprecations(self):
1✔
129
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
130
            depr_message = (
1✔
131
                "The 'inputs' field is deprecated. Please use 'input_fields' instead."
132
            )
133
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
134
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
135
            depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
1✔
136
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
137

138
    def verify(self):
1✔
139
        self.task_deprecations()
1✔
140

141
        if self.input_fields is None:
1✔
142
            raise UnitxtError(
1✔
143
                "Missing attribute in task: 'input_fields' not set.",
144
                Documentation.ADDING_TASK,
145
            )
146
        if self.reference_fields is None:
1✔
147
            raise UnitxtError(
1✔
148
                "Missing attribute in task: 'reference_fields' not set.",
149
                Documentation.ADDING_TASK,
150
            )
151
        for io_type in ["input_fields", "reference_fields"]:
1✔
152
            data = (
1✔
153
                self.input_fields
154
                if io_type == "input_fields"
155
                else self.reference_fields
156
            )
157

158
            if isinstance(data, list) or not is_type_dict(data):
1✔
159
                UnitxtWarning(
1✔
160
                    f"'{io_type}' field of Task should be a dictionary of field names and their types. "
161
                    f"For example, {{'text': str, 'classes': List[str]}}. Instead only '{data}' was "
162
                    f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
163
                    f"will raise an exception.",
164
                    Documentation.ADDING_TASK,
165
                )
166
                if isinstance(data, dict):
1✔
167
                    data = parse_type_dict(to_type_dict(data))
×
168
                else:
169
                    data = {key: Any for key in data}
1✔
170

171
                if io_type == "input_fields":
1✔
172
                    self.input_fields = data
1✔
173
                else:
174
                    self.reference_fields = data
1✔
175

176
        if not self.prediction_type:
1✔
177
            UnitxtWarning(
1✔
178
                "'prediction_type' was not set in Task. It is used to check the output of "
179
                "template post processors is compatible with the expected input of the metrics. "
180
                "Setting `prediction_type` to 'Any' (no checking is done). In future version "
181
                "of unitxt this will raise an exception.",
182
                Documentation.ADDING_TASK,
183
            )
184
            self.prediction_type = Any
1✔
185

186
        self.check_metrics_type()
1✔
187

188
        for augmentable_input in self.augmentable_inputs:
1✔
189
            assert (
1✔
190
                augmentable_input in self.input_fields
191
            ), f"augmentable_input {augmentable_input} is not part of {self.input_fields}"
192

193
        self.verify_defaults()
1✔
194

195
    @classmethod
1✔
196
    def process_data_after_load(cls, data):
1✔
197
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
198
        for dict_name in possible_dicts:
1✔
199
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
200
                data[dict_name] = parse_type_dict(data[dict_name])
1✔
201
        if "prediction_type" in data:
1✔
202
            data["prediction_type"] = parse_type_string(data["prediction_type"])
1✔
203
        return data
1✔
204

205
    def process_data_before_dump(self, data):
1✔
206
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
207
        for dict_name in possible_dicts:
1✔
208
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
209
                if not isoftype(data[dict_name], Dict[str, str]):
1✔
210
                    data[dict_name] = to_type_dict(data[dict_name])
1✔
211
        if "prediction_type" in data:
1✔
212
            if not isinstance(data["prediction_type"], str):
1✔
213
                data["prediction_type"] = to_type_string(data["prediction_type"])
1✔
214
        return data
1✔
215

216
    @classmethod
1✔
217
    def get_metrics_artifact_without_load(cls, metric_id: str):
1✔
218
        with settings.context(skip_artifacts_prepare_and_verify=True):
1✔
219
            metric, _ = fetch_artifact(metric_id)
1✔
220
        if isinstance(metric, MetricsList):
1✔
221
            return metric.items
×
222
        return [metric]
1✔
223

224
    def check_metrics_type(self) -> None:
1✔
225
        prediction_type = self.prediction_type
1✔
226
        for metric_id in self.metrics:
1✔
227
            metric_artifacts_list = Task.get_metrics_artifact_without_load(metric_id)
1✔
228
            for metric_artifact in metric_artifacts_list:
1✔
229
                metric_prediction_type = metric_artifact.prediction_type
1✔
230
                if (
1✔
231
                    prediction_type == metric_prediction_type
232
                    or prediction_type == Any
233
                    or metric_prediction_type == Any
234
                    or (
235
                        get_origin(metric_prediction_type) is Union
236
                        and prediction_type in get_args(metric_prediction_type)
237
                    )
238
                ):
239
                    continue
1✔
240

241
                raise UnitxtError(
1✔
242
                    f"The task's prediction type ({prediction_type}) and '{metric_id}' "
243
                    f"metric's prediction type ({metric_prediction_type}) are different.",
244
                    Documentation.ADDING_TASK,
245
                )
246

247
    def verify_defaults(self):
1✔
248
        if self.defaults:
1✔
249
            if not isinstance(self.defaults, dict):
1✔
250
                raise UnitxtError(
×
251
                    f"If specified, the 'defaults' must be a dictionary, "
252
                    f"however, '{self.defaults}' was provided instead, "
253
                    f"which is of type '{to_type_string(type(self.defaults))}'.",
254
                    Documentation.ADDING_TASK,
255
                )
256

257
            for default_name, default_value in self.defaults.items():
1✔
258
                assert isinstance(default_name, str), (
1✔
259
                    f"If specified, all keys of the 'defaults' must be strings, "
260
                    f"however, the key '{default_name}' is of type '{to_type_string(type(default_name))}'."
261
                )
262

263
                val_type = self.input_fields.get(
1✔
264
                    default_name
265
                ) or self.reference_fields.get(default_name)
266

267
                assert val_type, (
1✔
268
                    f"If specified, all keys of the 'defaults' must refer to a chosen "
269
                    f"key in either 'input_fields' or 'reference_fields'. However, the name '{default_name}' "
270
                    f"was provided which does not match any of the keys."
271
                )
272

273
                assert isoftype(default_value, val_type), (
1✔
274
                    f"The value of '{default_name}' from the 'defaults' must be of "
275
                    f"type '{to_type_string(val_type)}', however, it is of type '{to_type_string(type(default_value))}'."
276
                )
277

278
    def set_default_values(self, instance: Dict[str, Any]) -> Dict[str, Any]:
1✔
279
        if self.defaults:
1✔
280
            instance = {**self.defaults, **instance}
1✔
281
        return instance
1✔
282

283
    def process(
1✔
284
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
285
    ) -> Dict[str, Any]:
286
        instance = self.set_default_values(instance)
1✔
287

288
        with error_context(
1✔
289
            self,
290
            stage="Schema Verification",
291
            help="https://www.unitxt.ai/en/latest/docs/adding_task.html",
292
        ):
293
            verify_required_schema(
1✔
294
                self.input_fields,
295
                instance,
296
                class_name="Task",
297
                id=self.__id__,
298
                description=self.__description__,
299
            )
300
        input_fields = {key: instance[key] for key in self.input_fields.keys()}
1✔
301
        data_classification_policy = instance.get("data_classification_policy", [])
1✔
302

303
        result = {
1✔
304
            "input_fields": input_fields,
305
            "metrics": self.metrics,
306
            "data_classification_policy": data_classification_policy,
307
            "media": instance.get("media", {}),
308
            "recipe_metadata": instance.get("recipe_metadata", {}),
309
        }
310
        if constants.demos_field in instance:
1✔
311
            # for the case of recipe.skip_demoed_instances
312
            result[constants.demos_field] = instance[constants.demos_field]
1✔
313

314
        if constants.instruction_field in instance:
1✔
315
            result[constants.instruction_field] = instance[constants.instruction_field]
×
316

317
        if constants.system_prompt_field in instance:
1✔
318
            result[constants.system_prompt_field] = instance[
×
319
                constants.system_prompt_field
320
            ]
321

322
        if stream_name == constants.inference_stream:
1✔
323
            return result
1✔
324

325
        verify_required_schema(
1✔
326
            self.reference_fields,
327
            instance,
328
            class_name="Task",
329
            id=self.__id__,
330
            description=self.__description__,
331
        )
332
        result["reference_fields"] = {
1✔
333
            key: instance[key] for key in self.reference_fields.keys()
334
        }
335

336
        return result
1✔
337

338

339
@deprecation(version="2.0.0", alternative=Task)
1✔
340
class FormTask(Task):
1✔
341
    pass
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc