• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14703910144

28 Apr 2025 08:46AM UTC coverage: 79.483% (-0.6%) from 80.035%
14703910144

Pull #1764

github

web-flow
Merge 7f307a34d into 29ef085a0
Pull Request #1764: Add tool calling support + Berekley Tool Calling Benchmark (simple-v3)

1622 of 2034 branches covered (79.74%)

Branch coverage included in aggregate %.

10190 of 12827 relevant lines covered (79.44%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.45
src/unitxt/templates.py
1
import json
1✔
2
from abc import abstractmethod
1✔
3
from random import random
1✔
4
from typing import Any, Dict, List, Optional, Tuple, Union
1✔
5

6
from .artifact import Artifact
1✔
7
from .collections import DictCollection, ListCollection
1✔
8
from .dataclass import NonPositionalField
1✔
9
from .dict_utils import dict_set
1✔
10
from .error_utils import Documentation, UnitxtError
1✔
11
from .operator import InstanceOperator, Operator
1✔
12
from .random_utils import new_random_generator
1✔
13
from .serializers import (
1✔
14
    DialogSerializer,
15
    ImageSerializer,
16
    ListSerializer,
17
    MultiTypeSerializer,
18
    NumberQuantizingSerializer,
19
    Serializer,
20
    SQLDatabaseAsSchemaSerializer,
21
    TableSerializer,
22
    ToolCallSerializer,
23
    ToolsSerializer,
24
    VideoSerializer,
25
)
26
from .settings_utils import get_constants
1✔
27
from .type_utils import isoftype, to_type_string
1✔
28

29
constants = get_constants()
1✔
30

31

32
class TemplateFormatKeyError(UnitxtError):
1✔
33
    def __init__(self, template, data, data_type, format_str, format_name):
1✔
34
        keys = ", ".join(data.keys())
1✔
35
        super().__init__(
1✔
36
            f"Available {data_type}s are [{keys}] "
37
            f"but {template.__class__.__name__}.{format_name} format requires a different ones: '{format_str}'",
38
            Documentation.ADDING_TEMPLATE,
39
        )
40

41

42
class Template(InstanceOperator):
1✔
43
    """The role of template is to take the fields of every instance and verbalize it.
44

45
    Meaning the template is taking the instance and generating source, target and references.
46

47
    Args:
48
        skip_rendered_instance (bool): if "source", "target", and "references" are already defined fields in the instance, skip its processing
49
        postprocessors: a list of strings being artifact names of text processors, to be applied on the model output
50
        instruction: a formatting string that yields an instruction with potential participation of values from the "input_fields" part of the instance
51
        target_prefix: a string to be used to format the prompt. Not a formatting string.
52

53
    """
54

55
    skip_rendered_instance: bool = NonPositionalField(default=True)
1✔
56
    postprocessors: List[str] = NonPositionalField(
1✔
57
        default_factory=lambda: ["processors.to_string_stripped"]
58
    )
59
    instruction: str = NonPositionalField(default="")
1✔
60
    target_prefix: str = NonPositionalField(default="")
1✔
61
    title_fields: List[str] = NonPositionalField(default_factory=list)
1✔
62
    serializer: Serializer = NonPositionalField(
1✔
63
        default_factory=lambda: MultiTypeSerializer(
64
            serializers=[
65
                ImageSerializer(),
66
                VideoSerializer(),
67
                TableSerializer(),
68
                ToolCallSerializer(),
69
                ToolsSerializer(),
70
                DialogSerializer(),
71
                ListSerializer(),
72
                SQLDatabaseAsSchemaSerializer(),
73
            ]
74
        )
75
    )
76

77
    def verify(self):
1✔
78
        super().verify()
1✔
79
        assert isoftype(
1✔
80
            self.postprocessors, List[Union[Operator, str]]
81
        ), f"The template post processors field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'."
82

83
    def input_fields_to_instruction_and_target_prefix(self, input_fields, instruction):
1✔
84
        instruction = self.apply_formatting(
1✔
85
            input_fields, "input field", instruction, "instruction"
86
        )
87
        target_prefix = self.apply_formatting(
1✔
88
            input_fields,
89
            "input field",
90
            self.target_prefix,
91
            "target_prefix",
92
        )
93
        return instruction, target_prefix
1✔
94

95
    def preprocess_input_and_reference_fields(
1✔
96
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
97
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
98
        return input_fields, reference_fields
1✔
99

100
    def preprocess_input_fields(self, input_fields: Dict[str, Any]):
1✔
101
        return input_fields
1✔
102

103
    def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
1✔
104
        return reference_fields
1✔
105

106
    def process(
1✔
107
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
108
    ) -> Dict[str, Any]:
109
        if self.skip_rendered_instance:
1✔
110
            if (
1✔
111
                "source" in instance
112
                and "target" in instance
113
                and "references" in instance
114
            ):
115
                return instance
1✔
116

117
        input_fields = instance.get("input_fields")
1✔
118
        reference_fields = instance.get("reference_fields")
1✔
119

120
        if stream_name != constants.inference_stream:
1✔
121
            input_fields, reference_fields = self.preprocess_input_and_reference_fields(
1✔
122
                input_fields, reference_fields
123
            )
124

125
        input_fields = self.preprocess_input_fields(input_fields)
1✔
126

127
        self.set_titles(input_fields)
1✔
128

129
        serialized_inputs = self.serialize(input_fields, instance)
1✔
130

131
        source = self.input_fields_to_source(serialized_inputs)
1✔
132
        instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
1✔
133
            serialized_inputs, instance.get(constants.instruction_field, self.instruction)
134
        )
135

136
        result = {
1✔
137
            **instance,
138
            "source": source,
139
            constants.instruction_field: instruction,
140
            "target_prefix": target_prefix,
141
            "postprocessors": self.postprocessors,
142
        }
143

144
        if stream_name == constants.inference_stream:
1✔
145
            return self.post_process_instance(result)
1✔
146

147
        if reference_fields is None:
1✔
148
            raise ValueError("Should have reference_fields")
×
149

150
        reference_fields = self.preprocess_reference_fields(reference_fields)
1✔
151

152
        serialized_references = self.serialize(
1✔
153
            reference_fields, instance
154
        )  # Dict[str, str]
155

156
        target, references = self.reference_fields_to_target_and_references(
1✔
157
            serialized_references
158
        )
159

160
        result["target"] = target
1✔
161
        result["references"] = references
1✔
162

163
        return self.post_process_instance(result)
1✔
164

165
    def post_process_instance(self, instance):
1✔
166
        return instance
1✔
167

168
    def serialize(
1✔
169
        self, data: Dict[str, Any], instance: Dict[str, Any]
170
    ) -> Dict[str, str]:
171
        return {k: self.serializer.serialize(v, instance) for k, v in data.items()}
1✔
172

173
    @abstractmethod
1✔
174
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
175
        pass
1✔
176

177
    def set_titles(self, data):
1✔
178
        for field in self.title_fields:
1✔
179
            data[field] = data[field].title()
1✔
180

181
    @abstractmethod
1✔
182
    def reference_fields_to_target_and_references(
1✔
183
        self, reference_fields: Dict[str, object]
184
    ) -> Tuple[str, List[str]]:
185
        pass
1✔
186

187
    def apply_formatting(
1✔
188
        self, data: Dict[str, Any], data_type: str, format_str: str, format_name: str
189
    ) -> str:
190
        try:
1✔
191
            if format_str is None:
1✔
192
                raise UnitxtError(
1✔
193
                    f"Required field '{format_name}' of class {self.__class__.__name__} not set in {self.__class__.__name__}",
194
                    Documentation.ADDING_TEMPLATE,
195
                )
196
            return format_str.format(**data)
1✔
197
        except KeyError as e:
1✔
198
            raise TemplateFormatKeyError(
1✔
199
                self, data, data_type, format_str, format_name
200
            ) from e
201

202

203
class ApplyTemplate(InstanceOperator):
1✔
204
    demos_field: Optional[str] = None
1✔
205

206
    @abstractmethod
1✔
207
    def get_template(self, instance: Dict[str, Any]) -> Template:
1✔
208
        pass
×
209

210
    def apply(
1✔
211
        self,
212
        template: Template,
213
        instance: Dict[str, Any],
214
        stream_name: Optional[str] = None,
215
    ):
216
        return template.process_instance(instance, stream_name)
1✔
217

218
    def process(
1✔
219
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
220
    ) -> Dict[str, Any]:
221
        template = self.get_template(instance)
1✔
222

223
        if self.demos_field is not None:
1✔
224
            if self.demos_field not in instance:
1✔
225
                raise ValueError("Demos field is missing.")
×
226
            instance[self.demos_field] = [
1✔
227
                self.apply(template, demo_instance)
228
                for demo_instance in instance[self.demos_field]
229
            ]
230
        dict_set(instance, "recipe_metadata/template", template)
1✔
231
        return self.apply(template, instance, stream_name)
1✔
232

233

234
class ApplySingleTemplate(ApplyTemplate):
1✔
235
    template: Template
1✔
236

237
    def get_template(self, instance: Dict[str, Any]) -> Template:
1✔
238
        return self.template
1✔
239

240

241
class ApplyRandomTemplate(ApplyTemplate):
1✔
242
    templates: List[Template]
1✔
243

244
    def get_template(self, instance: Dict[str, Any]) -> Template:
1✔
245
        random_generator = new_random_generator(
1✔
246
            {**instance["input_fields"], **instance["reference_fields"]}
247
        )
248
        return random_generator.choice(self.templates)
1✔
249

250

251
class InputFormatTemplate(Template):
1✔
252
    input_format: str
1✔
253

254
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
255
        return self.apply_formatting(
1✔
256
            input_fields,
257
            "input field",
258
            self.input_format,
259
            "input_format",
260
        )
261

262

263
class OutputFormatTemplate(Template):
1✔
264
    output_format: str = None
1✔
265

266
    def reference_fields_to_target_and_references(
1✔
267
        self, reference_fields: Dict[str, object]
268
    ) -> str:
269
        target = self.apply_formatting(
1✔
270
            reference_fields,
271
            "reference field",
272
            self.output_format,
273
            "output_format",
274
        )
275
        references = [target]
1✔
276
        return target, references
1✔
277

278

279
class JsonOutputFormatTemplate(Template):
1✔
280
    output_fields: Dict[str, str]
1✔
281
    wrap_with_list_fields: List[str]
1✔
282

283
    def reference_fields_to_target_and_references(
1✔
284
        self, reference_fields: Dict[str, object]
285
    ) -> str:
286
        data = {}
×
287
        for field, target_field in self.output_fields.items():
×
288
            value = reference_fields[field]
×
289
            if field in self.wrap_with_list_fields:
×
290
                value = [value]
×
291
            data[target_field] = value
×
292
        target = json.dumps(data, ensure_ascii=False)
×
293
        references = [target]
×
294
        return target, references
×
295

296

297
class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate):
1✔
298
    """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
299

300
    Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
301
    """
302

303
    pass
1✔
304

305

306
class JsonOutputTemplate(InputFormatTemplate, JsonOutputFormatTemplate):
1✔
307
    """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
308

309
    Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
310
    """
311

312
    pass
1✔
313

314

315
class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
1✔
316
    reference: str
1✔
317

318
    def reference_fields_to_target_and_references(
1✔
319
        self, reference_fields: Dict[str, object]
320
    ) -> str:
321
        target = self.apply_formatting(
1✔
322
            reference_fields,
323
            "reference field",
324
            self.output_format,
325
            "output_format",
326
        )
327
        reference = self.apply_formatting(
1✔
328
            reference_fields,
329
            "reference field",
330
            self.reference,
331
            "reference",
332
        )
333
        return target, [reference]
1✔
334

335

336
class PairwiseChoiceTemplate(InputOutputTemplate):
1✔
337
    """PairwiseChoiceTemplate.
338

339
    Requirements:
340
     The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
341

342
    Args:
343
         choice_a_field (str):
344
            The field which contains choice_a value
345
         choice_b_field (str):
346
            The field which contains choice_b value
347
         answer_field (str):
348
            The field which contains the answer value.
349
            Should be of type Literal["choice_1", "choice_2", "tie"]
350
         choice_a_label (str):
351
            The label of choice A answer as it is verbalized in the template.
352
         choice_b_label (str):
353
            The label of choice B answer as it is verbalized in the template.
354
         choice_tie_label (str):
355
            The label of a tie answer as it should be verbalized in the template.
356
         shuffle (bool):
357
            whether to shuffle the choices or not. This is done to take into account position bias.
358

359
    shuffle: 50% of the time:
360
     1. The values of choice_a_field and choice_b_field will be swapped.
361
     2. If the values of answer_field is choice_a_label, set it to choice_b_label.
362
        Else if the values of answer_field is choice_b_label, set it to choice_a_label.
363
        Else if the value of answer_field is choice_tie_label, do nothing.
364

365
    """
366

367
    choice_a_field: str
1✔
368
    choice_b_field: str
1✔
369
    answer_field: str
1✔
370
    choice_a_label: str
1✔
371
    choice_b_label: str
1✔
372
    choice_tie_label: str
1✔
373
    shuffle: bool
1✔
374

375
    def verify(self):
1✔
376
        super().verify()
×
377

378
    def verbalize_answer_field(self, reference_fields: Dict[str, object]):
1✔
379
        answer = reference_fields[self.answer_field]
×
380
        assert answer in ["choice_a", "choice_b", "tie"]
×
381
        if answer == "choice_a":
×
382
            reference_fields[self.answer_field] = self.choice_a_label
×
383
        elif answer == "choice_b":
×
384
            reference_fields[self.answer_field] = self.choice_b_label
×
385
        else:
386
            reference_fields[self.answer_field] = self.choice_tie_label
×
387

388
        return reference_fields
×
389

390
    def shuffle_values(
1✔
391
        self, input_fields: Dict[str, object], reference_fields: Dict[str, object]
392
    ):
393
        if not self.shuffle:
×
394
            return input_fields, reference_fields
×
395
        outcome = random()  # A float between 0 and 1
×
396
        if outcome <= 0.5:
×
397
            choice_a_value = input_fields[self.choice_a_field]
×
398
            choice_b_value = input_fields[self.choice_b_field]
×
399

400
            input_fields[self.choice_a_field] = choice_b_value
×
401
            input_fields[self.choice_b_field] = choice_a_value
×
402

403
            answer = reference_fields[self.answer_field]
×
404
            assert answer in [
×
405
                self.choice_a_label,
406
                self.choice_b_label,
407
                self.choice_tie_label,
408
            ]
409
            if answer == self.choice_a_label:
×
410
                reference_fields[self.answer_field] = self.choice_b_label
×
411
            elif answer == self.choice_b_label:
×
412
                reference_fields[self.answer_field] = self.choice_a_label
×
413

414
        return input_fields, reference_fields
×
415

416
    def preprocess_input_and_reference_fields(
1✔
417
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
418
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
419
        reference_fields = self.verbalize_answer_field(reference_fields)
×
420
        input_fields, reference_fields = self.shuffle_values(
×
421
            input_fields, reference_fields
422
        )
423
        return input_fields, reference_fields
×
424

425

426
class DialogFieldsData(Artifact):
1✔
427
    user_role_label: str
1✔
428
    assistant_role_label: str
1✔
429
    system_role_label: str
1✔
430
    dialog_field: str
1✔
431

432

433
class DialogTemplate(InputOutputTemplate):
1✔
434
    dialog_fields: List[DialogFieldsData]
1✔
435
    turns_separator: str = "\n\n"
1✔
436
    label_separator: str = " "
1✔
437

438
    def process_dialog(self, input_fields: Dict[str, object]):
1✔
439
        for dialog_fields in self.dialog_fields:
×
440
            dialog = input_fields[dialog_fields.dialog_field]
×
441
            # TODO: update isoftype method to support Literal verification and check
442
            #  it's List[Tuple[Literal["user", "assistant", "system"], str]] (Issue #799)
443
            assert isoftype(dialog, List[Tuple[str, str]])
×
444

445
            user_role_label = dialog_fields.user_role_label
×
446
            assistant_role_label = dialog_fields.assistant_role_label
×
447
            system_role_label = dialog_fields.system_role_label
×
448

449
            dialog_str = ""
×
450
            for i, turn in enumerate(dialog):
×
451
                (turn_type, turn_text) = turn
×
452
                turns_separator = "" if i == 0 else self.turns_separator
×
453
                if turn_type == "user":
×
454
                    dialog_str += f"{turns_separator}{user_role_label}{self.label_separator}{turn_text}"
×
455
                elif turn_type == "assistant":
×
456
                    dialog_str += f"{turns_separator}{assistant_role_label}{self.label_separator}{turn_text}"
×
457
                elif turn_type == "system":
×
458
                    dialog_str += f"{turns_separator}{system_role_label}{self.label_separator}{turn_text}"
×
459

460
            input_fields[dialog_fields.dialog_field] = dialog_str
×
461
        return input_fields
×
462

463
    def preprocess_input_fields(self, input_fields: Dict[str, Any]):
1✔
464
        return self.process_dialog(input_fields)
×
465

466

467
class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
1✔
468
    pass
1✔
469

470

471
class PairwiseComparativeRatingTemplate(InputOutputTemplate):
1✔
472
    """PairwiseChoiceTemplate.
473

474
    Args:
475
         choice_a_field (str): The field which contains choice_a value
476

477
         choice_b_field (str): The field which contains choice_b value
478

479
         answer_field (str): The field which contains the answer value. The value should be an int.
480
         Positive for preferring choice_a, and negative for preferring choice_b
481

482
         shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
483

484
    shuffle: 50% of the time:
485
    | 1) The values of choice_a_field and choice_b_field will be swapped.
486
    | 2) Replace the values of answer_field with its mapped value according to the reverse_preference_map Dict.
487

488
    """
489

490
    choice_a_field: str
1✔
491
    choice_b_field: str
1✔
492
    choice_a_id_field: str
1✔
493
    choice_b_id_field: str
1✔
494
    answer_field: str
1✔
495
    shuffle: bool
1✔
496

497
    def shuffle_values(
1✔
498
        self, input_fields: Dict[str, object], reference_fields: Dict[str, object]
499
    ):
500
        if not self.shuffle:
×
501
            return input_fields, reference_fields
×
502
        outcome = random()  # A float between 0 and 1
×
503
        if outcome <= 0.5:
×
504
            choice_a_value = input_fields[self.choice_a_field]
×
505
            choice_b_value = input_fields[self.choice_b_field]
×
506
            input_fields[self.choice_a_field] = choice_b_value
×
507
            input_fields[self.choice_b_field] = choice_a_value
×
508

509
            choice_a_id_value = input_fields[self.choice_a_id_field]
×
510
            choice_b_id_value = input_fields[self.choice_b_id_field]
×
511
            input_fields[self.choice_a_id_field] = choice_b_id_value
×
512
            input_fields[self.choice_b_id_field] = choice_a_id_value
×
513

514
            assert isinstance(reference_fields[self.answer_field], int)
×
515
            reference_fields[self.answer_field] = (
×
516
                int(reference_fields[self.answer_field]) * -1
517
            )
518

519
        return input_fields, reference_fields
×
520

521
    def preprocess_input_and_reference_fields(
1✔
522
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
523
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
524
        input_fields, reference_fields = self.shuffle_values(
×
525
            input_fields, reference_fields
526
        )
527
        return input_fields, reference_fields
×
528

529

530
class MultipleChoiceTemplate(InputFormatTemplate):
1✔
531
    """Formats the input that specifies a multiple-choice question, with a list of possible answers to choose from, and identifies the correct answer.
532

533
    Args:
534
        target_prefix (str): Optional prefix that can be added before the target label in
535
            generated prompts or outputs.
536
        choices_field (str): The key under which the multiple choices are stored in the
537
            input and reference dictionaries.
538
        target_field (str): The key under which the correct choice is stored in the
539
            reference dictionary (can be integer index or textual label).
540
        choices_separator (str): A string used to join formatted
541
            choices (e.g. ", ").
542
        source_choice_format (str): A Python format string used for displaying each choice
543
            in the input fields (e.g. "{choice_numeral}. {choice_text}").
544
        target_choice_format (str): A Python format string used for displaying each choice
545
            in the target or final output (e.g. "{choice_numeral}").
546
        enumerator (str): Determines how choice numerals are enumerated. Possible values
547
            include "capitals", "lowercase", "numbers", or "roman".
548
        shuffle_choices (bool): If True, shuffle the choices. The shuffling seed can be
549
            set with `shuffle_choices_seed`.
550
        shuffle_choices_seed (int, optional): If provided, the choices are shuffled with
551
            this fixed integer seed for reproducibility.
552
        sort_choices_by_length (bool): If True, sorts choices
553
            by their length (ascending).
554
        sort_choices_alphabetically (bool): If True, sorts choices
555
            in alphabetical order.
556
        reverse_choices (bool): If True, reverses the order of the choices after any
557
            sorting has been applied. Defaults to False to preserve backward compatibility.
558
    """
559

560
    target_prefix: str = ""
1✔
561
    choices_field: str = "choices"
1✔
562
    target_field: str = "label"
1✔
563
    choices_separator: str = ", "
1✔
564
    source_choice_format: str = "{choice_numeral}. {choice_text}"
1✔
565
    target_choice_format: str = "{choice_numeral}"
1✔
566
    enumerator: str = "capitals"
1✔
567

568
    shuffle_choices: bool = False
1✔
569
    shuffle_choices_seed: int = None
1✔
570
    sort_choices_by_length: bool = False
1✔
571
    sort_choices_alphabetically: bool = False
1✔
572
    reverse_choices: bool = False  # False by default for backward-compat
1✔
573
    place_correct_choice_position: int = None
1✔
574

575
    def prepare(self):
1✔
576
        super().prepare()
1✔
577
        if self.enumerator == "capitals":
1✔
578
            self.enumerator = "ABCDEFGHIJKLMNOP"
1✔
579
        if self.enumerator == "lowercase":
1✔
580
            self.enumerator = "abcdefghijklmnop"
1✔
581
        if self.enumerator == "numbers":
1✔
582
            self.enumerator = [str(i + 1) for i in range(20)]
1✔
583
        if self.enumerator == "roman":
1✔
584
            self.enumerator = [
1✔
585
                "I",
586
                "II",
587
                "III",
588
                "IV",
589
                "V",
590
                "VI",
591
                "VII",
592
                "VIII",
593
                "IX",
594
                "X",
595
                "XI",
596
                "XII",
597
                "XIII",
598
                "XIV",
599
                "XV",
600
                "XVI",
601
                "XVII",
602
                "XVIII",
603
                "XIX",
604
                "XX",
605
            ]
606

607
    def verify(self):
1✔
608
        super().verify()
1✔
609
        if self.shuffle_choices and (
1✔
610
            self.sort_choices_by_length
611
            or self.sort_choices_alphabetically
612
            or self.reverse_choices
613
            or self.place_correct_choice_position is not None
614
        ):
615
            raise UnitxtError(
1✔
616
                "You cannot combine shuffle_choices with sorting or reversing flags."
617
            )
618

619
        if self.sort_choices_by_length and self.sort_choices_alphabetically:
1✔
620
            raise UnitxtError(
1✔
621
                "You cannot combine both sort_choices_by_length and sort_choices_alphabetically simultaneously."
622
            )
623
        if self.place_correct_choice_position is not None and (
1✔
624
            self.sort_choices_by_length
625
            or self.sort_choices_alphabetically
626
            or self.reverse_choices
627
        ):
628
            raise UnitxtError(
1✔
629
                "You cannot combine place_correct_choice_position with sorting or reversing flags."
630
            )
631

632
    def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
1✔
633
        choices = data[self.choices_field]
1✔
634
        enumrated_choices = []
1✔
635
        for i, choice in enumerate(choices):
1✔
636
            enumrated_choices.append(
1✔
637
                choice_format.format(
638
                    choice_text=choice,
639
                    choice_numeral=self.enumerator[i],
640
                )
641
            )
642
        return enumrated_choices
1✔
643

644
    def inputs_to_numerals(self, input_fields: Dict[str, Any]) -> Tuple[str, str]:
1✔
645
        return self.inputs_to_choices(input_fields, "{choice_numeral}")
1✔
646

647
    def prepare_multiple_choice_inputs(
1✔
648
        self, input_fields: Dict[str, Any]
649
    ) -> Dict[str, Any]:
650
        choices = self.inputs_to_choices(input_fields, self.source_choice_format)
1✔
651
        return {
1✔
652
            "numerals": self.inputs_to_numerals(input_fields),
653
            **input_fields,
654
            self.choices_field: self.choices_separator.join(choices),
655
        }
656

657
    def preprocess_input_fields(self, input_fields: Dict[str, Any]) -> Dict[str, Any]:
1✔
658
        return self.prepare_multiple_choice_inputs(input_fields)
1✔
659

660
    def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> int:
1✔
661
        target = reference_fields[self.target_field]
1✔
662

663
        if not isinstance(target, int):
1✔
664
            try:
1✔
665
                return reference_fields[self.choices_field].index(target)
1✔
666
            except ValueError as e:
×
667
                raise UnitxtError(
×
668
                    f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}",
669
                    Documentation.ADDING_TEMPLATE,
670
                ) from e
671
        return target
1✔
672

673
    def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
1✔
674
        target = reference_fields[self.target_field]
1✔
675

676
        if not isinstance(target, int):
1✔
677
            try:
1✔
678
                target = reference_fields[self.choices_field].index(target)
1✔
679
            except ValueError as e:
×
680
                raise UnitxtError(
×
681
                    f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}",
682
                    Documentation.ADDING_TEMPLATE,
683
                ) from e
684

685
        choices = self.inputs_to_choices(reference_fields, self.target_choice_format)
1✔
686

687
        try:
1✔
688
            target = choices[target]
1✔
689
        except IndexError as e:
×
690
            raise UnitxtError(
×
691
                f"MultipleChoiceTemplate cannot find index number {target} in choices: {choices}",
692
                Documentation.ADDING_TEMPLATE,
693
            ) from e
694

695
        return {self.target_field: target}
1✔
696

697
    def reference_fields_to_target_and_references(
1✔
698
        self, reference_fields: Dict[str, object]
699
    ) -> str:
700
        target = reference_fields[self.target_field]
1✔
701
        return target, [target]
1✔
702

703
    def preprocess_input_and_reference_fields(
1✔
704
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
705
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
706
        if (
1✔
707
            not self.shuffle_choices
708
            and not self.sort_choices_by_length
709
            and not self.sort_choices_alphabetically
710
            and not self.reverse_choices
711
            and self.place_correct_choice_position is None
712
        ):
713
            return input_fields, reference_fields
1✔
714

715
        choices = input_fields[self.choices_field]
1✔
716
        target_index = self.outputs_to_target_index(reference_fields)
1✔
717
        original_label_choice = reference_fields[self.choices_field][target_index]
1✔
718

719
        if self.sort_choices_by_length:
1✔
720
            choices.sort(key=len)
1✔
721
        if self.sort_choices_alphabetically:
1✔
722
            choices.sort()
1✔
723
        if self.reverse_choices:
1✔
724
            choices.reverse()
1✔
725
        if self.shuffle_choices:
1✔
726
            random_generator = new_random_generator(
1✔
727
                self.shuffle_choices_seed
728
                if self.shuffle_choices_seed is not None
729
                else {**input_fields}
730
            )
731
            random_generator.shuffle(choices)
1✔
732
        if self.place_correct_choice_position is not None:
1✔
733
            fix_pos = self.place_correct_choice_position
1✔
734

735
            # Supporting negative indexes similar to Python lists
736
            # If fix_pos is negative, convert it to a valid positive index by adding len(choices).
737
            # For example, -1 becomes the last index, -2 becomes the one before last, etc.
738
            if fix_pos < 0:
1✔
739
                fix_pos += len(choices)
1✔
740
            self.place_correct_choice_position = fix_pos
1✔
741
            # Remove the original label choice from the list
742
            if not 0 <= self.place_correct_choice_position < len(choices):
1✔
743
                raise ValueError(
×
744
                    f"fix_correct_choice_position={self.place_correct_choice_position} out of range (0..{len(choices) - 1})."
745
                )
746
            choices.remove(original_label_choice)
1✔
747
            choices.insert(self.place_correct_choice_position, original_label_choice)
1✔
748

749
        # Update both input_fields and reference_fields once at the end
750
        input_fields[self.choices_field] = choices
1✔
751
        reference_fields[self.choices_field] = choices
1✔
752
        reference_fields[self.target_field] = choices.index(original_label_choice)
1✔
753

754
        return input_fields, reference_fields
1✔
755

756
    def post_process_instance(self, instance):
1✔
757
        instance["input_fields"]["options"] = self.inputs_to_choices(
1✔
758
            instance["input_fields"], self.target_choice_format
759
        )
760
        return instance
1✔
761

762

763
class YesNoTemplate(InputFormatTemplate):
1✔
764
    """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
765

766
    Args:
767
        input_format:
768
            Defines the format of the question.
769
        class_field:
770
            Defines the field that contains the name of the class that this template
771
            asks of.
772
        label_field:
773
            Defines the field which contains the true label of the input text. If a gold label is equal to the
774
            value in class_name, then the correct output is self.yes_answer (by default, "Yes").
775
            Otherwise the correct output is self.no_answer (by default, "No").
776
        yes_answer:
777
            The output value for when the gold label equals self.class_name.
778
            Defaults to "Yes".
779
        no_answer:
780
            The output value for when the gold label differs from self.class_name.
781
            Defaults to "No".
782
    """
783

784
    input_format: str = None
1✔
785
    class_field: str = None
1✔
786
    label_field: str = None
1✔
787
    yes_answer: str = "Yes"
1✔
788
    no_answer: str = "No"
1✔
789

790
    def reference_fields_to_target_and_references(
1✔
791
        self, reference_fields: Dict[str, object]
792
    ) -> str:
793
        try:
1✔
794
            gold_class_names = reference_fields[self.label_field]
1✔
795
        except KeyError as e:
1✔
796
            raise UnitxtError(
1✔
797
                f"Available reference_fields are {list(reference_fields.keys())}, missing required label field: '{self.label_field}'."
798
            ) from e
799
        if not isinstance(gold_class_names, list):
1✔
800
            raise UnitxtError(
1✔
801
                f"Unexpected value for gold_class_names: '{gold_class_names}'. Expecting a list."
802
            )
803
        try:
1✔
804
            queried_class_name = reference_fields[self.class_field]
1✔
805
        except KeyError as e:
1✔
806
            raise UnitxtError(
1✔
807
                f"Available reference_fields are {list(reference_fields.keys())}, missing required class field: '{self.class_field}'."
808
            ) from e
809
        if not queried_class_name or not isinstance(queried_class_name, str):
1✔
810
            raise UnitxtError(
1✔
811
                f"Unexpected value for queried_class_names: '{queried_class_name}'. Expected a string."
812
            )
813
        if queried_class_name in gold_class_names:
1✔
814
            return self.yes_answer, [self.yes_answer]
1✔
815
        return self.no_answer, [self.no_answer]
1✔
816

817

818
class NullTemplate(Template):
1✔
819
    """Templates that returns empty prompt and no references."""
820

821
    postprocessors = []
1✔
822

823
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
824
        return ""
×
825

826
    def reference_fields_to_target_and_references(self, reference_fields):
1✔
827
        return "", []
×
828

829

830
class KeyValTemplate(Template):
1✔
831
    """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
832

833
    Args specify with what separators to glue together the input and output designated fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
834
    """
835

836
    pairs_separator: str = ", "
1✔
837
    key_val_separator: str = ": "
1✔
838
    use_keys_for_inputs: bool = True
1✔
839
    outputs_key_val_separator: str = ": "
1✔
840
    use_keys_for_outputs: bool = False
1✔
841

842
    def process_dict(
1✔
843
        self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
844
    ) -> str:
845
        pairs = []
1✔
846
        for key, val in data.items():
1✔
847
            key_val = [key, str(val)] if use_keys else [str(val)]
1✔
848
            pairs.append(key_val_sep.join(key_val))
1✔
849
        return pairs_sep.join(pairs)
1✔
850

851
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
852
        return self.process_dict(
1✔
853
            input_fields,
854
            key_val_sep=self.key_val_separator,
855
            pairs_sep=self.pairs_separator,
856
            use_keys=self.use_keys_for_inputs,
857
        )
858

859
    def reference_fields_to_target_and_references(
1✔
860
        self, reference_fields: Dict[str, object]
861
    ) -> str:
862
        target = self.process_dict(
1✔
863
            reference_fields,
864
            key_val_sep=self.key_val_separator,
865
            pairs_sep=self.pairs_separator,
866
            use_keys=self.use_keys_for_outputs,
867
        )
868
        return target, [target]
1✔
869

870

871
class OutputQuantizingTemplate(InputOutputTemplate):
1✔
872
    serializer: MultiTypeSerializer = NonPositionalField(
1✔
873
        default_factory=MultiTypeSerializer
874
    )
875
    quantum: Union[float, int] = 0.1
1✔
876

877
    def prepare(self):
1✔
878
        super().prepare()
1✔
879
        self.serializer.add_serializers(
1✔
880
            [NumberQuantizingSerializer(quantum=self.quantum)]
881
        )
882

883

884
class MultiLabelTemplate(InputOutputTemplate):
1✔
885
    labels_field: str = "labels"
1✔
886
    labels_separator: str = ", "
1✔
887
    postprocessors = ["processors.to_list_by_comma"]
1✔
888
    output_format: str = "{labels}"
1✔
889
    empty_label: str = "None"
1✔
890

891
    def preprocess_reference_fields(
1✔
892
        self, reference_fields: Dict[str, Any]
893
    ) -> Dict[str, Any]:
894
        labels = reference_fields[self.labels_field]
1✔
895
        if not isinstance(labels, list):
1✔
896
            raise UnitxtError(
×
897
                f"MultiLabelTemplate requires labels field '{self.labels_field}' to be a list. Got {self.labels_field}<{type(labels).__name__}>: {labels}",
898
                Documentation.ADDING_TEMPLATE,
899
            )
900
        if len(labels) == 0:
1✔
901
            labels = [self.empty_label]
1✔
902
        labels_str = self.labels_separator.join(labels)
1✔
903
        return {self.labels_field: labels_str}
1✔
904

905

906
class MultiReferenceTemplate(InputOutputTemplate):
1✔
907
    references_field: str = "references"
1✔
908
    random_reference: bool = False
1✔
909
    serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer)
1✔
910

911
    def serialize(
1✔
912
        self, data: Dict[str, Any], instance: Dict[str, Any]
913
    ) -> Dict[str, str]:
914
        result = {}
1✔
915
        for k, v in data.items():
1✔
916
            if k == self.references_field:
1✔
917
                v = [self.serializer.serialize(item, instance) for item in v]
1✔
918
            else:
919
                v = self.serializer.serialize(v, instance)
1✔
920
            result[k] = v
1✔
921
        return result
1✔
922

923
    def reference_fields_to_target_and_references(
1✔
924
        self, reference_fields: Dict[str, object]
925
    ) -> Tuple[str, List[str]]:
926
        references = reference_fields[self.references_field]
1✔
927
        if not isoftype(references, List[str]):
1✔
928
            raise UnitxtError(
×
929
                f"MultiReferenceTemplate requires references field '{self.references_field}' to be List[str]. Got {self.references_field}<{type(references).__name__}>: {references}",
930
                Documentation.ADDING_TEMPLATE,
931
            )
932
        if len(references) == 0:
1✔
933
            return "", []
1✔
934

935
        if self.random_reference:
1✔
936
            random_generator = new_random_generator(reference_fields)
1✔
937
            target = random_generator.choice(references)
1✔
938
        else:
939
            target = references[0]
1✔
940

941
        return target, references
1✔
942

943

944
def escape_chars(s, chars_to_escape):
1✔
945
    for char in chars_to_escape:
1✔
946
        s = s.replace(char, f"\\{char}")
1✔
947
    return s
1✔
948

949

950
class SpanLabelingBaseTemplate(MultiLabelTemplate):
1✔
951
    spans_starts_field: str = "spans_starts"
1✔
952
    spans_ends_field: str = "spans_ends"
1✔
953
    text_field: str = "text"
1✔
954
    labels_support: list = None
1✔
955

956
    def extract_span_label_pairs(self, reference_fields):
1✔
957
        spans_starts = reference_fields[self.spans_starts_field]
1✔
958
        spans_ends = reference_fields[self.spans_ends_field]
1✔
959
        text = reference_fields[self.text_field]
1✔
960
        labels = reference_fields[self.labels_field]
1✔
961

962
        spans = []
1✔
963
        for span_start, span_end, label in zip(spans_starts, spans_ends, labels):
1✔
964
            if self.labels_support is None or label in self.labels_support:
1✔
965
                spans.append((span_start, span_end, text[span_start:span_end], label))
1✔
966

967
        for span in sorted(spans):
1✔
968
            if self.labels_support is None or span[3] in self.labels_support:
1✔
969
                yield span[2], span[3]
1✔
970

971
    def preprocess_reference_fields(
1✔
972
        self, reference_fields: Dict[str, Any]
973
    ) -> Dict[str, Any]:
974
        span_labels_pairs = self.extract_span_label_pairs(reference_fields)
1✔
975
        targets = self.span_label_pairs_to_targets(span_labels_pairs)
1✔
976
        return super().preprocess_reference_fields({"labels": targets})
1✔
977

978
    @abstractmethod
1✔
979
    def span_label_pairs_to_targets(self, pairs):
1✔
980
        pass
×
981

982

983
class SpanLabelingTemplate(SpanLabelingBaseTemplate):
1✔
984
    span_label_format: str = "{span}: {label}"
1✔
985
    escape_characters: List[str] = [":", ","]
1✔
986
    postprocessors: List[str] = ["processors.to_span_label_pairs"]
1✔
987

988
    def span_label_pairs_to_targets(self, span_label_pairs):
1✔
989
        targets = []
1✔
990
        for span, label in span_label_pairs:
1✔
991
            if self.escape_characters is not None:
1✔
992
                span = escape_chars(span, self.escape_characters)
1✔
993
            target = self.span_label_format.format(span=span, label=label)
1✔
994
            targets.append(target)
1✔
995
        return targets
1✔
996

997

998
class SpanLabelingJsonTemplate(SpanLabelingBaseTemplate):
1✔
999
    postprocessors = [
1✔
1000
        "processors.load_json",
1001
        "processors.dict_of_lists_to_value_key_pairs",
1002
    ]
1003

1004
    def span_label_pairs_to_targets(self, span_label_pairs):
1✔
1005
        groups = {}
1✔
1006
        for span, label in span_label_pairs:
1✔
1007
            if label not in groups:
1✔
1008
                groups[label] = []
1✔
1009
            groups[label].append(span)
1✔
1010
        if len(groups) > 0:
1✔
1011
            targets = [json.dumps(groups, ensure_ascii=False)]
1✔
1012
        else:
1013
            targets = []
1✔
1014
        return targets
1✔
1015

1016

1017
class TemplatesList(ListCollection):
1✔
1018
    def verify(self):
1✔
1019
        for template in self.items:
1✔
1020
            assert isinstance(template, Template)
1✔
1021

1022

1023
class TemplatesDict(DictCollection):
1✔
1024
    def verify(self):
1✔
1025
        for template in self.items.values():
1✔
1026
            assert isinstance(template, Template)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc