• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12568880050

01 Jan 2025 08:32AM UTC coverage: 80.188% (+0.2%) from 80.023%
12568880050

Pull #1459

github

web-flow
Merge af792bced into def3e0ea1
Pull Request #1459: Add MapReduceMetric a new base class to integrate all metrics into

1365 of 1696 branches covered (80.48%)

Branch coverage included in aggregate %.

8636 of 10776 relevant lines covered (80.14%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

70.75
src/unitxt/processors.py
1
import ast
1✔
2
import copy
1✔
3
import json
1✔
4
import re
1✔
5
import string
1✔
6
from difflib import get_close_matches
1✔
7
from typing import Any, Dict
1✔
8

9
import numpy as np
1✔
10

11
from .deprecation_utils import deprecation
1✔
12
from .error_utils import Documentation, UnitxtError
1✔
13
from .operator import MultiStreamOperator
1✔
14
from .operators import FieldOperator, InstanceFieldOperator
1✔
15
from .settings_utils import get_constants
1✔
16
from .type_utils import isoftype
1✔
17

18
constants = get_constants()
1✔
19

20

21
class PostProcess(MultiStreamOperator):
1✔
22
    operator: InstanceFieldOperator
1✔
23
    process_prediction: bool = True
1✔
24
    process_references: bool = True
1✔
25

26
    def prepare(self):
1✔
27
        super().prepare()
1✔
28
        if not isoftype(self.operator, InstanceFieldOperator):
1✔
29
            raise UnitxtError(
×
30
                f"PostProcess requires operator field to be of type InstanceFieldOperator. Got object of type <{type(self.operator).__name__}>.",
31
                Documentation.POST_PROCESSORS,
32
            )
33
        self.prediction_operator = copy.copy(self.operator)
1✔
34
        self.prediction_operator.field = "prediction"
1✔
35
        self.references_operator = copy.copy(self.operator)
1✔
36
        self.references_operator.field = "references"
1✔
37
        self.references_operator.process_every_value = True
1✔
38
        self.references_operator.dont_apply_to_streams = [constants.inference_stream]
1✔
39

40
    def process(self, multi_stream):
1✔
41
        if self.process_prediction:
1✔
42
            multi_stream = self.prediction_operator(multi_stream)
1✔
43
        if self.process_references:
1✔
44
            multi_stream = self.references_operator(multi_stream)
1✔
45
        return multi_stream
1✔
46

47

48
class ToString(FieldOperator):
1✔
49
    def process_value(self, text: Any) -> Any:
1✔
50
        return str(text)
1✔
51

52

53
class ToStringStripped(FieldOperator):
1✔
54
    def process_value(self, text: Any) -> Any:
1✔
55
        return str(text).strip()
1✔
56

57

58
class SplitStrip(FieldOperator):
1✔
59
    delimiter: str = " "
1✔
60
    strip_every_element: bool = False
1✔
61

62
    def process_value(self, text: Any) -> Any:
1✔
63
        return [
1✔
64
            x.strip() if self.strip_every_element else x
65
            for x in text.split(self.delimiter)
66
        ]
67

68

69
class ToListByComma(SplitStrip):
1✔
70
    delimiter = ","
1✔
71
    strip_every_element = True
1✔
72

73

74
class ToListByCommaSpace(SplitStrip):
1✔
75
    delimiter = ", "
1✔
76
    strip_every_element = True
1✔
77

78

79
class RegexParser(FieldOperator):
1✔
80
    """A processor that uses regex in order to parse a string."""
81

82
    regex: str
1✔
83
    termination_regex: str = None
1✔
84

85
    def process_value(self, text: Any) -> Any:
1✔
86
        if self.termination_regex is not None and re.fullmatch(
1✔
87
            self.termination_regex, text
88
        ):
89
            return []
1✔
90
        return re.findall(self.regex, text)
1✔
91

92

93
class ExtractWithRegex(RegexParser):
1✔
94
    def process_value(self, text: Any) -> Any:
1✔
95
        matches = super().process_value(text)
×
96
        if matches:
×
97
            return matches[0]
×
98
        return ""
×
99

100

101
class ListToEmptyEntitiesTuples(FieldOperator):
1✔
102
    def process_value(self, lst: Any) -> Any:
1✔
103
        try:
×
104
            return [(str(item), "") for item in lst]
×
105
        except json.JSONDecodeError:
×
106
            return []
×
107

108

109
class DictOfListsToPairs(FieldOperator):
1✔
110
    position_key_before_value: bool = True
1✔
111

112
    def process_value(self, obj: Any) -> Any:
1✔
113
        try:
1✔
114
            result = []
1✔
115
            for key, values in obj.items():
1✔
116
                for value in values:
1✔
117
                    assert isinstance(value, str)
1✔
118
                    pair = (
1✔
119
                        (key, value) if self.position_key_before_value else (value, key)
120
                    )
121
                    result.append(pair)
1✔
122
            return result
1✔
123
        except:
1✔
124
            return []
1✔
125

126

127
class TakeFirstNonEmptyLine(FieldOperator):
1✔
128
    def process_value(self, text: Any) -> Any:
1✔
129
        parts = str(text).strip().split("\n")
1✔
130
        if len(parts) == 0:
1✔
131
            return ""
×
132
        return parts[0].strip()
1✔
133

134

135
class TakeLastNonEmptyLine(FieldOperator):
1✔
136
    def process_value(self, text: Any) -> Any:
1✔
137
        parts = str(text).strip().split("\n")
×
138
        if len(parts) == 0:
×
139
            return ""
×
140
        return parts[-1].strip()
×
141

142

143
class ConvertToBoolean(FieldOperator):
1✔
144
    def process_value(self, text: Any) -> Any:
1✔
145
        clean_instance = str(text).strip().lower()
1✔
146
        if any(w in clean_instance for w in ["no", "not", "wrong", "false"]):
1✔
147
            return "FALSE"
1✔
148
        if any(w in clean_instance for w in ["yes", "right", "correct", "true"]):
1✔
149
            return "TRUE"
1✔
150
        return "OTHER"
1✔
151

152

153
class LowerCaseTillPunc(FieldOperator):
1✔
154
    def process_value(self, text: Any) -> Any:
1✔
155
        non_empty_line = text.lower()
1✔
156
        match = re.search(r"[.,!?;]", non_empty_line)
1✔
157
        if match:
1✔
158
            # Extract text up to the first punctuation
159
            non_empty_line = non_empty_line[: match.start()]
×
160
        return non_empty_line
1✔
161

162

163
class Lower(FieldOperator):
1✔
164
    def process_value(self, text: Any) -> Any:
1✔
165
        return text.lower()
1✔
166

167

168
class Upper(FieldOperator):
1✔
169
    def process_value(self, text: Any) -> Any:
1✔
170
        return str(text).upper()
×
171

172

173
class Title(FieldOperator):
1✔
174
    def process_value(self, text: Any) -> Any:
1✔
175
        return str(text).title()
×
176

177

178
class TakeUntilPunc(FieldOperator):
1✔
179
    _requirements_list = ["regex"]
1✔
180

181
    def prepare(self):
1✔
182
        super().prepare()
×
183
        import regex
×
184

185
        self.pattern = regex.compile(r"\p{P}+")
×
186

187
    def process_value(self, text: Any) -> Any:
1✔
188
        match = self.pattern.search(text)
×
189
        if match:
×
190
            text = text[: match.start()]
×
191
        return text
×
192

193

194
@deprecation("2.0.0", alternative=Lower)
1✔
195
class LowerCase(Lower):
1✔
196
    pass
1✔
197

198

199
class Capitalize(FieldOperator):
1✔
200
    def process_value(self, text: Any) -> Any:
1✔
201
        return text.capitalize()
1✔
202

203

204
class GetStringAfter(FieldOperator):
1✔
205
    substring: str
1✔
206

207
    def process_value(self, text: Any) -> Any:
1✔
208
        return text.split(self.substring, 1)[-1].strip()
×
209

210

211
class MatchClosestOption(InstanceFieldOperator):
1✔
212
    options_field: str = "options"
1✔
213

214
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
215
        options = instance["task_data"][self.options_field]
1✔
216
        return get_close_matches(value, options, n=1, cutoff=0.0)[0]
1✔
217

218

219
def process_instance_value(self, value, instance):
1✔
220
    options = instance[self.options_field]
×
221
    # Get the closest match; n=1 returns the single closest match
222
    closest_match = get_close_matches(value, options, n=1, cutoff=0)
×
223
    return closest_match[0] if closest_match else None
×
224

225

226
class Substring(FieldOperator):
1✔
227
    begin: int = 0
1✔
228
    end: int = None
1✔
229

230
    def process_value(self, text: Any) -> Any:
1✔
231
        if self.end is None:
1✔
232
            return text[self.begin :]
1✔
233
        return text[self.begin : self.end]
1✔
234

235

236
class FirstCharacter(FieldOperator):
1✔
237
    def process_value(self, text: Any) -> Any:
1✔
238
        match = re.search(r"\s*(\w)", text)
1✔
239
        if match:
1✔
240
            return match.groups(0)[0]
1✔
241
        return ""
×
242

243

244
class TakeFirstWord(FieldOperator):
1✔
245
    def process_value(self, text: Any) -> Any:
1✔
246
        match = re.search(r"([-]*[0-9]+(\.([0-9]+))*)|([\w]+)", text)
1✔
247
        if match:
1✔
248
            return text[match.start() : match.end()]
1✔
249
        return ""
1✔
250

251

252
class YesNoToInt(FieldOperator):
1✔
253
    def process_value(self, text: Any) -> Any:
1✔
254
        if text == "yes":
1✔
255
            return "1"
1✔
256
        if text == "no":
1✔
257
            return "0"
1✔
258
        return text
1✔
259

260

261
class YesToOneElseZero(FieldOperator):
1✔
262
    def process_value(self, text: Any) -> Any:
1✔
263
        if text == "yes":
1✔
264
            return "1"
1✔
265
        return "0"
1✔
266

267

268
class StrToFloatFormat(FieldOperator):
1✔
269
    def process_value(self, text: Any) -> Any:
1✔
270
        try:
1✔
271
            return str(float(text))
1✔
272
        except Exception:
1✔
273
            return str(text)
1✔
274

275

276
class ToYesOrNone(FieldOperator):
1✔
277
    def process_value(self, text: Any) -> Any:
1✔
278
        if text == "yes":
1✔
279
            return "yes"
1✔
280
        return "none"
1✔
281

282

283
class StanceToProCon(FieldOperator):
1✔
284
    def process_value(self, text: Any) -> Any:
1✔
285
        if text == "positive":
1✔
286
            return "PRO"
1✔
287
        if text in ["negative", "suggestion"]:
1✔
288
            return "CON"
1✔
289
        return "none"
1✔
290

291

292
class StringEquals(FieldOperator):
1✔
293
    string: str
1✔
294

295
    def process_value(self, text: Any) -> Any:
1✔
296
        if "not " + self.string.lower() in text.lower():
×
297
            return "not " + self.string.lower()
×
298
        if self.string.lower() in text.lower():
×
299
            return self.string.lower()
×
300
        return text
×
301

302

303
@deprecation("2.0.0", alternative=StringEquals)
1✔
304
class StringOrNotString(StringEquals):
1✔
305
    pass
1✔
306

307

308
class ExtractMtBenchRatingJudgment(FieldOperator):
1✔
309
    def process_value(self, text: Any) -> Any:
1✔
310
        match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
1✔
311
        try:
1✔
312
            return float(match.group(1)) / 10
1✔
313
        except:
1✔
314
            return 0.0
1✔
315

316

317
class ExtractMtBenchLabelJudgment(FieldOperator):
1✔
318
    options = {
1✔
319
        "A": "choice_a",
320
        "B": "choice_b",
321
        "C": "tie",
322
    }
323

324
    def process_value(self, text: Any) -> Any:
1✔
325
        match = re.search(r"\[\[([^\]]+)\]\]", text)
1✔
326
        try:
1✔
327
            return self.options.get(str(match.group(1)), "None")
1✔
328
        except:
1✔
329
            return "None"
1✔
330

331

332
class LiteralEval(FieldOperator):
1✔
333
    def process_value(self, text: Any) -> Any:
1✔
334
        if text is not None and not isinstance(text, str):
1✔
335
            raise ValueError(
×
336
                f"LiteralEval: field '{self.field}' is expected to be of 'str' input type, got: {type(text)}"
337
            )
338
        if text is None or text == "":
1✔
339
            return text
×
340
        return ast.literal_eval(text.strip())
1✔
341

342

343
class ExtractSafeUnsafeJudgment(FieldOperator):
1✔
344
    def process_value(self, text: Any) -> Any:
1✔
345
        first_line = str(text).strip().split("\n")[0].lower()
×
346
        if first_line == "safe":
×
347
            return 1.0
×
348
        return 0.0
×
349

350

351
class ExtractArenaHardNumericalJudgment(FieldOperator):
1✔
352
    def process_value(self, text: Any) -> Any:
1✔
353
        match = re.search(r"\[\[([^\]]+)\]\]", text)
×
354
        try:
×
355
            res = str(match.group(1))
×
356
            if res == "A>B":
×
357
                return 1
×
358
            if res == "A>>B":
×
359
                return 3
×
360
            if res == "B>A":
×
361
                return -1
×
362
            if res == "B>>A":
×
363
                return -3
×
364
            return 0
×
365

366
        except:
×
367
            return 0
×
368

369

370
class InferDictsToBinaryLogprobs(FieldOperator):
1✔
371
    neg_class_name: str
1✔
372
    pos_class_name: str
1✔
373

374
    take_logprobs_from_end: bool = False
1✔
375
    num_logprobs_to_take: int = 3
1✔
376
    min_probability_mass = 0.0001
1✔
377

378
    def verify(self):
1✔
379
        super().verify()
×
380
        if (
×
381
            self.neg_class_name.lower() in self.pos_class_name.lower()
382
            or self.pos_class_name.lower() in self.neg_class_name.lower()
383
        ):
384
            raise ValueError(
×
385
                f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
386
            )
387

388
    def process_value(self, obj: Any) -> Any:
1✔
389
        for i in self.get_token_range(obj):
×
390
            try:
×
391
                pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
×
392
                if pos_probs or neg_probs:
×
393
                    sum_probs = sum(pos_probs) + sum(neg_probs)
×
394
                    if sum_probs > self.min_probability_mass:
×
395
                        return sum(pos_probs) / sum_probs
×
396
            except:
×
397
                pass
×
398
        return 0
×
399

400
    def get_pos_neg_probs(self, pred_dict):
1✔
401
        token_logprobs = pred_dict["top_tokens"]
×
402

403
        pos_and_neg_probs = []
×
404
        for class_name in [self.pos_class_name, self.neg_class_name]:
×
405
            # We need to capture different variants of model behavior and tokenizers, for example with opening space,
406
            # punctuation etc. but avoid longer words that contain the class name.
407
            # For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
408
            name_regex = re.compile(
×
409
                rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
410
            )
411
            class_probs = [
×
412
                np.exp(d["logprob"])
413
                for d in token_logprobs
414
                if name_regex.fullmatch(d["text"])
415
            ]
416
            pos_and_neg_probs.append(class_probs)
×
417
        return pos_and_neg_probs
×
418

419
    def get_token_range(self, obj: Any) -> range:
1✔
420
        n_tokens = min([self.num_logprobs_to_take, len(obj)])
×
421
        if self.take_logprobs_from_end:
×
422
            return range(-1, -(n_tokens + 1), -1)
×
423
        return range(n_tokens)
×
424

425

426
class RemoveArticles(FieldOperator):
1✔
427
    def process_value(self, text: Any) -> Any:
1✔
428
        return re.sub(r"\b(a|an|the)\b", " ", text)
×
429

430

431
class RemovePunctuations(FieldOperator):
1✔
432
    def process_value(self, text: Any) -> Any:
1✔
433
        puncs_to_exclude = set(string.punctuation)
×
434
        return "".join(c for c in text if c not in puncs_to_exclude)
×
435

436

437
class FixWhiteSpace(FieldOperator):
1✔
438
    def process_value(self, text: Any) -> Any:
1✔
439
        return " ".join(text.split())
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc