• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12535194121

29 Dec 2024 12:03PM UTC coverage: 80.228% (+0.2%) from 80.023%
12535194121

Pull #1459

github

web-flow
Merge 7067995c0 into def3e0ea1
Pull Request #1459: Add MapReduceMetric a new base class to integrate all metrics into

1365 of 1695 branches covered (80.53%)

Branch coverage included in aggregate %.

8629 of 10762 relevant lines covered (80.18%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.17
src/unitxt/operators.py
1
"""This section describes unitxt operators.
2

3
Operators: Building Blocks of Unitxt Processing Pipelines
4
==============================================================
5

6
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
These operators are callable entities that receive a MultiStream as input.
9
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10

11
Creating Custom Operators
12
-------------------------------
13
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16

17
General or Specialized Operators
18
--------------------------------
19
Some operators are specialized in specific data or specific operations such as:
20

21
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
- :class:`splitters<unitxt.splitters>` for fixing data splits.
23
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
- :class:`string_operators<unitxt.string_operators>` for handling strings.
28
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30

31
Other specialized operators are used by unitxt internally:
32

33
- :class:`templates<unitxt.templates>` for verbalizing data examples.
34
- :class:`formats<unitxt.formats>` for preparing data for models.
35

36
The rest of this section is dedicated to general operators.
37

38
General Operators List:
39
------------------------
40
"""
41

42
import operator
1✔
43
import uuid
1✔
44
import warnings
1✔
45
import zipfile
1✔
46
from abc import abstractmethod
1✔
47
from collections import Counter, defaultdict
1✔
48
from dataclasses import field
1✔
49
from itertools import zip_longest
1✔
50
from random import Random
1✔
51
from typing import (
1✔
52
    Any,
53
    Callable,
54
    Dict,
55
    Generator,
56
    Iterable,
57
    List,
58
    Optional,
59
    Tuple,
60
    Union,
61
)
62

63
import requests
1✔
64

65
from .artifact import Artifact, fetch_artifact
1✔
66
from .dataclass import NonPositionalField, OptionalField
1✔
67
from .deprecation_utils import deprecation
1✔
68
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
1✔
69
from .generator_utils import ReusableGenerator
1✔
70
from .operator import (
1✔
71
    InstanceOperator,
72
    MultiStream,
73
    MultiStreamOperator,
74
    PagedStreamOperator,
75
    SequentialOperator,
76
    SideEffectOperator,
77
    SingleStreamReducer,
78
    SourceOperator,
79
    StreamingOperator,
80
    StreamInitializerOperator,
81
    StreamOperator,
82
)
83
from .random_utils import new_random_generator
1✔
84
from .settings_utils import get_settings
1✔
85
from .stream import DynamicStream, Stream
1✔
86
from .text_utils import nested_tuple_to_string
1✔
87
from .type_utils import isoftype
1✔
88
from .utils import (
1✔
89
    LRUCache,
90
    deep_copy,
91
    flatten_dict,
92
    recursive_copy,
93
    recursive_shallow_copy,
94
    shallow_copy,
95
)
96

97
settings = get_settings()
1✔
98

99

100
class FromIterables(StreamInitializerOperator):
1✔
101
    """Creates a MultiStream from a dict of named iterables.
102

103
    Example:
104
        operator = FromIterables()
105
        ms = operator.process(iterables)
106

107
    """
108

109
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
1✔
110
        return MultiStream.from_iterables(iterables)
1✔
111

112

113
class IterableSource(SourceOperator):
1✔
114
    """Creates a MultiStream from a dict of named iterables.
115

116
    It is a callable.
117

118
    Args:
119
        iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
120

121
    Example:
122
        operator =  IterableSource(input_dict)
123
        ms = operator()
124

125
    """
126

127
    iterables: Dict[str, Iterable]
1✔
128

129
    def process(self) -> MultiStream:
1✔
130
        return MultiStream.from_iterables(self.iterables)
1✔
131

132

133
class MapInstanceValues(InstanceOperator):
1✔
134
    """A class used to map instance values into other values.
135

136
    This class is a type of ``InstanceOperator``,
137
    it maps values of instances in a stream using predefined mappers.
138

139
    Args:
140
        mappers (Dict[str, Dict[str, Any]]):
141
            The mappers to use for mapping instance values.
142
            Keys are the names of the fields to undergo mapping, and values are dictionaries
143
            that define the mapping from old values to new values.
144
            Note that mapped values are defined by their string representation, so mapped values
145
            are converted to strings before being looked up in the mappers.
146
        strict (bool):
147
            If True, the mapping is applied strictly. That means if a value
148
            does not exist in the mapper, it will raise a KeyError. If False, values
149
            that are not present in the mapper are kept as they are.
150
        process_every_value (bool):
151
            If True, all fields to be mapped should be lists, and the mapping
152
            is to be applied to their individual elements.
153
            If False, mapping is only applied to a field containing a single value.
154

155
    Examples:
156
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
157
        replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
158
        instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
159
        since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
160
        up in the mapper of ``"a"``.
161

162
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
163
        Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
164
        each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
165
        instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
166

167
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
168
        To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
169
        Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
170
        because ``"3"`` is not a key in the mapper of ``"a"``.
171

172
        ``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
173
        replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
174

175
    """
176

177
    mappers: Dict[str, Dict[str, str]]
1✔
178
    strict: bool = True
1✔
179
    process_every_value: bool = False
1✔
180

181
    def verify(self):
1✔
182
        # make sure the mappers are valid
183
        for key, mapper in self.mappers.items():
1✔
184
            assert isinstance(
1✔
185
                mapper, dict
186
            ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
187
            for k in mapper.keys():
1✔
188
                assert isinstance(
1✔
189
                    k, str
190
                ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
191

192
    def process(
1✔
193
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
194
    ) -> Dict[str, Any]:
195
        for key, mapper in self.mappers.items():
1✔
196
            value = dict_get(instance, key)
1✔
197
            if value is not None:
1✔
198
                if (self.process_every_value is True) and (not isinstance(value, list)):
1✔
199
                    raise ValueError(
1✔
200
                        f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
201
                    )
202
                if isinstance(value, list) and self.process_every_value:
1✔
203
                    for i, val in enumerate(value):
1✔
204
                        value[i] = self.get_mapped_value(instance, key, mapper, val)
1✔
205
                else:
206
                    value = self.get_mapped_value(instance, key, mapper, value)
1✔
207
                dict_set(
1✔
208
                    instance,
209
                    key,
210
                    value,
211
                )
212

213
        return instance
1✔
214

215
    def get_mapped_value(self, instance, key, mapper, val):
1✔
216
        val_as_str = str(val)  # make sure the value is a string
1✔
217
        if val_as_str in mapper:
1✔
218
            return recursive_copy(mapper[val_as_str])
1✔
219
        if self.strict:
1✔
220
            raise KeyError(
1✔
221
                f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
222
            )
223
        return val
1✔
224

225

226
class FlattenInstances(InstanceOperator):
1✔
227
    """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
228

229
    Args:
230
        parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
231
        sep (str): The separator to use when concatenating nested keys. Defaults to "_".
232
    """
233

234
    parent_key: str = ""
1✔
235
    sep: str = "_"
1✔
236

237
    def process(
1✔
238
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
239
    ) -> Dict[str, Any]:
240
        return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
1✔
241

242

243
class Set(InstanceOperator):
1✔
244
    """Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
245

246
    Args:
247
        fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
248

249
        use_deepcopy (bool) : Deep copy the input value to avoid later modifications
250

251
    Examples:
252
        # Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
253
        ``Set(fields={"classes": ["positive","negatives"]})``
254

255
        # In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
256
        ``Set(fields={"span/start": 0}``
257

258
        # In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
259
        ``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
260

261
        # Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
262
        ``Set(fields={"classes": alist}), use_deepcopy=True)``  if now alist is modified, still the instances remain intact.
263
    """
264

265
    fields: Dict[str, object]
1✔
266
    use_query: Optional[bool] = None
1✔
267
    use_deepcopy: bool = False
1✔
268

269
    def verify(self):
1✔
270
        super().verify()
1✔
271
        if self.use_query is not None:
1✔
272
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
273
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
274

275
    def process(
1✔
276
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
277
    ) -> Dict[str, Any]:
278
        for key, value in self.fields.items():
1✔
279
            if self.use_deepcopy:
1✔
280
                value = deep_copy(value)
1✔
281
            dict_set(instance, key, value)
1✔
282
        return instance
1✔
283

284

285
@deprecation(version="2.0.0", alternative=Set)
1✔
286
class AddFields(Set):
1✔
287
    pass
1✔
288

289

290
class RemoveFields(InstanceOperator):
1✔
291
    """Remove specified fields from each instance in a stream.
292

293
    Args:
294
        fields (List[str]): The fields to remove from each instance.
295
    """
296

297
    fields: List[str]
1✔
298

299
    def process(
1✔
300
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
301
    ) -> Dict[str, Any]:
302
        for field_name in self.fields:
1✔
303
            del instance[field_name]
1✔
304
        return instance
1✔
305

306

307
class SelectFields(InstanceOperator):
1✔
308
    """Keep only specified fields from each instance in a stream.
309

310
    Args:
311
        fields (List[str]): The fields to keep from each instance.
312
    """
313

314
    fields: List[str]
1✔
315

316
    def prepare(self):
1✔
317
        super().prepare()
1✔
318
        self.fields.extend(["data_classification_policy", "recipe_metadata"])
1✔
319

320
    def process(
1✔
321
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
322
    ) -> Dict[str, Any]:
323
        new_instance = {}
1✔
324
        for selected_field in self.fields:
1✔
325
            new_instance[selected_field] = instance[selected_field]
1✔
326
        return new_instance
1✔
327

328

329
class DefaultPlaceHolder:
1✔
330
    pass
1✔
331

332

333
default_place_holder = DefaultPlaceHolder()
1✔
334

335

336
class InstanceFieldOperator(InstanceOperator):
1✔
337
    """A general stream instance operator that processes the values of a field (or multiple ones).
338

339
    Args:
340
        field (Optional[str]):
341
            The field to process, if only a single one is passed. Defaults to None
342
        to_field (Optional[str]):
343
            Field name to save result into, if only one field is processed, if None is passed the
344
            operation would happen in-place and its result would replace the value of ``field``. Defaults to None
345
        field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
346
            Mapping from names of fields to process,
347
            to names of fields to save the results into. Inner List, if used, should be of length 2.
348
            A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
349
            is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
350
            in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
351
            order. The end result might depend on that order if either (1) two different fields are mapped to the same
352
            to_field, or (2) a field shows both as a key and as a value in different mappings.
353
            The operator throws an AssertionError in either of these cases. ``field_to_field``
354
            defaults to None.
355
        process_every_value (bool):
356
            Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
357

358
    Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
359
    prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
360

361
    """
362

363
    field: Optional[str] = None
1✔
364
    to_field: Optional[str] = None
1✔
365
    field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
1✔
366
    use_query: Optional[bool] = None
1✔
367
    process_every_value: bool = False
1✔
368
    get_default: Any = None
1✔
369
    not_exist_ok: bool = False
1✔
370
    not_exist_do_nothing: bool = False
1✔
371

372
    def verify(self):
1✔
373
        super().verify()
1✔
374
        if self.use_query is not None:
1✔
375
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
1✔
376
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
377

378
    def verify_field_definition(self):
1✔
379
        if hasattr(self, "_field_to_field") and self._field_to_field is not None:
1✔
380
            return
1✔
381
        assert (
1✔
382
            (self.field is None) != (self.field_to_field is None)
383
        ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
384
        assert (
1✔
385
            self.to_field is None or self.field_to_field is None
386
        ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
387

388
        if self.field_to_field is None:
1✔
389
            self._field_to_field = [
1✔
390
                (self.field, self.to_field if self.to_field is not None else self.field)
391
            ]
392
        else:
393
            self._field_to_field = (
1✔
394
                list(self.field_to_field.items())
395
                if isinstance(self.field_to_field, dict)
396
                else self.field_to_field
397
            )
398
        assert (
1✔
399
            self.field is not None or self.field_to_field is not None
400
        ), "Must supply a field to work on"
401
        assert (
1✔
402
            self.to_field is None or self.field_to_field is None
403
        ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
404
        assert (
1✔
405
            self.field is None or self.field_to_field is None
406
        ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
407
        assert (
1✔
408
            self._field_to_field is not None
409
        ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
410
        assert (
1✔
411
            len(self._field_to_field) > 0
412
        ), f"'input argument '{self.__class__.__name__}.field_to_field' should convey at least one field to process. Got {self.field_to_field}"
413
        # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
414
        if self.field_to_field is None:
1✔
415
            return
1✔
416
        # for backward compatibility also allow list of tuples of two strings
417
        if isoftype(self.field_to_field, List[List[str]]) or isoftype(
1✔
418
            self.field_to_field, List[Tuple[str, str]]
419
        ):
420
            for pair in self._field_to_field:
1✔
421
                assert (
1✔
422
                    len(pair) == 2
423
                ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
424
            # order of field processing is uniquely determined by the input field_to_field when a list
425
            return
1✔
426
        if isoftype(self.field_to_field, Dict[str, str]):
1✔
427
            if len(self.field_to_field) < 2:
1✔
428
                return
1✔
429
            for ff, tt in self.field_to_field.items():
1✔
430
                for f, t in self.field_to_field.items():
1✔
431
                    if f == ff:
1✔
432
                        continue
1✔
433
                    assert (
1✔
434
                        t != ff
435
                    ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
436
                    assert (
1✔
437
                        tt != t
438
                    ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
439
            return
1✔
440
        raise ValueError(
1✔
441
            "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
442
        )
443

444
    @abstractmethod
1✔
445
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
446
        pass
×
447

448
    def process(
1✔
449
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
450
    ) -> Dict[str, Any]:
451
        self.verify_field_definition()
1✔
452
        for from_field, to_field in self._field_to_field:
1✔
453
            try:
1✔
454
                old_value = dict_get(
1✔
455
                    instance,
456
                    from_field,
457
                    default=default_place_holder,
458
                    not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
459
                )
460
                if old_value is default_place_holder:
1✔
461
                    if self.not_exist_do_nothing:
1✔
462
                        continue
1✔
463
                    old_value = self.get_default
×
464
            except Exception as e:
1✔
465
                raise ValueError(
1✔
466
                    f"Failed to get '{from_field}' from instance due to the exception above."
467
                ) from e
468
            try:
1✔
469
                if self.process_every_value:
1✔
470
                    new_value = [
1✔
471
                        self.process_instance_value(value, instance)
472
                        for value in old_value
473
                    ]
474
                else:
475
                    new_value = self.process_instance_value(old_value, instance)
1✔
476
            except Exception as e:
1✔
477
                raise ValueError(
1✔
478
                    f"Failed to process field '{from_field}' from instance due to the exception above."
479
                ) from e
480
            dict_set(
1✔
481
                instance,
482
                to_field,
483
                new_value,
484
                not_exist_ok=True,
485
            )
486
        return instance
1✔
487

488

489
class FieldOperator(InstanceFieldOperator):
1✔
490
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
491
        return self.process_value(value)
1✔
492

493
    @abstractmethod
1✔
494
    def process_value(self, value: Any) -> Any:
1✔
495
        pass
1✔
496

497

498
class MapValues(FieldOperator):
1✔
499
    mapping: Dict[str, str]
1✔
500

501
    def process_value(self, value: Any) -> Any:
1✔
502
        return self.mapping[str(value)]
×
503

504

505
class Rename(FieldOperator):
1✔
506
    """Renames fields.
507

508
    Move value from one field to another, potentially, if field name contains a /, from one branch into another.
509
    Remove the from field, potentially part of it in case of / in from_field.
510

511
    Examples:
512
        Rename(field_to_field={"b": "c"})
513
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
514

515
        Rename(field_to_field={"b": "c/d"})
516
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
517

518
        Rename(field_to_field={"b": "b/d"})
519
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
520

521
        Rename(field_to_field={"b/c/e": "b/d"})
522
        will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
523

524
    """
525

526
    def process_value(self, value: Any) -> Any:
1✔
527
        return value
1✔
528

529
    def process(
1✔
530
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
531
    ) -> Dict[str, Any]:
532
        res = super().process(instance=instance, stream_name=stream_name)
1✔
533
        for from_field, to_field in self._field_to_field:
1✔
534
            if (not is_subpath(from_field, to_field)) and (
1✔
535
                not is_subpath(to_field, from_field)
536
            ):
537
                dict_delete(res, from_field, remove_empty_ancestors=True)
1✔
538

539
        return res
1✔
540

541

542
@deprecation(version="2.0.0", alternative=Rename)
1✔
543
class RenameFields(Rename):
1✔
544
    pass
1✔
545

546

547
class AddConstant(FieldOperator):
1✔
548
    """Adds a constant, being argument 'add', to the processed value.
549

550
    Args:
551
        add: the constant to add.
552
    """
553

554
    add: Any
1✔
555

556
    def process_value(self, value: Any) -> Any:
1✔
557
        return self.add + value
1✔
558

559

560
class ShuffleFieldValues(FieldOperator):
1✔
561
    """Shuffles a list of values found in a field."""
562

563
    def process_value(self, value: Any) -> Any:
1✔
564
        res = list(value)
1✔
565
        random_generator = new_random_generator(sub_seed=res)
1✔
566
        random_generator.shuffle(res)
1✔
567
        return res
1✔
568

569

570
class JoinStr(FieldOperator):
1✔
571
    """Joins a list of strings (contents of a field), similar to str.join().
572

573
    Args:
574
        separator (str): text to put between values
575
    """
576

577
    separator: str = ","
1✔
578

579
    def process_value(self, value: Any) -> Any:
1✔
580
        return self.separator.join(str(x) for x in value)
1✔
581

582

583
class Apply(InstanceOperator):
1✔
584
    """A class used to apply a python function and store the result in a field.
585

586
    Args:
587
        function (str): name of function.
588
        to_field (str): the field to store the result
589

590
    any additional arguments are field names whose values will be passed directly to the function specified
591

592
    Examples:
593
    Store in field  "b" the uppercase string of the value in field "a":
594
    ``Apply("a", function=str.upper, to_field="b")``
595

596
    Dump the json representation of field "t" and store back in the same field:
597
    ``Apply("t", function=json.dumps, to_field="t")``
598

599
    Set the time in a field 'b':
600
    ``Apply(function=time.time, to_field="b")``
601

602
    """
603

604
    __allow_unexpected_arguments__ = True
1✔
605
    function: Callable = NonPositionalField(required=True)
1✔
606
    to_field: str = NonPositionalField(required=True)
1✔
607

608
    def function_to_str(self, function: Callable) -> str:
1✔
609
        parts = []
1✔
610

611
        if hasattr(function, "__module__"):
1✔
612
            parts.append(function.__module__)
1✔
613
        if hasattr(function, "__qualname__"):
1✔
614
            parts.append(function.__qualname__)
1✔
615
        else:
616
            parts.append(function.__name__)
×
617

618
        return ".".join(parts)
1✔
619

620
    def str_to_function(self, function_str: str) -> Callable:
1✔
621
        parts = function_str.split(".", 1)
1✔
622
        if len(parts) == 1:
1✔
623
            return __builtins__[parts[0]]
1✔
624

625
        module_name, function_name = parts
1✔
626
        if module_name in __builtins__:
1✔
627
            obj = __builtins__[module_name]
1✔
628
        elif module_name in globals():
1✔
629
            obj = globals()[module_name]
×
630
        else:
631
            obj = __import__(module_name)
1✔
632
        for part in function_name.split("."):
1✔
633
            obj = getattr(obj, part)
1✔
634
        return obj
1✔
635

636
    def prepare(self):
1✔
637
        super().prepare()
1✔
638
        if isinstance(self.function, str):
1✔
639
            self.function = self.str_to_function(self.function)
1✔
640
        self._init_dict["function"] = self.function_to_str(self.function)
1✔
641

642
    def process(
1✔
643
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
644
    ) -> Dict[str, Any]:
645
        argv = [instance[arg] for arg in self._argv]
1✔
646
        kwargs = {key: instance[val] for key, val in self._kwargs}
1✔
647

648
        result = self.function(*argv, **kwargs)
1✔
649

650
        instance[self.to_field] = result
1✔
651
        return instance
1✔
652

653

654
class ListFieldValues(InstanceOperator):
1✔
655
    """Concatenates values of multiple fields into a list, and assigns it to a new field."""
656

657
    fields: List[str]
1✔
658
    to_field: str
1✔
659
    use_query: Optional[bool] = None
1✔
660

661
    def verify(self):
1✔
662
        super().verify()
1✔
663
        if self.use_query is not None:
1✔
664
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
665
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
666

667
    def process(
1✔
668
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
669
    ) -> Dict[str, Any]:
670
        values = []
1✔
671
        for field_name in self.fields:
1✔
672
            values.append(dict_get(instance, field_name))
1✔
673

674
        dict_set(instance, self.to_field, values)
1✔
675

676
        return instance
1✔
677

678

679
class ZipFieldValues(InstanceOperator):
1✔
680
    """Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
681

682
    The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
683
    are zipped, and stored into 'to_field'.
684

685
    | If 'longest'=False, the length of the zipped result is determined by the shortest input value.
686
    | If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
687

688
    """
689

690
    fields: List[str]
1✔
691
    to_field: str
1✔
692
    longest: bool = False
1✔
693
    use_query: Optional[bool] = None
1✔
694

695
    def verify(self):
1✔
696
        super().verify()
1✔
697
        if self.use_query is not None:
1✔
698
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
699
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
700

701
    def process(
1✔
702
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
703
    ) -> Dict[str, Any]:
704
        values = []
1✔
705
        for field_name in self.fields:
1✔
706
            values.append(dict_get(instance, field_name))
1✔
707
        if self.longest:
1✔
708
            zipped = zip_longest(*values)
1✔
709
        else:
710
            zipped = zip(*values)
1✔
711
        dict_set(instance, self.to_field, list(zipped))
1✔
712
        return instance
1✔
713

714

715
class InterleaveListsToDialogOperator(InstanceOperator):
1✔
716
    """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
717

718
    The list of tuples if of format (role, turn_content), where the role label is specified by
719
    the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
720

721
    The user turns and assistant turns field are specified in the arguments.
722
    The value of each of the 'fields' is assumed to be a list.
723

724
    """
725

726
    user_turns_field: str
1✔
727
    assistant_turns_field: str
1✔
728
    user_role_label: str = "user"
1✔
729
    assistant_role_label: str = "assistant"
1✔
730
    to_field: str
1✔
731

732
    def process(
1✔
733
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
734
    ) -> Dict[str, Any]:
735
        user_turns = instance[self.user_turns_field]
×
736
        assistant_turns = instance[self.assistant_turns_field]
×
737

738
        assert (
×
739
            len(user_turns) == len(assistant_turns)
740
            or (len(user_turns) - len(assistant_turns) == 1)
741
        ), "user_turns must have either the same length as assistant_turns or one more turn."
742

743
        interleaved_dialog = []
×
744
        i, j = 0, 0  # Indices for the user and assistant lists
×
745
        # While either list has elements left, continue interleaving
746
        while i < len(user_turns) or j < len(assistant_turns):
×
747
            if i < len(user_turns):
×
748
                interleaved_dialog.append((self.user_role_label, user_turns[i]))
×
749
                i += 1
×
750
            if j < len(assistant_turns):
×
751
                interleaved_dialog.append(
×
752
                    (self.assistant_role_label, assistant_turns[j])
753
                )
754
                j += 1
×
755

756
        instance[self.to_field] = interleaved_dialog
×
757
        return instance
×
758

759

760
class IndexOf(InstanceOperator):
1✔
761
    """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
762

763
    search_in: str
1✔
764
    index_of: str
1✔
765
    to_field: str
1✔
766
    use_query: Optional[bool] = None
1✔
767

768
    def verify(self):
1✔
769
        super().verify()
1✔
770
        if self.use_query is not None:
1✔
771
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
772
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
773

774
    def process(
1✔
775
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
776
    ) -> Dict[str, Any]:
777
        lst = dict_get(instance, self.search_in)
1✔
778
        item = dict_get(instance, self.index_of)
1✔
779
        instance[self.to_field] = lst.index(item)
1✔
780
        return instance
1✔
781

782

783
class TakeByField(InstanceOperator):
1✔
784
    """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
785

786
    field: str
1✔
787
    index: str
1✔
788
    to_field: str = None
1✔
789
    use_query: Optional[bool] = None
1✔
790

791
    def verify(self):
1✔
792
        super().verify()
1✔
793
        if self.use_query is not None:
1✔
794
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
795
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
796

797
    def prepare(self):
1✔
798
        if self.to_field is None:
1✔
799
            self.to_field = self.field
1✔
800

801
    def process(
1✔
802
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
803
    ) -> Dict[str, Any]:
804
        value = dict_get(instance, self.field)
1✔
805
        index_value = dict_get(instance, self.index)
1✔
806
        instance[self.to_field] = value[index_value]
1✔
807
        return instance
1✔
808

809

810
class Perturb(FieldOperator):
1✔
811
    """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
812

813
    When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
814
    relevant perturbation
815

816
    Args:
817
        percentage_to_perturb (int):
818
            the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
819
        select_from: List[Any]:
820
            a list of values to select from, as a perturbation of the field's value. Defaults to [].
821
    """
822

823
    select_from: List[Any] = []
1✔
824
    percentage_to_perturb: int = 1  # 1 percent
1✔
825

826
    def verify(self):
1✔
827
        assert (
1✔
828
            0 <= self.percentage_to_perturb and self.percentage_to_perturb <= 100
829
        ), f"'percentage_to_perturb' should be in the range 0..100. Received {self.percentage_to_perturb}"
830

831
    def prepare(self):
1✔
832
        super().prepare()
1✔
833
        self.random_generator = new_random_generator(sub_seed="CopyWithPerturbation")
1✔
834

835
    def process_value(self, value: Any) -> Any:
1✔
836
        perturb = self.random_generator.randint(1, 100) <= self.percentage_to_perturb
1✔
837
        if not perturb:
1✔
838
            return value
1✔
839

840
        if value in self.select_from:
1✔
841
            # 80% of cases, return a decent class, otherwise, perturb the value itself as follows
842
            if self.random_generator.random() < 0.8:
1✔
843
                return self.random_generator.choice(self.select_from)
1✔
844

845
        if isinstance(value, float):
1✔
846
            return value * (0.5 + self.random_generator.random())
1✔
847

848
        if isinstance(value, int):
1✔
849
            perturb = 1 if self.random_generator.random() < 0.5 else -1
1✔
850
            return value + perturb
1✔
851

852
        if isinstance(value, str):
1✔
853
            if len(value) < 2:
1✔
854
                # give up perturbation
855
                return value
1✔
856
            # throw one char out
857
            prefix_len = self.random_generator.randint(1, len(value) - 1)
1✔
858
            return value[:prefix_len] + value[prefix_len + 1 :]
1✔
859

860
        # and in any other case:
861
        return value
×
862

863

864
class Copy(FieldOperator):
1✔
865
    """Copies values from specified fields to specified fields.
866

867
    Args (of parent class):
868
        field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
869

870
    Examples:
871
        An input instance {"a": 2, "b": 3}, when processed by
872
        ``Copy(field_to_field={"a": "b"})``
873
        would yield {"a": 2, "b": 2}, and when processed by
874
        ``Copy(field_to_field={"a": "c"})`` would yield
875
        {"a": 2, "b": 3, "c": 2}
876

877
        with field names containing / , we can also copy inside the field:
878
        ``Copy(field="a/0",to_field="a")``
879
        would process instance {"a": [1, 3]} into {"a": 1}
880

881

882
    """
883

884
    def process_value(self, value: Any) -> Any:
1✔
885
        return value
1✔
886

887

888
class RecursiveCopy(FieldOperator):
1✔
889
    def process_value(self, value: Any) -> Any:
1✔
890
        return recursive_copy(value)
1✔
891

892

893
@deprecation(version="2.0.0", alternative=Copy)
1✔
894
class CopyFields(Copy):
1✔
895
    pass
1✔
896

897

898
class GetItemByIndex(FieldOperator):
1✔
899
    """Get from the item list by the index in the field."""
900

901
    items_list: List[Any]
1✔
902

903
    def process_value(self, value: Any) -> Any:
1✔
904
        return self.items_list[value]
×
905

906

907
class AddID(InstanceOperator):
1✔
908
    """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
909

910
    id_field_name: str = "id"
1✔
911

912
    def process(
1✔
913
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
914
    ) -> Dict[str, Any]:
915
        instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
1✔
916
        return instance
1✔
917

918

919
class Cast(FieldOperator):
1✔
920
    """Casts specified fields to specified types.
921

922
    Args:
923
        default (object): A dictionary mapping field names to default values for cases of casting failure.
924
        process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
925
    """
926

927
    to: str
1✔
928
    failure_default: Optional[Any] = "__UNDEFINED__"
1✔
929

930
    def prepare(self):
1✔
931
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
932

933
    def process_value(self, value):
1✔
934
        try:
1✔
935
            return self.types[self.to](value)
1✔
936
        except ValueError as e:
1✔
937
            if self.failure_default == "__UNDEFINED__":
1✔
938
                raise ValueError(
×
939
                    f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
940
                ) from e
941
            return self.failure_default
1✔
942

943

944
class CastFields(InstanceOperator):
1✔
945
    """Casts specified fields to specified types.
946

947
    Args:
948
        fields (Dict[str, str]):
949
            A dictionary mapping field names to the names of the types to cast the fields to.
950
            e.g: "int", "str", "float", "bool". Basic names of types
951
        defaults (Dict[str, object]):
952
            A dictionary mapping field names to default values for cases of casting failure.
953
        process_every_value (bool):
954
            If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
955

956
    Example:
957
        .. code-block:: python
958

959
                CastFields(
960
                    fields={"a/d": "float", "b": "int"},
961
                    failure_defaults={"a/d": 0.0, "b": 0},
962
                    process_every_value=True,
963
                )
964

965
    would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
966
    into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
967

968
    """
969

970
    fields: Dict[str, str] = field(default_factory=dict)
1✔
971
    failure_defaults: Dict[str, object] = field(default_factory=dict)
1✔
972
    use_nested_query: bool = None  # deprecated field
1✔
973
    process_every_value: bool = False
1✔
974

975
    def prepare(self):
1✔
976
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
977

978
    def verify(self):
1✔
979
        super().verify()
1✔
980
        if self.use_nested_query is not None:
1✔
981
            depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
1✔
982
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
983

984
    def _cast_single(self, value, type, field):
1✔
985
        try:
1✔
986
            return self.types[type](value)
1✔
987
        except Exception as e:
1✔
988
            if field not in self.failure_defaults:
1✔
989
                raise ValueError(
1✔
990
                    f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
991
                ) from e
992
            return self.failure_defaults[field]
1✔
993

994
    def _cast_multiple(self, values, type, field):
1✔
995
        return [self._cast_single(value, type, field) for value in values]
1✔
996

997
    def process(
1✔
998
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
999
    ) -> Dict[str, Any]:
1000
        for field_name, type in self.fields.items():
1✔
1001
            value = dict_get(instance, field_name)
1✔
1002
            if self.process_every_value:
1✔
1003
                assert isinstance(
1✔
1004
                    value, list
1005
                ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
1006
                casted_value = self._cast_multiple(value, type, field_name)
1✔
1007
            else:
1008
                casted_value = self._cast_single(value, type, field_name)
1✔
1009

1010
            dict_set(instance, field_name, casted_value)
1✔
1011
        return instance
1✔
1012

1013

1014
class DivideAllFieldsBy(InstanceOperator):
1✔
1015
    """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1016

1017
    The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
1018
    the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
1019
    if input flag 'strict' is True, or -- left alone, if 'strict' is False.
1020

1021
    Args:
1022
        divisor (float) the value to divide by
1023
        strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
1024

1025
    Example:
1026
        when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
1027
        operator = DivideAllFieldsBy(divisor=2.0)
1028
        the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
1029
        If the operator were defined with strict=True, through:
1030
        operator = DivideAllFieldsBy(divisor=2.0, strict=True),
1031
        the processing of the above instance would raise a ValueError, for the integer at "c".
1032
    """
1033

1034
    divisor: float = 1.0
1✔
1035
    strict: bool = False
1✔
1036

1037
    def _recursive_divide(self, instance, divisor):
1✔
1038
        if isinstance(instance, dict):
1✔
1039
            for key, value in instance.items():
1✔
1040
                instance[key] = self._recursive_divide(value, divisor)
1✔
1041
        elif isinstance(instance, list):
1✔
1042
            for i, value in enumerate(instance):
1✔
1043
                instance[i] = self._recursive_divide(value, divisor)
1✔
1044
        elif isinstance(instance, float):
1✔
1045
            instance /= divisor
1✔
1046
        elif self.strict:
1✔
1047
            raise ValueError(f"Cannot divide instance of type {type(instance)}")
1✔
1048
        return instance
1✔
1049

1050
    def process(
1✔
1051
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1052
    ) -> Dict[str, Any]:
1053
        return self._recursive_divide(instance, self.divisor)
1✔
1054

1055

1056
class ArtifactFetcherMixin:
1✔
1057
    """Provides a way to fetch and cache artifacts in the system.
1058

1059
    Args:
1060
        cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
1061
    """
1062

1063
    _artifacts_cache = LRUCache(max_size=1000)
1✔
1064

1065
    @classmethod
1✔
1066
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
1✔
1067
        if str(artifact_identifier) not in cls._artifacts_cache:
1✔
1068
            artifact, catalog = fetch_artifact(artifact_identifier)
1✔
1069
            cls._artifacts_cache[str(artifact_identifier)] = artifact
1✔
1070
        return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])
1✔
1071

1072

1073
class ApplyOperatorsField(InstanceOperator):
1✔
1074
    """Applies value operators to each instance in a stream based on specified fields.
1075

1076
    Args:
1077
        operators_field (str): name of the field that contains a single name, or a list of names, of the operators to be applied,
1078
            one after the other, for the processing of the instance. Each operator is equipped with 'process_instance()'
1079
            method.
1080

1081
        default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
1082

1083
    Example:
1084
        when instance {"prediction": 111, "references": [222, 333] , "c": ["processors.to_string", "processors.first_character"]}
1085
        is processed by operator (please look up the catalog that these operators, they are tuned to process fields "prediction" and
1086
        "references"):
1087
        operator = ApplyOperatorsField(operators_field="c"),
1088
        the resulting instance is: {"prediction": "1", "references": ["2", "3"], "c": ["processors.to_string", "processors.first_character"]}
1089

1090
    """
1091

1092
    operators_field: str
1✔
1093
    default_operators: List[str] = None
1✔
1094

1095
    def process(
1✔
1096
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1097
    ) -> Dict[str, Any]:
1098
        operator_names = instance.get(self.operators_field)
1✔
1099
        if operator_names is None:
1✔
1100
            assert (
1✔
1101
                self.default_operators is not None
1102
            ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1103
            operator_names = self.default_operators
1✔
1104

1105
        if isinstance(operator_names, str):
1✔
1106
            operator_names = [operator_names]
1✔
1107
        # otherwise , operator_names is already a list
1108

1109
        # we now have a list of nanes of operators, each is equipped with process_instance method.
1110
        operator = SequentialOperator(steps=operator_names)
1✔
1111
        return operator.process_instance(instance, stream_name=stream_name)
1✔
1112

1113

1114
class FilterByCondition(StreamOperator):
1✔
1115
    """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1116

1117
    Raises an error if a required field name is missing from the input instance.
1118

1119
    Args:
1120
       values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
1121

1122
       condition: the name of the desired condition operator between the specified (sub) field's value  and the provided constant value.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1123

1124
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1125

1126
    Examples:
1127
       | ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
1128
       | ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
1129
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
1130
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
1131
       | ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
1132
       | ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
1133

1134

1135
    """
1136

1137
    values: Dict[str, Any]
1✔
1138
    condition: str
1✔
1139
    condition_to_func = {
1✔
1140
        "gt": operator.gt,
1141
        "ge": operator.ge,
1142
        "lt": operator.lt,
1143
        "le": operator.le,
1144
        "eq": operator.eq,
1145
        "ne": operator.ne,
1146
        "in": None,  # Handled as special case
1147
        "not in": None,  # Handled as special case
1148
    }
1149
    error_on_filtered_all: bool = True
1✔
1150

1151
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1152
        yielded = False
1✔
1153
        for instance in stream:
1✔
1154
            if self._is_required(instance):
1✔
1155
                yielded = True
1✔
1156
                yield instance
1✔
1157

1158
        if not yielded and self.error_on_filtered_all:
1✔
1159
            raise RuntimeError(
1✔
1160
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1161
            )
1162

1163
    def verify(self):
1✔
1164
        if self.condition not in self.condition_to_func:
1✔
1165
            raise ValueError(
1✔
1166
                f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1167
            )
1168

1169
        for key, value in self.values.items():
1✔
1170
            if self.condition in ["in", "not it"] and not isinstance(value, list):
1✔
1171
                raise ValueError(
1✔
1172
                    f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1173
                )
1174
        return super().verify()
1✔
1175

1176
    def _is_required(self, instance: dict) -> bool:
1✔
1177
        for key, value in self.values.items():
1✔
1178
            try:
1✔
1179
                instance_key = dict_get(instance, key)
1✔
1180
            except ValueError as ve:
1✔
1181
                raise ValueError(
1✔
1182
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1183
                ) from ve
1184
            if self.condition == "in":
1✔
1185
                if instance_key not in value:
1✔
1186
                    return False
1✔
1187
            elif self.condition == "not in":
1✔
1188
                if instance_key in value:
1✔
1189
                    return False
1✔
1190
            else:
1191
                func = self.condition_to_func[self.condition]
1✔
1192
                if func is None:
1✔
1193
                    raise ValueError(
×
1194
                        f"Function not defined for condition '{self.condition}'"
1195
                    )
1196
                if not func(instance_key, value):
1✔
1197
                    return False
1✔
1198
        return True
1✔
1199

1200

1201
class FilterByConditionBasedOnFields(FilterByCondition):
1✔
1202
    """Filters a stream based on a condition between 2 fields values.
1203

1204
    Raises an error if either of the required fields names is missing from the input instance.
1205

1206
    Args:
1207
       values (Dict[str, str]): The fields names that the filter operation is based on.
1208
       condition: the name of the desired condition operator between the specified field's values.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1209
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1210

1211
    Examples:
1212
       FilterByCondition(values = {"a":"b}, condition = "gt") will yield only instances where field "a" contains a value greater then the value in field "b".
1213
       FilterByCondition(values = {"a":"b}, condition = "le") will yield only instances where "a"<="b"
1214
    """
1215

1216
    def _is_required(self, instance: dict) -> bool:
1✔
1217
        for key, value in self.values.items():
1✔
1218
            try:
1✔
1219
                instance_key = dict_get(instance, key)
1✔
1220
            except ValueError as ve:
×
1221
                raise ValueError(
×
1222
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1223
                ) from ve
1224
            try:
1✔
1225
                instance_value = dict_get(instance, value)
1✔
1226
            except ValueError as ve:
×
1227
                raise ValueError(
×
1228
                    f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1229
                ) from ve
1230
            if self.condition == "in":
1✔
1231
                if instance_key not in instance_value:
×
1232
                    return False
×
1233
            elif self.condition == "not in":
1✔
1234
                if instance_key in instance_value:
×
1235
                    return False
×
1236
            else:
1237
                func = self.condition_to_func[self.condition]
1✔
1238
                if func is None:
1✔
1239
                    raise ValueError(
×
1240
                        f"Function not defined for condition '{self.condition}'"
1241
                    )
1242
                if not func(instance_key, instance_value):
1✔
1243
                    return False
×
1244
        return True
1✔
1245

1246

1247
class ComputeExpressionMixin(Artifact):
1✔
1248
    """Computes an expression expressed over fields of an instance.
1249

1250
    Args:
1251
        expression (str): the expression, in terms of names of fields of an instance
1252
        imports_list (List[str]): list of names of imports needed for the evaluation of the expression
1253
    """
1254

1255
    expression: str
1✔
1256
    imports_list: List[str] = OptionalField(default_factory=list)
1✔
1257

1258
    def prepare(self):
1✔
1259
        # can not do the imports here, because object does not pickle with imports
1260
        self.globals = {
1✔
1261
            module_name: __import__(module_name) for module_name in self.imports_list
1262
        }
1263

1264
    def compute_expression(self, instance: dict) -> Any:
1✔
1265
        if settings.allow_unverified_code:
1✔
1266
            return eval(self.expression, {**self.globals, **instance})
1✔
1267

1268
        raise ValueError(
×
1269
            f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1270
            "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1271
        )
1272

1273

1274
class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1✔
1275
    """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1276

1277
    Raises an error if a field participating in the specified condition is missing from the instance
1278

1279
    Args:
1280
        expression (str):
1281
            a condition over fields of the instance, to be processed by python's eval()
1282
        imports_list (List[str]):
1283
            names of imports needed for the eval of the query (e.g. 're', 'json')
1284
        error_on_filtered_all (bool, optional):
1285
            If True, raises an error if all instances are filtered out. Defaults to True.
1286

1287
    Examples:
1288
        | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1289
        | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1290
        | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1291
        | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1292
        | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1293
    """
1294

1295
    error_on_filtered_all: bool = True
1✔
1296

1297
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1298
        yielded = False
1✔
1299
        for instance in stream:
1✔
1300
            if self.compute_expression(instance):
1✔
1301
                yielded = True
1✔
1302
                yield instance
1✔
1303

1304
        if not yielded and self.error_on_filtered_all:
1✔
1305
            raise RuntimeError(
1✔
1306
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1307
            )
1308

1309

1310
class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1✔
1311
    """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1312

1313
    Raises an error if a field mentioned in the query is missing from the instance.
1314

1315
    Args:
1316
       expression (str): an expression to be evaluated over the fields of the instance
1317
       to_field (str): the field where the result is to be stored into
1318
       imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1319

1320
    Examples:
1321
       When instance {"a": 2, "b": 3} is process-ed by operator
1322
       ExecuteExpression(expression="a+b", to_field = "c")
1323
       the result is {"a": 2, "b": 3, "c": 5}
1324

1325
       When instance {"a": "hello", "b": "world"} is process-ed by operator
1326
       ExecuteExpression(expression = "a+' '+b", to_field = "c")
1327
       the result is {"a": "hello", "b": "world", "c": "hello world"}
1328

1329
    """
1330

1331
    to_field: str
1✔
1332

1333
    def process(
1✔
1334
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1335
    ) -> Dict[str, Any]:
1336
        instance[self.to_field] = self.compute_expression(instance)
1✔
1337
        return instance
1✔
1338

1339

1340
class ExtractMostCommonFieldValues(MultiStreamOperator):
1✔
1341
    field: str
1✔
1342
    stream_name: str
1✔
1343
    overall_top_frequency_percent: Optional[int] = 100
1✔
1344
    min_frequency_percent: Optional[int] = 0
1✔
1345
    to_field: str
1✔
1346
    process_every_value: Optional[bool] = False
1✔
1347

1348
    """
1349
    Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
1350
    as a list in a new field ('to_field') in all streams.
1351

1352
    More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
1353
    When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
1354
    the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
1355
    When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
1356
    less than 'min_frequency_percent' of the total number of instances in the stream.
1357
    At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
1358

1359
    Examples:
1360

1361
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1362
    field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1363
    every instance in all streams.
1364

1365
    ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1366
    in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1367
    value members in these lists, and report the most frequent values.
1368
    if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1369
    'to_field' of each instance of all streams.
1370

1371
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1372
    extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1373
    and stores them in field 'classes' of each instance of all streams.
1374

1375
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1376
    extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1377
    Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1378
    """
1379

1380
    def verify(self):
1✔
1381
        assert (
1✔
1382
            self.overall_top_frequency_percent <= 100
1383
            and self.overall_top_frequency_percent >= 0
1384
        ), "'overall_top_frequency_percent' must be between 0 and 100"
1385
        assert (
1✔
1386
            self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
1387
        ), "'min_frequency_percent' must be between 0 and 100"
1388
        assert not (
1✔
1389
            self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
1390
        ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
1391
        super().verify()
1✔
1392

1393
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1394
        stream = multi_stream[self.stream_name]
1✔
1395
        counter = Counter()
1✔
1396
        for instance in stream:
1✔
1397
            if (not isinstance(instance[self.field], list)) and (
1✔
1398
                self.process_every_value is True
1399
            ):
1400
                raise ValueError(
1✔
1401
                    "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
1402
                )
1403
            if (not isinstance(instance[self.field], list)) or (
1✔
1404
                self.process_every_value is False
1405
            ):
1406
                # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
1407
                counter.update(
1✔
1408
                    [(*instance[self.field],)]
1409
                    if isinstance(instance[self.field], list)
1410
                    else [instance[self.field]]
1411
                )  # convert to a tuple if list, to enable the use of Counter which would not accept
1412
                # a list as an hashable entity to count its occurrences
1413
            else:
1414
                # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
1415
                counter.update(instance[self.field])
1✔
1416
        # here counter counts occurrences of individual values, or tuples.
1417
        values_and_counts = counter.most_common()
1✔
1418
        if self.overall_top_frequency_percent < 100:
1✔
1419
            top_frequency = (
1✔
1420
                sum(counter.values()) * self.overall_top_frequency_percent / 100.0
1421
            )
1422
            sum_counts = 0
1✔
1423
            for _i, p in enumerate(values_and_counts):
1✔
1424
                sum_counts += p[1]
1✔
1425
                if sum_counts >= top_frequency:
1✔
1426
                    break
1✔
1427
            values_and_counts = counter.most_common(_i + 1)
1✔
1428
        if self.min_frequency_percent > 0:
1✔
1429
            min_frequency = self.min_frequency_percent * sum(counter.values()) / 100.0
1✔
1430
            while values_and_counts[-1][1] < min_frequency:
1✔
1431
                values_and_counts.pop()
1✔
1432
        values_to_keep = [
1✔
1433
            [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1434
            for ele in values_and_counts
1435
        ]
1436

1437
        addmostcommons = Set(fields={self.to_field: values_to_keep})
1✔
1438
        return addmostcommons(multi_stream)
1✔
1439

1440

1441
class ExtractFieldValues(ExtractMostCommonFieldValues):
1✔
1442
    def verify(self):
1✔
1443
        super().verify()
1✔
1444

1445
    def prepare(self):
1✔
1446
        self.overall_top_frequency_percent = 100
1✔
1447
        self.min_frequency_percent = 0
1✔
1448

1449

1450
class Intersect(FieldOperator):
1✔
1451
    """Intersects the value of a field, which must be a list, with a given list.
1452

1453
    Args:
1454
        allowed_values (list) - list to intersect.
1455
    """
1456

1457
    allowed_values: List[Any]
1✔
1458

1459
    def verify(self):
1✔
1460
        super().verify()
1✔
1461
        if self.process_every_value:
1✔
1462
            raise ValueError(
1✔
1463
                "'process_every_value=True' is not supported in Intersect operator"
1464
            )
1465

1466
        if not isinstance(self.allowed_values, list):
1✔
1467
            raise ValueError(
1✔
1468
                f"The allowed_values is not a list but '{self.allowed_values}'"
1469
            )
1470

1471
    def process_value(self, value: Any) -> Any:
1✔
1472
        super().process_value(value)
1✔
1473
        if not isinstance(value, list):
1✔
1474
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1475
        return [e for e in value if e in self.allowed_values]
1✔
1476

1477

1478
class RemoveValues(FieldOperator):
1✔
1479
    """Removes elements in a field, which must be a list, using a given list of unallowed.
1480

1481
    Args:
1482
        unallowed_values (list) - values to be removed.
1483
    """
1484

1485
    unallowed_values: List[Any]
1✔
1486

1487
    def verify(self):
1✔
1488
        super().verify()
1✔
1489

1490
        if not isinstance(self.unallowed_values, list):
1✔
1491
            raise ValueError(
1✔
1492
                f"The unallowed_values is not a list but '{self.unallowed_values}'"
1493
            )
1494

1495
    def process_value(self, value: Any) -> Any:
1✔
1496
        if not isinstance(value, list):
1✔
1497
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1498
        return [e for e in value if e not in self.unallowed_values]
1✔
1499

1500

1501
class Unique(SingleStreamReducer):
1✔
1502
    """Reduces a stream to unique instances based on specified fields.
1503

1504
    Args:
1505
        fields (List[str]): The fields that should be unique in each instance.
1506
    """
1507

1508
    fields: List[str] = field(default_factory=list)
1✔
1509

1510
    @staticmethod
1✔
1511
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
1✔
1512
        result = []
1✔
1513
        for field_name in fields:
1✔
1514
            value = instance[field_name]
1✔
1515
            if isinstance(value, list):
1✔
1516
                value = tuple(value)
1✔
1517
            result.append(value)
1✔
1518
        return tuple(result)
1✔
1519

1520
    def process(self, stream: Stream) -> Stream:
1✔
1521
        seen = set()
1✔
1522
        for instance in stream:
1✔
1523
            values = self.to_tuple(instance, self.fields)
1✔
1524
            if values not in seen:
1✔
1525
                seen.add(values)
1✔
1526
        return list(seen)
1✔
1527

1528

1529
class SplitByValue(MultiStreamOperator):
1✔
1530
    """Splits a MultiStream into multiple streams based on unique values in specified fields.
1531

1532
    Args:
1533
        fields (List[str]): The fields to use when splitting the MultiStream.
1534
    """
1535

1536
    fields: List[str] = field(default_factory=list)
1✔
1537

1538
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1539
        uniques = Unique(fields=self.fields)(multi_stream)
1✔
1540

1541
        result = {}
1✔
1542

1543
        for stream_name, stream in multi_stream.items():
1✔
1544
            stream_unique_values = uniques[stream_name]
1✔
1545
            for unique_values in stream_unique_values:
1✔
1546
                filtering_values = dict(zip(self.fields, unique_values))
1✔
1547
                filtered_streams = FilterByCondition(
1✔
1548
                    values=filtering_values, condition="eq"
1549
                )._process_single_stream(stream)
1550
                filtered_stream_name = (
1✔
1551
                    stream_name + "_" + nested_tuple_to_string(unique_values)
1552
                )
1553
                result[filtered_stream_name] = filtered_streams
1✔
1554

1555
        return MultiStream(result)
1✔
1556

1557

1558
class SplitByNestedGroup(MultiStreamOperator):
1✔
1559
    """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1560

1561
    Args:
1562
        number_of_fusion_generations: int
1563

1564
    the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1565
    when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1566
    (See BaseFusion and its extensions)
1567
    number_of_fuaion_generations  specifies the length of the prefix by which to split the stream.
1568
    E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1569
    For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1570
    """
1571

1572
    field_name_of_group: str = "group"
1✔
1573
    number_of_fusion_generations: int = 1
1✔
1574

1575
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1576
        result = defaultdict(list)
×
1577

1578
        for stream_name, stream in multi_stream.items():
×
1579
            for instance in stream:
×
1580
                if self.field_name_of_group not in instance:
×
1581
                    raise ValueError(
×
1582
                        f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1583
                    )
1584
                signature = (
×
1585
                    stream_name
1586
                    + "~"  #  a sign that does not show within group values
1587
                    + (
1588
                        "/".join(
1589
                            instance[self.field_name_of_group].split("/")[
1590
                                : self.number_of_fusion_generations
1591
                            ]
1592
                        )
1593
                        if self.number_of_fusion_generations >= 0
1594
                        # for values with a smaller number of generations - take up to their last generation
1595
                        else instance[self.field_name_of_group]
1596
                        # for each instance - take all its generations
1597
                    )
1598
                )
1599
                result[signature].append(instance)
×
1600

1601
        return MultiStream.from_iterables(result)
×
1602

1603

1604
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1✔
1605
    """Applies stream operators to a stream based on specified fields in each instance.
1606

1607
    Args:
1608
        field (str): The field containing the operators to be applied.
1609
        reversed (bool): Whether to apply the operators in reverse order.
1610
    """
1611

1612
    field: str
1✔
1613
    reversed: bool = False
1✔
1614

1615
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1616
        first_instance = stream.peek()
1✔
1617

1618
        operators = first_instance.get(self.field, [])
1✔
1619
        if isinstance(operators, str):
1✔
1620
            operators = [operators]
1✔
1621

1622
        if self.reversed:
1✔
1623
            operators = list(reversed(operators))
1✔
1624

1625
        for operator_name in operators:
1✔
1626
            operator = self.get_artifact(operator_name)
1✔
1627
            assert isinstance(
1✔
1628
                operator, StreamingOperator
1629
            ), f"Operator {operator_name} must be a StreamOperator"
1630

1631
            stream = operator(MultiStream({stream_name: stream}))[stream_name]
1✔
1632

1633
        yield from stream
1✔
1634

1635

1636
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
1✔
1637
    for instance, score in zip(stream, scores):
1✔
1638
        instance["score"] = recursive_copy(score)
1✔
1639
        yield instance
1✔
1640

1641

1642
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1✔
1643
    """Applies metric operators to a stream based on a metric field specified in each instance.
1644

1645
    Args:
1646
        metric_field (str): The field containing the metrics to be applied.
1647
        calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1648
    """
1649

1650
    metric_field: str
1✔
1651
    calc_confidence_intervals: bool
1✔
1652

1653
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1654
        from .metrics import Metric, MetricsList
1✔
1655

1656
        # to be populated only when two or more metrics
1657
        accumulated_scores = []
1✔
1658

1659
        first_instance = stream.peek()
1✔
1660

1661
        metric_names = first_instance.get(self.metric_field, [])
1✔
1662
        if not metric_names:
1✔
1663
            raise RuntimeError(
1✔
1664
                f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1665
            )
1666

1667
        if isinstance(metric_names, str):
1✔
1668
            metric_names = [metric_names]
1✔
1669

1670
        metrics_list = []
1✔
1671
        for metric_name in metric_names:
1✔
1672
            metric = self.get_artifact(metric_name)
1✔
1673
            if isinstance(metric, MetricsList):
1✔
1674
                metrics_list.extend(list(metric.items))
1✔
1675
            elif isinstance(metric, Metric):
1✔
1676
                metrics_list.append(metric)
1✔
1677
            else:
1678
                raise ValueError(
×
1679
                    f"Operator {metric_name} must be a Metric or MetricsList"
1680
                )
1681

1682
        for metric in metrics_list:
1✔
1683
            if not self.calc_confidence_intervals:
1✔
1684
                metric.disable_confidence_interval_calculation()
1✔
1685
        # Each metric operator computes its score and then sets the main score, overwriting
1686
        # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1687
        # This will cause the first listed metric to run last, and the main score will be set
1688
        # by the first listed metric (as desired).
1689
        metrics_list = list(reversed(metrics_list))
1✔
1690

1691
        for i, metric in enumerate(metrics_list):
1✔
1692
            if i == 0:  # first metric
1✔
1693
                multi_stream = MultiStream({"tmp": stream})
1✔
1694
            else:  # metrics with previous scores
1695
                reusable_generator = ReusableGenerator(
1✔
1696
                    generator=update_scores_of_stream_instances,
1697
                    gen_kwargs={"stream": stream, "scores": accumulated_scores},
1698
                )
1699
                multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1✔
1700

1701
            multi_stream = metric(multi_stream)
1✔
1702

1703
            if i < len(metrics_list) - 1:  # last metric
1✔
1704
                accumulated_scores = []
1✔
1705
                for inst in multi_stream["tmp"]:
1✔
1706
                    accumulated_scores.append(recursive_copy(inst["score"]))
1✔
1707

1708
        yield from multi_stream["tmp"]
1✔
1709

1710

1711
class MergeStreams(MultiStreamOperator):
1✔
1712
    """Merges multiple streams into a single stream.
1713

1714
    Args:
1715
        new_stream_name (str): The name of the new stream resulting from the merge.
1716
        add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
1717
        origin_stream_name_field_name (str): The field name for the origin stream name.
1718
    """
1719

1720
    streams_to_merge: List[str] = None
1✔
1721
    new_stream_name: str = "all"
1✔
1722
    add_origin_stream_name: bool = True
1✔
1723
    origin_stream_name_field_name: str = "origin"
1✔
1724

1725
    def merge(self, multi_stream) -> Generator:
1✔
1726
        for stream_name, stream in multi_stream.items():
1✔
1727
            if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1✔
1728
                for instance in stream:
1✔
1729
                    if self.add_origin_stream_name:
1✔
1730
                        instance[self.origin_stream_name_field_name] = stream_name
1✔
1731
                    yield instance
1✔
1732

1733
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1734
        return MultiStream(
1✔
1735
            {
1736
                self.new_stream_name: DynamicStream(
1737
                    self.merge, gen_kwargs={"multi_stream": multi_stream}
1738
                )
1739
            }
1740
        )
1741

1742

1743
class Shuffle(PagedStreamOperator):
1✔
1744
    """Shuffles the order of instances in each page of a stream.
1745

1746
    Args (of superclass):
1747
        page_size (int): The size of each page in the stream. Defaults to 1000.
1748
    """
1749

1750
    random_generator: Random = None
1✔
1751

1752
    def before_process_multi_stream(self):
1✔
1753
        super().before_process_multi_stream()
1✔
1754
        self.random_generator = new_random_generator(sub_seed="shuffle")
1✔
1755

1756
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1757
        self.random_generator.shuffle(page)
1✔
1758
        yield from page
1✔
1759

1760

1761
class FeatureGroupedShuffle(Shuffle):
1✔
1762
    """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1763

1764
    Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1765
    All paraphrases have the same ID value as the original.
1766
    In this case, we may want to shuffle on grouping_features = ['question ID'],
1767
    to keep the paraphrases and original question together.
1768
    We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1769
    In this case, grouping_features = ['question ID', 'topic']
1770

1771
    Args:
1772
        grouping_features (list of strings): list of feature names to use to define the groups.
1773
            a group is defined by each unique observed combination of data values for features in grouping_features
1774
        shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1775

1776
    Args (of superclass):
1777
        page_size (int): The size of each page in the stream. Defaults to 1000.
1778
            Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1779
            separately by page (determined by page_size).  If a block of instances in the same group are split
1780
            into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1781
            grouping_features), these instances will be shuffled separately and thus the grouping may be
1782
            broken up by pages.  If the user wants to ensure the shuffle does the grouping and shuffling
1783
            across all pages, set the page_size to be larger than the dataset size.
1784
            See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1785
    """
1786

1787
    grouping_features: List[str] = None
1✔
1788
    shuffle_within_group: bool = False
1✔
1789

1790
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1791
        if self.grouping_features is None:
1✔
1792
            super().process(page, stream_name)
×
1793
        else:
1794
            yield from self.shuffle_by_grouping_features(page)
1✔
1795

1796
    def shuffle_by_grouping_features(self, page):
1✔
1797
        import itertools
1✔
1798
        from collections import defaultdict
1✔
1799

1800
        groups_to_instances = defaultdict(list)
1✔
1801
        for item in page:
1✔
1802
            groups_to_instances[
1✔
1803
                tuple(item[ff] for ff in self.grouping_features)
1804
            ].append(item)
1805
        # now extract the groups (i.e., lists of dicts with order preserved)
1806
        page_blocks = list(groups_to_instances.values())
1✔
1807
        # and now shuffle the blocks
1808
        self.random_generator.shuffle(page_blocks)
1✔
1809
        if self.shuffle_within_group:
1✔
1810
            blocks = []
1✔
1811
            # reshuffle the instances within each block, but keep the blocks in order
1812
            for block in page_blocks:
1✔
1813
                self.random_generator.shuffle(block)
1✔
1814
                blocks.append(block)
1✔
1815
            page_blocks = blocks
1✔
1816

1817
        # now flatten the list so it consists of individual dicts, but in (randomized) block order
1818
        return list(itertools.chain(*page_blocks))
1✔
1819

1820

1821
class EncodeLabels(InstanceOperator):
1✔
1822
    """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1823

1824
    Encoding is determined by a str->int map that is built on the go, as different values are
1825
    first encountered in the stream, either as list members or as values in single-value fields.
1826

1827
    Args:
1828
        fields (List[str]): The fields to encode together.
1829

1830
    Example:
1831
        applying ``EncodeLabels(fields = ["a", "b/*"])``
1832
        on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
1833
        {"a": "blue", "b": ["green"], "c":"water"}]``   will yield the
1834
        output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
1835

1836
        Note: dict_utils are applied here, and hence, fields that are lists, should be included in
1837
        input 'fields' with the appendix ``"/*"``  as in the above example.
1838

1839
    """
1840

1841
    fields: List[str]
1✔
1842

1843
    def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
1✔
1844
        self.encoder = {}
1✔
1845
        return super()._process_multi_stream(multi_stream)
1✔
1846

1847
    def process(
1✔
1848
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1849
    ) -> Dict[str, Any]:
1850
        for field_name in self.fields:
1✔
1851
            values = dict_get(instance, field_name)
1✔
1852
            values_was_a_list = isinstance(values, list)
1✔
1853
            if not isinstance(values, list):
1✔
1854
                values = [values]
1✔
1855
            for value in values:
1✔
1856
                if value not in self.encoder:
1✔
1857
                    self.encoder[value] = len(self.encoder)
1✔
1858
            new_values = [self.encoder[value] for value in values]
1✔
1859
            if not values_was_a_list:
1✔
1860
                new_values = new_values[0]
1✔
1861
            dict_set(
1✔
1862
                instance,
1863
                field_name,
1864
                new_values,
1865
                not_exist_ok=False,  # the values to encode where just taken from there
1866
                set_multiple="*" in field_name
1867
                and isinstance(new_values, list)
1868
                and len(new_values) > 0,
1869
            )
1870

1871
        return instance
1✔
1872

1873

1874
class StreamRefiner(StreamOperator):
1✔
1875
    """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1876

1877
    Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
1878
    input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1879
    of the leading 'max_instances' of the input stream.
1880

1881
    Args:
1882
        max_instances (int)
1883
        apply_to_streams (optional, list(str)):
1884
            names of streams to refine.
1885

1886
    Examples:
1887
        when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
1888
        ``StreamRefiner(max_instances=4)``
1889
        the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
1890
    """
1891

1892
    max_instances: int = None
1✔
1893
    apply_to_streams: Optional[List[str]] = None
1✔
1894

1895
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1896
        if self.max_instances is not None:
1✔
1897
            yield from stream.take(self.max_instances)
1✔
1898
        else:
1899
            yield from stream
1✔
1900

1901

1902
class DeterministicBalancer(StreamRefiner):
1✔
1903
    """A class used to balance streams deterministically.
1904

1905
    For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
1906
    By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
1907
    When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
1908
    'max_instances'. The total number of discarded instances is as few as possible.
1909

1910
    Args:
1911
        fields (List[str]):
1912
            A list of field names to be used in producing the instance's signature.
1913
        max_instances (Optional, int):
1914
            overall max.
1915

1916
    Usage:
1917
        ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
1918
        ``balanced_stream = balancer.process(stream)``
1919

1920
    Example:
1921
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
1922
        ``DeterministicBalancer(fields=["a"])``
1923
        the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
1924
    """
1925

1926
    fields: List[str]
1✔
1927

1928
    def signature(self, instance):
1✔
1929
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
1930

1931
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1932
        counter = Counter()
1✔
1933

1934
        for instance in stream:
1✔
1935
            counter[self.signature(instance)] += 1
1✔
1936

1937
        if len(counter) == 0:
1✔
1938
            return
1✔
1939

1940
        lowest_count = counter.most_common()[-1][-1]
1✔
1941

1942
        max_total_instances_per_sign = lowest_count
1✔
1943
        if self.max_instances is not None:
1✔
1944
            max_total_instances_per_sign = min(
1✔
1945
                lowest_count, self.max_instances // len(counter)
1946
            )
1947

1948
        counter = Counter()
1✔
1949

1950
        for instance in stream:
1✔
1951
            sign = self.signature(instance)
1✔
1952
            if counter[sign] < max_total_instances_per_sign:
1✔
1953
                counter[sign] += 1
1✔
1954
                yield instance
1✔
1955

1956

1957
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1✔
1958
    """A class used to return a specified number instances ensuring at least one example  per label.
1959

1960
    For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
1961
    ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit.  In general, the refiner takes the first elements in the stream that meet the required conditions.
1962
    ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
1963
    from each class first and then the rest . If max instance is not set, the original stream will be used
1964

1965
    Args:
1966
        fields (List[str]):
1967
            A list of field names to be used in producing the instance's signature.
1968
        max_instances (Optional, int):
1969
            Number of elements to select. Note that max_instances of StreamRefiners
1970
            that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
1971
            by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
1972

1973
    Usage:
1974
        | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
1975
        | ``balanced_stream = balancer.process(stream)``
1976

1977
    Example:
1978
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
1979
        ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
1980
        the resulting stream will be:
1981
        ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
1982
    """
1983

1984
    fields: List[str]
1✔
1985

1986
    def signature(self, instance):
1✔
1987
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
1988

1989
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1990
        if self.max_instances is None:
1✔
1991
            for instance in stream:
×
1992
                yield instance
×
1993

1994
        counter = Counter()
1✔
1995
        for instance in stream:
1✔
1996
            counter[self.signature(instance)] += 1
1✔
1997
        all_keys = counter.keys()
1✔
1998
        if len(counter) == 0:
1✔
1999
            return
×
2000

2001
        if self.max_instances is not None and len(all_keys) > self.max_instances:
1✔
2002
            raise Exception(
×
2003
                f"Can not generate a stream with at least one example per label, because the max instances requested  {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2004
                f" ({len(all_keys)}"
2005
            )
2006

2007
        counter = Counter()
1✔
2008
        used_indices = set()
1✔
2009
        selected_elements = []
1✔
2010
        # select at least one per class
2011
        for idx, instance in enumerate(stream):
1✔
2012
            sign = self.signature(instance)
1✔
2013
            if counter[sign] == 0:
1✔
2014
                counter[sign] += 1
1✔
2015
                used_indices.add(idx)
1✔
2016
                selected_elements.append(
1✔
2017
                    instance
2018
                )  # collect all elements first to allow shuffling of both groups
2019

2020
        # select more to reach self.max_instances examples
2021
        for idx, instance in enumerate(stream):
1✔
2022
            if idx not in used_indices:
1✔
2023
                if self.max_instances is None or len(used_indices) < self.max_instances:
1✔
2024
                    used_indices.add(idx)
1✔
2025
                    selected_elements.append(
1✔
2026
                        instance
2027
                    )  # collect all elements first to allow shuffling of both groups
2028

2029
        # shuffle elements to avoid having one element from each class appear first
2030
        random_generator = new_random_generator(sub_seed=selected_elements)
1✔
2031
        random_generator.shuffle(selected_elements)
1✔
2032
        yield from selected_elements
1✔
2033

2034

2035
class LengthBalancer(DeterministicBalancer):
1✔
2036
    """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2037

2038
    Args:
2039
        segments_boundaries (List[int]):
2040
            distinct integers sorted in increasing order, that map a given total length
2041
            into the index of the least of them that exceeds the given total length.
2042
            (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2043
        fields (Optional, List[str]):
2044
            the total length of the values of these fields goes through the quantization described above
2045

2046

2047
    Example:
2048
        when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2049
        is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2050
        input instances will be counted and balanced against two categories:
2051
        empty total length (less than 1), and non-empty.
2052
    """
2053

2054
    segments_boundaries: List[int]
1✔
2055
    fields: Optional[List[str]]
1✔
2056

2057
    def signature(self, instance):
1✔
2058
        total_len = 0
1✔
2059
        for field_name in self.fields:
1✔
2060
            total_len += len(dict_get(instance, field_name))
1✔
2061
        for i, val in enumerate(self.segments_boundaries):
1✔
2062
            if total_len < val:
1✔
2063
                return i
1✔
2064
        return i + 1
1✔
2065

2066

2067
class DownloadError(Exception):
1✔
2068
    def __init__(
1✔
2069
        self,
2070
        message,
2071
    ):
2072
        self.__super__(message)
×
2073

2074

2075
class UnexpectedHttpCodeError(Exception):
1✔
2076
    def __init__(self, http_code):
1✔
2077
        self.__super__(f"unexpected http code {http_code}")
×
2078

2079

2080
class DownloadOperator(SideEffectOperator):
1✔
2081
    """Operator for downloading a file from a given URL to a specified local path.
2082

2083
    Args:
2084
        source (str):
2085
            URL of the file to be downloaded.
2086
        target (str):
2087
            Local path where the downloaded file should be saved.
2088
    """
2089

2090
    source: str
1✔
2091
    target: str
1✔
2092

2093
    def process(self):
1✔
2094
        try:
×
2095
            response = requests.get(self.source, allow_redirects=True)
×
2096
        except Exception as e:
×
2097
            raise DownloadError(f"Unabled to download {self.source}") from e
×
2098
        if response.status_code != 200:
×
2099
            raise UnexpectedHttpCodeError(response.status_code)
×
2100
        with open(self.target, "wb") as f:
×
2101
            f.write(response.content)
×
2102

2103

2104
class ExtractZipFile(SideEffectOperator):
1✔
2105
    """Operator for extracting files from a zip archive.
2106

2107
    Args:
2108
        zip_file (str):
2109
            Path of the zip file to be extracted.
2110
        target_dir (str):
2111
            Directory where the contents of the zip file will be extracted.
2112
    """
2113

2114
    zip_file: str
1✔
2115
    target_dir: str
1✔
2116

2117
    def process(self):
1✔
2118
        with zipfile.ZipFile(self.zip_file) as zf:
×
2119
            zf.extractall(self.target_dir)
×
2120

2121

2122
class DuplicateInstances(StreamOperator):
1✔
2123
    """Operator which duplicates each instance in stream a given number of times.
2124

2125
    Args:
2126
        num_duplications (int):
2127
            How many times each instance should be duplicated (1 means no duplication).
2128
        duplication_index_field (Optional[str]):
2129
            If given, then additional field with specified name is added to each duplicated instance,
2130
            which contains id of a given duplication. Defaults to None, so no field is added.
2131
    """
2132

2133
    num_duplications: int
1✔
2134
    duplication_index_field: Optional[str] = None
1✔
2135

2136
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2137
        for instance in stream:
1✔
2138
            for idx in range(self.num_duplications):
1✔
2139
                duplicate = recursive_shallow_copy(instance)
1✔
2140
                if self.duplication_index_field:
1✔
2141
                    duplicate.update({self.duplication_index_field: idx})
1✔
2142
                yield duplicate
1✔
2143

2144
    def verify(self):
1✔
2145
        if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1✔
2146
            raise ValueError(
×
2147
                f"num_duplications must be an integer equal to or greater than 1. "
2148
                f"Got: {self.num_duplications}."
2149
            )
2150

2151
        if self.duplication_index_field is not None and not isinstance(
1✔
2152
            self.duplication_index_field, str
2153
        ):
2154
            raise ValueError(
×
2155
                f"If given, duplication_index_field must be a string. "
2156
                f"Got: {self.duplication_index_field}"
2157
            )
2158

2159

2160
class CollateInstances(StreamOperator):
1✔
2161
    """Operator which collates values from multiple instances to a single instance.
2162

2163
    Each field becomes the list of values of corresponding field of collated `batch_size` of instances.
2164

2165
    Attributes:
2166
        batch_size (int)
2167

2168
    Example:
2169
        .. code-block:: text
2170

2171
            CollateInstances(batch_size=2)
2172

2173
            Given inputs = [
2174
                {"a": 1, "b": 2},
2175
                {"a": 2, "b": 2},
2176
                {"a": 3, "b": 2},
2177
                {"a": 4, "b": 2},
2178
                {"a": 5, "b": 2}
2179
            ]
2180

2181
            Returns targets = [
2182
                {"a": [1,2], "b": [2,2]},
2183
                {"a": [3,4], "b": [2,2]},
2184
                {"a": [5], "b": [2]},
2185
            ]
2186

2187

2188
    """
2189

2190
    batch_size: int
1✔
2191

2192
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2193
        stream = list(stream)
1✔
2194
        for i in range(0, len(stream), self.batch_size):
1✔
2195
            batch = stream[i : i + self.batch_size]
1✔
2196
            new_instance = {}
1✔
2197
            for a_field in batch[0]:
1✔
2198
                if a_field == "data_classification_policy":
1✔
2199
                    flattened_list = [
1✔
2200
                        classification
2201
                        for instance in batch
2202
                        for classification in instance[a_field]
2203
                    ]
2204
                    new_instance[a_field] = sorted(set(flattened_list))
1✔
2205
                else:
2206
                    new_instance[a_field] = [instance[a_field] for instance in batch]
1✔
2207
            yield new_instance
1✔
2208

2209
    def verify(self):
1✔
2210
        if not isinstance(self.batch_size, int) or self.batch_size < 1:
1✔
2211
            raise ValueError(
×
2212
                f"batch_size must be an integer equal to or greater than 1. "
2213
                f"Got: {self.batch_size}."
2214
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc