• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12618578668

05 Jan 2025 10:05AM UTC coverage: 80.02% (-0.01%) from 80.031%
12618578668

Pull #1470

github

web-flow
Merge 5f0250f7f into 5689aedfe
Pull Request #1470: Fix the type handeling for tasks to support string types

1339 of 1668 branches covered (80.28%)

Branch coverage included in aggregate %.

8449 of 10564 relevant lines covered (79.98%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.7
src/unitxt/task.py
1
import warnings
1✔
2
from functools import lru_cache
1✔
3
from typing import Any, Dict, List, Optional, Union
1✔
4

5
from .deprecation_utils import deprecation
1✔
6
from .error_utils import Documentation, UnitxtError, UnitxtWarning
1✔
7
from .logging_utils import get_logger
1✔
8
from .metrics import MetricsList
1✔
9
from .operator import InstanceOperator
1✔
10
from .operators import ArtifactFetcherMixin
1✔
11
from .settings_utils import get_constants
1✔
12
from .templates import Template
1✔
13
from .type_utils import (
1✔
14
    Type,
15
    get_args,
16
    get_origin,
17
    is_type_dict,
18
    isoftype,
19
    parse_type_dict,
20
    parse_type_string,
21
    to_type_dict,
22
    to_type_string,
23
    verify_required_schema,
24
)
25

26
constants = get_constants()
1✔
27
logger = get_logger()
1✔
28

29

30
@deprecation(
1✔
31
    version="2.0.0",
32
    msg="use python type instead of type strings (e.g Dict[str] instead of 'Dict[str]')",
33
)
34
def parse_string_types_instead_of_actual_objects(obj):
1✔
35
    if isinstance(obj, dict):
1✔
36
        return parse_type_dict(obj)
1✔
37
    return parse_type_string(obj)
1✔
38

39

40
class Task(InstanceOperator, ArtifactFetcherMixin):
1✔
41
    """Task packs the different instance fields into dictionaries by their roles in the task.
42

43
    Args:
44
        input_fields (Union[Dict[str, str], List[str]]):
45
            Dictionary with string names of instance input fields and types of respective values.
46
            In case a list is passed, each type will be assumed to be Any.
47
        reference_fields (Union[Dict[str, str], List[str]]):
48
            Dictionary with string names of instance output fields and types of respective values.
49
            In case a list is passed, each type will be assumed to be Any.
50
        metrics (List[str]):
51
            List of names of metrics to be used in the task.
52
        prediction_type (Optional[str]):
53
            Need to be consistent with all used metrics. Defaults to None, which means that it will
54
            be set to Any.
55
        defaults (Optional[Dict[str, Any]]):
56
            An optional dictionary with default values for chosen input/output keys. Needs to be
57
            consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
58
            Will not overwrite values if already provided in a given instance.
59

60
    The output instance contains three fields:
61
        1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
62
        2. "reference_fields" -- for the fields listed in Arg "reference_fields".
63
        3. "metrics" -- to contain the value of Arg 'metrics'
64
    """
65

66
    input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
67
    reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
68
    inputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
69
    outputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
70
    metrics: List[str]
1✔
71
    prediction_type: Optional[Union[Type, str]] = None
1✔
72
    augmentable_inputs: List[str] = []
1✔
73
    defaults: Optional[Dict[str, Any]] = None
1✔
74
    default_template: Template = None
1✔
75

76
    def prepare_args(self):
1✔
77
        super().prepare_args()
1✔
78

79
        if self.input_fields is not None and self.inputs is not None:
1✔
80
            raise UnitxtError(
1✔
81
                "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'",
82
                Documentation.ADDING_TASK,
83
            )
84
        if self.reference_fields is not None and self.outputs is not None:
1✔
85
            raise UnitxtError(
1✔
86
                "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'",
87
                Documentation.ADDING_TASK,
88
            )
89

90
        if self.default_template is not None and not isoftype(
1✔
91
            self.default_template, Template
92
        ):
93
            raise UnitxtError(
×
94
                f"The task's 'default_template' attribute is not of type Template. The 'default_template' attribute is of type {type(self.default_template)}: {self.default_template}",
95
                Documentation.ADDING_TASK,
96
            )
97

98
        self.input_fields = (
1✔
99
            self.input_fields if self.input_fields is not None else self.inputs
100
        )
101
        self.reference_fields = (
1✔
102
            self.reference_fields if self.reference_fields is not None else self.outputs
103
        )
104

105
        if isoftype(self.input_fields, Dict[str, str]):
1✔
106
            self.input_fields = parse_string_types_instead_of_actual_objects(
1✔
107
                self.input_fields
108
            )
109
        if isoftype(self.reference_fields, Dict[str, str]):
1✔
110
            self.reference_fields = parse_string_types_instead_of_actual_objects(
1✔
111
                self.reference_fields
112
            )
113

114
        if isinstance(self.prediction_type, str):
1✔
115
            self.prediction_type = parse_string_types_instead_of_actual_objects(
1✔
116
                self.prediction_type
117
            )
118

119
    def task_deprecations(self):
1✔
120
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
121
            depr_message = (
1✔
122
                "The 'inputs' field is deprecated. Please use 'input_fields' instead."
123
            )
124
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
125

126
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
127
            depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
1✔
128
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
129

130
    def verify(self):
1✔
131
        self.task_deprecations()
1✔
132

133
        if self.input_fields is None:
1✔
134
            raise UnitxtError(
1✔
135
                "Missing attribute in task: 'input_fields' not set.",
136
                Documentation.ADDING_TASK,
137
            )
138
        if self.reference_fields is None:
1✔
139
            raise UnitxtError(
1✔
140
                "Missing attribute in task: 'reference_fields' not set.",
141
                Documentation.ADDING_TASK,
142
            )
143
        for io_type in ["input_fields", "reference_fields"]:
1✔
144
            data = (
1✔
145
                self.input_fields
146
                if io_type == "input_fields"
147
                else self.reference_fields
148
            )
149

150
            if isinstance(data, list) or not is_type_dict(data):
1✔
151
                UnitxtWarning(
1✔
152
                    f"'{io_type}' field of Task should be a dictionary of field names and their types. "
153
                    f"For example, {{'text': str, 'classes': List[str]}}. Instead only '{data}' was "
154
                    f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
155
                    f"will raise an exception.",
156
                    Documentation.ADDING_TASK,
157
                )
158
                if isinstance(data, dict):
1✔
159
                    data = parse_type_dict(to_type_dict(data))
×
160
                else:
161
                    data = {key: Any for key in data}
1✔
162

163
                if io_type == "input_fields":
1✔
164
                    self.input_fields = data
1✔
165
                else:
166
                    self.reference_fields = data
1✔
167

168
        if not self.prediction_type:
1✔
169
            UnitxtWarning(
1✔
170
                "'prediction_type' was not set in Task. It is used to check the output of "
171
                "template post processors is compatible with the expected input of the metrics. "
172
                "Setting `prediction_type` to 'Any' (no checking is done). In future version "
173
                "of unitxt this will raise an exception.",
174
                Documentation.ADDING_TASK,
175
            )
176
            self.prediction_type = Any
1✔
177

178
        self.check_metrics_type()
1✔
179

180
        for augmentable_input in self.augmentable_inputs:
1✔
181
            assert (
1✔
182
                augmentable_input in self.input_fields
183
            ), f"augmentable_input {augmentable_input} is not part of {self.input_fields}"
184

185
        self.verify_defaults()
1✔
186

187
    @classmethod
1✔
188
    def process_data_after_load(cls, data):
1✔
189
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
190
        for dict_name in possible_dicts:
1✔
191
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
192
                data[dict_name] = parse_type_dict(data[dict_name])
1✔
193
        if "prediction_type" in data:
1✔
194
            data["prediction_type"] = parse_type_string(data["prediction_type"])
1✔
195
        return data
1✔
196

197
    def process_data_before_dump(self, data):
1✔
198
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
199
        for dict_name in possible_dicts:
1✔
200
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
201
                if not isoftype(data[dict_name], Dict[str, str]):
1✔
202
                    data[dict_name] = to_type_dict(data[dict_name])
1✔
203
        if "prediction_type" in data:
1✔
204
            if not isinstance(data["prediction_type"], str):
1✔
205
                data["prediction_type"] = to_type_string(data["prediction_type"])
1✔
206
        return data
1✔
207

208
    @classmethod
1✔
209
    @lru_cache(maxsize=None)
1✔
210
    def get_metrics_artifacts(cls, metric_id: str):
1✔
211
        metric = cls.get_artifact(metric_id)
1✔
212
        if isinstance(metric, MetricsList):
1✔
213
            return metric.items
×
214
        return [metric]
1✔
215

216
    def check_metrics_type(self) -> None:
1✔
217
        prediction_type = self.prediction_type
1✔
218
        for metric_id in self.metrics:
1✔
219
            metric_artifacts_list = Task.get_metrics_artifacts(metric_id)
1✔
220
            for metric_artifact in metric_artifacts_list:
1✔
221
                metric_prediction_type = metric_artifact.prediction_type
1✔
222
                if (
1✔
223
                    prediction_type == metric_prediction_type
224
                    or prediction_type == Any
225
                    or metric_prediction_type == Any
226
                    or (
227
                        get_origin(metric_prediction_type) is Union
228
                        and prediction_type in get_args(metric_prediction_type)
229
                    )
230
                ):
231
                    continue
1✔
232

233
                raise UnitxtError(
1✔
234
                    f"The task's prediction type ({prediction_type}) and '{metric_id}' "
235
                    f"metric's prediction type ({metric_prediction_type}) are different.",
236
                    Documentation.ADDING_TASK,
237
                )
238

239
    def verify_defaults(self):
1✔
240
        if self.defaults:
1✔
241
            if not isinstance(self.defaults, dict):
1✔
242
                raise UnitxtError(
×
243
                    f"If specified, the 'defaults' must be a dictionary, "
244
                    f"however, '{self.defaults}' was provided instead, "
245
                    f"which is of type '{to_type_string(type(self.defaults))}'.",
246
                    Documentation.ADDING_TASK,
247
                )
248

249
            for default_name, default_value in self.defaults.items():
1✔
250
                assert isinstance(default_name, str), (
1✔
251
                    f"If specified, all keys of the 'defaults' must be strings, "
252
                    f"however, the key '{default_name}' is of type '{to_type_string(type(default_name))}'."
253
                )
254

255
                val_type = self.input_fields.get(
1✔
256
                    default_name
257
                ) or self.reference_fields.get(default_name)
258

259
                assert val_type, (
1✔
260
                    f"If specified, all keys of the 'defaults' must refer to a chosen "
261
                    f"key in either 'input_fields' or 'reference_fields'. However, the name '{default_name}' "
262
                    f"was provided which does not match any of the keys."
263
                )
264

265
                assert isoftype(default_value, val_type), (
1✔
266
                    f"The value of '{default_name}' from the 'defaults' must be of "
267
                    f"type '{to_type_string(val_type)}', however, it is of type '{to_type_string(type(default_value))}'."
268
                )
269

270
    def set_default_values(self, instance: Dict[str, Any]) -> Dict[str, Any]:
1✔
271
        if self.defaults:
1✔
272
            instance = {**self.defaults, **instance}
1✔
273
        return instance
1✔
274

275
    def process(
1✔
276
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
277
    ) -> Dict[str, Any]:
278
        instance = self.set_default_values(instance)
1✔
279

280
        verify_required_schema(
1✔
281
            self.input_fields,
282
            instance,
283
            class_name="Task",
284
            id=self.__id__,
285
            description=self.__description__,
286
        )
287
        input_fields = {key: instance[key] for key in self.input_fields.keys()}
1✔
288
        data_classification_policy = instance.get("data_classification_policy", [])
1✔
289

290
        result = {
1✔
291
            "input_fields": input_fields,
292
            "metrics": self.metrics,
293
            "data_classification_policy": data_classification_policy,
294
            "media": instance.get("media", {}),
295
            "recipe_metadata": instance.get("recipe_metadata", {}),
296
        }
297
        if "demos" in instance:
1✔
298
            # for the case of recipe.skip_demoed_instances
299
            result["demos"] = instance["demos"]
1✔
300

301
        if stream_name == constants.inference_stream:
1✔
302
            return result
1✔
303

304
        verify_required_schema(
1✔
305
            self.reference_fields,
306
            instance,
307
            class_name="Task",
308
            id=self.__id__,
309
            description=self.__description__,
310
        )
311
        result["reference_fields"] = {
1✔
312
            key: instance[key] for key in self.reference_fields.keys()
313
        }
314

315
        return result
1✔
316

317

318
@deprecation(version="2.0.0", alternative=Task)
1✔
319
class FormTask(Task):
1✔
320
    pass
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc