• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12809765279

16 Jan 2025 01:19PM UTC coverage: 79.393% (-0.01%) from 79.403%
12809765279

Pull #1518

github

web-flow
Merge cb8776ccb into 5506f9c77
Pull Request #1518: Ensure fusion do not call streams before use

1389 of 1738 branches covered (79.92%)

Branch coverage included in aggregate %.

8767 of 11054 relevant lines covered (79.31%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.24
src/unitxt/fusion.py
1
from abc import abstractmethod
1✔
2
from typing import Dict, Generator, List, Optional, Union
1✔
3

4
from .dataclass import NonPositionalField
1✔
5
from .operator import SourceOperator
1✔
6
from .random_utils import new_random_generator
1✔
7
from .stream import DynamicStream, MultiStream
1✔
8
from .type_utils import isoftype
1✔
9

10

11
class BaseFusion(SourceOperator):
1✔
12
    """BaseFusion operator that combines multiple multistreams into one.
13

14
    Args:
15
        subsets: a dict of named SourceOperator objects (each to yield a MultiStream) or a list thereof,
16
          each is specified along with its input, so can generate a MultiStream
17
        include_splits: List of splits to include from each input MultiStream.
18
                If None, all splits are included.
19
    """
20

21
    subsets: Union[List[SourceOperator], Dict[str, SourceOperator]]
1✔
22
    include_splits: Optional[List[str]] = NonPositionalField(default=None)
1✔
23

24
    @abstractmethod
1✔
25
    def fusion_generator(self, split) -> Generator:
1✔
26
        pass
×
27

28
    def prepare_subsets(self):
1✔
29
        assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
1✔
30
            self.subsets, List[SourceOperator]
31
        )
32
        self.named_subsets = {}
1✔
33
        if isinstance(self.subsets, list):
1✔
34
            for i in range(len(self.subsets)):
1✔
35
                self.named_subsets[i] = self.subsets[i]
1✔
36
        else:
37
            for name, origin in self.subsets.items():
1✔
38
                try:
1✔
39
                    self.named_subsets[name] = origin
1✔
40
                except Exception as e:
×
41
                    raise RuntimeError(f"Exception in subset: {name}") from e
×
42

43
    def splits(self) -> List[str]:
1✔
44
        self.prepare_subsets()
1✔
45
        if self.include_splits is not None:
1✔
46
            return self.include_splits
1✔
47
        return ["train", "test", "validation"]
1✔
48

49
    def process(
1✔
50
        self,
51
    ) -> MultiStream:
52
        result = {}
1✔
53
        for split in self.splits():
1✔
54
            result[split] = DynamicStream(
1✔
55
                self.fusion_generator, gen_kwargs={"split": split}
56
            )
57
        return MultiStream(result)
1✔
58

59

60
class FixedFusion(BaseFusion):
1✔
61
    """FixedFusion operator that combines multiple multistreams into one, limiting the number of instances taken from each split of each input multistream.
62

63
    Args:
64
        subsets: Dict of named SourceOperator objects (each to yield a MultiStream), or a list thereof
65
        splits: List of splits (stream_names) to include, over all input multistreams. If None, all splits are included.
66
        max_instances_per_subset: Number of instances to take from each input split of each input multistream.
67
            If None, all instances of each split (that is specified in include_splits) are included in the result.
68

69
    """
70

71
    max_instances_per_subset: Optional[int] = None
1✔
72

73
    def prepare(self):
1✔
74
        super().prepare()
1✔
75

76
    # flake8: noqa: C901
77
    def fusion_generator(self, split) -> Generator:
1✔
78
        for origin_name, origin in self.named_subsets.items():
1✔
79
            multi_stream = origin()
1✔
80
            if split not in multi_stream:
1✔
81
                continue
1✔
82
            emitted_from_this_split = 0
1✔
83
            try:
1✔
84
                for instance in multi_stream[split]:
1✔
85
                    if (
1✔
86
                        self.max_instances_per_subset is not None
87
                        and emitted_from_this_split >= self.max_instances_per_subset
88
                    ):
89
                        break
1✔
90
                    if isinstance(origin_name, str):
1✔
91
                        if "subset" not in instance:
1✔
92
                            instance["subset"] = []
1✔
93
                        instance["subset"].insert(0, origin_name)
1✔
94
                    emitted_from_this_split += 1
1✔
95
                    yield instance
1✔
96
            except Exception as e:
1✔
97
                raise RuntimeError(f"Exception in subset: {origin_name}") from e
×
98

99

100
class WeightedFusion(BaseFusion):
1✔
101
    """Fusion operator that combines multiple MultiStream-s.
102

103
    Args:
104
        subsets: Dict of named MultiStream objects, or a list thereof
105
        weights: Dict of named weights for each origin, or a list thereof
106
        max_total_examples: Total number of instances to return per returned split.
107
            If None, all instances are returned
108
    """
109

110
    subsets: Union[Dict[str, SourceOperator], List[SourceOperator]] = None
1✔
111
    weights: Union[Dict[str, Union[float, int]], List[Union[int, float]]] = None
1✔
112
    max_total_samples: int = None
1✔
113

114
    def verify(self):
1✔
115
        super().verify()
1✔
116
        assert self.subsets is not None, "subsets must be specified"
1✔
117
        assert self.weights is not None, "weights must be specified"
1✔
118
        assert len(self.subsets) == len(
1✔
119
            self.weights
120
        ), "subsets and weights must have the same length"
121
        assert isoftype(self.subsets, Dict[str, SourceOperator]) or isoftype(
1✔
122
            self.subsets, List[SourceOperator]
123
        )
124
        assert isoftype(self.weights, Dict[str, Union[int, float]]) or isoftype(
1✔
125
            self.weights, List[Union[int, float]]
126
        )
127
        assert isinstance(self.subsets, dict) == isinstance(self.weights, dict)
1✔
128

129
    def prepare(self):
1✔
130
        super().prepare()
1✔
131
        self.named_weights = (
1✔
132
            {i: float(self.weights[i]) for i in range(len(self.weights))}
133
            if isinstance(self.weights, list)
134
            else {k: float(v) for (k, v) in self.weights.items()}
135
        )
136

137
    def fusion_generator(self, split) -> Generator:
1✔
138
        iterators = {
1✔
139
            named_origin: iter(origin()[split])
140
            for named_origin, origin in self.named_subsets.items()
141
        }
142
        total_examples = 0
1✔
143
        random_generator = new_random_generator(sub_seed="weighted_fusion_" + split)
1✔
144
        while (
1✔
145
            self.max_total_samples is None or total_examples < self.max_total_samples
146
        ) and len(iterators) > 0:
147
            population = list(iterators.keys())
1✔
148
            origin_name = random_generator.choices(
1✔
149
                population=population,
150
                weights=[self.named_weights[name] for name in population],
151
            )[0]
152
            iterator = iterators[origin_name]
1✔
153
            try:
1✔
154
                instance = next(iterator)
1✔
155
                if isinstance(origin_name, str):
1✔
156
                    if "subset" not in instance:
1✔
157
                        instance["subset"] = []
1✔
158
                    instance["subset"].insert(0, origin_name)
1✔
159
                total_examples += 1
1✔
160
                yield instance
1✔
161

162
            except StopIteration:
1✔
163
                iterators.pop(origin_name)
1✔
164
            except Exception as e:
1✔
165
                raise RuntimeError(f"Exception in subset: {origin_name}") from e
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc