• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12484941056

24 Dec 2024 06:13PM UTC coverage: 80.039% (-0.3%) from 80.313%
12484941056

push

github

web-flow
make demos_pool a local var rather than a separate stream (#1436)

* make demos_pool a variable, rather than a separate stream. Needs to restart train repeatedly. Fix needed in Loaders. Allow a given demos_pool, or input stream-instances already loaded with demos

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* renamed Sample->AssignDemosToInstance and emptied all but BaseRecipe

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* StandardRecipe->DatasetRecipe and standard_recipe->dataset_recipe

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* all recipes are DatasetRecipe

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* separate AddDemosPool from CreateDemosPool

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* add deprecation for old recipes, and update docs

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* allow to consume a whole stream

Signed-off-by: dafnapension <dafnashein@yahoo.com>

---------

Signed-off-by: dafnapension <dafnashein@yahoo.com>

1337 of 1665 branches covered (80.3%)

Branch coverage included in aggregate %.

8439 of 10549 relevant lines covered (80.0%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.44
src/unitxt/splitters.py
1
import itertools
1✔
2
from abc import abstractmethod
1✔
3
from difflib import get_close_matches
1✔
4
from typing import Any, Dict, List, Optional
1✔
5

6
from .artifact import Artifact
1✔
7
from .dict_utils import dict_get
1✔
8
from .operator import InstanceOperator, MultiStreamOperator
1✔
9
from .random_utils import new_random_generator
1✔
10
from .split_utils import (
1✔
11
    parse_random_mix_string,
12
    parse_slices_string,
13
    random_mix_streams,
14
    rename_split,
15
    slice_streams,
16
)
17
from .stream import MultiStream
1✔
18
from .type_utils import isoftype
1✔
19
from .utils import recursive_copy
1✔
20

21

22
class Splitter(MultiStreamOperator):
1✔
23
    pass
1✔
24

25

26
class RenameSplits(Splitter):
1✔
27
    mapper: Dict[str, str]
1✔
28

29
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
30
        generators = rename_split(multi_stream, self.mapper)
1✔
31
        return MultiStream(generators)
1✔
32

33

34
class SplitRandomMix(Splitter):
1✔
35
    """Splits a multistream into new streams (splits), whose names, source input stream, and amount of instances, are specified by arg 'mix'.
36

37
    The keys of arg 'mix', are the names of the new streams, the values are of the form: 'name-of-source-stream[percentage-of-source-stream]'
38
    Each input instance, of any input stream, is selected exactly once for inclusion in any of the output streams.
39

40
    Examples:
41
    When processing a multistream made of two streams whose names are 'train' and 'test', by
42
    SplitRandomMix(mix =  { "train": "train[99%]",  "validation": "train[1%]",  "test": "test" })
43
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
44
    Output stream 'train' is made of randomly selected 99% of the instances of input stream 'train',
45
    output stream 'validation' is made of the remaining 1% instances of input 'train', and output stream 'test' is made
46
    of the whole of input stream 'test'.
47

48
    When processing the above input multistream by
49
    SplitRandomMix(mix =  { "train": "train[50%]+test[0.1]",  "validation": "train[50%]+test[0.2]",  "test": "test[0.7]" })
50
    the output is a multistream, whose three streams are named 'train', 'validation', and 'test'.
51
    Output stream 'train' is made of randomly selected 50% of the instances of input stream 'train' + randomly selected
52
    0.1 (i.e., 10%) of the instances of input stream 'test'.
53
    Output stream 'validation' is made of the remaining 50% instances of input 'train'+ randomly selected 0.2 (i.e.,
54
    20%) of the original instances of input 'test', that were not selected for output 'train',
55
    and output stream 'test' is made of the remaining instances of input 'test'.
56
    """
57

58
    mix: Dict[str, str]
1✔
59

60
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
61
        mapping = {k: parse_random_mix_string(v) for k, v in self.mix.items()}
1✔
62
        generators = random_mix_streams(multi_stream, mapping)
1✔
63
        return MultiStream.from_generators(generators)
1✔
64

65

66
class SeparateSplit(Splitter):
1✔
67
    """Separates a split (e.g. train) into several splits (e.g. train1, train2).
68

69
    sizes must indicate the size of every split except the last. If no size is give for the last split,
70
     it includes all the examples not allocated to any split.
71
    """
72

73
    from_split: str
1✔
74
    to_split_names: List[str]
1✔
75
    to_split_sizes: List[int]
1✔
76
    remove_targets_from_source_split: bool = True
1✔
77

78
    def verify(self):
1✔
79
        assert (
×
80
            len(self.to_split_names) == len(self.to_split_sizes)
81
            or len(self.to_split_names) == len(self.to_split_sizes) + 1
82
        ), f"Examples num should be specified to all or all but the last splits, instead given {len(self.to_split_names)} split names and {len(self.to_split_sizes)} split sizes. \n split names:{self.to_split_names} split sizes {self.to_split_sizes}"
83
        return super().verify()
×
84

85
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
86
        mapping = {
×
87
            key: {key: [(None, None)]}
88
            for key in multi_stream.keys()
89
            if not self.remove_targets_from_source_split or key != self.from_split
90
        }
91
        so_far = 0
×
92
        for name, size in itertools.zip_longest(
×
93
            self.to_split_names, self.to_split_sizes
94
        ):
95
            if self.remove_targets_from_source_split or name != self.from_split:
×
96
                mapping[name] = {self.from_split: [(so_far, size)]}
×
97
            if size:
×
98
                so_far += size
×
99
        generators = slice_streams(multi_stream, mapping)
×
100
        return MultiStream.from_generators(generators)
×
101

102

103
class SliceSplit(Splitter):
1✔
104
    slices: Dict[str, str]
1✔
105

106
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
107
        mapping = {k: parse_slices_string(v) for k, v in self.slices.items()}
×
108
        generators = slice_streams(multi_stream, mapping)
×
109
        return MultiStream.from_generators(generators)
×
110

111

112
def get_random_generator_based_on_instance(instance):
1✔
113
    return new_random_generator(sub_seed={**instance["input_fields"]})
1✔
114

115

116
class Sampler(Artifact):
1✔
117
    @abstractmethod
1✔
118
    def sample(
1✔
119
        self,
120
        sample_size: int,
121
        instances_pool: List[Dict[str, Any]],
122
        instance: Dict[str, Any],
123
    ) -> List[Dict[str, Any]]:
124
        pass
×
125

126
    def filter_source_by_instance(
1✔
127
        self, instances_pool: List[Dict[str, Any]], instance: Dict[str, Any]
128
    ) -> List[Dict[str, Any]]:
129
        if "input_fields" not in instance:
1✔
130
            raise ValueError(f"'input_fields' field is missing from '{instance}'.")
1✔
131
        try:
1✔
132
            return [
1✔
133
                item
134
                for item in instances_pool
135
                if item["input_fields"] != instance["input_fields"]
136
            ]
137
        except Exception as e:
×
138
            raise e
×
139

140

141
class RandomSampler(Sampler):
1✔
142
    """Selects a random sample of instances."""
143

144
    def sample(
1✔
145
        self,
146
        sample_size,
147
        instances_pool: List[Dict[str, object]],
148
        instance: Optional[Dict[str, object]],
149
    ) -> List[Dict[str, object]]:
150
        instances_pool = list(instances_pool)
1✔
151
        random_generator = get_random_generator_based_on_instance(instance)
1✔
152
        return random_generator.sample(instances_pool, sample_size)
1✔
153

154

155
class FixedIndicesSampler(Sampler):
1✔
156
    """Selects a fix set of samples based on a list of indices."""
157

158
    indices: List[int]
1✔
159

160
    def verify(self):
1✔
161
        assert isoftype(
1✔
162
            self.indices, List[int]
163
        ), f"'indices' of {self.__class__.__name__} must be List[int]. Value {self.indices} is of type {type(self.indices)}"
164
        super().verify()
1✔
165

166
    def sample(
1✔
167
        self,
168
        sample_size,
169
        instances_pool: List[Dict[str, object]],
170
        instance: Optional[Dict[str, object]],
171
    ) -> List[Dict[str, object]]:
172
        num_instances = len(instances_pool)
1✔
173

174
        instances = []
1✔
175
        for index in self.indices[0:sample_size]:
1✔
176
            if index >= num_instances:
1✔
177
                raise ValueError(
1✔
178
                    f"FixedIndicesSampler 'indices' field contains index ({index}) which is out of bounds of the instance pool ( of size {num_instances})"
179
                )
180
            instances.append(instances_pool[index])
1✔
181
        return instances
1✔
182

183

184
class CloseTextSampler(Sampler):
1✔
185
    """Selects the samples of instances which are the closest textual match to the given instance.
186

187
    Comparison is done based on a given field in the instance.
188

189
    """
190

191
    field: str
1✔
192

193
    def sample(
1✔
194
        self,
195
        sample_size: int,
196
        instances_pool: List[Dict[str, object]],
197
        instance: Dict[str, object],
198
    ) -> List[Dict[str, object]]:
199
        field = f"input_fields/{self.field}"
1✔
200
        value = dict_get(instance, field)
1✔
201

202
        instances_pool = list(instances_pool)
1✔
203

204
        # Get 'sample_size'  closest matchest texts based on field
205
        options = []
1✔
206
        for instance_in_pool in instances_pool:
1✔
207
            options.append(dict_get(instance_in_pool, field))
1✔
208
        closest_matches = get_close_matches(value, options, n=sample_size, cutoff=0)
1✔
209
        # Randmly select 'sample_size' instances that are from the closest matches text
210
        # (There may be multiple instance with same text in the given field, and the order returned is
211
        # is also randomized )
212
        instances_pool = [
1✔
213
            instance_in_pool
214
            for instance_in_pool in instances_pool
215
            if dict_get(instance_in_pool, field) in closest_matches
216
        ]
217
        random_generator = get_random_generator_based_on_instance(instance)
1✔
218
        return random_generator.sample(instances_pool, sample_size)
1✔
219

220

221
class DiverseLabelsSampler(Sampler):
1✔
222
    """Selects a balanced sample of instances based on an output field.
223

224
    (used for selecting demonstrations in-context learning)
225

226
    The field must contain list of values e.g ['dog'], ['cat'], ['dog','cat','cow'].
227
    The balancing is done such that each value or combination of values
228
    appears as equals as possible in the samples.
229

230
    The `choices` param is required and determines which values should be considered.
231

232
    Example:
233
        If choices is ['dog','cat'] , then the following combinations will be considered.
234
        ['']
235
        ['cat']
236
        ['dog']
237
        ['dog','cat']
238

239
        If the instance contains a value not in the 'choice' param, it is ignored. For example,
240
        if choices is ['dog','cat'] and the instance field is ['dog','cat','cow'], then 'cow' is ignored
241
        then the instance is considered as ['dog','cat'].
242

243
    Args:
244
        sample_size (int):
245
            number of samples to extract
246
        choices (str):
247
            name of input field that contains the list of values to balance on
248
        labels (str):
249
            name of output field with labels that must be balanced
250

251
    """
252

253
    choices: str = "choices"
1✔
254
    labels: str = "labels"
1✔
255
    include_empty_label: bool = True
1✔
256

257
    def prepare(self):
1✔
258
        super().prepare()
1✔
259
        self.labels_cache = None
1✔
260

261
    def exemplar_repr(self, exemplar):
1✔
262
        if "input_fields" not in exemplar:
1✔
263
            raise ValueError(f"'input_fields' field is missing from '{exemplar}'.")
1✔
264
        inputs = exemplar["input_fields"]
1✔
265
        if self.choices not in inputs:
1✔
266
            raise ValueError(f"'{self.choices}' field is missing from '{inputs}'.")
×
267
        choices = inputs[self.choices]
1✔
268
        if not isinstance(choices, list):
1✔
269
            if isinstance(choices, str):
1✔
270
                choices = [choices]
×
271
            else:
272
                raise ValueError(
1✔
273
                    f"Unexpected input choices value '{choices}'. Expected a list or a string."
274
                )
275

276
        if "reference_fields" not in exemplar:
1✔
277
            raise ValueError(f"'reference_fields' field is missing from '{exemplar}'.")
1✔
278
        outputs = exemplar["reference_fields"]
1✔
279
        if self.labels not in outputs:
1✔
280
            raise ValueError(f"'{self.labels}' field is missing from '{outputs}'.")
×
281

282
        exemplar_outputs = exemplar["reference_fields"][self.labels]
1✔
283
        if not isinstance(exemplar_outputs, list):
1✔
284
            raise ValueError(
×
285
                f"Unexpected exemplar_outputs value '{exemplar_outputs}'. Expected a list."
286
            )
287

288
        return str([choice for choice in choices if choice in exemplar_outputs])
1✔
289

290
    def divide_by_repr(self, exemplars_pool):
1✔
291
        labels = {}
1✔
292
        for exemplar in exemplars_pool:
1✔
293
            label_repr = self.exemplar_repr(exemplar)
1✔
294
            if label_repr == "[]" and not self.include_empty_label:
1✔
295
                continue
1✔
296
            if label_repr not in labels:
1✔
297
                labels[label_repr] = []
1✔
298
            labels[label_repr].append(exemplar)
1✔
299
        return labels
1✔
300

301
    def sample(
1✔
302
        self,
303
        sample_size: int,
304
        instances_pool: List[Dict[str, object]],
305
        instance: Optional[Dict[str, object]],
306
    ) -> List[Dict[str, object]]:
307
        if self.labels_cache is None:
1✔
308
            self.labels_cache = self.divide_by_repr(instances_pool)
1✔
309
        all_labels = list(self.labels_cache.keys())
1✔
310
        random_generator = get_random_generator_based_on_instance(instance)
1✔
311
        random_generator.shuffle(all_labels)
1✔
312
        from collections import Counter
1✔
313

314
        if sample_size > len(instances_pool):
1✔
315
            raise ValueError(
×
316
                f"Request sample size {sample_size} is greater than number of instances {len(instances_pool)}"
317
            )
318
        total_allocated = 0
1✔
319
        allocations = Counter()
1✔
320

321
        while total_allocated < sample_size:
1✔
322
            for label in all_labels:
1✔
323
                if total_allocated < sample_size:
1✔
324
                    if len(self.labels_cache[label]) - allocations[label] > 0:
1✔
325
                        allocations[label] += 1
1✔
326
                        total_allocated += 1
1✔
327
                else:
328
                    break
1✔
329

330
        result = []
1✔
331
        for label, allocation in allocations.items():
1✔
332
            sample = random_generator.sample(self.labels_cache[label], allocation)
1✔
333
            result.extend(sample)
1✔
334

335
        random_generator.shuffle(result)
1✔
336
        return result
1✔
337

338

339
class AssignDemosToInstance(InstanceOperator):
1✔
340
    from_field: str
1✔
341
    to_field: str
1✔
342
    sampler: Sampler
1✔
343
    skip_demoed_instances: bool = False
1✔
344

345
    def prepare(self):
1✔
346
        self.local_cache = None
1✔
347
        self.sampler.prepare()
1✔
348

349
    @abstractmethod
1✔
350
    def get_sample_size(self, instance) -> int:
1✔
351
        pass
×
352

353
    def process(
1✔
354
        self, instance: Dict[str, Any], multi_stream: MultiStream
355
    ) -> Dict[str, Any]:
356
        if self.skip_demoed_instances and self.to_field in instance:
1✔
357
            if self.from_field in instance:
1✔
358
                instance.pop(self.from_field)
1✔
359
            return instance
1✔
360

361
        demos_pool = instance[self.from_field]
1✔
362
        sample_size = self.get_sample_size(instance)
1✔
363
        source_stream = self.sampler.filter_source_by_instance(demos_pool, instance)
1✔
364
        if len(source_stream) < sample_size:
1✔
365
            raise ValueError(
×
366
                f"Size of population to sample from: {len(source_stream)} is smaller than the needed sample_size: {sample_size}. Please consider increasing increasing the demos pool, for which you may need to increase loader_limit or employ a less strict stream filtering."
367
            )
368
        sampled_instances = self.sampler.sample(
1✔
369
            sample_size=sample_size, instances_pool=source_stream, instance=instance
370
        )
371
        instance[self.to_field] = recursive_copy(sampled_instances)
1✔
372
        instance.pop(self.from_field)  # pop the field pointing to the demos_pool
1✔
373
        return instance
1✔
374

375

376
class ConstantSizeSample(AssignDemosToInstance):
1✔
377
    sample_size: int
1✔
378

379
    def get_sample_size(self, instance) -> int:
1✔
380
        return self.sample_size
1✔
381

382

383
class RandomSizeSample(AssignDemosToInstance):
1✔
384
    sample_sizes: List[int]
1✔
385

386
    def get_sample_size(self, instance) -> int:
1✔
387
        random_generator = get_random_generator_based_on_instance(instance)
1✔
388
        return random_generator.choice(self.sample_sizes)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc