• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12635975303

06 Jan 2025 04:10PM UTC coverage: 80.075% (+0.02%) from 80.057%
12635975303

Pull #1473

github

web-flow
Merge cf8796abc into b32bb80fa
Pull Request #1473: Adding typed recipe test

1347 of 1676 branches covered (80.37%)

Branch coverage included in aggregate %.

8471 of 10585 relevant lines covered (80.03%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.78
src/unitxt/task.py
1
import warnings
1✔
2
from functools import lru_cache
1✔
3
from typing import Any, Dict, List, Optional, Union
1✔
4

5
from .deprecation_utils import deprecation
1✔
6
from .error_utils import Documentation, UnitxtError, UnitxtWarning
1✔
7
from .logging_utils import get_logger
1✔
8
from .metrics import MetricsList
1✔
9
from .operator import InstanceOperator
1✔
10
from .operators import ArtifactFetcherMixin
1✔
11
from .settings_utils import get_constants
1✔
12
from .templates import Template
1✔
13
from .type_utils import (
1✔
14
    Type,
15
    get_args,
16
    get_origin,
17
    is_type_dict,
18
    isoftype,
19
    parse_type_dict,
20
    parse_type_string,
21
    to_type_dict,
22
    to_type_string,
23
    verify_required_schema,
24
)
25

26
constants = get_constants()
1✔
27
logger = get_logger()
1✔
28

29

30
@deprecation(
1✔
31
    version="2.0.0",
32
    msg="use python type instead of type strings (e.g Dict[str] instead of 'Dict[str]')",
33
)
34
def parse_string_types_instead_of_actual_objects(obj):
1✔
35
    if isinstance(obj, dict):
1✔
36
        return parse_type_dict(obj)
1✔
37
    return parse_type_string(obj)
1✔
38

39

40
class Task(InstanceOperator, ArtifactFetcherMixin):
1✔
41
    """Task packs the different instance fields into dictionaries by their roles in the task.
42

43
    Args:
44
        input_fields (Union[Dict[str, str], List[str]]):
45
            Dictionary with string names of instance input fields and types of respective values.
46
            In case a list is passed, each type will be assumed to be Any.
47
        reference_fields (Union[Dict[str, str], List[str]]):
48
            Dictionary with string names of instance output fields and types of respective values.
49
            In case a list is passed, each type will be assumed to be Any.
50
        metrics (List[str]):
51
            List of names of metrics to be used in the task.
52
        prediction_type (Optional[str]):
53
            Need to be consistent with all used metrics. Defaults to None, which means that it will
54
            be set to Any.
55
        defaults (Optional[Dict[str, Any]]):
56
            An optional dictionary with default values for chosen input/output keys. Needs to be
57
            consistent with names and types provided in 'input_fields' and/or 'output_fields' arguments.
58
            Will not overwrite values if already provided in a given instance.
59

60
    The output instance contains three fields:
61
        1. "input_fields" whose value is a sub-dictionary of the input instance, consisting of all the fields listed in Arg 'input_fields'.
62
        2. "reference_fields" -- for the fields listed in Arg "reference_fields".
63
        3. "metrics" -- to contain the value of Arg 'metrics'
64
    """
65

66
    input_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
67
    reference_fields: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
68
    inputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
69
    outputs: Optional[Union[Dict[str, Type], Dict[str, str], List[str]]] = None
1✔
70
    metrics: List[str]
1✔
71
    prediction_type: Optional[Union[Type, str]] = None
1✔
72
    augmentable_inputs: List[str] = []
1✔
73
    defaults: Optional[Dict[str, Any]] = None
1✔
74
    default_template: Template = None
1✔
75

76
    def prepare_args(self):
1✔
77
        super().prepare_args()
1✔
78

79
        if self.input_fields is not None and self.inputs is not None:
1✔
80
            raise UnitxtError(
1✔
81
                "Conflicting attributes: 'input_fields' cannot be set simultaneously with 'inputs'. Use only 'input_fields'",
82
                Documentation.ADDING_TASK,
83
            )
84
        if self.reference_fields is not None and self.outputs is not None:
1✔
85
            raise UnitxtError(
1✔
86
                "Conflicting attributes: 'reference_fields' cannot be set simultaneously with 'output'. Use only 'reference_fields'",
87
                Documentation.ADDING_TASK,
88
            )
89

90
        if self.default_template is not None and not isoftype(
1✔
91
            self.default_template, Template
92
        ):
93
            raise UnitxtError(
×
94
                f"The task's 'default_template' attribute is not of type Template. The 'default_template' attribute is of type {type(self.default_template)}: {self.default_template}",
95
                Documentation.ADDING_TASK,
96
            )
97

98
        self.input_fields = (
1✔
99
            self.input_fields if self.input_fields is not None else self.inputs
100
        )
101
        self.reference_fields = (
1✔
102
            self.reference_fields if self.reference_fields is not None else self.outputs
103
        )
104

105
        if isoftype(self.input_fields, Dict[str, str]):
1✔
106
            self.input_fields = parse_string_types_instead_of_actual_objects(
1✔
107
                self.input_fields
108
            )
109
        if isoftype(self.reference_fields, Dict[str, str]):
1✔
110
            self.reference_fields = parse_string_types_instead_of_actual_objects(
1✔
111
                self.reference_fields
112
            )
113

114
        if isinstance(self.prediction_type, str):
1✔
115
            self.prediction_type = parse_string_types_instead_of_actual_objects(
1✔
116
                self.prediction_type
117
            )
118

119
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
120
            self.inputs = self.input_fields
1✔
121

122
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
123
            self.outputs = self.reference_fields
1✔
124

125
    def task_deprecations(self):
1✔
126
        if hasattr(self, "inputs") and self.inputs is not None:
1✔
127
            depr_message = (
1✔
128
                "The 'inputs' field is deprecated. Please use 'input_fields' instead."
129
            )
130
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
131
        if hasattr(self, "outputs") and self.outputs is not None:
1✔
132
            depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
1✔
133
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
134

135
    def verify(self):
1✔
136
        self.task_deprecations()
1✔
137

138
        if self.input_fields is None:
1✔
139
            raise UnitxtError(
1✔
140
                "Missing attribute in task: 'input_fields' not set.",
141
                Documentation.ADDING_TASK,
142
            )
143
        if self.reference_fields is None:
1✔
144
            raise UnitxtError(
1✔
145
                "Missing attribute in task: 'reference_fields' not set.",
146
                Documentation.ADDING_TASK,
147
            )
148
        for io_type in ["input_fields", "reference_fields"]:
1✔
149
            data = (
1✔
150
                self.input_fields
151
                if io_type == "input_fields"
152
                else self.reference_fields
153
            )
154

155
            if isinstance(data, list) or not is_type_dict(data):
1✔
156
                UnitxtWarning(
1✔
157
                    f"'{io_type}' field of Task should be a dictionary of field names and their types. "
158
                    f"For example, {{'text': str, 'classes': List[str]}}. Instead only '{data}' was "
159
                    f"passed. All types will be assumed to be 'Any'. In future version of unitxt this "
160
                    f"will raise an exception.",
161
                    Documentation.ADDING_TASK,
162
                )
163
                if isinstance(data, dict):
1✔
164
                    data = parse_type_dict(to_type_dict(data))
×
165
                else:
166
                    data = {key: Any for key in data}
1✔
167

168
                if io_type == "input_fields":
1✔
169
                    self.input_fields = data
1✔
170
                else:
171
                    self.reference_fields = data
1✔
172

173
        if not self.prediction_type:
1✔
174
            UnitxtWarning(
1✔
175
                "'prediction_type' was not set in Task. It is used to check the output of "
176
                "template post processors is compatible with the expected input of the metrics. "
177
                "Setting `prediction_type` to 'Any' (no checking is done). In future version "
178
                "of unitxt this will raise an exception.",
179
                Documentation.ADDING_TASK,
180
            )
181
            self.prediction_type = Any
1✔
182

183
        self.check_metrics_type()
1✔
184

185
        for augmentable_input in self.augmentable_inputs:
1✔
186
            assert (
1✔
187
                augmentable_input in self.input_fields
188
            ), f"augmentable_input {augmentable_input} is not part of {self.input_fields}"
189

190
        self.verify_defaults()
1✔
191

192
    @classmethod
1✔
193
    def process_data_after_load(cls, data):
1✔
194
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
195
        for dict_name in possible_dicts:
1✔
196
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
197
                data[dict_name] = parse_type_dict(data[dict_name])
1✔
198
        if "prediction_type" in data:
1✔
199
            data["prediction_type"] = parse_type_string(data["prediction_type"])
1✔
200
        return data
1✔
201

202
    def process_data_before_dump(self, data):
1✔
203
        possible_dicts = ["inputs", "input_fields", "outputs", "reference_fields"]
1✔
204
        for dict_name in possible_dicts:
1✔
205
            if dict_name in data and isinstance(data[dict_name], dict):
1✔
206
                if not isoftype(data[dict_name], Dict[str, str]):
1✔
207
                    data[dict_name] = to_type_dict(data[dict_name])
1✔
208
        if "prediction_type" in data:
1✔
209
            if not isinstance(data["prediction_type"], str):
1✔
210
                data["prediction_type"] = to_type_string(data["prediction_type"])
1✔
211
        return data
1✔
212

213
    @classmethod
1✔
214
    @lru_cache(maxsize=None)
1✔
215
    def get_metrics_artifacts(cls, metric_id: str):
1✔
216
        metric = cls.get_artifact(metric_id)
1✔
217
        if isinstance(metric, MetricsList):
1✔
218
            return metric.items
×
219
        return [metric]
1✔
220

221
    def check_metrics_type(self) -> None:
1✔
222
        prediction_type = self.prediction_type
1✔
223
        for metric_id in self.metrics:
1✔
224
            metric_artifacts_list = Task.get_metrics_artifacts(metric_id)
1✔
225
            for metric_artifact in metric_artifacts_list:
1✔
226
                metric_prediction_type = metric_artifact.prediction_type
1✔
227
                if (
1✔
228
                    prediction_type == metric_prediction_type
229
                    or prediction_type == Any
230
                    or metric_prediction_type == Any
231
                    or (
232
                        get_origin(metric_prediction_type) is Union
233
                        and prediction_type in get_args(metric_prediction_type)
234
                    )
235
                ):
236
                    continue
1✔
237

238
                raise UnitxtError(
1✔
239
                    f"The task's prediction type ({prediction_type}) and '{metric_id}' "
240
                    f"metric's prediction type ({metric_prediction_type}) are different.",
241
                    Documentation.ADDING_TASK,
242
                )
243

244
    def verify_defaults(self):
1✔
245
        if self.defaults:
1✔
246
            if not isinstance(self.defaults, dict):
1✔
247
                raise UnitxtError(
×
248
                    f"If specified, the 'defaults' must be a dictionary, "
249
                    f"however, '{self.defaults}' was provided instead, "
250
                    f"which is of type '{to_type_string(type(self.defaults))}'.",
251
                    Documentation.ADDING_TASK,
252
                )
253

254
            for default_name, default_value in self.defaults.items():
1✔
255
                assert isinstance(default_name, str), (
1✔
256
                    f"If specified, all keys of the 'defaults' must be strings, "
257
                    f"however, the key '{default_name}' is of type '{to_type_string(type(default_name))}'."
258
                )
259

260
                val_type = self.input_fields.get(
1✔
261
                    default_name
262
                ) or self.reference_fields.get(default_name)
263

264
                assert val_type, (
1✔
265
                    f"If specified, all keys of the 'defaults' must refer to a chosen "
266
                    f"key in either 'input_fields' or 'reference_fields'. However, the name '{default_name}' "
267
                    f"was provided which does not match any of the keys."
268
                )
269

270
                assert isoftype(default_value, val_type), (
1✔
271
                    f"The value of '{default_name}' from the 'defaults' must be of "
272
                    f"type '{to_type_string(val_type)}', however, it is of type '{to_type_string(type(default_value))}'."
273
                )
274

275
    def set_default_values(self, instance: Dict[str, Any]) -> Dict[str, Any]:
1✔
276
        if self.defaults:
1✔
277
            instance = {**self.defaults, **instance}
1✔
278
        return instance
1✔
279

280
    def process(
1✔
281
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
282
    ) -> Dict[str, Any]:
283
        instance = self.set_default_values(instance)
1✔
284

285
        verify_required_schema(
1✔
286
            self.input_fields,
287
            instance,
288
            class_name="Task",
289
            id=self.__id__,
290
            description=self.__description__,
291
        )
292
        input_fields = {key: instance[key] for key in self.input_fields.keys()}
1✔
293
        data_classification_policy = instance.get("data_classification_policy", [])
1✔
294

295
        result = {
1✔
296
            "input_fields": input_fields,
297
            "metrics": self.metrics,
298
            "data_classification_policy": data_classification_policy,
299
            "media": instance.get("media", {}),
300
            "recipe_metadata": instance.get("recipe_metadata", {}),
301
        }
302
        if "demos" in instance:
1✔
303
            # for the case of recipe.skip_demoed_instances
304
            result["demos"] = instance["demos"]
1✔
305

306
        if stream_name == constants.inference_stream:
1✔
307
            return result
1✔
308

309
        verify_required_schema(
1✔
310
            self.reference_fields,
311
            instance,
312
            class_name="Task",
313
            id=self.__id__,
314
            description=self.__description__,
315
        )
316
        result["reference_fields"] = {
1✔
317
            key: instance[key] for key in self.reference_fields.keys()
318
        }
319

320
        return result
1✔
321

322

323
@deprecation(version="2.0.0", alternative=Task)
1✔
324
class FormTask(Task):
1✔
325
    pass
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc