• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12535194121

29 Dec 2024 12:03PM UTC coverage: 80.228% (+0.2%) from 80.023%
12535194121

Pull #1459

github

web-flow
Merge 7067995c0 into def3e0ea1
Pull Request #1459: Add MapReduceMetric a new base class to integrate all metrics into

1365 of 1695 branches covered (80.53%)

Branch coverage included in aggregate %.

8629 of 10762 relevant lines covered (80.18%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.19
src/unitxt/processors.py
1
import ast
1✔
2
import copy
1✔
3
import json
1✔
4
import re
1✔
5
import string
1✔
6
from difflib import get_close_matches
1✔
7
from typing import Any, Dict
1✔
8

9
import numpy as np
1✔
10

11
from .deprecation_utils import deprecation
1✔
12
from .error_utils import Documentation, UnitxtError
1✔
13
from .operator import MultiStreamOperator
1✔
14
from .operators import FieldOperator, InstanceFieldOperator
1✔
15
from .settings_utils import get_constants
1✔
16
from .type_utils import isoftype
1✔
17

18
constants = get_constants()
1✔
19

20

21
class PostProcess(MultiStreamOperator):
1✔
22
    operator: InstanceFieldOperator
1✔
23
    process_prediction: bool = True
1✔
24
    process_references: bool = True
1✔
25

26
    def prepare(self):
1✔
27
        super().prepare()
1✔
28
        if not isoftype(self.operator, InstanceFieldOperator):
1✔
29
            raise UnitxtError(
×
30
                f"PostProcess requires operator field to be of type InstanceFieldOperator. Got object of type <{type(self.operator).__name__}>.",
31
                Documentation.POST_PROCESSORS,
32
            )
33
        self.prediction_operator = copy.copy(self.operator)
1✔
34
        self.prediction_operator.field = "prediction"
1✔
35
        self.references_operator = copy.copy(self.operator)
1✔
36
        self.references_operator.field = "references"
1✔
37
        self.references_operator.process_every_value = True
1✔
38
        self.references_operator.dont_apply_to_streams = [constants.inference_stream]
1✔
39

40
    def process(self, multi_stream):
1✔
41
        if self.process_prediction:
1✔
42
            multi_stream = self.prediction_operator(multi_stream)
1✔
43
        if self.process_references:
1✔
44
            multi_stream = self.references_operator(multi_stream)
1✔
45
        return multi_stream
1✔
46

47

48
class ToString(FieldOperator):
1✔
49
    def process_value(self, text: Any) -> Any:
1✔
50
        return str(text)
1✔
51

52

53
class ToStringStripped(FieldOperator):
1✔
54
    def process_value(self, text: Any) -> Any:
1✔
55
        return str(text).strip()
1✔
56

57

58
class SplitStrip(FieldOperator):
1✔
59
    delimiter: str = " "
1✔
60
    strip_every_element: bool = False
1✔
61

62
    def process_value(self, text: Any) -> Any:
1✔
63
        return [
1✔
64
            x.strip() if self.strip_every_element else x
65
            for x in text.split(self.delimiter)
66
        ]
67

68

69
class ToListByComma(SplitStrip):
1✔
70
    delimiter = ","
1✔
71
    strip_every_element = True
1✔
72

73

74
class ToListByCommaSpace(SplitStrip):
1✔
75
    delimiter = ", "
1✔
76
    strip_every_element = True
1✔
77

78

79
class RegexParser(FieldOperator):
1✔
80
    """A processor that uses regex in order to parse a string."""
81

82
    regex: str
1✔
83
    termination_regex: str = None
1✔
84

85
    def process_value(self, text: Any) -> Any:
1✔
86
        if self.termination_regex is not None and re.fullmatch(
1✔
87
            self.termination_regex, text
88
        ):
89
            return []
1✔
90
        return re.findall(self.regex, text)
1✔
91

92

93
class ExtractWithRegex(RegexParser):
1✔
94
    def process_value(self, text: Any) -> Any:
1✔
95
        matches = super().process_value(text)
×
96
        if matches:
×
97
            return matches[0]
×
98
        return ""
×
99

100

101
class ListToEmptyEntitiesTuples(FieldOperator):
1✔
102
    def process_value(self, lst: Any) -> Any:
1✔
103
        try:
×
104
            return [(str(item), "") for item in lst]
×
105
        except json.JSONDecodeError:
×
106
            return []
×
107

108

109
class DictOfListsToPairs(FieldOperator):
1✔
110
    position_key_before_value: bool = True
1✔
111

112
    def process_value(self, obj: Any) -> Any:
1✔
113
        try:
1✔
114
            result = []
1✔
115
            for key, values in obj.items():
1✔
116
                for value in values:
1✔
117
                    assert isinstance(value, str)
1✔
118
                    pair = (
1✔
119
                        (key, value) if self.position_key_before_value else (value, key)
120
                    )
121
                    result.append(pair)
1✔
122
            return result
1✔
123
        except:
1✔
124
            return []
1✔
125

126

127
class TakeFirstNonEmptyLine(FieldOperator):
1✔
128
    def process_value(self, text: Any) -> Any:
1✔
129
        parts = str(text).strip().split("\n")
1✔
130
        if len(parts) == 0:
1✔
131
            return ""
×
132
        return parts[0].strip()
1✔
133

134

135
class TakeLastNonEmptyLine(FieldOperator):
1✔
136
    def process_value(self, text: Any) -> Any:
1✔
137
        parts = str(text).strip().split("\n")
×
138
        if len(parts) == 0:
×
139
            return ""
×
140
        return parts[-1].strip()
×
141

142

143
class ConvertToBoolean(FieldOperator):
1✔
144
    def process_value(self, text: Any) -> Any:
1✔
145
        clean_instance = str(text).strip().lower()
1✔
146
        if any(w in clean_instance for w in ["no", "not", "wrong", "false"]):
1✔
147
            return "FALSE"
1✔
148
        if any(w in clean_instance for w in ["yes", "right", "correct", "true"]):
1✔
149
            return "TRUE"
1✔
150
        return "OTHER"
1✔
151

152

153
class LowerCaseTillPunc(FieldOperator):
1✔
154
    def process_value(self, text: Any) -> Any:
1✔
155
        non_empty_line = text.lower()
1✔
156
        match = re.search(r"[.,!?;]", non_empty_line)
1✔
157
        if match:
1✔
158
            # Extract text up to the first punctuation
159
            non_empty_line = non_empty_line[: match.start()]
×
160
        return non_empty_line
1✔
161

162

163
class Lower(FieldOperator):
1✔
164
    def process_value(self, text: Any) -> Any:
1✔
165
        return text.lower()
1✔
166

167

168
class Upper(FieldOperator):
1✔
169
    def process_value(self, text: Any) -> Any:
1✔
170
        return str(text).upper()
×
171

172

173
@deprecation("2.0.0", alternative=Lower)
1✔
174
class LowerCase(Lower):
1✔
175
    pass
1✔
176

177

178
class Capitalize(FieldOperator):
1✔
179
    def process_value(self, text: Any) -> Any:
1✔
180
        return text.capitalize()
1✔
181

182

183
class GetStringAfter(FieldOperator):
1✔
184
    substring: str
1✔
185

186
    def process_value(self, text: Any) -> Any:
1✔
187
        return text.split(self.substring, 1)[-1].strip()
×
188

189

190
class MatchClosestOption(InstanceFieldOperator):
1✔
191
    options_field: str = "options"
1✔
192

193
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
194
        options = instance["task_data"][self.options_field]
1✔
195
        return get_close_matches(value, options, n=1, cutoff=0.0)[0]
1✔
196

197

198
def process_instance_value(self, value, instance):
1✔
199
    options = instance[self.options_field]
×
200
    # Get the closest match; n=1 returns the single closest match
201
    closest_match = get_close_matches(value, options, n=1, cutoff=0)
×
202
    return closest_match[0] if closest_match else None
×
203

204

205
class Substring(FieldOperator):
1✔
206
    begin: int = 0
1✔
207
    end: int = None
1✔
208

209
    def process_value(self, text: Any) -> Any:
1✔
210
        if self.end is None:
1✔
211
            return text[self.begin :]
1✔
212
        return text[self.begin : self.end]
1✔
213

214

215
class FirstCharacter(FieldOperator):
1✔
216
    def process_value(self, text: Any) -> Any:
1✔
217
        match = re.search(r"\s*(\w)", text)
1✔
218
        if match:
1✔
219
            return match.groups(0)[0]
1✔
220
        return ""
×
221

222

223
class TakeFirstWord(FieldOperator):
1✔
224
    def process_value(self, text: Any) -> Any:
1✔
225
        match = re.search(r"([-]*[0-9]+(\.([0-9]+))*)|([\w]+)", text)
1✔
226
        if match:
1✔
227
            return text[match.start() : match.end()]
1✔
228
        return ""
1✔
229

230

231
class YesNoToInt(FieldOperator):
1✔
232
    def process_value(self, text: Any) -> Any:
1✔
233
        if text == "yes":
1✔
234
            return "1"
1✔
235
        if text == "no":
1✔
236
            return "0"
1✔
237
        return text
1✔
238

239

240
class YesToOneElseZero(FieldOperator):
1✔
241
    def process_value(self, text: Any) -> Any:
1✔
242
        if text == "yes":
1✔
243
            return "1"
1✔
244
        return "0"
1✔
245

246

247
class StrToFloatFormat(FieldOperator):
1✔
248
    def process_value(self, text: Any) -> Any:
1✔
249
        try:
1✔
250
            return str(float(text))
1✔
251
        except Exception:
1✔
252
            return str(text)
1✔
253

254

255
class ToYesOrNone(FieldOperator):
1✔
256
    def process_value(self, text: Any) -> Any:
1✔
257
        if text == "yes":
1✔
258
            return "yes"
1✔
259
        return "none"
1✔
260

261

262
class StanceToProCon(FieldOperator):
1✔
263
    def process_value(self, text: Any) -> Any:
1✔
264
        if text == "positive":
1✔
265
            return "PRO"
1✔
266
        if text in ["negative", "suggestion"]:
1✔
267
            return "CON"
1✔
268
        return "none"
1✔
269

270

271
class StringEquals(FieldOperator):
1✔
272
    string: str
1✔
273

274
    def process_value(self, text: Any) -> Any:
1✔
275
        if "not " + self.string.lower() in text.lower():
×
276
            return "not " + self.string.lower()
×
277
        if self.string.lower() in text.lower():
×
278
            return self.string.lower()
×
279
        return text
×
280

281

282
@deprecation("2.0.0", alternative=StringEquals)
1✔
283
class StringOrNotString(StringEquals):
1✔
284
    pass
1✔
285

286

287
class ExtractMtBenchRatingJudgment(FieldOperator):
1✔
288
    def process_value(self, text: Any) -> Any:
1✔
289
        match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
1✔
290
        try:
1✔
291
            return float(match.group(1)) / 10
1✔
292
        except:
1✔
293
            return 0.0
1✔
294

295

296
class ExtractMtBenchLabelJudgment(FieldOperator):
1✔
297
    options = {
1✔
298
        "A": "choice_a",
299
        "B": "choice_b",
300
        "C": "tie",
301
    }
302

303
    def process_value(self, text: Any) -> Any:
1✔
304
        match = re.search(r"\[\[([^\]]+)\]\]", text)
1✔
305
        try:
1✔
306
            return self.options.get(str(match.group(1)), "None")
1✔
307
        except:
1✔
308
            return "None"
1✔
309

310

311
class LiteralEval(FieldOperator):
1✔
312
    def process_value(self, text: Any) -> Any:
1✔
313
        if text is not None and not isinstance(text, str):
1✔
314
            raise ValueError(
×
315
                f"LiteralEval: field '{self.field}' is expected to be of 'str' input type, got: {type(text)}"
316
            )
317
        if text is None or text == "":
1✔
318
            return text
×
319
        return ast.literal_eval(text.strip())
1✔
320

321

322
class ExtractSafeUnsafeJudgment(FieldOperator):
1✔
323
    def process_value(self, text: Any) -> Any:
1✔
324
        first_line = str(text).strip().split("\n")[0].lower()
×
325
        if first_line == "safe":
×
326
            return 1.0
×
327
        return 0.0
×
328

329

330
class ExtractArenaHardNumericalJudgment(FieldOperator):
1✔
331
    def process_value(self, text: Any) -> Any:
1✔
332
        match = re.search(r"\[\[([^\]]+)\]\]", text)
×
333
        try:
×
334
            res = str(match.group(1))
×
335
            if res == "A>B":
×
336
                return 1
×
337
            if res == "A>>B":
×
338
                return 3
×
339
            if res == "B>A":
×
340
                return -1
×
341
            if res == "B>>A":
×
342
                return -3
×
343
            return 0
×
344

345
        except:
×
346
            return 0
×
347

348

349
class InferDictsToBinaryLogprobs(FieldOperator):
1✔
350
    neg_class_name: str
1✔
351
    pos_class_name: str
1✔
352

353
    take_logprobs_from_end: bool = False
1✔
354
    num_logprobs_to_take: int = 3
1✔
355
    min_probability_mass = 0.0001
1✔
356

357
    def verify(self):
1✔
358
        super().verify()
×
359
        if (
×
360
            self.neg_class_name.lower() in self.pos_class_name.lower()
361
            or self.pos_class_name.lower() in self.neg_class_name.lower()
362
        ):
363
            raise ValueError(
×
364
                f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
365
            )
366

367
    def process_value(self, obj: Any) -> Any:
1✔
368
        for i in self.get_token_range(obj):
×
369
            try:
×
370
                pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
×
371
                if pos_probs or neg_probs:
×
372
                    sum_probs = sum(pos_probs) + sum(neg_probs)
×
373
                    if sum_probs > self.min_probability_mass:
×
374
                        return sum(pos_probs) / sum_probs
×
375
            except:
×
376
                pass
×
377
        return 0
×
378

379
    def get_pos_neg_probs(self, pred_dict):
1✔
380
        token_logprobs = pred_dict["top_tokens"]
×
381

382
        pos_and_neg_probs = []
×
383
        for class_name in [self.pos_class_name, self.neg_class_name]:
×
384
            # We need to capture different variants of model behavior and tokenizers, for example with opening space,
385
            # punctuation etc. but avoid longer words that contain the class name.
386
            # For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
387
            name_regex = re.compile(
×
388
                rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
389
            )
390
            class_probs = [
×
391
                np.exp(d["logprob"])
392
                for d in token_logprobs
393
                if name_regex.fullmatch(d["text"])
394
            ]
395
            pos_and_neg_probs.append(class_probs)
×
396
        return pos_and_neg_probs
×
397

398
    def get_token_range(self, obj: Any) -> range:
1✔
399
        n_tokens = min([self.num_logprobs_to_take, len(obj)])
×
400
        if self.take_logprobs_from_end:
×
401
            return range(-1, -(n_tokens + 1), -1)
×
402
        return range(n_tokens)
×
403

404

405
class RemoveArticles(FieldOperator):
1✔
406
    def process_value(self, text: Any) -> Any:
1✔
407
        return re.sub(r"\b(a|an|the)\b", " ", text)
×
408

409

410
class RemovePunctuations(FieldOperator):
1✔
411
    def process_value(self, text: Any) -> Any:
1✔
412
        puncs_to_exclude = set(string.punctuation)
×
413
        return "".join(c for c in text if c not in puncs_to_exclude)
×
414

415

416
class FixWhiteSpace(FieldOperator):
1✔
417
    def process_value(self, text: Any) -> Any:
1✔
418
        return " ".join(text.split())
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc