• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12955904068

24 Jan 2025 07:03PM UTC coverage: 79.154%. Remained the same
12955904068

Pull #1545

github

web-flow
Merge 5a0994b83 into 7eb177a47
Pull Request #1545: Renamed criterias in LLM-as-a-Judge metrics to criteria.

1425 of 1796 branches covered (79.34%)

Branch coverage included in aggregate %.

9070 of 11463 relevant lines covered (79.12%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.44
src/unitxt/splitters.py
1
import itertools
1✔
2
from abc import abstractmethod
1✔
3
from difflib import get_close_matches
1✔
4
from typing import Any, Dict, List, Optional
1✔
5

6
from .artifact import Artifact
1✔
7
from .dict_utils import dict_get
1✔
8
from .operator import InstanceOperator, MultiStreamOperator
1✔
9
from .random_utils import new_random_generator
1✔
10
from .split_utils import (
1✔
11
    parse_random_mix_string,
12
    parse_slices_string,
13
    random_mix_streams,
14
    rename_split,
15
    slice_streams,
16
)
17
from .stream import MultiStream
1✔
18
from .type_utils import isoftype
1✔
19
from .utils import recursive_copy
1✔
20

21

22
class Splitter(MultiStreamOperator):
1✔
23
    pass
1✔
24

25

26
class RenameSplits(Splitter):
1✔
27
    mapper: Dict[str, str]
1✔
28

29
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
30
        generators = rename_split(multi_stream, self.mapper)
1✔
31
        return MultiStream(generators)
1✔
32

33

34
class SplitRandomMix(Splitter):
1✔
35
    """Splits a multistream into new streams (splits), whose names, source input stream, and amount of instances, are specified by arg 'mix'.
36

37
    The keys of arg 'mix', are the names of the new streams, the values are of the form: 'name-of-source-stream[percentage-of-source-stream]'
38
    Each input instance, of any input stream, is selected exactly once for inclusion in any of the output streams.
39

40
    Examples:
41
    When processing a multistream made of two streams whose names are 'train' and 'test', by
42
    SplitRandomMix(mix =  { "train": "train[99%]",  "validation": "train[1%]",  "test": "test" })
43
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
44
    Output stream 'train' is made of randomly selected 99% of the instances of input stream 'train',
45
    output stream 'validation' is made of the remaining 1% instances of input 'train', and output stream 'test' is made
46
    of the whole of input stream 'test'.
47

48
    When processing the above input multistream by
49
    SplitRandomMix(mix =  { "train": "train[50%]+test[0.1]",  "validation": "train[50%]+test[0.2]",  "test": "test[0.7]" })
50
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
51
    Output stream 'train' is made of randomly selected 50% of the instances of input stream 'train' + randomly selected
52
    0.1 (i.e., 10%) of the instances of input stream 'test'.
53
    Output stream 'validation' is made of the remaining 50% instances of input 'train'+ randomly selected 0.2 (i.e.,
54
    20%) of the original instances of input 'test', that were not selected for output 'train',
55
    and output stream 'test' is made of the remaining instances of input 'test'.
56
    """
57

58
    mix: Dict[str, str]
1✔
59

60
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
61
        mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
1✔
62
        generators = random_mix_streams(multi_stream, mapping)
1✔
63
        return MultiStream.from_generators(generators)
1✔
64

65

66
class SeparateSplit(Splitter):
1✔
67
    """Separates a split (e.g. train) into several splits (e.g. train1, train2).
68

69
    sizes must indicate the size of every split except the last. If no size is give for the last split,
70
     it includes all the examples not allocated to any split.
71
    """
72

73
    from_split: str
1✔
74
    to_split_names: List[str]
1✔
75
    to_split_sizes: List[int]
1✔
76
    remove_targets_from_source_split: bool = True
1✔
77

78
    def verify(self):
1✔
79
        assert (
×
80
            len(self.to_split_names) == len(self.to_split_sizes)
81
            or len(self.to_split_names) == len(self.to_split_sizes) + 1
82
        ), (
83
            f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
84
        )
85
        return super().verify()
×
86

87
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
88
        mapping = {
×
89
            key: {key: [(None, None)]}
90
            for key in multi_stream.keys()
91
            if not self.remove_targets_from_source_split or key != self.from_split
92
        }
93
        so_far = 0
×
94
        for name, size in itertools.zip_longest(
×
95
            self.to_split_names, self.to_split_sizes
96
        ):
97
            if self.remove_targets_from_source_split or name != self.from_split:
×
98
                mapping[name] = {self.from_split: [(so_far, size)]}
×
99
            if size:
×
100
                so_far += size
×
101
        generators = slice_streams(multi_stream, mapping)
×
102
        return MultiStream.from_generators(generators)
×
103

104

105
class SliceSplit(Splitter):
1✔
106
    slices: Dict[str, str]
1✔
107

108
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
109
        mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
×
110
        generators = slice_streams(multi_stream, mapping)
×
111
        return MultiStream.from_generators(generators)
×
112

113

114
def get_random_generator_based_on_instance(instance):
1✔
115
    return new_random_generator(sub_seed={**instance["input_fields"]})
1✔
116

117

118
class Sampler(Artifact):
1✔
119
    @abstractmethod
1✔
120
    def sample(
1✔
121
        self,
122
        sample_size: int,
123
        instances_pool: List[Dict[str, Any]],
124
        instance: Dict[str, Any],
125
    ) -> List[Dict[str, Any]]:
126
        pass
×
127

128
    def filter_source_by_instance(
1✔
129
        self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any]
130
    ) -> List[Dict[str, Any]]:
131
        if "input_fields" not in instance:
1✔
132
            raise ValueError(f"'input_fields' field is missing from '{instance}'.")
1✔
133
        try:
1✔
134
            return [
1✔
135
                item
136
                for item in instances_pool
137
                if item["input_fields"] != instance["input_fields"]
138
            ]
139
        except Exception as e:
×
140
            raise e
×
141

142

143
class RandomSampler(Sampler):
1✔
144
    """Selects a random sample of instances."""
145

146
    def sample(
1✔
147
        self,
148
        sample_size,
149
        instances_pool: List[Dict[str, object]],
150
        instance: Optional[Dict[str, object]],
151
    ) -> List[Dict[str, object]]:
152
        instances_pool = list(instances_pool)
1✔
153
        random_generator = get_random_generator_based_on_instance(instance)
1✔
154
        return random_generator.sample(instances_pool, sample_size)
1✔
155

156

157
class FixedIndicesSampler(Sampler):
1✔
158
    """Selects a fix set of samples based on a list of indices."""
159

160
    indices: List[int]
1✔
161

162
    def verify(self):
1✔
163
        assert isoftype(self.indices, List[int]), (
1✔
164
            f"'indices' of {self.__class__.__name__} must be List[int]. Value {self.indices} is of type {type(self.indices)}"
165
        )
166
        super().verify()
1✔
167

168
    def sample(
1✔
169
        self,
170
        sample_size,
171
        instances_pool: List[Dict[str, object]],
172
        instance: Optional[Dict[str, object]],
173
    ) -> List[Dict[str, object]]:
174
        num_instances = len(instances_pool)
1✔
175

176
        instances = []
1✔
177
        for index in self.indices[0:sample_size]:
1✔
178
            if index >= num_instances:
1✔
179
                raise ValueError(
1✔
180
                    f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
181
                )
182
            instances.append(instances_pool[index])
1✔
183
        return instances
1✔
184

185

186
class CloseTextSampler(Sampler):
1✔
187
    """Selects the samples of instances which are the closest textual match to the given instance.
188

189
    Comparison is done based on a given field in the instance.
190

191
    """
192

193
    field: str
1✔
194

195
    def sample(
1✔
196
        self,
197
        sample_size: int,
198
        instances_pool: List[Dict[str, object]],
199
        instance: Dict[str, object],
200
    ) -> List[Dict[str, object]]:
201
        field = f"input_fields/{self.field}"
1✔
202
        value = dict_get(instance, field)
1✔
203

204
        instances_pool = list(instances_pool)
1✔
205

206
        # Get 'sample_size'  closest matchest texts based on field
207
        options = []
1✔
208
        for instance_in_pool in instances_pool:
1✔
209
            options.append(dict_get(instance_in_pool, field))
1✔
210
        closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0)
1✔
211
        # Randmly select 'sample_size' instances that are from the closest matches text
212
        # (There may be multiple instance with same text in the given field, and the order returned is
213
        # is also randomized )
214
        instances_pool = [
1✔
215
            instance_in_pool
216
            for instance_in_pool in instances_pool
217
            if dict_get(instance_in_pool, field) in closest_matches
218
        ]
219
        random_generator = get_random_generator_based_on_instance(instance)
1✔
220
        return random_generator.sample(instances_pool, sample_size)
1✔
221

222

223
class DiverseLabelsSampler(Sampler):
1✔
224
    """Selects a balanced sample of instances based on an output field.
225

226
    (used for selecting demonstrations in-context learning)
227

228
    The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
229
    The balancing is done such that each value or combination of values
230
    appears as equals as possible in the samples.
231

232
    The `choices` param is required and determines which values should be considered.
233

234
    Example:
235
        If choices is ['dog','cat'] , then the following combinations will be considered.
236
        ['']
237
        ['cat']
238
        ['dog']
239
        ['dog','cat']
240

241
        If the instance contains a value not in the 'choice' param, it is ignored. For example,
242
        if choices is ['dog','cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
243
        then the instance is considered as ['dog','cat'].
244

245
    Args:
246
        sample_size (int):
247
            number of samples to extract
248
        choices (str):
249
            name of input field that contains the list of values to balance on
250
        labels (str):
251
            name of output field with labels that must be balanced
252

253
    """
254

255
    choices: str = "choices"
1✔
256
    labels: str = "labels"
1✔
257
    include_empty_label: bool = True
1✔
258

259
    def prepare(self):
1✔
260
        super().prepare()
1✔
261
        self.labels_cache = None
1✔
262

263
    def exemplar_repr(self, exemplar):
1✔
264
        if "input_fields" not in exemplar:
1✔
265
            raise ValueError(f"'input_fields' field is missing from '{exemplar}'.")
1✔
266
        inputs = exemplar["input_fields"]
1✔
267
        if self.choices not in inputs:
1✔
268
            raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
×
269
        choices = inputs[self.choices]
1✔
270
        if not isinstance(choices, list):
1✔
271
            if isinstance(choices, str):
1✔
272
                choices = [choices]
×
273
            else:
274
                raise ValueError(
1✔
275
                    f"Unexpected input choices value '{choices}'. Expected a list or a string."
276
                )
277

278
        if "reference_fields" not in exemplar:
1✔
279
            raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.")
1✔
280
        outputs = exemplar["reference_fields"]
1✔
281
        if self.labels not in outputs:
1✔
282
            raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")
×
283

284
        exemplar_outputs = exemplar["reference_fields"][self.labels]
1✔
285
        if not isinstance(exemplar_outputs, list):
1✔
286
            raise ValueError(
×
287
                f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list."
288
            )
289

290
        return str([choice for choice in choices if choice in exemplar_outputs])
1✔
291

292
    def divide_by_repr(self, exemplars_pool):
1✔
293
        labels = {}
1✔
294
        for exemplar in exemplars_pool:
1✔
295
            label_repr = self.exemplar_repr(exemplar)
1✔
296
            if label_repr == "[]" and not self.include_empty_label:
1✔
297
                continue
1✔
298
            if label_repr not in labels:
1✔
299
                labels[label_repr] = []
1✔
300
            labels[label_repr].append(exemplar)
1✔
301
        return labels
1✔
302

303
    def sample(
1✔
304
        self,
305
        sample_size: int,
306
        instances_pool: List[Dict[str, object]],
307
        instance: Optional[Dict[str, object]],
308
    ) -> List[Dict[str, object]]:
309
        if self.labels_cache is None:
1✔
310
            self.labels_cache = self.divide_by_repr(instances_pool)
1✔
311
        all_labels = list(self.labels_cache.keys())
1✔
312
        random_generator = get_random_generator_based_on_instance(instance)
1✔
313
        random_generator.shuffle(all_labels)
1✔
314
        from collections import Counter
1✔
315

316
        if sample_size > len(instances_pool):
1✔
317
            raise ValueError(
×
318
                f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}"
319
            )
320
        total_allocated = 0
1✔
321
        allocations = Counter()
1✔
322

323
        while total_allocated < sample_size:
1✔
324
            for label in all_labels:
1✔
325
                if total_allocated < sample_size:
1✔
326
                    if len(self.labels_cache[label]) - allocations[label] > 0:
1✔
327
                        allocations[label] += 1
1✔
328
                        total_allocated += 1
1✔
329
                else:
330
                    break
1✔
331

332
        result = []
1✔
333
        for label, allocation in allocations.items():
1✔
334
            sample = random_generator.sample(self.labels_cache[label], allocation)
1✔
335
            result.extend(sample)
1✔
336

337
        random_generator.shuffle(result)
1✔
338
        return result
1✔
339

340

341
class AssignDemosToInstance(InstanceOperator):
1✔
342
    from_field: str
1✔
343
    to_field: str
1✔
344
    sampler: Sampler
1✔
345
    skip_demoed_instances: bool = False
1✔
346

347
    def prepare(self):
1✔
348
        self.local_cache = None
1✔
349
        self.sampler.prepare()
1✔
350

351
    @abstractmethod
1✔
352
    def get_sample_size(self, instance) -> int:
1✔
353
        pass
×
354

355
    def process(
1✔
356
        self, instance: Dict[str, Any], multi_stream: MultiStream
357
    ) -> Dict[str, Any]:
358
        if self.skip_demoed_instances and self.to_field in instance:
1✔
359
            if self.from_field in instance:
1✔
360
                instance.pop(self.from_field)
1✔
361
            return instance
1✔
362

363
        demos_pool = instance[self.from_field]
1✔
364
        sample_size = self.get_sample_size(instance)
1✔
365
        source_stream = self.sampler.filter_source_by_instance(demos_pool, instance)
1✔
366
        if len(source_stream) < sample_size:
1✔
367
            raise ValueError(
×
368
                f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering."
369
            )
370
        sampled_instances = self.sampler.sample(
1✔
371
            sample_size=sample_size, instances_pool=source_stream, instance=instance
372
        )
373
        instance[self.to_field] = recursive_copy(sampled_instances)
1✔
374
        instance.pop(self.from_field)  # pop the field pointing to the demos_pool
1✔
375
        return instance
1✔
376

377

378
class ConstantSizeSample(AssignDemosToInstance):
1✔
379
    sample_size: int
1✔
380

381
    def get_sample_size(self, instance) -> int:
1✔
382
        return self.sample_size
1✔
383

384

385
class RandomSizeSample(AssignDemosToInstance):
1✔
386
    sample_sizes: List[int]
1✔
387

388
    def get_sample_size(self, instance) -> int:
1✔
389
        random_generator = get_random_generator_based_on_instance(instance)
1✔
390
        return random_generator.choice(self.sample_sizes)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc