• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15805318527

22 Jun 2025 09:48AM UTC coverage: 79.887% (-0.3%) from 80.211%
15805318527

Pull #1811

github

web-flow
Merge 00f9eb5f5 into 9aa85a9a0
Pull Request #1811: Add multi turn tool calling task and support multiple tools per call

1698 of 2104 branches covered (80.7%)

Branch coverage included in aggregate %.

10563 of 13244 relevant lines covered (79.76%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.31
src/unitxt/processors.py
1
import ast
1✔
2
import copy
1✔
3
import json
1✔
4
import re
1✔
5
import string
1✔
6
from difflib import get_close_matches
1✔
7
from typing import Any, Dict
1✔
8

9
import numpy as np
1✔
10

11
from .deprecation_utils import deprecation
1✔
12
from .error_utils import Documentation, UnitxtError
1✔
13
from .operator import MultiStreamOperator
1✔
14
from .operators import FieldOperator, InstanceFieldOperator
1✔
15
from .settings_utils import get_constants
1✔
16
from .type_utils import isoftype
1✔
17

18
constants = get_constants()
1✔
19

20

21
class PostProcess(MultiStreamOperator):
1✔
22
    operator: InstanceFieldOperator
1✔
23
    process_prediction: bool = True
1✔
24
    process_references: bool = True
1✔
25

26
    def prepare(self):
1✔
27
        super().prepare()
1✔
28
        if not isoftype(self.operator, InstanceFieldOperator):
1✔
29
            raise UnitxtError(
×
30
                f"PostProcess requires operator field to be of type InstanceFieldOperator. Got object of type <{type(self.operator).__name__}>.",
31
                Documentation.POST_PROCESSORS,
32
            )
33
        self.prediction_operator = copy.copy(self.operator)
1✔
34
        self.prediction_operator.field = "prediction"
1✔
35
        self.references_operator = copy.copy(self.operator)
1✔
36
        self.references_operator.field = "references"
1✔
37
        self.references_operator.process_every_value = True
1✔
38
        self.references_operator.dont_apply_to_streams = [constants.inference_stream]
1✔
39

40
    def process(self, multi_stream):
1✔
41
        if self.process_prediction:
1✔
42
            multi_stream = self.prediction_operator(multi_stream)
1✔
43
        if self.process_references:
1✔
44
            multi_stream = self.references_operator(multi_stream)
1✔
45
        return multi_stream
1✔
46

47

48
class ToString(FieldOperator):
1✔
49
    def process_value(self, text: Any) -> Any:
1✔
50
        return str(text)
1✔
51

52

53
class ToStringStripped(FieldOperator):
1✔
54
    def process_value(self, text: Any) -> Any:
1✔
55
        return str(text).strip()
1✔
56

57

58
class SplitStrip(FieldOperator):
1✔
59
    delimiter: str = " "
1✔
60
    strip_every_element: bool = False
1✔
61

62
    def process_value(self, text: Any) -> Any:
1✔
63
        return [
1✔
64
            x.strip() if self.strip_every_element else x
65
            for x in text.split(self.delimiter)
66
        ]
67

68

69
class ToListByComma(SplitStrip):
1✔
70
    delimiter = ","
1✔
71
    strip_every_element = True
1✔
72

73

74
class ToListByCommaSpace(SplitStrip):
1✔
75
    delimiter = ", "
1✔
76
    strip_every_element = True
1✔
77

78

79
class RegexParser(FieldOperator):
1✔
80
    """A processor that uses regex in order to parse a string."""
81

82
    regex: str
1✔
83
    termination_regex: str = None
1✔
84

85
    def process_value(self, text: Any) -> Any:
1✔
86
        if self.termination_regex is not None and re.fullmatch(
1✔
87
            self.termination_regex, text
88
        ):
89
            return []
1✔
90
        return re.findall(self.regex, text)
1✔
91

92

93
class ExtractWithRegex(RegexParser):
1✔
94
    def process_value(self, text: Any) -> Any:
1✔
95
        matches = super().process_value(text)
×
96
        if matches:
×
97
            return matches[0]
×
98
        return ""
×
99

100

101
class GroupDictWithRegex(FieldOperator):
1✔
102
    pattern: str
1✔
103

104
    def process_value(self, value: Any) -> Any:
1✔
105
        match = re.match(self.pattern, value)
×
106
        if match:
×
107
            return match.groupdict()
×
108
        return {}
×
109

110

111
class ListToEmptyEntitiesTuples(FieldOperator):
1✔
112
    def process_value(self, lst: Any) -> Any:
1✔
113
        try:
×
114
            return [(str(item), "") for item in lst]
×
115
        except json.JSONDecodeError:
×
116
            return []
×
117

118

119
class DictOfListsToPairs(FieldOperator):
1✔
120
    position_key_before_value: bool = True
1✔
121

122
    def process_value(self, obj: Any) -> Any:
1✔
123
        try:
1✔
124
            result = []
1✔
125
            for key, values in obj.items():
1✔
126
                for value in values:
1✔
127
                    assert isinstance(value, str)
1✔
128
                    pair = (
1✔
129
                        (key, value) if self.position_key_before_value else (value, key)
130
                    )
131
                    result.append(pair)
1✔
132
            return result
1✔
133
        except:
1✔
134
            return []
1✔
135

136

137
class TakeFirstNonEmptyLine(FieldOperator):
1✔
138
    def process_value(self, text: Any) -> Any:
1✔
139
        parts = str(text).strip().split("\n")
1✔
140
        if len(parts) == 0:
1✔
141
            return ""
×
142
        return parts[0].strip()
1✔
143

144

145
class TakeLastNonEmptyLine(FieldOperator):
1✔
146
    def process_value(self, text: Any) -> Any:
1✔
147
        parts = str(text).strip().split("\n")
×
148
        if len(parts) == 0:
×
149
            return ""
×
150
        return parts[-1].strip()
×
151

152

153
class ConvertToBoolean(FieldOperator):
1✔
154
    def process_value(self, text: Any) -> Any:
1✔
155
        clean_instance = str(text).strip().lower()
1✔
156
        if any(w in clean_instance for w in ["no", "not", "wrong", "false"]):
1✔
157
            return "FALSE"
1✔
158
        if any(w in clean_instance for w in ["yes", "right", "correct", "true"]):
1✔
159
            return "TRUE"
1✔
160
        return "OTHER"
1✔
161

162

163
class LowerCaseTillPunc(FieldOperator):
1✔
164
    def process_value(self, text: Any) -> Any:
1✔
165
        non_empty_line = text.lower()
1✔
166
        match = re.search(r"[.,!?;]", non_empty_line)
1✔
167
        if match:
1✔
168
            # Extract text up to the first punctuation
169
            non_empty_line = non_empty_line[: match.start()]
×
170
        return non_empty_line
1✔
171

172

173
class Lower(FieldOperator):
1✔
174
    def process_value(self, text: Any) -> Any:
1✔
175
        return text.lower()
1✔
176

177

178
class Upper(FieldOperator):
1✔
179
    def process_value(self, text: Any) -> Any:
1✔
180
        return str(text).upper()
×
181

182

183
@deprecation("2.0.0", alternative=Lower)
1✔
184
class LowerCase(Lower):
1✔
185
    pass
1✔
186

187

188
class Capitalize(FieldOperator):
1✔
189
    def process_value(self, text: Any) -> Any:
1✔
190
        return text.capitalize()
1✔
191

192

193
class GetStringAfter(FieldOperator):
1✔
194
    substring: str
1✔
195

196
    def process_value(self, text: Any) -> Any:
1✔
197
        return text.split(self.substring, 1)[-1].strip()
×
198

199

200
class MatchClosestOption(InstanceFieldOperator):
1✔
201
    options_field: str = "options"
1✔
202

203
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
204
        options = instance["task_data"][self.options_field]
1✔
205
        return get_close_matches(value, options, n=1, cutoff=0.0)[0]
1✔
206

207

208
def process_instance_value(self, value, instance):
1✔
209
    options = instance[self.options_field]
×
210
    # Get the closest match; n=1 returns the single closest match
211
    closest_match = get_close_matches(value, options, n=1, cutoff=0)
×
212
    return closest_match[0] if closest_match else None
×
213

214

215
class Substring(FieldOperator):
1✔
216
    begin: int = 0
1✔
217
    end: int = None
1✔
218

219
    def process_value(self, text: Any) -> Any:
1✔
220
        if self.end is None:
1✔
221
            return text[self.begin :]
1✔
222
        return text[self.begin : self.end]
1✔
223

224

225
class FirstCharacter(FieldOperator):
1✔
226
    def process_value(self, text: Any) -> Any:
1✔
227
        match = re.search(r"\s*(\w)", text)
1✔
228
        if match:
1✔
229
            return match.groups(0)[0]
1✔
230
        return ""
×
231

232

233
class TakeFirstWord(FieldOperator):
1✔
234
    def process_value(self, text: Any) -> Any:
1✔
235
        match = re.search(r"([-]*[0-9]+(\.([0-9]+))*)|([\w]+)", text)
1✔
236
        if match:
1✔
237
            return text[match.start() : match.end()]
1✔
238
        return ""
1✔
239

240

241
class YesNoToInt(FieldOperator):
1✔
242
    def process_value(self, text: Any) -> Any:
1✔
243
        if text == "yes":
1✔
244
            return "1"
1✔
245
        if text == "no":
1✔
246
            return "0"
1✔
247
        return text
1✔
248

249

250
class YesToOneElseZero(FieldOperator):
1✔
251
    def process_value(self, text: Any) -> Any:
1✔
252
        if text == "yes":
1✔
253
            return "1"
1✔
254
        return "0"
1✔
255

256

257
class StrToFloatFormat(FieldOperator):
1✔
258
    def process_value(self, text: Any) -> Any:
1✔
259
        try:
1✔
260
            return str(float(text))
1✔
261
        except Exception:
1✔
262
            return str(text)
1✔
263

264

265
class ToYesOrNone(FieldOperator):
1✔
266
    def process_value(self, text: Any) -> Any:
1✔
267
        if text == "yes":
1✔
268
            return "yes"
1✔
269
        return "none"
1✔
270

271

272
class StanceToProCon(FieldOperator):
1✔
273
    def process_value(self, text: Any) -> Any:
1✔
274
        if text == "positive":
1✔
275
            return "PRO"
1✔
276
        if text in ["negative", "suggestion"]:
1✔
277
            return "CON"
1✔
278
        return "none"
1✔
279

280

281
class StringEquals(FieldOperator):
1✔
282
    string: str
1✔
283

284
    def process_value(self, text: Any) -> Any:
1✔
285
        if "not " + self.string.lower() in text.lower():
×
286
            return "not " + self.string.lower()
×
287
        if self.string.lower() in text.lower():
×
288
            return self.string.lower()
×
289
        return text
×
290

291

292
@deprecation("2.0.0", alternative=StringEquals)
1✔
293
class StringOrNotString(StringEquals):
1✔
294
    pass
1✔
295

296

297
class ExtractMtBenchRatingJudgment(FieldOperator):
1✔
298
    def process_value(self, text: Any) -> Any:
1✔
299
        match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
1✔
300
        try:
1✔
301
            return float(match.group(1)) / 10
1✔
302
        except:
1✔
303
            return 0.0
1✔
304

305

306
class ExtractHarmRatingJudgement(FieldOperator):
1✔
307
    def process_value(self, text: Any) -> Any:
1✔
308
        match = re.search(r"\[\[([\d]+\.?[\d]*)\]\]", text)
×
309
        try:
×
310
            return float(match.group(1)) * 0.25 - 0.25
×
311
        except:
×
312
            return np.NaN
×
313

314

315
class ExtractMtBenchLabelJudgment(FieldOperator):
1✔
316
    def process_value(self, text: Any) -> Any:
1✔
317
        match = re.search(r"\[\[([^\]]+)\]\]", text)
1✔
318
        try:
1✔
319
            return str(match.group(1))
1✔
320
        except:
1✔
321
            return "None"
1✔
322

323

324
class LiteralEval(FieldOperator):
1✔
325
    def process_value(self, text: Any) -> Any:
1✔
326
        if text is not None and not isinstance(text, str):
1✔
327
            raise ValueError(
×
328
                f"LiteralEval: field '{self.field}' is expected to be of 'str' input type, got: {type(text)}"
329
            )
330
        if text is None or text == "":
1✔
331
            return text
×
332
        return ast.literal_eval(text.strip())
1✔
333

334

335
class ExtractSafeUnsafeJudgment(FieldOperator):
1✔
336
    def process_value(self, text: Any) -> Any:
1✔
337
        first_line = str(text).strip().split("\n")[0].lower()
×
338
        if first_line == "safe":
×
339
            return 1.0
×
340
        return 0.0
×
341

342

343
class ExtractArenaHardNumericalJudgment(FieldOperator):
1✔
344
    def process_value(self, text: Any) -> Any:
1✔
345
        match = re.search(r"\[\[([^\]]+)\]\]", text)
×
346
        try:
×
347
            res = str(match.group(1))
×
348
            if res == "A>B":
×
349
                return 1
×
350
            if res == "A>>B":
×
351
                return 3
×
352
            if res == "B>A":
×
353
                return -1
×
354
            if res == "B>>A":
×
355
                return -3
×
356
            return 0
×
357

358
        except:
×
359
            return 0
×
360

361

362
class InferDictsToBinaryLogprobs(FieldOperator):
1✔
363
    neg_class_name: str
1✔
364
    pos_class_name: str
1✔
365

366
    take_logprobs_from_end: bool = False
1✔
367
    num_logprobs_to_take: int = 3
1✔
368
    min_probability_mass = 0.0001
1✔
369

370
    def verify(self):
1✔
371
        super().verify()
×
372
        if (
×
373
            self.neg_class_name.lower() in self.pos_class_name.lower()
374
            or self.pos_class_name.lower() in self.neg_class_name.lower()
375
        ):
376
            raise ValueError(
×
377
                f"""Class names in {self.__class__.__name__} should not overlap, got "{self.pos_class_name}" and "{self.neg_class_name}"""
378
            )
379

380
    def process_value(self, obj: Any) -> Any:
1✔
381
        for i in self.get_token_range(obj):
×
382
            try:
×
383
                pos_probs, neg_probs = self.get_pos_neg_probs(pred_dict=obj[i])
×
384
                if pos_probs or neg_probs:
×
385
                    sum_probs = sum(pos_probs) + sum(neg_probs)
×
386
                    if sum_probs > self.min_probability_mass:
×
387
                        return sum(pos_probs) / sum_probs
×
388
            except:
×
389
                pass
×
390
        return 0
×
391

392
    def get_pos_neg_probs(self, pred_dict):
1✔
393
        token_logprobs = pred_dict["top_tokens"]
×
394

395
        pos_and_neg_probs = []
×
396
        for class_name in [self.pos_class_name, self.neg_class_name]:
×
397
            # We need to capture different variants of model behavior and tokenizers, for example with opening space,
398
            # punctuation etc. but avoid longer words that contain the class name.
399
            # For example, for class "yes" we would capture "YES," and " Yes" but not "yesterday".
400
            name_regex = re.compile(
×
401
                rf"(\W|Ġ|_)*{class_name}(\W|Ġ|_)*", flags=re.IGNORECASE
402
            )
403
            class_probs = [
×
404
                np.exp(d["logprob"])
405
                for d in token_logprobs
406
                if name_regex.fullmatch(d["text"])
407
            ]
408
            pos_and_neg_probs.append(class_probs)
×
409
        return pos_and_neg_probs
×
410

411
    def get_token_range(self, obj: Any) -> range:
1✔
412
        n_tokens = min([self.num_logprobs_to_take, len(obj)])
×
413
        if self.take_logprobs_from_end:
×
414
            return range(-1, -(n_tokens + 1), -1)
×
415
        return range(n_tokens)
×
416

417

418
class RemoveArticles(FieldOperator):
1✔
419
    def process_value(self, text: Any) -> Any:
1✔
420
        return re.sub(r"\b(a|an|the)\b", " ", text)
1✔
421

422

423
class RemovePunctuations(FieldOperator):
1✔
424
    def process_value(self, text: Any) -> Any:
1✔
425
        puncs_to_exclude = set(string.punctuation)
1✔
426
        return "".join(c for c in text if c not in puncs_to_exclude)
1✔
427

428

429
class FixWhiteSpace(FieldOperator):
1✔
430
    def process_value(self, text: Any) -> Any:
1✔
431
        return " ".join(text.split())
1✔
432

433

434
class AddPrefix(FieldOperator):
1✔
435
    prefix: str
1✔
436

437
    def process_value(self, text: str) -> str:
1✔
438
        text = text.strip()
1✔
439
        if text.startswith(self.prefix):
1✔
440
            return text
1✔
441
        return self.prefix + text.strip()
1✔
442

443

444
class GetSQL(FieldOperator):
1✔
445
    """Operator to extract the most likely SQL query from text, often generated by language models.
446

447
    It prioritizes SQL within markdown code blocks (```sql or ```)
448
    and defaults to finding the last SELECT statement in the text
449
    if no code blocks are found. It attempts to remove trailing text
450
    after the first semicolon in the identified query.
451
    """
452

453
    def process_value(self, text: str) -> str:
1✔
454
        """Extracts the most plausible SQL query from the given text.
455

456
        Args:
457
            text: The input string potentially containing an SQL query
458
                  and other text (e.g., explanations, markdown).
459

460
        Returns:
461
            The extracted SQL query string, or a message indicating
462
            no query was found.
463
        """
464
        if not isinstance(text, str):
1✔
465
            return "Input must be a string"  # Basic type check
1✔
466

467
        sql_query_candidate = None  # Renamed to indicate it might need cleanup
1✔
468

469
        # 1. Try to find ```sql ... ``` code blocks
470
        sql_blocks = re.findall(
1✔
471
            r"```sql\s*(.*?)\s*```", text, re.DOTALL | re.IGNORECASE
472
        )
473
        if sql_blocks:
1✔
474
            # Use the content of the last ```sql block
475
            sql_query_candidate = sql_blocks[-1].strip()
1✔
476
        else:
477
            # 2. If no ```sql blocks, try to find generic ``` ... ``` blocks
478
            generic_blocks = re.findall(r"```\s*(.*?)\s*```", text, re.DOTALL)
1✔
479
            if generic_blocks:
1✔
480
                # Check if the last block looks like SQL (starts with SELECT, INSERT, etc.)
481
                last_block_content = generic_blocks[-1].strip()
1✔
482
                # Allow common SQL starting keywords
483
                sql_keywords = (
1✔
484
                    r"^(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|WITH|DROP|TRUNCATE)\b"
485
                )
486
                if re.match(sql_keywords, last_block_content, re.IGNORECASE):
1✔
487
                    sql_query_candidate = last_block_content
1✔
488

489
        # 3. If no suitable code blocks found, search the entire text for the last relevant SQL keyword
490
        if sql_query_candidate is None:
1✔
491
            # Find the start index of the *last* common SQL keyword (case-insensitive)
492
            last_match = None
1✔
493
            # Expand search beyond just SELECT for better fallback
494
            sql_keywords_search = (
1✔
495
                r"\b(SELECT|INSERT|UPDATE|DELETE|CREATE|ALTER|WITH|DROP|TRUNCATE)\b"
496
            )
497
            for match in re.finditer(sql_keywords_search, text, re.IGNORECASE):
1✔
498
                last_match = match
1✔
499

500
            if last_match:
1✔
501
                # Extract from the last keyword to the end of the string
502
                sql_query_candidate = text[last_match.start() :].strip()
1✔
503

504
        # 4. Cleanup: Truncate at first semicolon and strip whitespace
505
        if sql_query_candidate:
1✔
506
            # Find the first semicolon in the candidate string
507
            first_semicolon_index = sql_query_candidate.find(";")
1✔
508
            if first_semicolon_index != -1:
1✔
509
                # If found, take everything before it
510
                sql_query = sql_query_candidate[:first_semicolon_index].strip()
1✔
511
            else:
512
                # If no semicolon, use the candidate as is (after stripping)
513
                sql_query = sql_query_candidate.strip()
1✔
514

515
            # clean the ```sql\n from the start and the \n``` in case it is there
516
            sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
1✔
517

518
        else:
519
            sql_query = None  # Ensure sql_query is None if no candidate was found
1✔
520

521
        # 5. Return result or 'not found' message
522
        return (
1✔
523
            sql_query if sql_query is not None else "No query found in generation"
524
        )  # Check for None explicitly
525

526

527
class ScaleNumberToZeroOneReturnZeroIfFails(FieldOperator):
1✔
528
    max_val = 10
1✔
529
    min_val = 0
1✔
530

531
    def process_value(self, text: Any) -> Any:
1✔
532
        try:
1✔
533
            text = float(text)
1✔
534
            return (text - self.min_val) / self.max_val
1✔
535
        except Exception:
1✔
536
            return 0
1✔
537

538

539
class ExtractVerbalJudgment(FieldOperator):
1✔
540
    classes = ["not", "somewhat", "mostly", "completely"]
1✔
541

542
    def process_value(self, text: Any) -> Any:
1✔
543
        max_val = len(self.classes) - 1
1✔
544
        for i, c in enumerate(self.classes):
1✔
545
            if text.strip().lower().startswith(c):
1✔
546
                return i / (max_val)
1✔
547
        return 0
1✔
548

549

550
class ExtractVerbalJudgementBadGood(ExtractVerbalJudgment):
1✔
551
    classes = ["very bad", "bad", "mediocre", "good", "very good"]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc