• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12484941056

24 Dec 2024 06:13PM UTC coverage: 80.039% (-0.3%) from 80.313%
12484941056

push

github

web-flow
make demos_pool a local var rather than a separate stream (#1436)

* make demos_pool a variable, rather than a separate stream. Needs to restart train repeatedly. Fix needed in Loaders. Allow a given demos_pool, or input stream-instances already loaded with demos

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* renamed Sample->AssignDemosToInstance and emptied all but BaseRecipe

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* StandardRecipe->DatasetRecipe and standard_recipe->dataset_recipe

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* all recipes are DatasetRecipe

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* separate AddDemosPool from CreateDemosPool

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* add deprecation for old recipes, and update docs

Signed-off-by: dafnapension <dafnashein@yahoo.com>

* allow to consume a whole stream

Signed-off-by: dafnapension <dafnashein@yahoo.com>

---------

Signed-off-by: dafnapension <dafnashein@yahoo.com>

1337 of 1665 branches covered (80.3%)

Branch coverage included in aggregate %.

8439 of 10549 relevant lines covered (80.0%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.53
src/unitxt/split_utils.py
1
import itertools
1✔
2
import re
1✔
3
from typing import Dict
1✔
4

5
from .generator_utils import ReusableGenerator
1✔
6
from .logging_utils import get_logger
1✔
7
from .random_utils import new_random_generator
1✔
8
from .stream import MissingStreamError, Stream
1✔
9

10
logger = get_logger()
1✔
11

12

13
def parse_random_mix_string(input_str):
1✔
14
    """Parses a string of format "source1[percentage1%]+source2[value2]+..." and returns a dictionary.
15

16
    Args:
17
        input_str (str): A string containing source names and their respective proportions. The format is
18
                         "source[proportion%]" or "source[proportion]", with multiple sources separated by "+".
19
                         The proportion can be a percentage (e.g., "90%") or a decimal number (e.g., "0.7").
20
                         If the proportion is not provided, it assumes 100%.
21

22
    Returns:
23
        dict: A dictionary where the keys are the source names and the values are the proportions converted to floats.
24
              If the proportion was given as a percentage, the value is divided by 100.
25

26
    Raises:
27
        ValueError: If the input string is not in the correct format.
28

29
    Example:
30
        >>> parse_random_mix_string("dale[90%]+oren[0.7]+mike")
31
            {'dale': 0.9, 'oren': 0.7, 'mike': 1.0}
32
    """
33
    if not re.fullmatch(
1✔
34
        r"((\w+\[\d*\.?\d*%?\]|\w+)\+)*(\w+\[\d*\.?\d*%?\]|\w+)",
35
        input_str,
36
    ):
37
        raise ValueError(f"Invalid input format for split '{input_str}'")
×
38

39
    pattern = re.compile(r"(\w+)(\[\d*\.?\d*%?\])?")
1✔
40
    matches = pattern.findall(input_str)
1✔
41

42
    return {
1✔
43
        name: float(value.strip("[]%")) / 100
44
        if "%" in value
45
        else (float(value.strip("[]")) if value else 1.0)
46
        for name, value in matches
47
    }
48

49

50
def parse_slices_string(input_str):
1✔
51
    """Parses a string of format "source1[value1:value2] + source2[value2:] + source3 + ..." and returns a dictionary.
52

53
    {"source1": [(value1,value2)], "source2": [(value2, None)], "source3": [(None,None)]...}.
54

55
    If a source appears multiple times with different indices, all index pairs are included in the list.
56

57
    Args:
58
        input_str (str): A string containing source names and their respective indices. The format is
59
                         "source[:index]" or "source[index:]", with multiple sources separated by "+".
60
                         The index represents the items to be taken from the source.
61

62
    Returns:
63
        dict: A dictionary where the keys are the source names and the values are lists of indices as tuples.
64
              If the index is before the colon, it is represented as (None, index),
65
              if it's after the colon, it's represented as (index, None)
66

67
    Raises:
68
        ValueError: If the input string is not in the correct format.
69

70
    Example:
71
        >>> parse_slices_string("oren[:50]+jake[24:]+test+oren[5:10]")
72
        {'oren': [(None, 50), (5, 10)], 'jake': [(24, None)], 'test': [(None, None)]}
73
    """
74
    result_dict = {}
×
75

76
    # Split the input string into a list of sources
77
    sources = re.split(r"\+", input_str)
×
78
    for source in sources:
×
79
        # If the source has a slice, parse it
80
        match = re.fullmatch(r"(\w+)\[(\d*):(\d*)\]", source)
×
81
        if match:
×
82
            name, start, end = match.groups()
×
83
            start = int(start) if start else None
×
84
            end = int(end) if end else None
×
85
        elif re.fullmatch(r"\w+", source):
×
86
            # If the source has no slice, use None for both start and end
87
            name = source
×
88
            start = end = None
×
89
        else:
90
            raise ValueError(
×
91
                f'The input string "{input_str}" is not in the correct format.'
92
            )
93

94
        if name not in result_dict:
×
95
            result_dict[name] = [(start, end)]
×
96
        else:
97
            result_dict[name].append((start, end))
×
98

99
    return result_dict
×
100

101

102
def slice_stream(stream, start, end):
1✔
103
    # If start is None, consume from the beginning
104
    if start is not None:
×
105
        stream = itertools.islice(stream, start, None)
×
106
    # If end is not None, consume until end
107
    if end is not None:
×
108
        stream = itertools.islice(stream, end)
×
109

110
    yield from stream
×
111
    # return stream
112

113

114
def slice_streams(input_streams, mapping):
1✔
115
    """Slices multiple input streams according to a mapping and chains the results together.
116

117
    Args:
118
        input_streams (dict): A dictionary where the keys are the names of the input streams
119
                              and the values are the input streams themselves.
120
        mapping (dict): A dictionary where the keys are the names of the new streams
121
                        and the values are dictionaries mapping old stream names
122
                        to lists of tuples representing slices.
123

124
    Returns:
125
        dict: A dictionary where the keys are the names of the new streams and the values are
126
              the new streams, which consist of parts of the old streams chained together.
127

128
    Raises:
129
        ValueError: If a stream is supposed to be sliced at an index greater than its length or a negative one.
130

131
    Example:
132
        >>> old_streams = {"train": [1, 2, 3, 4, 5, 6, 7, 8, 9], "test": [10, 11, 12, 13, 14]}
133
        >>> mapping = {"new_train": {"train": [(None, 5), (7, 9)]}, "new_test": {"test": [(2, None)]}}
134
        >>> slice_streams(old_streams, mapping)
135
        {"new_train": [1, 2, 3, 4, 5, 8, 9], "new_test": [12, 13, 14]}
136
    """
137
    new_streams = {}
×
138
    for new_stream, sources in mapping.items():
×
139

140
        def generator(new_stream, sources):
×
141
            for old_stream, slices in sources.items():
×
142
                if old_stream not in input_streams:
×
143
                    raise MissingStreamError(
×
144
                        f"'{old_stream}' is not available in input streams, but need to slice there from"
145
                    )
146
                old_stream_content = input_streams[old_stream]
×
147
                for start, end in slices:
×
148
                    yield from slice_stream(old_stream_content, start, end)
×
149

150
        new_streams[new_stream] = ReusableGenerator(
×
151
            generator, gen_kwargs={"new_stream": new_stream, "sources": sources}
152
        )
153

154
    return new_streams
×
155

156

157
def build_stream_routing(mapping):
1✔
158
    """Builds the stream mapping dictionary based on the provided mapping.
159

160
    The stream mapping dictionary represents the mapping of old streams to new streams
161
    and their respective probabilities. It ensures that the probabilities for each old stream
162
    do not sum up to more than one. If the sum of probabilities is less than one,
163
    a null stream (None) is included to account for the remaining probability.
164

165
    Args:
166
        mapping (dict): A dictionary specifying the mapping of old streams to new streams
167
                        and their respective probabilities.
168

169
    Returns:
170
        dict: A dictionary representing the stream mapping, where each entry corresponds to an
171
              old stream, and the value is a tuple containing the new streams and their respective
172
                probabilities.
173

174
    Example:
175
        >>> mapping = {
176
                'my_new_stream': {
177
                    'my_old_stream1': 0.6,
178
                    'my_old_stream2': 0.2
179
                },
180
                'my_new_stream2': {
181
                    'my_old_stream1': 0.4,
182
                    'my_old_stream2': 0.8
183
                }
184
            }
185
            stream_mapping = build_stream_mapping(mapping)
186
            logger.info(stream_mapping)
187
            # Output: {'my_old_stream1': (['my_new_stream', 'my_new_stream2'], [0.6, 0.4]),
188
            #          'my_old_stream2': (['my_new_stream', 'my_new_stream2'], [0.2, 0.8])}
189
    """
190
    stream_mapping = {}
1✔
191

192
    # Calculate total weight for each old stream
193
    total_weights = {}
1✔
194
    for _new_stream, old_streams in mapping.items():
1✔
195
        for old_stream, weight in old_streams.items():
1✔
196
            if old_stream not in total_weights:
1✔
197
                total_weights[old_stream] = weight
1✔
198
            else:
199
                total_weights[old_stream] += weight
1✔
200

201
    # Build stream_mapping with null stream included
202
    for new_stream, old_streams in mapping.items():
1✔
203
        for old_stream, weight in old_streams.items():
1✔
204
            if old_stream not in stream_mapping:
1✔
205
                stream_mapping[old_stream] = {}
1✔
206
            stream_mapping[old_stream][new_stream] = weight
1✔
207

208
            # Add null stream if total weight less than 1
209
            if total_weights[old_stream] < 1:
1✔
210
                stream_mapping[old_stream][None] = 1 - total_weights[old_stream]
1✔
211

212
    return {k: (list(v.keys()), list(v.values())) for k, v in stream_mapping.items()}
1✔
213

214

215
def rename_split(input_streams: Dict[str, Stream], mapping: Dict[str, str]):
1✔
216
    """Renames the streams.
217

218
    Args:
219
        input_streams (dict): A dictionary containing the input streams, where each key is
220
                              the name of the stream and the value is an iterable or generator
221
                              representing the stream.
222

223
        mapping (dict): A dictionary specifying the mapping of old streams to new streams.
224

225
    Returns:
226
        dict: A dictionary containing the generated new streams, where each key is the name
227
    of the new stream and the value is a generator representing the stream.
228
    """
229
    new_streams = {}
1✔
230
    for key, val in mapping.items():
1✔
231
        if key not in input_streams:
1✔
232
            raise ValueError(
×
233
                f"Stream '{key}' is not in input_streams '{input_streams.keys()}'"
234
            )
235
        new_streams[val] = input_streams.pop(key)
1✔
236
    return {**input_streams, **new_streams}
1✔
237

238

239
def random_mix_generator(
1✔
240
    new_stream_name, new_stream_sources, stream_routing, input_streams
241
):
242
    for old_stream_name in new_stream_sources:
1✔
243
        optional_streams, weights = stream_routing[old_stream_name]
1✔
244
        random_generator = new_random_generator(sub_seed=old_stream_name)
1✔
245
        assert (
1✔
246
            old_stream_name in input_streams
247
        ), f"'{old_stream_name}' split not found.  Possibles options: {input_streams.keys()}"
248
        for item in input_streams[old_stream_name]:
1✔
249
            choice = random_generator.choices(optional_streams, weights=weights, k=1)[0]
1✔
250
            if choice == new_stream_name:
1✔
251
                yield item
1✔
252

253

254
def random_mix_streams(input_streams, mapping):
1✔
255
    """Creates new streams based on the provided input streams and mapping.
256

257
    The create_streams function generates new streams by selectively including items from
258
    the old streams based on the specified mapping. Each item will be included in at most
259
    one new stream, as defined by the probabilities in the mapping and stream routing.
260

261
    Args:
262
        input_streams (dict): A dictionary containing the input streams, where each key is
263
                              the name of the stream and the value is an iterable or generator
264
                              representing the stream.
265

266
        mapping (dict): A dictionary specifying the mapping of old streams to new streams
267
                        and their respective probabilities.
268

269
    Returns:
270
        dict: A dictionary containing the generated new streams, where each key is the name
271
              of the new stream and the value is a generator representing the stream.
272

273
    Example:
274
        >>> input_streams = {
275
                'my_old_stream1': gen1(),
276
                'my_old_stream2': gen2(),
277
            }
278
            mapping = {
279
                'my_new_stream': {
280
                    'my_old_stream1': 0.6,
281
                    'my_old_stream2': 0.2
282
                },
283
                'my_new_stream2': {
284
                    'my_old_stream1': 0.4,
285
                    'my_old_stream2': 0.8
286
                }
287
            }
288
            new_streams = create_streams(input_streams, mapping)
289
            for new_stream_name, new_stream in new_streams.items():
290
                logger.info(f"{new_stream_name}:")
291
                for _, item in zip(range(10), new_stream):
292
                    logger.info(item)
293
    """
294
    new_streams = {}
1✔
295

296
    # Build stream routing
297
    stream_routing = build_stream_routing(mapping)
1✔
298

299
    # Create new stream generators
300
    for new_stream_name, new_stream_sources in mapping.items():
1✔
301
        new_streams[new_stream_name] = ReusableGenerator(
1✔
302
            random_mix_generator,
303
            gen_kwargs={
304
                "new_stream_name": new_stream_name,
305
                "new_stream_sources": new_stream_sources,
306
                "stream_routing": stream_routing,
307
                "input_streams": input_streams,
308
            },
309
        )
310

311
    return new_streams
1✔
312

313

314
if __name__ == "__main__":
1✔
315
    logger.info(parse_random_mix_string("dale[90%]+oren[0.7]+mike"))
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc