• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12652101016

07 Jan 2025 01:01PM UTC coverage: 80.239% (+0.005%) from 80.234%
12652101016

Pull #1487

github

web-flow
Merge ab2eec642 into 822eb5ad2
Pull Request #1487: Fix bug in metrics loading in tasks

1381 of 1712 branches covered (80.67%)

Branch coverage included in aggregate %.

8697 of 10848 relevant lines covered (80.17%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.81
src/unitxt/task.py
1
import warnings
1✔
2
from functools import lru_cache
1✔
3
from typing import Any, Dict, List, Optional, Union
1✔
4

5
from .deprecation_utils import deprecation
1✔
6
from .error_utils import Documentation, UnitxtError, UnitxtWarning
1✔
7
from .logging_utils import get_logger
1✔
8
from .metrics import MetricsList
1✔
9
from .operator import InstanceOperator
1✔
10
from .operators import ArtifactFetcherMixin
1✔
11
from .settings_utils import get_constants
1✔
12
from .templates import Template
1✔
13
from .type_utils import (
1✔
14
    Type,
15
    get_args,
16
    get_origin,
17
    is_type_dict,
18
    isoftype,
19
    parse_type_dict,
20
    parse_type_string,
21
    to_type_dict,
22
    to_type_string,
23
    verify_required_schema,
24
)
25

26
constants = get_constants()
1✔
27
logger = get_logger()
1✔
28

29

30
@deprecation(
1✔
31
    version="2.0.0",
32
    msg="use python type instead of type strings (e.g Dict[str] instead of 'Dict[str]')",
33
)
34
def parse_string_types_instead_of_actual_objects(obj):
1✔
35
    if isinstance(obj, dict):
1✔
36
        return parse_type_dict(obj)
1✔
37
    return parse_type_string(obj)
1✔
38

39

40
class Task(InstanceOperator, ArtifactFetcherMixin):
1✔
41
    """Task packs the different instance fields into dictionaries by their roles in the task.
42

43
    Args:
44
        input_fields (Union[Dict[str, str], List[str]]):
45
            Dictionary with string names of instance input fields and types of respective values.
46
            In case a list is passed, each type will be assumed to be Any.
47
        reference_fields (Union[Dict[str, str], List[str]]):
48
            Dictionary with string names of instance output fields and types of respective values.
49
            In case a list is passed, each type will be assumed to be Any.
50
        metrics (List[str]):
51
            List of names of metrics to be used in the task.
52
        prediction_type (Optional[str]):
53
            Need to be consistent with all used metrics. Defaults to None, which means that it will
54
            be set to Any.
55
        defaults (Optional[Dict[str, Any]]):
56
            An optional dictionary with default values for chosen input/output keys. Needs to be
57
            consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
58
            Will not overwrite values if already provided in a given instance.
59

60
    The output instance contains three fields:
61
        1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
62
        2. "reference_fields" -- for the fields listed in Arg "reference_fields".
63
        3. "metrics" -- to contain the value of Arg 'metrics'
64
    """
65

66
    input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
67
    reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
68
    inputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
69
    outputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
70
    metrics: List[str]
1✔
71
    prediction_type: Optional[Union[Type, str]] = None
1✔
72
    augmentable_inputs: List[str] = []
1✔
73
    defaults: Optional[Dict[str, Any]] = None
1✔
74
    default_template: Template = None
1✔
75

76
    def prepare_args(self):
1✔
77
        super().prepare_args()
1✔
78
        if isinstance(self.metrics, str):
1✔
79
            self.metrics = [self.metrics]
1✔
80

81
        if self.input_fields is not None and self.inputs is not None:
1✔
82
            raise UnitxtError(
1✔
83
                "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'",
84
                Documentation.ADDING_TASK,
85
            )
86
        if self.reference_fields is not None and self.outputs is not None:
1✔
87
            raise UnitxtError(
1✔
88
                "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'",
89
                Documentation.ADDING_TASK,
90
            )
91

92
        if self.default_template is not None and not isoftype(
1✔
93
            self.default_template, Template
94
        ):
95
            raise UnitxtError(
×
96
                f"The task's 'default_template' attribute is not of type Template. The 'default_template' attribute is of type {type(self.default_template)}: {self.default_template}",
97
                Documentation.ADDING_TASK,
98
            )
99

100
        self.input_fields = (
1✔
101
            self.input_fields if self.input_fields is not None else self.inputs
102
        )
103
        self.reference_fields = (
1✔
104
            self.reference_fields if self.reference_fields is not None else self.outputs
105
        )
106

107
        if isoftype(self.input_fields, Dict[str, str]):
1✔
108
            self.input_fields = parse_string_types_instead_of_actual_objects(
1✔
109
                self.input_fields
110
            )
111
        if isoftype(self.reference_fields, Dict[str, str]):
1✔
112
            self.reference_fields = parse_string_types_instead_of_actual_objects(
1✔
113
                self.reference_fields
114
            )
115

116
        if isinstance(self.prediction_type, str):
1✔
117
            self.prediction_type = parse_string_types_instead_of_actual_objects(
1✔
118
                self.prediction_type
119
            )
120

121
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
122
            self.inputs = self.input_fields
1✔
123

124
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
125
            self.outputs = self.reference_fields
1✔
126

127
    def task_deprecations(self):
1✔
128
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
129
            depr_message = (
1✔
130
                "The 'inputs' field is deprecated. Please use 'input_fields' instead."
131
            )
132
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
133
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
134
            depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
1✔
135
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
136

137
    def verify(self):
1✔
138
        self.task_deprecations()
1✔
139

140
        if self.input_fields is None:
1✔
141
            raise UnitxtError(
1✔
142
                "Missing attribute in task: 'input_fields' not set.",
143
                Documentation.ADDING_TASK,
144
            )
145
        if self.reference_fields is None:
1✔
146
            raise UnitxtError(
1✔
147
                "Missing attribute in task: 'reference_fields' not set.",
148
                Documentation.ADDING_TASK,
149
            )
150
        for io_type in ["input_fields", "reference_fields"]:
1✔
151
            data = (
1✔
152
                self.input_fields
153
                if io_type == "input_fields"
154
                else self.reference_fields
155
            )
156

157
            if isinstance(data, list) or not is_type_dict(data):
1✔
158
                UnitxtWarning(
1✔
159
                    f"'{io_type}' field of Task should be a dictionary of field names and their types. "
160
                    f"For example, {{'text': str, 'classes': List[str]}}. Instead only '{data}' was "
161
                    f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
162
                    f"will raise an exception.",
163
                    Documentation.ADDING_TASK,
164
                )
165
                if isinstance(data, dict):
1✔
166
                    data = parse_type_dict(to_type_dict(data))
×
167
                else:
168
                    data = {key: Any for key in data}
1✔
169

170
                if io_type == "input_fields":
1✔
171
                    self.input_fields = data
1✔
172
                else:
173
                    self.reference_fields = data
1✔
174

175
        if not self.prediction_type:
1✔
176
            UnitxtWarning(
1✔
177
                "'prediction_type' was not set in Task. It is used to check the output of "
178
                "template post processors is compatible with the expected input of the metrics. "
179
                "Setting `prediction_type` to 'Any' (no checking is done). In future version "
180
                "of unitxt this will raise an exception.",
181
                Documentation.ADDING_TASK,
182
            )
183
            self.prediction_type = Any
1✔
184

185
        self.check_metrics_type()
1✔
186

187
        for augmentable_input in self.augmentable_inputs:
1✔
188
            assert (
1✔
189
                augmentable_input in self.input_fields
190
            ), f"augmentable_input {augmentable_input} is not part of {self.input_fields}"
191

192
        self.verify_defaults()
1✔
193

194
    @classmethod
1✔
195
    def process_data_after_load(cls, data):
1✔
196
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
197
        for dict_name in possible_dicts:
1✔
198
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
199
                data[dict_name] = parse_type_dict(data[dict_name])
1✔
200
        if "prediction_type" in data:
1✔
201
            data["prediction_type"] = parse_type_string(data["prediction_type"])
1✔
202
        return data
1✔
203

204
    def process_data_before_dump(self, data):
1✔
205
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
206
        for dict_name in possible_dicts:
1✔
207
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
208
                if not isoftype(data[dict_name], Dict[str, str]):
1✔
209
                    data[dict_name] = to_type_dict(data[dict_name])
1✔
210
        if "prediction_type" in data:
1✔
211
            if not isinstance(data["prediction_type"], str):
1✔
212
                data["prediction_type"] = to_type_string(data["prediction_type"])
1✔
213
        return data
1✔
214

215
    @classmethod
1✔
216
    @lru_cache(maxsize=None)
1✔
217
    def get_metrics_artifacts(cls, metric_id: str):
1✔
218
        metric = cls.get_artifact(metric_id)
1✔
219
        if isinstance(metric, MetricsList):
1✔
220
            return metric.items
×
221
        return [metric]
1✔
222

223
    def check_metrics_type(self) -> None:
1✔
224
        prediction_type = self.prediction_type
1✔
225
        for metric_id in self.metrics:
1✔
226
            metric_artifacts_list = Task.get_metrics_artifacts(metric_id)
1✔
227
            for metric_artifact in metric_artifacts_list:
1✔
228
                metric_prediction_type = metric_artifact.prediction_type
1✔
229
                if (
1✔
230
                    prediction_type == metric_prediction_type
231
                    or prediction_type == Any
232
                    or metric_prediction_type == Any
233
                    or (
234
                        get_origin(metric_prediction_type) is Union
235
                        and prediction_type in get_args(metric_prediction_type)
236
                    )
237
                ):
238
                    continue
1✔
239

240
                raise UnitxtError(
1✔
241
                    f"The task's prediction type ({prediction_type}) and '{metric_id}' "
242
                    f"metric's prediction type ({metric_prediction_type}) are different.",
243
                    Documentation.ADDING_TASK,
244
                )
245

246
    def verify_defaults(self):
1✔
247
        if self.defaults:
1✔
248
            if not isinstance(self.defaults, dict):
1✔
249
                raise UnitxtError(
×
250
                    f"If specified, the 'defaults' must be a dictionary, "
251
                    f"however, '{self.defaults}' was provided instead, "
252
                    f"which is of type '{to_type_string(type(self.defaults))}'.",
253
                    Documentation.ADDING_TASK,
254
                )
255

256
            for default_name, default_value in self.defaults.items():
1✔
257
                assert isinstance(default_name, str), (
1✔
258
                    f"If specified, all keys of the 'defaults' must be strings, "
259
                    f"however, the key '{default_name}' is of type '{to_type_string(type(default_name))}'."
260
                )
261

262
                val_type = self.input_fields.get(
1✔
263
                    default_name
264
                ) or self.reference_fields.get(default_name)
265

266
                assert val_type, (
1✔
267
                    f"If specified, all keys of the 'defaults' must refer to a chosen "
268
                    f"key in either 'input_fields' or 'reference_fields'. However, the name '{default_name}' "
269
                    f"was provided which does not match any of the keys."
270
                )
271

272
                assert isoftype(default_value, val_type), (
1✔
273
                    f"The value of '{default_name}' from the 'defaults' must be of "
274
                    f"type '{to_type_string(val_type)}', however, it is of type '{to_type_string(type(default_value))}'."
275
                )
276

277
    def set_default_values(self, instance: Dict[str, Any]) -> Dict[str, Any]:
1✔
278
        if self.defaults:
1✔
279
            instance = {**self.defaults, **instance}
1✔
280
        return instance
1✔
281

282
    def process(
1✔
283
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
284
    ) -> Dict[str, Any]:
285
        instance = self.set_default_values(instance)
1✔
286

287
        verify_required_schema(
1✔
288
            self.input_fields,
289
            instance,
290
            class_name="Task",
291
            id=self.__id__,
292
            description=self.__description__,
293
        )
294
        input_fields = {key: instance[key] for key in self.input_fields.keys()}
1✔
295
        data_classification_policy = instance.get("data_classification_policy", [])
1✔
296

297
        result = {
1✔
298
            "input_fields": input_fields,
299
            "metrics": self.metrics,
300
            "data_classification_policy": data_classification_policy,
301
            "media": instance.get("media", {}),
302
            "recipe_metadata": instance.get("recipe_metadata", {}),
303
        }
304
        if "demos" in instance:
1✔
305
            # for the case of recipe.skip_demoed_instances
306
            result["demos"] = instance["demos"]
1✔
307

308
        if stream_name == constants.inference_stream:
1✔
309
            return result
1✔
310

311
        verify_required_schema(
1✔
312
            self.reference_fields,
313
            instance,
314
            class_name="Task",
315
            id=self.__id__,
316
            description=self.__description__,
317
        )
318
        result["reference_fields"] = {
1✔
319
            key: instance[key] for key in self.reference_fields.keys()
320
        }
321

322
        return result
1✔
323

324

325
@deprecation(version="2.0.0", alternative=Task)
1✔
326
class FormTask(Task):
1✔
327
    pass
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc