• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16144996899

08 Jul 2025 01:43PM UTC coverage: 81.077% (+0.001%) from 81.076%
16144996899

push

github

web-flow
Add demos_sampling_seed to recipe api (#1858)

Signed-off-by: elronbandel <elronbandel@gmail.com>

1541 of 1913 branches covered (80.55%)

Branch coverage included in aggregate %.

10494 of 12931 relevant lines covered (81.15%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.77
src/unitxt/splitters.py
1
import itertools
1✔
2
from abc import abstractmethod
1✔
3
from difflib import get_close_matches
1✔
4
from typing import Any, Dict, List, Optional
1✔
5

6
from .artifact import Artifact
1✔
7
from .dict_utils import dict_get
1✔
8
from .operator import InstanceOperator, MultiStreamOperator
1✔
9
from .random_utils import new_random_generator
1✔
10
from .split_utils import (
1✔
11
    parse_random_mix_string,
12
    parse_slices_string,
13
    random_mix_streams,
14
    rename_split,
15
    slice_streams,
16
)
17
from .stream import MultiStream
1✔
18
from .type_utils import isoftype
1✔
19
from .utils import recursive_copy
1✔
20

21

22
class Splitter(MultiStreamOperator):
1✔
23
    pass
24

25

26
class RenameSplits(Splitter):
1✔
27
    mapper: Dict[str, str]
1✔
28

29
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
30
        generators = rename_split(multi_stream, self.mapper)
1✔
31
        return MultiStream(generators)
1✔
32

33

34
class SplitRandomMix(Splitter):
1✔
35
    """Splits a multistream into new streams (splits), whose names, source input stream, and amount of instances, are specified by arg 'mix'.
36

37
    The keys of arg 'mix', are the names of the new streams, the values are of the form: 'name-of-source-stream[percentage-of-source-stream]'
38
    Each input instance, of any input stream, is selected exactly once for inclusion in any of the output streams.
39

40
    Examples:
41
    When processing a multistream made of two streams whose names are 'train' and 'test', by
42
    SplitRandomMix(mix =  { "train": "train[99%]",  "validation": "train[1%]",  "test": "test" })
43
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
44
    Output stream 'train' is made of randomly selected 99% of the instances of input stream 'train',
45
    output stream 'validation' is made of the remaining 1% instances of input 'train', and output stream 'test' is made
46
    of the whole of input stream 'test'.
47

48
    When processing the above input multistream by
49
    SplitRandomMix(mix =  { "train": "train[50%]+test[0.1]",  "validation": "train[50%]+test[0.2]",  "test": "test[0.7]" })
50
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
51
    Output stream 'train' is made of randomly selected 50% of the instances of input stream 'train' + randomly selected
52
    0.1 (i.e., 10%) of the instances of input stream 'test'.
53
    Output stream 'validation' is made of the remaining 50% instances of input 'train'+ randomly selected 0.2 (i.e.,
54
    20%) of the original instances of input 'test', that were not selected for output 'train',
55
    and output stream 'test' is made of the remaining instances of input 'test'.
56
    """
57

58
    mix: Dict[str, str]
1✔
59

60
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
61
        mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
1✔
62
        generators = random_mix_streams(multi_stream, mapping)
1✔
63
        return MultiStream.from_generators(generators)
1✔
64

65

66
class SeparateSplit(Splitter):
1✔
67
    """Separates a split (e.g. train) into several splits (e.g. train1, train2).
68

69
    sizes must indicate the size of every split except the last. If no size is give for the last split,
70
     it includes all the examples not allocated to any split.
71
    """
72

73
    from_split: str
1✔
74
    to_split_names: List[str]
1✔
75
    to_split_sizes: List[int]
1✔
76
    remove_targets_from_source_split: bool = True
1✔
77

78
    def verify(self):
1✔
79
        assert (
×
80
            len(self.to_split_names) == len(self.to_split_sizes)
81
            or len(self.to_split_names) == len(self.to_split_sizes) + 1
82
        ), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
83
        return super().verify()
×
84

85
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
86
        mapping = {
×
87
            key: {key: [(None, None)]}
88
            for key in multi_stream.keys()
89
            if not self.remove_targets_from_source_split or key != self.from_split
90
        }
91
        so_far = 0
×
92
        for name, size in itertools.zip_longest(
×
93
            self.to_split_names, self.to_split_sizes
94
        ):
95
            if self.remove_targets_from_source_split or name != self.from_split:
×
96
                mapping[name] = {self.from_split: [(so_far, size)]}
×
97
            if size:
×
98
                so_far += size
×
99
        generators = slice_streams(multi_stream, mapping)
×
100
        return MultiStream.from_generators(generators)
×
101

102

103
class SliceSplit(Splitter):
1✔
104
    slices: Dict[str, str]
1✔
105

106
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
107
        mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
×
108
        generators = slice_streams(multi_stream, mapping)
×
109
        return MultiStream.from_generators(generators)
×
110

111

112
def get_random_generator_based_on_instance(instance, local_seed=None):
1✔
113
    sub_seed = {**instance["input_fields"]}
1✔
114
    if local_seed is not None:
1✔
115
        sub_seed["local_seed"] = local_seed
1✔
116
    return new_random_generator(sub_seed=sub_seed)
1✔
117

118

119
class Sampler(Artifact):
1✔
120
    @abstractmethod
1✔
121
    def sample(
1✔
122
        self,
123
        sample_size: int,
124
        instances_pool: List[Dict[str, Any]],
125
        instance: Dict[str, Any],
126
        sampling_seed: Optional[int] = None,
127
    ) -> List[Dict[str, Any]]:
128
        pass
129

130
    def filter_source_by_instance(
1✔
131
        self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any]
132
    ) -> List[Dict[str, Any]]:
133
        if "input_fields" not in instance:
1✔
134
            raise ValueError(f"'input_fields' field is missing from '{instance}'.")
135
        try:
1✔
136
            return [
1✔
137
                item
138
                for item in instances_pool
139
                if item["input_fields"] != instance["input_fields"]
140
            ]
141
        except Exception as e:
142
            raise e
143

144

145
class RandomSampler(Sampler):
1✔
146
    """Selects a random sample of instances."""
147

148
    def sample(
1✔
149
        self,
150
        sample_size,
151
        instances_pool: List[Dict[str, object]],
152
        instance: Optional[Dict[str, object]],
153
        sampling_seed: Optional[int] = None,
154
    ) -> List[Dict[str, object]]:
155
        instances_pool = list(instances_pool)
1✔
156
        random_generator = get_random_generator_based_on_instance(
1✔
157
            instance, local_seed=sampling_seed
158
        )
159
        return random_generator.sample(instances_pool, sample_size)
1✔
160

161

162
class FixedIndicesSampler(Sampler):
1✔
163
    """Selects a fix set of samples based on a list of indices."""
164

165
    indices: List[int]
1✔
166

167
    def verify(self):
1✔
168
        assert isoftype(
1✔
169
            self.indices, List[int]
170
        ), f"'indices' of {self.__class__.__name__} must be List[int]. Value {self.indices} is of type {type(self.indices)}"
171
        super().verify()
1✔
172

173
    def sample(
1✔
174
        self,
175
        sample_size,
176
        instances_pool: List[Dict[str, object]],
177
        instance: Optional[Dict[str, object]],
178
        sampling_seed: Optional[int] = None,
179
    ) -> List[Dict[str, object]]:
180
        num_instances = len(instances_pool)
1✔
181

182
        instances = []
1✔
183
        for index in self.indices[0:sample_size]:
1✔
184
            if index >= num_instances:
1✔
185
                raise ValueError(
186
                    f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
187
                )
188
            instances.append(instances_pool[index])
1✔
189
        return instances
1✔
190

191

192
class CloseTextSampler(Sampler):
1✔
193
    """Selects the samples of instances which are the closest textual match to the given instance.
194

195
    Comparison is done based on a given field in the instance.
196

197
    """
198

199
    field: str
1✔
200

201
    def sample(
1✔
202
        self,
203
        sample_size: int,
204
        instances_pool: List[Dict[str, object]],
205
        instance: Dict[str, object],
206
        sampling_seed: Optional[int] = None,
207
    ) -> List[Dict[str, object]]:
208
        field = f"input_fields/{self.field}"
1✔
209
        value = dict_get(instance, field)
1✔
210

211
        instances_pool = list(instances_pool)
1✔
212

213
        # Get 'sample_size'  closest matchest texts based on field
214
        options = []
1✔
215
        for instance_in_pool in instances_pool:
1✔
216
            options.append(dict_get(instance_in_pool, field))
1✔
217
        closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0)
1✔
218
        # Randmly select 'sample_size' instances that are from the closest matches text
219
        # (There may be multiple instance with same text in the given field, and the order returned is
220
        # is also randomized )
221
        instances_pool = [
1✔
222
            instance_in_pool
223
            for instance_in_pool in instances_pool
224
            if dict_get(instance_in_pool, field) in closest_matches
225
        ]
226
        random_generator = get_random_generator_based_on_instance(instance)
1✔
227
        return random_generator.sample(instances_pool, sample_size)
1✔
228

229

230
class DiverseLabelsSampler(Sampler):
1✔
231
    """Selects a balanced sample of instances based on an output field.
232

233
    (used for selecting demonstrations in-context learning)
234

235
    The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
236
    The balancing is done such that each value or combination of values
237
    appears as equals as possible in the samples.
238

239
    The `choices` param is required and determines which values should be considered.
240

241
    Example:
242
        If choices is ['dog','cat'] , then the following combinations will be considered.
243
        ['']
244
        ['cat']
245
        ['dog']
246
        ['dog','cat']
247

248
        If the instance contains a value not in the 'choice' param, it is ignored. For example,
249
        if choices is ['dog','cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
250
        then the instance is considered as ['dog','cat'].
251

252
    Args:
253
        sample_size (int):
254
            number of samples to extract
255
        choices (str):
256
            name of input field that contains the list of values to balance on
257
        labels (str):
258
            name of output field with labels that must be balanced
259

260
    """
261

262
    choices: str = "choices"
1✔
263
    labels: str = "labels"
1✔
264
    include_empty_label: bool = True
1✔
265

266
    def prepare(self):
1✔
267
        super().prepare()
1✔
268
        self.labels_cache = None
1✔
269

270
    def exemplar_repr(self, exemplar):
1✔
271
        if "input_fields" not in exemplar:
1✔
272
            raise ValueError(f"'input_fields' field is missing from '{exemplar}'.")
273
        inputs = exemplar["input_fields"]
1✔
274
        if self.choices not in inputs:
1✔
275
            raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
276
        choices = inputs[self.choices]
1✔
277
        if not isinstance(choices, list):
1✔
278
            if isinstance(choices, str):
1✔
279
                choices = [choices]
×
280
            else:
281
                raise ValueError(
282
                    f"Unexpected input choices value '{choices}'. Expected a list or a string."
283
                )
284

285
        if "reference_fields" not in exemplar:
1✔
286
            raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.")
287
        outputs = exemplar["reference_fields"]
1✔
288
        if self.labels not in outputs:
1✔
289
            raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")
290

291
        exemplar_outputs = exemplar["reference_fields"][self.labels]
1✔
292
        if not isinstance(exemplar_outputs, list):
1✔
293
            raise ValueError(
294
                f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list."
295
            )
296

297
        return str([choice for choice in choices if choice in exemplar_outputs])
1✔
298

299
    def divide_by_repr(self, exemplars_pool):
1✔
300
        labels = {}
1✔
301
        for exemplar in exemplars_pool:
1✔
302
            label_repr = self.exemplar_repr(exemplar)
1✔
303
            if label_repr == "[]" and not self.include_empty_label:
1✔
304
                continue
1✔
305
            if label_repr not in labels:
1✔
306
                labels[label_repr] = []
1✔
307
            labels[label_repr].append(exemplar)
1✔
308
        return labels
1✔
309

310
    def sample(
1✔
311
        self,
312
        sample_size: int,
313
        instances_pool: List[Dict[str, object]],
314
        instance: Optional[Dict[str, object]],
315
    ) -> List[Dict[str, object]]:
316
        if self.labels_cache is None:
1✔
317
            self.labels_cache = self.divide_by_repr(instances_pool)
1✔
318
        all_labels = list(self.labels_cache.keys())
1✔
319
        random_generator = get_random_generator_based_on_instance(instance)
1✔
320
        random_generator.shuffle(all_labels)
1✔
321
        from collections import Counter
1✔
322

323
        if sample_size > len(instances_pool):
1✔
324
            raise ValueError(
325
                f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}"
326
            )
327
        total_allocated = 0
1✔
328
        allocations = Counter()
1✔
329

330
        while total_allocated < sample_size:
1✔
331
            for label in all_labels:
1✔
332
                if total_allocated < sample_size:
1✔
333
                    if len(self.labels_cache[label]) - allocations[label] > 0:
1✔
334
                        allocations[label] += 1
1✔
335
                        total_allocated += 1
1✔
336
                else:
337
                    break
1✔
338

339
        result = []
1✔
340
        for label, allocation in allocations.items():
1✔
341
            sample = random_generator.sample(self.labels_cache[label], allocation)
1✔
342
            result.extend(sample)
1✔
343

344
        random_generator.shuffle(result)
1✔
345
        return result
1✔
346

347

348
class AssignDemosToInstance(InstanceOperator):
1✔
349
    from_field: str
1✔
350
    to_field: str
1✔
351
    sampler: Sampler
1✔
352
    skip_demoed_instances: bool = False
1✔
353
    sampling_seed: Optional[int] = None
1✔
354

355
    def prepare(self):
1✔
356
        self.local_cache = None
1✔
357
        self.sampler.prepare()
1✔
358

359
    @abstractmethod
1✔
360
    def get_sample_size(self, instance) -> int:
1✔
361
        pass
362

363
    def process(
1✔
364
        self, instance: Dict[str, Any], multi_stream: MultiStream
365
    ) -> Dict[str, Any]:
366
        if self.skip_demoed_instances and self.to_field in instance:
1✔
367
            if self.from_field in instance:
1✔
368
                instance.pop(self.from_field)
1✔
369
            return instance
1✔
370

371
        demos_pool = instance[self.from_field]
1✔
372
        sample_size = self.get_sample_size(instance)
1✔
373
        source_stream = self.sampler.filter_source_by_instance(demos_pool, instance)
1✔
374
        if len(source_stream) < sample_size:
1✔
375
            raise ValueError(
376
                f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering."
377
            )
378
        sampled_instances = self.sampler.sample(
1✔
379
            sample_size=sample_size,
380
            instances_pool=source_stream,
381
            instance=instance,
382
            sampling_seed=self.sampling_seed,
383
        )
384
        instance[self.to_field] = recursive_copy(sampled_instances)
1✔
385
        instance.pop(self.from_field)  # pop the field pointing to the demos_pool
1✔
386
        return instance
1✔
387

388

389
class ConstantSizeSample(AssignDemosToInstance):
1✔
390
    sample_size: int
1✔
391

392
    def get_sample_size(self, instance) -> int:
1✔
393
        return self.sample_size
1✔
394

395

396
class RandomSizeSample(AssignDemosToInstance):
1✔
397
    sample_sizes: List[int]
1✔
398

399
    def get_sample_size(self, instance) -> int:
1✔
400
        random_generator = get_random_generator_based_on_instance(instance)
1✔
401
        return random_generator.choice(self.sample_sizes)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc