• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13101960820

02 Feb 2025 07:54PM UTC coverage: 79.304%. Remained the same
13101960820

Pull #1567

github

web-flow
Merge cf7ff76ba into 7152be4f0
Pull Request #1567: fix the printout of empty strings in the yaml cards of the catalog

1451 of 1823 branches covered (79.59%)

Branch coverage included in aggregate %.

9163 of 11561 relevant lines covered (79.26%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.27
src/unitxt/templates.py
1
import json
1✔
2
from abc import abstractmethod
1✔
3
from random import random
1✔
4
from typing import Any, Dict, List, Optional, Tuple, Union
1✔
5

6
from .artifact import Artifact
1✔
7
from .collections import DictCollection, ListCollection
1✔
8
from .dataclass import NonPositionalField
1✔
9
from .dict_utils import dict_set
1✔
10
from .error_utils import Documentation, UnitxtError
1✔
11
from .operator import InstanceOperator, Operator
1✔
12
from .random_utils import new_random_generator
1✔
13
from .serializers import (
1✔
14
    DialogSerializer,
15
    ImageSerializer,
16
    ListSerializer,
17
    MultiTypeSerializer,
18
    NumberQuantizingSerializer,
19
    Serializer,
20
    SQLDatabaseAsSchemaSerializer,
21
    TableSerializer,
22
    VideoSerializer,
23
)
24
from .settings_utils import get_constants
1✔
25
from .type_utils import isoftype, to_type_string
1✔
26

27
constants = get_constants()
1✔
28

29

30
class TemplateFormatKeyError(UnitxtError):
1✔
31
    def __init__(self, template, data, data_type, format_str, format_name):
1✔
32
        keys = ", ".join(data.keys())
1✔
33
        super().__init__(
1✔
34
            f"Available {data_type}s are [{keys}] "
35
            f"but {template.__class__.__name__}.{format_name} format requires a different ones: '{format_str}'",
36
            Documentation.ADDING_TEMPLATE,
37
        )
38

39

40
class Template(InstanceOperator):
1✔
41
    """The role of template is to take the fields of every instance and verbalize it.
42

43
    Meaning the template is taking the instance and generating source, target and references.
44

45
    Args:
46
        skip_rendered_instance (bool): if "source", "target", and "references" are already defined fields in the instance, skip its processing
47
        postprocessors: a list of strings being artifact names of text processors, to be applied on the model output
48
        instruction: a formatting string that yields an instruction with potential participation of values from the "input_fields" part of the instance
49
        target_prefix: a string to be used to format the prompt. Not a formatting string.
50

51
    """
52

53
    skip_rendered_instance: bool = NonPositionalField(default=True)
1✔
54
    postprocessors: List[str] = NonPositionalField(
1✔
55
        default_factory=lambda: ["processors.to_string_stripped"]
56
    )
57
    instruction: str = NonPositionalField(default="")
1✔
58
    target_prefix: str = NonPositionalField(default="")
1✔
59
    title_fields: List[str] = NonPositionalField(default_factory=list)
1✔
60
    serializer: Serializer = NonPositionalField(
1✔
61
        default_factory=lambda: MultiTypeSerializer(
62
            serializers=[
63
                ImageSerializer(),
64
                VideoSerializer(),
65
                TableSerializer(),
66
                DialogSerializer(),
67
                ListSerializer(),
68
                SQLDatabaseAsSchemaSerializer(),
69
            ]
70
        )
71
    )
72

73
    def verify(self):
1✔
74
        super().verify()
1✔
75
        assert isoftype(
1✔
76
            self.postprocessors, List[Union[Operator, str]]
77
        ), f"The template post processors field '{self.postprocessors}' is not a list of processors. Instead it is of type '{to_type_string(type(self.postprocessors))}'."
78

79
    def input_fields_to_instruction_and_target_prefix(self, input_fields):
1✔
80
        instruction = self.apply_formatting(
1✔
81
            input_fields, "input field", self.instruction, "instruction"
82
        )
83
        target_prefix = self.apply_formatting(
1✔
84
            input_fields,
85
            "input field",
86
            self.target_prefix,
87
            "target_prefix",
88
        )
89
        return instruction, target_prefix
1✔
90

91
    def preprocess_input_and_reference_fields(
1✔
92
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
93
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
94
        return input_fields, reference_fields
1✔
95

96
    def preprocess_input_fields(self, input_fields: Dict[str, Any]):
1✔
97
        return input_fields
1✔
98

99
    def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
1✔
100
        return reference_fields
1✔
101

102
    def process(
1✔
103
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
104
    ) -> Dict[str, Any]:
105
        if self.skip_rendered_instance:
1✔
106
            if (
1✔
107
                "source" in instance
108
                and "target" in instance
109
                and "references" in instance
110
            ):
111
                return instance
1✔
112

113
        input_fields = instance.get("input_fields")
1✔
114
        reference_fields = instance.get("reference_fields")
1✔
115

116
        if stream_name != constants.inference_stream:
1✔
117
            input_fields, reference_fields = self.preprocess_input_and_reference_fields(
1✔
118
                input_fields, reference_fields
119
            )
120

121
        input_fields = self.preprocess_input_fields(input_fields)
1✔
122

123
        self.set_titles(input_fields)
1✔
124

125
        serialized_inputs = self.serialize(input_fields, instance)
1✔
126

127
        source = self.input_fields_to_source(serialized_inputs)
1✔
128
        instruction, target_prefix = self.input_fields_to_instruction_and_target_prefix(
1✔
129
            serialized_inputs
130
        )
131

132
        result = {
1✔
133
            **instance,
134
            "source": source,
135
            "instruction": instruction,
136
            "target_prefix": target_prefix,
137
            "postprocessors": self.postprocessors,
138
        }
139

140
        if stream_name == constants.inference_stream:
1✔
141
            return self.post_process_instance(result)
1✔
142

143
        if reference_fields is None:
1✔
144
            raise ValueError("Should have reference_fields")
×
145

146
        reference_fields = self.preprocess_reference_fields(reference_fields)
1✔
147

148
        serialized_references = self.serialize(
1✔
149
            reference_fields, instance
150
        )  # Dict[str, str]
151

152
        target, references = self.reference_fields_to_target_and_references(
1✔
153
            serialized_references
154
        )
155

156
        result["target"] = target
1✔
157
        result["references"] = references
1✔
158

159
        return self.post_process_instance(result)
1✔
160

161
    def post_process_instance(self, instance):
1✔
162
        return instance
1✔
163

164
    def serialize(
1✔
165
        self, data: Dict[str, Any], instance: Dict[str, Any]
166
    ) -> Dict[str, str]:
167
        return {k: self.serializer.serialize(v, instance) for k, v in data.items()}
1✔
168

169
    @abstractmethod
1✔
170
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
171
        pass
1✔
172

173
    def set_titles(self, data):
1✔
174
        for field in self.title_fields:
1✔
175
            data[field] = data[field].title()
1✔
176

177
    @abstractmethod
1✔
178
    def reference_fields_to_target_and_references(
1✔
179
        self, reference_fields: Dict[str, object]
180
    ) -> Tuple[str, List[str]]:
181
        pass
1✔
182

183
    def apply_formatting(
1✔
184
        self, data: Dict[str, Any], data_type: str, format_str: str, format_name: str
185
    ) -> str:
186
        try:
1✔
187
            if format_str is None:
1✔
188
                raise UnitxtError(
1✔
189
                    f"Required field '{format_name}' of class {self.__class__.__name__} not set in {self.__class__.__name__}",
190
                    Documentation.ADDING_TEMPLATE,
191
                )
192
            return format_str.format(**data)
1✔
193
        except KeyError as e:
1✔
194
            raise TemplateFormatKeyError(
1✔
195
                self, data, data_type, format_str, format_name
196
            ) from e
197

198

199
class ApplyTemplate(InstanceOperator):
1✔
200
    demos_field: Optional[str] = None
1✔
201

202
    @abstractmethod
1✔
203
    def get_template(self, instance: Dict[str, Any]) -> Template:
1✔
204
        pass
×
205

206
    def apply(
1✔
207
        self,
208
        template: Template,
209
        instance: Dict[str, Any],
210
        stream_name: Optional[str] = None,
211
    ):
212
        return template.process_instance(instance, stream_name)
1✔
213

214
    def process(
1✔
215
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
216
    ) -> Dict[str, Any]:
217
        template = self.get_template(instance)
1✔
218

219
        if self.demos_field is not None:
1✔
220
            if self.demos_field not in instance:
1✔
221
                raise ValueError("Demos field is missing.")
×
222
            instance[self.demos_field] = [
1✔
223
                self.apply(template, demo_instance)
224
                for demo_instance in instance[self.demos_field]
225
            ]
226
        dict_set(instance, "recipe_metadata/template", template)
1✔
227
        return self.apply(template, instance, stream_name)
1✔
228

229

230
class ApplySingleTemplate(ApplyTemplate):
1✔
231
    template: Template
1✔
232

233
    def get_template(self, instance: Dict[str, Any]) -> Template:
1✔
234
        return self.template
1✔
235

236

237
class ApplyRandomTemplate(ApplyTemplate):
1✔
238
    templates: List[Template]
1✔
239

240
    def get_template(self, instance: Dict[str, Any]) -> Template:
1✔
241
        random_generator = new_random_generator(
1✔
242
            {**instance["input_fields"], **instance["reference_fields"]}
243
        )
244
        return random_generator.choice(self.templates)
1✔
245

246

247
class InputFormatTemplate(Template):
1✔
248
    input_format: str
1✔
249

250
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
251
        return self.apply_formatting(
1✔
252
            input_fields,
253
            "input field",
254
            self.input_format,
255
            "input_format",
256
        )
257

258

259
class OutputFormatTemplate(Template):
1✔
260
    output_format: str = None
1✔
261

262
    def reference_fields_to_target_and_references(
1✔
263
        self, reference_fields: Dict[str, object]
264
    ) -> str:
265
        target = self.apply_formatting(
1✔
266
            reference_fields,
267
            "reference field",
268
            self.output_format,
269
            "output_format",
270
        )
271
        references = [target]
1✔
272
        return target, references
1✔
273

274

275
class JsonOutputFormatTemplate(Template):
1✔
276
    output_fields: Dict[str, str]
1✔
277
    wrap_with_list_fields: List[str]
1✔
278

279
    def reference_fields_to_target_and_references(
1✔
280
        self, reference_fields: Dict[str, object]
281
    ) -> str:
282
        data = {}
×
283
        for field, target_field in self.output_fields.items():
×
284
            value = reference_fields[field]
×
285
            if field in self.wrap_with_list_fields:
×
286
                value = [value]
×
287
            data[target_field] = value
×
288
        target = json.dumps(data, ensure_ascii=False)
×
289
        references = [target]
×
290
        return target, references
×
291

292

293
class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate):
1✔
294
    """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
295

296
    Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
297
    """
298

299
    pass
1✔
300

301

302
class JsonOutputTemplate(InputFormatTemplate, JsonOutputFormatTemplate):
1✔
303
    """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
304

305
    Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
306
    """
307

308
    pass
1✔
309

310

311
class InputOutputTemplateWithCustomTarget(InputOutputTemplate):
1✔
312
    reference: str
1✔
313

314
    def reference_fields_to_target_and_references(
1✔
315
        self, reference_fields: Dict[str, object]
316
    ) -> str:
317
        target = self.apply_formatting(
1✔
318
            reference_fields,
319
            "reference field",
320
            self.output_format,
321
            "output_format",
322
        )
323
        reference = self.apply_formatting(
1✔
324
            reference_fields,
325
            "reference field",
326
            self.reference,
327
            "reference",
328
        )
329
        return target, [reference]
1✔
330

331

332
class PairwiseChoiceTemplate(InputOutputTemplate):
1✔
333
    """PairwiseChoiceTemplate.
334

335
    Requirements:
336
     The answer field value should be of type Literal["choice_a", "choice_b", "tie"]
337

338
    Args:
339
         choice_a_field (str):
340
            The field which contains choice_a value
341
         choice_b_field (str):
342
            The field which contains choice_b value
343
         answer_field (str):
344
            The field which contains the answer value.
345
            Should be of type Literal["choice_1", "choice_2", "tie"]
346
         choice_a_label (str):
347
            The label of choice A answer as it is verbalized in the template.
348
         choice_b_label (str):
349
            The label of choice B answer as it is verbalized in the template.
350
         choice_tie_label (str):
351
            The label of a tie answer as it should be verbalized in the template.
352
         shuffle (bool):
353
            whether to shuffle the choices or not. This is done to take into account position bias.
354

355
    shuffle: 50% of the time:
356
     1. The values of choice_a_field and choice_b_field will be swapped.
357
     2. If the values of answer_field is choice_a_label, set it to choice_b_label.
358
        Else if the values of answer_field is choice_b_label, set it to choice_a_label.
359
        Else if the value of answer_field is choice_tie_label, do nothing.
360

361
    """
362

363
    choice_a_field: str
1✔
364
    choice_b_field: str
1✔
365
    answer_field: str
1✔
366
    choice_a_label: str
1✔
367
    choice_b_label: str
1✔
368
    choice_tie_label: str
1✔
369
    shuffle: bool
1✔
370

371
    def verify(self):
1✔
372
        super().verify()
×
373

374
    def verbalize_answer_field(self, reference_fields: Dict[str, object]):
1✔
375
        answer = reference_fields[self.answer_field]
×
376
        assert answer in ["choice_a", "choice_b", "tie"]
×
377
        if answer == "choice_a":
×
378
            reference_fields[self.answer_field] = self.choice_a_label
×
379
        elif answer == "choice_b":
×
380
            reference_fields[self.answer_field] = self.choice_b_label
×
381
        else:
382
            reference_fields[self.answer_field] = self.choice_tie_label
×
383

384
        return reference_fields
×
385

386
    def shuffle_values(
1✔
387
        self, input_fields: Dict[str, object], reference_fields: Dict[str, object]
388
    ):
389
        if not self.shuffle:
×
390
            return input_fields, reference_fields
×
391
        outcome = random()  # A float between 0 and 1
×
392
        if outcome <= 0.5:
×
393
            choice_a_value = input_fields[self.choice_a_field]
×
394
            choice_b_value = input_fields[self.choice_b_field]
×
395

396
            input_fields[self.choice_a_field] = choice_b_value
×
397
            input_fields[self.choice_b_field] = choice_a_value
×
398

399
            answer = reference_fields[self.answer_field]
×
400
            assert answer in [
×
401
                self.choice_a_label,
402
                self.choice_b_label,
403
                self.choice_tie_label,
404
            ]
405
            if answer == self.choice_a_label:
×
406
                reference_fields[self.answer_field] = self.choice_b_label
×
407
            elif answer == self.choice_b_label:
×
408
                reference_fields[self.answer_field] = self.choice_a_label
×
409

410
        return input_fields, reference_fields
×
411

412
    def preprocess_input_and_reference_fields(
1✔
413
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
414
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
415
        reference_fields = self.verbalize_answer_field(reference_fields)
×
416
        input_fields, reference_fields = self.shuffle_values(
×
417
            input_fields, reference_fields
418
        )
419
        return input_fields, reference_fields
×
420

421

422
class DialogFieldsData(Artifact):
1✔
423
    user_role_label: str
1✔
424
    assistant_role_label: str
1✔
425
    system_role_label: str
1✔
426
    dialog_field: str
1✔
427

428

429
class DialogTemplate(InputOutputTemplate):
1✔
430
    dialog_fields: List[DialogFieldsData]
1✔
431
    turns_separator: str = "\n\n"
1✔
432
    label_separator: str = " "
1✔
433

434
    def process_dialog(self, input_fields: Dict[str, object]):
1✔
435
        for dialog_fields in self.dialog_fields:
×
436
            dialog = input_fields[dialog_fields.dialog_field]
×
437
            # TODO: update isoftype method to support Literal verification and check
438
            #  it's List[Tuple[Literal["user", "assistant", "system"], str]] (Issue #799)
439
            assert isoftype(dialog, List[Tuple[str, str]])
×
440

441
            user_role_label = dialog_fields.user_role_label
×
442
            assistant_role_label = dialog_fields.assistant_role_label
×
443
            system_role_label = dialog_fields.system_role_label
×
444

445
            dialog_str = ""
×
446
            for i, turn in enumerate(dialog):
×
447
                (turn_type, turn_text) = turn
×
448
                turns_separator = "" if i == 0 else self.turns_separator
×
449
                if turn_type == "user":
×
450
                    dialog_str += f"{turns_separator}{user_role_label}{self.label_separator}{turn_text}"
×
451
                elif turn_type == "assistant":
×
452
                    dialog_str += f"{turns_separator}{assistant_role_label}{self.label_separator}{turn_text}"
×
453
                elif turn_type == "system":
×
454
                    dialog_str += f"{turns_separator}{system_role_label}{self.label_separator}{turn_text}"
×
455

456
            input_fields[dialog_fields.dialog_field] = dialog_str
×
457
        return input_fields
×
458

459
    def preprocess_input_fields(self, input_fields: Dict[str, Any]):
1✔
460
        return self.process_dialog(input_fields)
×
461

462

463
class DialogPairwiseChoiceTemplate(DialogTemplate, PairwiseChoiceTemplate):
1✔
464
    pass
1✔
465

466

467
class PairwiseComparativeRatingTemplate(InputOutputTemplate):
1✔
468
    """PairwiseChoiceTemplate.
469

470
    Args:
471
         choice_a_field (str): The field which contains choice_a value
472

473
         choice_b_field (str): The field which contains choice_b value
474

475
         answer_field (str): The field which contains the answer value. The value should be an int.
476
         Positive for preferring choice_a, and negative for preferring choice_b
477

478
         shuffle (bool): whether to shuffle the choices or not. This is done to take into account position bias.
479

480
    shuffle: 50% of the time:
481
    | 1) The values of choice_a_field and choice_b_field will be swapped.
482
    | 2) Replace the values of answer_field with its mapped value according to the reverse_preference_map Dict.
483

484
    """
485

486
    choice_a_field: str
1✔
487
    choice_b_field: str
1✔
488
    choice_a_id_field: str
1✔
489
    choice_b_id_field: str
1✔
490
    answer_field: str
1✔
491
    shuffle: bool
1✔
492

493
    def shuffle_values(
1✔
494
        self, input_fields: Dict[str, object], reference_fields: Dict[str, object]
495
    ):
496
        if not self.shuffle:
×
497
            return input_fields, reference_fields
×
498
        outcome = random()  # A float between 0 and 1
×
499
        if outcome <= 0.5:
×
500
            choice_a_value = input_fields[self.choice_a_field]
×
501
            choice_b_value = input_fields[self.choice_b_field]
×
502
            input_fields[self.choice_a_field] = choice_b_value
×
503
            input_fields[self.choice_b_field] = choice_a_value
×
504

505
            choice_a_id_value = input_fields[self.choice_a_id_field]
×
506
            choice_b_id_value = input_fields[self.choice_b_id_field]
×
507
            input_fields[self.choice_a_id_field] = choice_b_id_value
×
508
            input_fields[self.choice_b_id_field] = choice_a_id_value
×
509

510
            assert isinstance(reference_fields[self.answer_field], int)
×
511
            reference_fields[self.answer_field] = (
×
512
                int(reference_fields[self.answer_field]) * -1
513
            )
514

515
        return input_fields, reference_fields
×
516

517
    def preprocess_input_and_reference_fields(
1✔
518
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
519
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
520
        input_fields, reference_fields = self.shuffle_values(
×
521
            input_fields, reference_fields
522
        )
523
        return input_fields, reference_fields
×
524

525

526
class MultipleChoiceTemplate(InputFormatTemplate):
1✔
527
    """Formats the input that specifies a multiple-choice question, with a list of possible answers to choose from, and identifies the correct answer.
528

529
    Args:
530
        target_prefix (str): Optional prefix that can be added before the target label in
531
            generated prompts or outputs.
532
        choices_field (str): The key under which the multiple choices are stored in the
533
            input and reference dictionaries.
534
        target_field (str): The key under which the correct choice is stored in the
535
            reference dictionary (can be integer index or textual label).
536
        choices_separator (str): A string used to join formatted
537
            choices (e.g. ", ").
538
        source_choice_format (str): A Python format string used for displaying each choice
539
            in the input fields (e.g. "{choice_numeral}. {choice_text}").
540
        target_choice_format (str): A Python format string used for displaying each choice
541
            in the target or final output (e.g. "{choice_numeral}").
542
        enumerator (str): Determines how choice numerals are enumerated. Possible values
543
            include "capitals", "lowercase", "numbers", or "roman".
544
        shuffle_choices (bool): If True, shuffle the choices. The shuffling seed can be
545
            set with `shuffle_choices_seed`.
546
        shuffle_choices_seed (int, optional): If provided, the choices are shuffled with
547
            this fixed integer seed for reproducibility.
548
        sort_choices_by_length (bool): If True, sorts choices
549
            by their length (ascending).
550
        sort_choices_alphabetically (bool): If True, sorts choices
551
            in alphabetical order.
552
        reverse_choices (bool): If True, reverses the order of the choices after any
553
            sorting has been applied. Defaults to False to preserve backward compatibility.
554
    """
555

556
    target_prefix: str = ""
1✔
557
    choices_field: str = "choices"
1✔
558
    target_field: str = "label"
1✔
559
    choices_separator: str = ", "
1✔
560
    source_choice_format: str = "{choice_numeral}. {choice_text}"
1✔
561
    target_choice_format: str = "{choice_numeral}"
1✔
562
    enumerator: str = "capitals"
1✔
563

564
    shuffle_choices: bool = False
1✔
565
    shuffle_choices_seed: int = None
1✔
566
    sort_choices_by_length: bool = False
1✔
567
    sort_choices_alphabetically: bool = False
1✔
568
    reverse_choices: bool = False  # False by default for backward-compat
1✔
569
    place_correct_choice_position: int = None
1✔
570

571
    def prepare(self):
1✔
572
        super().prepare()
1✔
573
        if self.enumerator == "capitals":
1✔
574
            self.enumerator = "ABCDEFGHIJKLMNOP"
1✔
575
        if self.enumerator == "lowercase":
1✔
576
            self.enumerator = "abcdefghijklmnop"
1✔
577
        if self.enumerator == "numbers":
1✔
578
            self.enumerator = [str(i + 1) for i in range(20)]
1✔
579
        if self.enumerator == "roman":
1✔
580
            self.enumerator = [
1✔
581
                "I",
582
                "II",
583
                "III",
584
                "IV",
585
                "V",
586
                "VI",
587
                "VII",
588
                "VIII",
589
                "IX",
590
                "X",
591
                "XI",
592
                "XII",
593
                "XIII",
594
                "XIV",
595
                "XV",
596
                "XVI",
597
                "XVII",
598
                "XVIII",
599
                "XIX",
600
                "XX",
601
            ]
602

603
    def verify(self):
1✔
604
        super().verify()
1✔
605
        if self.shuffle_choices and (
1✔
606
            self.sort_choices_by_length
607
            or self.sort_choices_alphabetically
608
            or self.reverse_choices
609
            or self.place_correct_choice_position is not None
610
        ):
611
            raise UnitxtError(
1✔
612
                "You cannot combine shuffle_choices with sorting or reversing flags."
613
            )
614

615
        if self.sort_choices_by_length and self.sort_choices_alphabetically:
1✔
616
            raise UnitxtError(
1✔
617
                "You cannot combine both sort_choices_by_length and sort_choices_alphabetically simultaneously."
618
            )
619
        if self.place_correct_choice_position is not None and (
1✔
620
            self.sort_choices_by_length
621
            or self.sort_choices_alphabetically
622
            or self.reverse_choices
623
        ):
624
            raise UnitxtError(
1✔
625
                "You cannot combine place_correct_choice_position with sorting or reversing flags."
626
            )
627

628
    def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
1✔
629
        choices = data[self.choices_field]
1✔
630
        enumrated_choices = []
1✔
631
        for i, choice in enumerate(choices):
1✔
632
            enumrated_choices.append(
1✔
633
                choice_format.format(
634
                    choice_text=choice,
635
                    choice_numeral=self.enumerator[i],
636
                )
637
            )
638
        return enumrated_choices
1✔
639

640
    def inputs_to_numerals(self, input_fields: Dict[str, Any]) -> Tuple[str, str]:
1✔
641
        return self.inputs_to_choices(input_fields, "{choice_numeral}")
1✔
642

643
    def prepare_multiple_choice_inputs(
1✔
644
        self, input_fields: Dict[str, Any]
645
    ) -> Dict[str, Any]:
646
        choices = self.inputs_to_choices(input_fields, self.source_choice_format)
1✔
647
        return {
1✔
648
            "numerals": self.inputs_to_numerals(input_fields),
649
            **input_fields,
650
            self.choices_field: self.choices_separator.join(choices),
651
        }
652

653
    def preprocess_input_fields(self, input_fields: Dict[str, Any]) -> Dict[str, Any]:
1✔
654
        return self.prepare_multiple_choice_inputs(input_fields)
1✔
655

656
    def outputs_to_target_index(self, reference_fields: Dict[str, object]) -> int:
1✔
657
        target = reference_fields[self.target_field]
1✔
658

659
        if not isinstance(target, int):
1✔
660
            try:
1✔
661
                return reference_fields[self.choices_field].index(target)
1✔
662
            except ValueError as e:
×
663
                raise UnitxtError(
×
664
                    f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}",
665
                    Documentation.ADDING_TEMPLATE,
666
                ) from e
667
        return target
1✔
668

669
    def preprocess_reference_fields(self, reference_fields: Dict[str, Any]):
1✔
670
        target = reference_fields[self.target_field]
1✔
671

672
        if not isinstance(target, int):
1✔
673
            try:
1✔
674
                target = reference_fields[self.choices_field].index(target)
1✔
675
            except ValueError as e:
×
676
                raise UnitxtError(
×
677
                    f"MultipleChoiceTemplate could not locate textual target '{target}' in choices list: {reference_fields[self.choices_field]}",
678
                    Documentation.ADDING_TEMPLATE,
679
                ) from e
680

681
        choices = self.inputs_to_choices(reference_fields, self.target_choice_format)
1✔
682

683
        try:
1✔
684
            target = choices[target]
1✔
685
        except IndexError as e:
×
686
            raise UnitxtError(
×
687
                f"MultipleChoiceTemplate cannot find index number {target} in choices: {choices}",
688
                Documentation.ADDING_TEMPLATE,
689
            ) from e
690

691
        return {self.target_field: target}
1✔
692

693
    def reference_fields_to_target_and_references(
1✔
694
        self, reference_fields: Dict[str, object]
695
    ) -> str:
696
        target = reference_fields[self.target_field]
1✔
697
        return target, [target]
1✔
698

699
    def preprocess_input_and_reference_fields(
1✔
700
        self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
701
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
702
        if (
1✔
703
            not self.shuffle_choices
704
            and not self.sort_choices_by_length
705
            and not self.sort_choices_alphabetically
706
            and not self.reverse_choices
707
            and self.place_correct_choice_position is None
708
        ):
709
            return input_fields, reference_fields
1✔
710

711
        choices = input_fields[self.choices_field]
1✔
712
        target_index = self.outputs_to_target_index(reference_fields)
1✔
713
        original_label_choice = reference_fields[self.choices_field][target_index]
1✔
714

715
        if self.sort_choices_by_length:
1✔
716
            choices.sort(key=len)
1✔
717
        if self.sort_choices_alphabetically:
1✔
718
            choices.sort()
1✔
719
        if self.reverse_choices:
1✔
720
            choices.reverse()
1✔
721
        if self.shuffle_choices:
1✔
722
            random_generator = new_random_generator(
1✔
723
                self.shuffle_choices_seed
724
                if self.shuffle_choices_seed is not None
725
                else {**input_fields}
726
            )
727
            random_generator.shuffle(choices)
1✔
728
        if self.place_correct_choice_position is not None:
1✔
729
            fix_pos = self.place_correct_choice_position
1✔
730

731
            # Supporting negative indexes similar to Python lists
732
            # If fix_pos is negative, convert it to a valid positive index by adding len(choices).
733
            # For example, -1 becomes the last index, -2 becomes the one before last, etc.
734
            if fix_pos < 0:
1✔
735
                fix_pos += len(choices)
1✔
736
            self.place_correct_choice_position = fix_pos
1✔
737
            # Remove the original label choice from the list
738
            if not 0 <= self.place_correct_choice_position < len(choices):
1✔
739
                raise ValueError(
×
740
                    f"fix_correct_choice_position={self.place_correct_choice_position} out of range (0..{len(choices) - 1})."
741
                )
742
            choices.remove(original_label_choice)
1✔
743
            choices.insert(self.place_correct_choice_position, original_label_choice)
1✔
744

745
        # Update both input_fields and reference_fields once at the end
746
        input_fields[self.choices_field] = choices
1✔
747
        reference_fields[self.choices_field] = choices
1✔
748
        reference_fields[self.target_field] = choices.index(original_label_choice)
1✔
749

750
        return input_fields, reference_fields
1✔
751

752
    def post_process_instance(self, instance):
1✔
753
        instance["input_fields"]["options"] = self.inputs_to_choices(
1✔
754
            instance["input_fields"], self.target_choice_format
755
        )
756
        return instance
1✔
757

758

759
class YesNoTemplate(InputFormatTemplate):
1✔
760
    """A template for generating binary Yes/No questions asking whether an input text is of a specific class.
761

762
    Args:
763
        input_format:
764
            Defines the format of the question.
765
        class_field:
766
            Defines the field that contains the name of the class that this template
767
            asks of.
768
        label_field:
769
            Defines the field which contains the true label of the input text. If a gold label is equal to the
770
            value in class_name, then the correct output is self.yes_answer (by default, "Yes").
771
            Otherwise the correct output is self.no_answer (by default, "No").
772
        yes_answer:
773
            The output value for when the gold label equals self.class_name.
774
            Defaults to "Yes".
775
        no_answer:
776
            The output value for when the gold label differs from self.class_name.
777
            Defaults to "No".
778
    """
779

780
    input_format: str = None
1✔
781
    class_field: str = None
1✔
782
    label_field: str = None
1✔
783
    yes_answer: str = "Yes"
1✔
784
    no_answer: str = "No"
1✔
785

786
    def reference_fields_to_target_and_references(
1✔
787
        self, reference_fields: Dict[str, object]
788
    ) -> str:
789
        try:
1✔
790
            gold_class_names = reference_fields[self.label_field]
1✔
791
        except KeyError as e:
1✔
792
            raise UnitxtError(
1✔
793
                f"Available reference_fields are {list(reference_fields.keys())}, missing required label field: '{self.label_field}'."
794
            ) from e
795
        if not isinstance(gold_class_names, list):
1✔
796
            raise UnitxtError(
1✔
797
                f"Unexpected value for gold_class_names: '{gold_class_names}'. Expecting a list."
798
            )
799
        try:
1✔
800
            queried_class_name = reference_fields[self.class_field]
1✔
801
        except KeyError as e:
1✔
802
            raise UnitxtError(
1✔
803
                f"Available reference_fields are {list(reference_fields.keys())}, missing required class field: '{self.class_field}'."
804
            ) from e
805
        if not queried_class_name or not isinstance(queried_class_name, str):
1✔
806
            raise UnitxtError(
1✔
807
                f"Unexpected value for queried_class_names: '{queried_class_name}'. Expected a string."
808
            )
809
        if queried_class_name in gold_class_names:
1✔
810
            return self.yes_answer, [self.yes_answer]
1✔
811
        return self.no_answer, [self.no_answer]
1✔
812

813

814
class NullTemplate(Template):
1✔
815
    """Templates that returns empty prompt and no references."""
816

817
    postprocessors = []
1✔
818

819
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
820
        return ""
×
821

822
    def reference_fields_to_target_and_references(self, reference_fields):
1✔
823
        return "", []
×
824

825

826
class KeyValTemplate(Template):
1✔
827
    """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance.
828

829
    Args specify with what separators to glue together the input and output designated fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references').
830
    """
831

832
    pairs_separator: str = ", "
1✔
833
    key_val_separator: str = ": "
1✔
834
    use_keys_for_inputs: bool = True
1✔
835
    outputs_key_val_separator: str = ": "
1✔
836
    use_keys_for_outputs: bool = False
1✔
837

838
    def process_dict(
1✔
839
        self, data: Dict[str, object], key_val_sep, pairs_sep, use_keys
840
    ) -> str:
841
        pairs = []
1✔
842
        for key, val in data.items():
1✔
843
            key_val = [key, str(val)] if use_keys else [str(val)]
1✔
844
            pairs.append(key_val_sep.join(key_val))
1✔
845
        return pairs_sep.join(pairs)
1✔
846

847
    def input_fields_to_source(self, input_fields: Dict[str, object]) -> str:
1✔
848
        return self.process_dict(
1✔
849
            input_fields,
850
            key_val_sep=self.key_val_separator,
851
            pairs_sep=self.pairs_separator,
852
            use_keys=self.use_keys_for_inputs,
853
        )
854

855
    def reference_fields_to_target_and_references(
1✔
856
        self, reference_fields: Dict[str, object]
857
    ) -> str:
858
        target = self.process_dict(
1✔
859
            reference_fields,
860
            key_val_sep=self.key_val_separator,
861
            pairs_sep=self.pairs_separator,
862
            use_keys=self.use_keys_for_outputs,
863
        )
864
        return target, [target]
1✔
865

866

867
class OutputQuantizingTemplate(InputOutputTemplate):
1✔
868
    serializer: MultiTypeSerializer = NonPositionalField(
1✔
869
        default_factory=MultiTypeSerializer
870
    )
871
    quantum: Union[float, int] = 0.1
1✔
872

873
    def prepare(self):
1✔
874
        super().prepare()
1✔
875
        self.serializer.add_serializers(
1✔
876
            [NumberQuantizingSerializer(quantum=self.quantum)]
877
        )
878

879

880
class MultiLabelTemplate(InputOutputTemplate):
1✔
881
    labels_field: str = "labels"
1✔
882
    labels_separator: str = ", "
1✔
883
    postprocessors = ["processors.to_list_by_comma"]
1✔
884
    output_format: str = "{labels}"
1✔
885
    empty_label: str = "None"
1✔
886

887
    def preprocess_reference_fields(
1✔
888
        self, reference_fields: Dict[str, Any]
889
    ) -> Dict[str, Any]:
890
        labels = reference_fields[self.labels_field]
1✔
891
        if not isinstance(labels, list):
1✔
892
            raise UnitxtError(
×
893
                f"MultiLabelTemplate requires labels field '{self.labels_field}' to be a list. Got {self.labels_field}<{type(labels).__name__}>: {labels}",
894
                Documentation.ADDING_TEMPLATE,
895
            )
896
        if len(labels) == 0:
1✔
897
            labels = [self.empty_label]
1✔
898
        labels_str = self.labels_separator.join(labels)
1✔
899
        return {self.labels_field: labels_str}
1✔
900

901

902
class MultiReferenceTemplate(InputOutputTemplate):
1✔
903
    references_field: str = "references"
1✔
904
    random_reference: bool = False
1✔
905
    serializer: Serializer = NonPositionalField(default_factory=MultiTypeSerializer)
1✔
906

907
    def serialize(
1✔
908
        self, data: Dict[str, Any], instance: Dict[str, Any]
909
    ) -> Dict[str, str]:
910
        result = {}
1✔
911
        for k, v in data.items():
1✔
912
            if k == self.references_field:
1✔
913
                v = [self.serializer.serialize(item, instance) for item in v]
1✔
914
            else:
915
                v = self.serializer.serialize(v, instance)
1✔
916
            result[k] = v
1✔
917
        return result
1✔
918

919
    def reference_fields_to_target_and_references(
1✔
920
        self, reference_fields: Dict[str, object]
921
    ) -> Tuple[str, List[str]]:
922
        references = reference_fields[self.references_field]
1✔
923
        if not isoftype(references, List[str]):
1✔
924
            raise UnitxtError(
×
925
                f"MultiReferenceTemplate requires references field '{self.references_field}' to be List[str]. Got {self.references_field}<{type(references).__name__}>: {references}",
926
                Documentation.ADDING_TEMPLATE,
927
            )
928
        if len(references) == 0:
1✔
929
            return "", []
×
930

931
        if self.random_reference:
1✔
932
            random_generator = new_random_generator(reference_fields)
1✔
933
            target = random_generator.choice(references)
1✔
934
        else:
935
            target = references[0]
1✔
936

937
        return target, references
1✔
938

939

940
def escape_chars(s, chars_to_escape):
1✔
941
    for char in chars_to_escape:
1✔
942
        s = s.replace(char, f"\\{char}")
1✔
943
    return s
1✔
944

945

946
class SpanLabelingBaseTemplate(MultiLabelTemplate):
1✔
947
    spans_starts_field: str = "spans_starts"
1✔
948
    spans_ends_field: str = "spans_ends"
1✔
949
    text_field: str = "text"
1✔
950
    labels_support: list = None
1✔
951

952
    def extract_span_label_pairs(self, reference_fields):
1✔
953
        spans_starts = reference_fields[self.spans_starts_field]
1✔
954
        spans_ends = reference_fields[self.spans_ends_field]
1✔
955
        text = reference_fields[self.text_field]
1✔
956
        labels = reference_fields[self.labels_field]
1✔
957

958
        spans = []
1✔
959
        for span_start, span_end, label in zip(spans_starts, spans_ends, labels):
1✔
960
            if self.labels_support is None or label in self.labels_support:
1✔
961
                spans.append((span_start, span_end, text[span_start:span_end], label))
1✔
962

963
        for span in sorted(spans):
1✔
964
            if self.labels_support is None or span[3] in self.labels_support:
1✔
965
                yield span[2], span[3]
1✔
966

967
    def preprocess_reference_fields(
1✔
968
        self, reference_fields: Dict[str, Any]
969
    ) -> Dict[str, Any]:
970
        span_labels_pairs = self.extract_span_label_pairs(reference_fields)
1✔
971
        targets = self.span_label_pairs_to_targets(span_labels_pairs)
1✔
972
        return super().preprocess_reference_fields({"labels": targets})
1✔
973

974
    @abstractmethod
1✔
975
    def span_label_pairs_to_targets(self, pairs):
1✔
976
        pass
×
977

978

979
class SpanLabelingTemplate(SpanLabelingBaseTemplate):
1✔
980
    span_label_format: str = "{span}: {label}"
1✔
981
    escape_characters: List[str] = [":", ","]
1✔
982
    postprocessors: List[str] = ["processors.to_span_label_pairs"]
1✔
983

984
    def span_label_pairs_to_targets(self, span_label_pairs):
1✔
985
        targets = []
1✔
986
        for span, label in span_label_pairs:
1✔
987
            if self.escape_characters is not None:
1✔
988
                span = escape_chars(span, self.escape_characters)
1✔
989
            target = self.span_label_format.format(span=span, label=label)
1✔
990
            targets.append(target)
1✔
991
        return targets
1✔
992

993

994
class SpanLabelingJsonTemplate(SpanLabelingBaseTemplate):
1✔
995
    postprocessors = [
1✔
996
        "processors.load_json",
997
        "processors.dict_of_lists_to_value_key_pairs",
998
    ]
999

1000
    def span_label_pairs_to_targets(self, span_label_pairs):
1✔
1001
        groups = {}
1✔
1002
        for span, label in span_label_pairs:
1✔
1003
            if label not in groups:
1✔
1004
                groups[label] = []
1✔
1005
            groups[label].append(span)
1✔
1006
        if len(groups) > 0:
1✔
1007
            targets = [json.dumps(groups, ensure_ascii=False)]
1✔
1008
        else:
1009
            targets = []
1✔
1010
        return targets
1✔
1011

1012

1013
class TemplatesList(ListCollection):
1✔
1014
    def verify(self):
1✔
1015
        for template in self.items:
1✔
1016
            assert isinstance(template, Template)
1✔
1017

1018

1019
class TemplatesDict(DictCollection):
1✔
1020
    def verify(self):
1✔
1021
        for template in self.items.values():
1✔
1022
            assert isinstance(template, Template)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc