• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12925167726

23 Jan 2025 08:28AM UTC coverage: 79.207% (+0.02%) from 79.19%
12925167726

Pull #1549

github

web-flow
Merge 809bc69d5 into 38c8aeaf6
Pull Request #1549: Add deduplicate operator

1420 of 1788 branches covered (79.42%)

Branch coverage included in aggregate %.

9048 of 11428 relevant lines covered (79.17%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.79
src/unitxt/operators.py
1
"""This section describes unitxt operators.
2

3
Operators: Building Blocks of Unitxt Processing Pipelines
4
==============================================================
5

6
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
These operators are callable entities that receive a MultiStream as input.
9
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10

11
Creating Custom Operators
12
-------------------------------
13
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16

17
General or Specialized Operators
18
--------------------------------
19
Some operators are specialized in specific data or specific operations such as:
20

21
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
- :class:`splitters<unitxt.splitters>` for fixing data splits.
23
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
- :class:`string_operators<unitxt.string_operators>` for handling strings.
28
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30

31
Other specialized operators are used by unitxt internally:
32

33
- :class:`templates<unitxt.templates>` for verbalizing data examples.
34
- :class:`formats<unitxt.formats>` for preparing data for models.
35

36
The rest of this section is dedicated to general operators.
37

38
General Operators List:
39
------------------------
40
"""
41

42
import operator
1✔
43
import uuid
1✔
44
import warnings
1✔
45
import zipfile
1✔
46
from abc import abstractmethod
1✔
47
from collections import Counter, defaultdict
1✔
48
from dataclasses import field
1✔
49
from itertools import zip_longest
1✔
50
from random import Random
1✔
51
from typing import (
1✔
52
    Any,
53
    Callable,
54
    Dict,
55
    Generator,
56
    Iterable,
57
    List,
58
    Literal,
59
    Optional,
60
    Tuple,
61
    Union,
62
)
63

64
import requests
1✔
65

66
from .artifact import Artifact, fetch_artifact
1✔
67
from .dataclass import NonPositionalField, OptionalField
1✔
68
from .deprecation_utils import deprecation
1✔
69
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
1✔
70
from .generator_utils import ReusableGenerator
1✔
71
from .operator import (
1✔
72
    InstanceOperator,
73
    MultiStream,
74
    MultiStreamOperator,
75
    PagedStreamOperator,
76
    SequentialOperator,
77
    SideEffectOperator,
78
    SingleStreamReducer,
79
    SourceOperator,
80
    StreamingOperator,
81
    StreamInitializerOperator,
82
    StreamOperator,
83
)
84
from .random_utils import new_random_generator
1✔
85
from .settings_utils import get_settings
1✔
86
from .stream import DynamicStream, Stream
1✔
87
from .text_utils import nested_tuple_to_string
1✔
88
from .type_utils import isoftype
1✔
89
from .utils import (
1✔
90
    LRUCache,
91
    deep_copy,
92
    flatten_dict,
93
    recursive_copy,
94
    recursive_shallow_copy,
95
    shallow_copy,
96
)
97

98
settings = get_settings()
1✔
99

100

101
class FromIterables(StreamInitializerOperator):
1✔
102
    """Creates a MultiStream from a dict of named iterables.
103

104
    Example:
105
        operator = FromIterables()
106
        ms = operator.process(iterables)
107

108
    """
109

110
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
1✔
111
        return MultiStream.from_iterables(iterables)
1✔
112

113

114
class IterableSource(SourceOperator):
1✔
115
    """Creates a MultiStream from a dict of named iterables.
116

117
    It is a callable.
118

119
    Args:
120
        iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
121

122
    Example:
123
        operator =  IterableSource(input_dict)
124
        ms = operator()
125

126
    """
127

128
    iterables: Dict[str, Iterable]
1✔
129

130
    def process(self) -> MultiStream:
1✔
131
        return MultiStream.from_iterables(self.iterables)
1✔
132

133

134
class MapInstanceValues(InstanceOperator):
1✔
135
    """A class used to map instance values into other values.
136

137
    This class is a type of ``InstanceOperator``,
138
    it maps values of instances in a stream using predefined mappers.
139

140
    Args:
141
        mappers (Dict[str, Dict[str, Any]]):
142
            The mappers to use for mapping instance values.
143
            Keys are the names of the fields to undergo mapping, and values are dictionaries
144
            that define the mapping from old values to new values.
145
            Note that mapped values are defined by their string representation, so mapped values
146
            are converted to strings before being looked up in the mappers.
147
        strict (bool):
148
            If True, the mapping is applied strictly. That means if a value
149
            does not exist in the mapper, it will raise a KeyError. If False, values
150
            that are not present in the mapper are kept as they are.
151
        process_every_value (bool):
152
            If True, all fields to be mapped should be lists, and the mapping
153
            is to be applied to their individual elements.
154
            If False, mapping is only applied to a field containing a single value.
155

156
    Examples:
157
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
158
        replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
159
        instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
160
        since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
161
        up in the mapper of ``"a"``.
162

163
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
164
        Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
165
        each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
166
        instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
167

168
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
169
        To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
170
        Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
171
        because ``"3"`` is not a key in the mapper of ``"a"``.
172

173
        ``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
174
        replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
175

176
    """
177

178
    mappers: Dict[str, Dict[str, str]]
1✔
179
    strict: bool = True
1✔
180
    process_every_value: bool = False
1✔
181

182
    def verify(self):
1✔
183
        # make sure the mappers are valid
184
        for key, mapper in self.mappers.items():
1✔
185
            assert isinstance(
1✔
186
                mapper, dict
187
            ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
188
            for k in mapper.keys():
1✔
189
                assert isinstance(
1✔
190
                    k, str
191
                ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
192

193
    def process(
1✔
194
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
195
    ) -> Dict[str, Any]:
196
        for key, mapper in self.mappers.items():
1✔
197
            value = dict_get(instance, key)
1✔
198
            if value is not None:
1✔
199
                if (self.process_every_value is True) and (not isinstance(value, list)):
1✔
200
                    raise ValueError(
1✔
201
                        f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
202
                    )
203
                if isinstance(value, list) and self.process_every_value:
1✔
204
                    for i, val in enumerate(value):
1✔
205
                        value[i] = self.get_mapped_value(instance, key, mapper, val)
1✔
206
                else:
207
                    value = self.get_mapped_value(instance, key, mapper, value)
1✔
208
                dict_set(
1✔
209
                    instance,
210
                    key,
211
                    value,
212
                )
213

214
        return instance
1✔
215

216
    def get_mapped_value(self, instance, key, mapper, val):
1✔
217
        val_as_str = str(val)  # make sure the value is a string
1✔
218
        if val_as_str in mapper:
1✔
219
            return recursive_copy(mapper[val_as_str])
1✔
220
        if self.strict:
1✔
221
            raise KeyError(
1✔
222
                f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
223
            )
224
        return val
1✔
225

226

227
class FlattenInstances(InstanceOperator):
1✔
228
    """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
229

230
    Args:
231
        parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
232
        sep (str): The separator to use when concatenating nested keys. Defaults to "_".
233
    """
234

235
    parent_key: str = ""
1✔
236
    sep: str = "_"
1✔
237

238
    def process(
1✔
239
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
240
    ) -> Dict[str, Any]:
241
        return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
1✔
242

243

244
class Set(InstanceOperator):
1✔
245
    """Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
246

247
    Args:
248
        fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
249

250
        use_deepcopy (bool) : Deep copy the input value to avoid later modifications
251

252
    Examples:
253
        # Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
254
        ``Set(fields={"classes": ["positive","negatives"]})``
255

256
        # In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
257
        ``Set(fields={"span/start": 0}``
258

259
        # In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
260
        ``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
261

262
        # Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
263
        ``Set(fields={"classes": alist}), use_deepcopy=True)``  if now alist is modified, still the instances remain intact.
264
    """
265

266
    fields: Dict[str, object]
1✔
267
    use_query: Optional[bool] = None
1✔
268
    use_deepcopy: bool = False
1✔
269

270
    def verify(self):
1✔
271
        super().verify()
1✔
272
        if self.use_query is not None:
1✔
273
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
274
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
275

276
    def process(
1✔
277
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
278
    ) -> Dict[str, Any]:
279
        for key, value in self.fields.items():
1✔
280
            if self.use_deepcopy:
1✔
281
                value = deep_copy(value)
1✔
282
            dict_set(instance, key, value)
1✔
283
        return instance
1✔
284

285

286
@deprecation(version="2.0.0", alternative=Set)
1✔
287
class AddFields(Set):
1✔
288
    pass
1✔
289

290

291
class RemoveFields(InstanceOperator):
1✔
292
    """Remove specified fields from each instance in a stream.
293

294
    Args:
295
        fields (List[str]): The fields to remove from each instance.
296
    """
297

298
    fields: List[str]
1✔
299

300
    def process(
1✔
301
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
302
    ) -> Dict[str, Any]:
303
        for field_name in self.fields:
1✔
304
            del instance[field_name]
1✔
305
        return instance
1✔
306

307

308
class SelectFields(InstanceOperator):
1✔
309
    """Keep only specified fields from each instance in a stream.
310

311
    Args:
312
        fields (List[str]): The fields to keep from each instance.
313
    """
314

315
    fields: List[str]
1✔
316

317
    def prepare(self):
1✔
318
        super().prepare()
1✔
319
        self.fields.extend(["data_classification_policy", "recipe_metadata"])
1✔
320

321
    def process(
1✔
322
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
323
    ) -> Dict[str, Any]:
324
        new_instance = {}
1✔
325
        for selected_field in self.fields:
1✔
326
            new_instance[selected_field] = instance[selected_field]
1✔
327
        return new_instance
1✔
328

329

330
class DefaultPlaceHolder:
1✔
331
    pass
1✔
332

333

334
default_place_holder = DefaultPlaceHolder()
1✔
335

336

337
class InstanceFieldOperator(InstanceOperator):
1✔
338
    """A general stream instance operator that processes the values of a field (or multiple ones).
339

340
    Args:
341
        field (Optional[str]):
342
            The field to process, if only a single one is passed. Defaults to None
343
        to_field (Optional[str]):
344
            Field name to save result into, if only one field is processed, if None is passed the
345
            operation would happen in-place and its result would replace the value of ``field``. Defaults to None
346
        field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
347
            Mapping from names of fields to process,
348
            to names of fields to save the results into. Inner List, if used, should be of length 2.
349
            A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
350
            is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
351
            in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
352
            order. The end result might depend on that order if either (1) two different fields are mapped to the same
353
            to_field, or (2) a field shows both as a key and as a value in different mappings.
354
            The operator throws an AssertionError in either of these cases. ``field_to_field``
355
            defaults to None.
356
        process_every_value (bool):
357
            Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
358

359
    Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
360
    prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
361

362
    """
363

364
    field: Optional[str] = None
1✔
365
    to_field: Optional[str] = None
1✔
366
    field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
1✔
367
    use_query: Optional[bool] = None
1✔
368
    process_every_value: bool = False
1✔
369
    get_default: Any = None
1✔
370
    not_exist_ok: bool = False
1✔
371
    not_exist_do_nothing: bool = False
1✔
372

373
    def verify(self):
1✔
374
        super().verify()
1✔
375
        if self.use_query is not None:
1✔
376
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
1✔
377
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
378

379
    def verify_field_definition(self):
1✔
380
        if hasattr(self, "_field_to_field") and self._field_to_field is not None:
1✔
381
            return
1✔
382
        assert (
1✔
383
            (self.field is None) != (self.field_to_field is None)
384
        ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
385
        assert (
1✔
386
            self.to_field is None or self.field_to_field is None
387
        ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
388

389
        if self.field_to_field is None:
1✔
390
            self._field_to_field = [
1✔
391
                (self.field, self.to_field if self.to_field is not None else self.field)
392
            ]
393
        else:
394
            self._field_to_field = (
1✔
395
                list(self.field_to_field.items())
396
                if isinstance(self.field_to_field, dict)
397
                else self.field_to_field
398
            )
399
        assert (
1✔
400
            self.field is not None or self.field_to_field is not None
401
        ), "Must supply a field to work on"
402
        assert (
1✔
403
            self.to_field is None or self.field_to_field is None
404
        ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
405
        assert (
1✔
406
            self.field is None or self.field_to_field is None
407
        ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
408
        assert (
1✔
409
            self._field_to_field is not None
410
        ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
411
        assert (
1✔
412
            len(self._field_to_field) > 0
413
        ), f"'input argument '{self.__class__.__name__}.field_to_field' should convey at least one field to process. Got {self.field_to_field}"
414
        # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
415
        if self.field_to_field is None:
1✔
416
            return
1✔
417
        # for backward compatibility also allow list of tuples of two strings
418
        if isoftype(self.field_to_field, List[List[str]]) or isoftype(
1✔
419
            self.field_to_field, List[Tuple[str, str]]
420
        ):
421
            for pair in self._field_to_field:
1✔
422
                assert (
1✔
423
                    len(pair) == 2
424
                ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
425
            # order of field processing is uniquely determined by the input field_to_field when a list
426
            return
1✔
427
        if isoftype(self.field_to_field, Dict[str, str]):
1✔
428
            if len(self.field_to_field) < 2:
1✔
429
                return
1✔
430
            for ff, tt in self.field_to_field.items():
1✔
431
                for f, t in self.field_to_field.items():
1✔
432
                    if f == ff:
1✔
433
                        continue
1✔
434
                    assert (
1✔
435
                        t != ff
436
                    ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
437
                    assert (
1✔
438
                        tt != t
439
                    ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
440
            return
1✔
441
        raise ValueError(
1✔
442
            "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
443
        )
444

445
    @abstractmethod
1✔
446
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
447
        pass
×
448

449
    def process(
1✔
450
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
451
    ) -> Dict[str, Any]:
452
        self.verify_field_definition()
1✔
453
        for from_field, to_field in self._field_to_field:
1✔
454
            try:
1✔
455
                old_value = dict_get(
1✔
456
                    instance,
457
                    from_field,
458
                    default=default_place_holder,
459
                    not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
460
                )
461
                if old_value is default_place_holder:
1✔
462
                    if self.not_exist_do_nothing:
1✔
463
                        continue
1✔
464
                    old_value = self.get_default
×
465
            except Exception as e:
1✔
466
                raise ValueError(
1✔
467
                    f"Failed to get '{from_field}' from instance due to the exception above."
468
                ) from e
469
            try:
1✔
470
                if self.process_every_value:
1✔
471
                    new_value = [
1✔
472
                        self.process_instance_value(value, instance)
473
                        for value in old_value
474
                    ]
475
                else:
476
                    new_value = self.process_instance_value(old_value, instance)
1✔
477
            except Exception as e:
1✔
478
                raise ValueError(
1✔
479
                    f"Failed to process field '{from_field}' from instance due to the exception above."
480
                ) from e
481
            dict_set(
1✔
482
                instance,
483
                to_field,
484
                new_value,
485
                not_exist_ok=True,
486
            )
487
        return instance
1✔
488

489

490
class FieldOperator(InstanceFieldOperator):
1✔
491
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
492
        return self.process_value(value)
1✔
493

494
    @abstractmethod
1✔
495
    def process_value(self, value: Any) -> Any:
1✔
496
        pass
1✔
497

498

499
class MapValues(FieldOperator):
1✔
500
    mapping: Dict[str, str]
1✔
501

502
    def process_value(self, value: Any) -> Any:
1✔
503
        return self.mapping[str(value)]
×
504

505

506
class Rename(FieldOperator):
1✔
507
    """Renames fields.
508

509
    Move value from one field to another, potentially, if field name contains a /, from one branch into another.
510
    Remove the from field, potentially part of it in case of / in from_field.
511

512
    Examples:
513
        Rename(field_to_field={"b": "c"})
514
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
515

516
        Rename(field_to_field={"b": "c/d"})
517
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
518

519
        Rename(field_to_field={"b": "b/d"})
520
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
521

522
        Rename(field_to_field={"b/c/e": "b/d"})
523
        will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
524

525
    """
526

527
    def process_value(self, value: Any) -> Any:
1✔
528
        return value
1✔
529

530
    def process(
1✔
531
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
532
    ) -> Dict[str, Any]:
533
        res = super().process(instance=instance, stream_name=stream_name)
1✔
534
        for from_field, to_field in self._field_to_field:
1✔
535
            if (not is_subpath(from_field, to_field)) and (
1✔
536
                not is_subpath(to_field, from_field)
537
            ):
538
                dict_delete(res, from_field, remove_empty_ancestors=True)
1✔
539

540
        return res
1✔
541

542

543
@deprecation(version="2.0.0", alternative=Rename)
1✔
544
class RenameFields(Rename):
1✔
545
    pass
1✔
546

547

548
class AddConstant(FieldOperator):
1✔
549
    """Adds a constant, being argument 'add', to the processed value.
550

551
    Args:
552
        add: the constant to add.
553
    """
554

555
    add: Any
1✔
556

557
    def process_value(self, value: Any) -> Any:
1✔
558
        return self.add + value
1✔
559

560

561
class ShuffleFieldValues(FieldOperator):
1✔
562
    """Shuffles a list of values found in a field."""
563

564
    def process_value(self, value: Any) -> Any:
1✔
565
        res = list(value)
1✔
566
        random_generator = new_random_generator(sub_seed=res)
1✔
567
        random_generator.shuffle(res)
1✔
568
        return res
1✔
569

570

571
class JoinStr(FieldOperator):
1✔
572
    """Joins a list of strings (contents of a field), similar to str.join().
573

574
    Args:
575
        separator (str): text to put between values
576
    """
577

578
    separator: str = ","
1✔
579

580
    def process_value(self, value: Any) -> Any:
1✔
581
        return self.separator.join(str(x) for x in value)
1✔
582

583

584
class Apply(InstanceOperator):
1✔
585
    """A class used to apply a python function and store the result in a field.
586

587
    Args:
588
        function (str): name of function.
589
        to_field (str): the field to store the result
590

591
    any additional arguments are field names whose values will be passed directly to the function specified
592

593
    Examples:
594
    Store in field  "b" the uppercase string of the value in field "a":
595
    ``Apply("a", function=str.upper, to_field="b")``
596

597
    Dump the json representation of field "t" and store back in the same field:
598
    ``Apply("t", function=json.dumps, to_field="t")``
599

600
    Set the time in a field 'b':
601
    ``Apply(function=time.time, to_field="b")``
602

603
    """
604

605
    __allow_unexpected_arguments__ = True
1✔
606
    function: Callable = NonPositionalField(required=True)
1✔
607
    to_field: str = NonPositionalField(required=True)
1✔
608

609
    def function_to_str(self, function: Callable) -> str:
1✔
610
        parts = []
1✔
611

612
        if hasattr(function, "__module__"):
1✔
613
            parts.append(function.__module__)
1✔
614
        if hasattr(function, "__qualname__"):
1✔
615
            parts.append(function.__qualname__)
1✔
616
        else:
617
            parts.append(function.__name__)
×
618

619
        return ".".join(parts)
1✔
620

621
    def str_to_function(self, function_str: str) -> Callable:
1✔
622
        parts = function_str.split(".", 1)
1✔
623
        if len(parts) == 1:
1✔
624
            return __builtins__[parts[0]]
1✔
625

626
        module_name, function_name = parts
1✔
627
        if module_name in __builtins__:
1✔
628
            obj = __builtins__[module_name]
1✔
629
        elif module_name in globals():
1✔
630
            obj = globals()[module_name]
×
631
        else:
632
            obj = __import__(module_name)
1✔
633
        for part in function_name.split("."):
1✔
634
            obj = getattr(obj, part)
1✔
635
        return obj
1✔
636

637
    def prepare(self):
1✔
638
        super().prepare()
1✔
639
        if isinstance(self.function, str):
1✔
640
            self.function = self.str_to_function(self.function)
1✔
641
        self._init_dict["function"] = self.function_to_str(self.function)
1✔
642

643
    def process(
1✔
644
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
645
    ) -> Dict[str, Any]:
646
        argv = [instance[arg] for arg in self._argv]
1✔
647
        kwargs = {key: instance[val] for key, val in self._kwargs}
1✔
648

649
        result = self.function(*argv, **kwargs)
1✔
650

651
        instance[self.to_field] = result
1✔
652
        return instance
1✔
653

654

655
class ListFieldValues(InstanceOperator):
1✔
656
    """Concatenates values of multiple fields into a list, and assigns it to a new field."""
657

658
    fields: List[str]
1✔
659
    to_field: str
1✔
660
    use_query: Optional[bool] = None
1✔
661

662
    def verify(self):
1✔
663
        super().verify()
1✔
664
        if self.use_query is not None:
1✔
665
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
666
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
667

668
    def process(
1✔
669
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
670
    ) -> Dict[str, Any]:
671
        values = []
1✔
672
        for field_name in self.fields:
1✔
673
            values.append(dict_get(instance, field_name))
1✔
674

675
        dict_set(instance, self.to_field, values)
1✔
676

677
        return instance
1✔
678

679

680
class ZipFieldValues(InstanceOperator):
1✔
681
    """Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
682

683
    The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
684
    are zipped, and stored into 'to_field'.
685

686
    | If 'longest'=False, the length of the zipped result is determined by the shortest input value.
687
    | If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
688

689
    """
690

691
    fields: List[str]
1✔
692
    to_field: str
1✔
693
    longest: bool = False
1✔
694
    use_query: Optional[bool] = None
1✔
695

696
    def verify(self):
1✔
697
        super().verify()
1✔
698
        if self.use_query is not None:
1✔
699
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
700
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
701

702
    def process(
1✔
703
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
704
    ) -> Dict[str, Any]:
705
        values = []
1✔
706
        for field_name in self.fields:
1✔
707
            values.append(dict_get(instance, field_name))
1✔
708
        if self.longest:
1✔
709
            zipped = zip_longest(*values)
1✔
710
        else:
711
            zipped = zip(*values)
1✔
712
        dict_set(instance, self.to_field, list(zipped))
1✔
713
        return instance
1✔
714

715

716
class InterleaveListsToDialogOperator(InstanceOperator):
1✔
717
    """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
718

719
    The list of tuples if of format (role, turn_content), where the role label is specified by
720
    the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
721

722
    The user turns and assistant turns field are specified in the arguments.
723
    The value of each of the 'fields' is assumed to be a list.
724

725
    """
726

727
    user_turns_field: str
1✔
728
    assistant_turns_field: str
1✔
729
    user_role_label: str = "user"
1✔
730
    assistant_role_label: str = "assistant"
1✔
731
    to_field: str
1✔
732

733
    def process(
1✔
734
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
735
    ) -> Dict[str, Any]:
736
        user_turns = instance[self.user_turns_field]
×
737
        assistant_turns = instance[self.assistant_turns_field]
×
738

739
        assert (
×
740
            len(user_turns) == len(assistant_turns)
741
            or (len(user_turns) - len(assistant_turns) == 1)
742
        ), "user_turns must have either the same length as assistant_turns or one more turn."
743

744
        interleaved_dialog = []
×
745
        i, j = 0, 0  # Indices for the user and assistant lists
×
746
        # While either list has elements left, continue interleaving
747
        while i < len(user_turns) or j < len(assistant_turns):
×
748
            if i < len(user_turns):
×
749
                interleaved_dialog.append((self.user_role_label, user_turns[i]))
×
750
                i += 1
×
751
            if j < len(assistant_turns):
×
752
                interleaved_dialog.append(
×
753
                    (self.assistant_role_label, assistant_turns[j])
754
                )
755
                j += 1
×
756

757
        instance[self.to_field] = interleaved_dialog
×
758
        return instance
×
759

760

761
class IndexOf(InstanceOperator):
1✔
762
    """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
763

764
    search_in: str
1✔
765
    index_of: str
1✔
766
    to_field: str
1✔
767
    use_query: Optional[bool] = None
1✔
768

769
    def verify(self):
1✔
770
        super().verify()
1✔
771
        if self.use_query is not None:
1✔
772
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
773
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
774

775
    def process(
1✔
776
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
777
    ) -> Dict[str, Any]:
778
        lst = dict_get(instance, self.search_in)
1✔
779
        item = dict_get(instance, self.index_of)
1✔
780
        instance[self.to_field] = lst.index(item)
1✔
781
        return instance
1✔
782

783

784
class TakeByField(InstanceOperator):
1✔
785
    """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
786

787
    field: str
1✔
788
    index: str
1✔
789
    to_field: str = None
1✔
790
    use_query: Optional[bool] = None
1✔
791

792
    def verify(self):
1✔
793
        super().verify()
1✔
794
        if self.use_query is not None:
1✔
795
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
796
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
797

798
    def prepare(self):
1✔
799
        if self.to_field is None:
1✔
800
            self.to_field = self.field
1✔
801

802
    def process(
1✔
803
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
804
    ) -> Dict[str, Any]:
805
        value = dict_get(instance, self.field)
1✔
806
        index_value = dict_get(instance, self.index)
1✔
807
        instance[self.to_field] = value[index_value]
1✔
808
        return instance
1✔
809

810

811
class Perturb(FieldOperator):
1✔
812
    """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
813

814
    When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
815
    relevant perturbation
816

817
    Args:
818
        percentage_to_perturb (int):
819
            the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
820
        select_from: List[Any]:
821
            a list of values to select from, as a perturbation of the field's value. Defaults to [].
822
    """
823

824
    select_from: List[Any] = []
1✔
825
    percentage_to_perturb: int = 1  # 1 percent
1✔
826

827
    def verify(self):
1✔
828
        assert (
1✔
829
            0 <= self.percentage_to_perturb and self.percentage_to_perturb <= 100
830
        ), f"'percentage_to_perturb' should be in the range 0..100. Received {self.percentage_to_perturb}"
831

832
    def prepare(self):
1✔
833
        super().prepare()
1✔
834
        self.random_generator = new_random_generator(sub_seed="CopyWithPerturbation")
1✔
835

836
    def process_value(self, value: Any) -> Any:
1✔
837
        perturb = self.random_generator.randint(1, 100) <= self.percentage_to_perturb
1✔
838
        if not perturb:
1✔
839
            return value
1✔
840

841
        if value in self.select_from:
1✔
842
            # 80% of cases, return a decent class, otherwise, perturb the value itself as follows
843
            if self.random_generator.random() < 0.8:
1✔
844
                return self.random_generator.choice(self.select_from)
1✔
845

846
        if isinstance(value, float):
1✔
847
            return value * (0.5 + self.random_generator.random())
1✔
848

849
        if isinstance(value, int):
1✔
850
            perturb = 1 if self.random_generator.random() < 0.5 else -1
1✔
851
            return value + perturb
1✔
852

853
        if isinstance(value, str):
1✔
854
            if len(value) < 2:
1✔
855
                # give up perturbation
856
                return value
1✔
857
            # throw one char out
858
            prefix_len = self.random_generator.randint(1, len(value) - 1)
1✔
859
            return value[:prefix_len] + value[prefix_len + 1 :]
1✔
860

861
        # and in any other case:
862
        return value
×
863

864

865
class Copy(FieldOperator):
1✔
866
    """Copies values from specified fields to specified fields.
867

868
    Args (of parent class):
869
        field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
870

871
    Examples:
872
        An input instance {"a": 2, "b": 3}, when processed by
873
        ``Copy(field_to_field={"a": "b"})``
874
        would yield {"a": 2, "b": 2}, and when processed by
875
        ``Copy(field_to_field={"a": "c"})`` would yield
876
        {"a": 2, "b": 3, "c": 2}
877

878
        with field names containing / , we can also copy inside the field:
879
        ``Copy(field="a/0",to_field="a")``
880
        would process instance {"a": [1, 3]} into {"a": 1}
881

882

883
    """
884

885
    def process_value(self, value: Any) -> Any:
1✔
886
        return value
1✔
887

888

889
class RecursiveCopy(FieldOperator):
1✔
890
    def process_value(self, value: Any) -> Any:
1✔
891
        return recursive_copy(value)
1✔
892

893

894
@deprecation(version="2.0.0", alternative=Copy)
1✔
895
class CopyFields(Copy):
1✔
896
    pass
1✔
897

898

899
class GetItemByIndex(FieldOperator):
1✔
900
    """Get from the item list by the index in the field."""
901

902
    items_list: List[Any]
1✔
903

904
    def process_value(self, value: Any) -> Any:
1✔
905
        return self.items_list[value]
×
906

907

908
class AddID(InstanceOperator):
1✔
909
    """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
910

911
    id_field_name: str = "id"
1✔
912

913
    def process(
1✔
914
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
915
    ) -> Dict[str, Any]:
916
        instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
1✔
917
        return instance
1✔
918

919

920
class Cast(FieldOperator):
1✔
921
    """Casts specified fields to specified types.
922

923
    Args:
924
        default (object): A dictionary mapping field names to default values for cases of casting failure.
925
        process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
926
    """
927

928
    to: str
1✔
929
    failure_default: Optional[Any] = "__UNDEFINED__"
1✔
930

931
    def prepare(self):
1✔
932
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
933

934
    def process_value(self, value):
1✔
935
        try:
1✔
936
            return self.types[self.to](value)
1✔
937
        except ValueError as e:
1✔
938
            if self.failure_default == "__UNDEFINED__":
1✔
939
                raise ValueError(
×
940
                    f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
941
                ) from e
942
            return self.failure_default
1✔
943

944

945
class CastFields(InstanceOperator):
1✔
946
    """Casts specified fields to specified types.
947

948
    Args:
949
        fields (Dict[str, str]):
950
            A dictionary mapping field names to the names of the types to cast the fields to.
951
            e.g: "int", "str", "float", "bool". Basic names of types
952
        defaults (Dict[str, object]):
953
            A dictionary mapping field names to default values for cases of casting failure.
954
        process_every_value (bool):
955
            If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
956

957
    Example:
958
        .. code-block:: python
959

960
                CastFields(
961
                    fields={"a/d": "float", "b": "int"},
962
                    failure_defaults={"a/d": 0.0, "b": 0},
963
                    process_every_value=True,
964
                )
965

966
    would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
967
    into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
968

969
    """
970

971
    fields: Dict[str, str] = field(default_factory=dict)
1✔
972
    failure_defaults: Dict[str, object] = field(default_factory=dict)
1✔
973
    use_nested_query: bool = None  # deprecated field
1✔
974
    process_every_value: bool = False
1✔
975

976
    def prepare(self):
1✔
977
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
978

979
    def verify(self):
1✔
980
        super().verify()
1✔
981
        if self.use_nested_query is not None:
1✔
982
            depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
1✔
983
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
984

985
    def _cast_single(self, value, type, field):
1✔
986
        try:
1✔
987
            return self.types[type](value)
1✔
988
        except Exception as e:
1✔
989
            if field not in self.failure_defaults:
1✔
990
                raise ValueError(
1✔
991
                    f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
992
                ) from e
993
            return self.failure_defaults[field]
1✔
994

995
    def _cast_multiple(self, values, type, field):
1✔
996
        return [self._cast_single(value, type, field) for value in values]
1✔
997

998
    def process(
1✔
999
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1000
    ) -> Dict[str, Any]:
1001
        for field_name, type in self.fields.items():
1✔
1002
            value = dict_get(instance, field_name)
1✔
1003
            if self.process_every_value:
1✔
1004
                assert isinstance(
1✔
1005
                    value, list
1006
                ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
1007
                casted_value = self._cast_multiple(value, type, field_name)
1✔
1008
            else:
1009
                casted_value = self._cast_single(value, type, field_name)
1✔
1010

1011
            dict_set(instance, field_name, casted_value)
1✔
1012
        return instance
1✔
1013

1014

1015
class DivideAllFieldsBy(InstanceOperator):
1✔
1016
    """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1017

1018
    The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
1019
    the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
1020
    if input flag 'strict' is True, or -- left alone, if 'strict' is False.
1021

1022
    Args:
1023
        divisor (float) the value to divide by
1024
        strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
1025

1026
    Example:
1027
        when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
1028
        operator = DivideAllFieldsBy(divisor=2.0)
1029
        the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
1030
        If the operator were defined with strict=True, through:
1031
        operator = DivideAllFieldsBy(divisor=2.0, strict=True),
1032
        the processing of the above instance would raise a ValueError, for the integer at "c".
1033
    """
1034

1035
    divisor: float = 1.0
1✔
1036
    strict: bool = False
1✔
1037

1038
    def _recursive_divide(self, instance, divisor):
1✔
1039
        if isinstance(instance, dict):
1✔
1040
            for key, value in instance.items():
1✔
1041
                instance[key] = self._recursive_divide(value, divisor)
1✔
1042
        elif isinstance(instance, list):
1✔
1043
            for i, value in enumerate(instance):
1✔
1044
                instance[i] = self._recursive_divide(value, divisor)
1✔
1045
        elif isinstance(instance, float):
1✔
1046
            instance /= divisor
1✔
1047
        elif self.strict:
1✔
1048
            raise ValueError(f"Cannot divide instance of type {type(instance)}")
1✔
1049
        return instance
1✔
1050

1051
    def process(
1✔
1052
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1053
    ) -> Dict[str, Any]:
1054
        return self._recursive_divide(instance, self.divisor)
1✔
1055

1056

1057
class ArtifactFetcherMixin:
1✔
1058
    """Provides a way to fetch and cache artifacts in the system.
1059

1060
    Args:
1061
        cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
1062
    """
1063

1064
    _artifacts_cache = LRUCache(max_size=1000)
1✔
1065

1066
    @classmethod
1✔
1067
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
1✔
1068
        if str(artifact_identifier) not in cls._artifacts_cache:
1✔
1069
            artifact, catalog = fetch_artifact(artifact_identifier)
1✔
1070
            cls._artifacts_cache[str(artifact_identifier)] = artifact
1✔
1071
        return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])
1✔
1072

1073

1074
class ApplyOperatorsField(InstanceOperator):
1✔
1075
    """Applies value operators to each instance in a stream based on specified fields.
1076

1077
    Args:
1078
        operators_field (str): name of the field that contains a single name, or a list of names, of the operators to be applied,
1079
            one after the other, for the processing of the instance. Each operator is equipped with 'process_instance()'
1080
            method.
1081

1082
        default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
1083

1084
    Example:
1085
        when instance {"prediction": 111, "references": [222, 333] , "c": ["processors.to_string", "processors.first_character"]}
1086
        is processed by operator (please look up the catalog that these operators, they are tuned to process fields "prediction" and
1087
        "references"):
1088
        operator = ApplyOperatorsField(operators_field="c"),
1089
        the resulting instance is: {"prediction": "1", "references": ["2", "3"], "c": ["processors.to_string", "processors.first_character"]}
1090

1091
    """
1092

1093
    operators_field: str
1✔
1094
    default_operators: List[str] = None
1✔
1095

1096
    def process(
1✔
1097
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1098
    ) -> Dict[str, Any]:
1099
        operator_names = instance.get(self.operators_field)
1✔
1100
        if operator_names is None:
1✔
1101
            assert (
1✔
1102
                self.default_operators is not None
1103
            ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1104
            operator_names = self.default_operators
1✔
1105

1106
        if isinstance(operator_names, str):
1✔
1107
            operator_names = [operator_names]
1✔
1108
        # otherwise , operator_names is already a list
1109

1110
        # we now have a list of nanes of operators, each is equipped with process_instance method.
1111
        operator = SequentialOperator(steps=operator_names)
1✔
1112
        return operator.process_instance(instance, stream_name=stream_name)
1✔
1113

1114

1115
class FilterByCondition(StreamOperator):
1✔
1116
    """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1117

1118
    Raises an error if a required field name is missing from the input instance.
1119

1120
    Args:
1121
       values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
1122

1123
       condition: the name of the desired condition operator between the specified (sub) field's value  and the provided constant value.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1124

1125
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1126

1127
    Examples:
1128
       | ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
1129
       | ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
1130
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
1131
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
1132
       | ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
1133
       | ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
1134

1135

1136
    """
1137

1138
    values: Dict[str, Any]
1✔
1139
    condition: str
1✔
1140
    condition_to_func = {
1✔
1141
        "gt": operator.gt,
1142
        "ge": operator.ge,
1143
        "lt": operator.lt,
1144
        "le": operator.le,
1145
        "eq": operator.eq,
1146
        "ne": operator.ne,
1147
        "in": None,  # Handled as special case
1148
        "not in": None,  # Handled as special case
1149
    }
1150
    error_on_filtered_all: bool = True
1✔
1151

1152
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1153
        yielded = False
1✔
1154
        for instance in stream:
1✔
1155
            if self._is_required(instance):
1✔
1156
                yielded = True
1✔
1157
                yield instance
1✔
1158

1159
        if not yielded and self.error_on_filtered_all:
1✔
1160
            raise RuntimeError(
1✔
1161
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1162
            )
1163

1164
    def verify(self):
1✔
1165
        if self.condition not in self.condition_to_func:
1✔
1166
            raise ValueError(
1✔
1167
                f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1168
            )
1169

1170
        for key, value in self.values.items():
1✔
1171
            if self.condition in ["in", "not it"] and not isinstance(value, list):
1✔
1172
                raise ValueError(
1✔
1173
                    f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1174
                )
1175
        return super().verify()
1✔
1176

1177
    def _is_required(self, instance: dict) -> bool:
1✔
1178
        for key, value in self.values.items():
1✔
1179
            try:
1✔
1180
                instance_key = dict_get(instance, key)
1✔
1181
            except ValueError as ve:
1✔
1182
                raise ValueError(
1✔
1183
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1184
                ) from ve
1185
            if self.condition == "in":
1✔
1186
                if instance_key not in value:
1✔
1187
                    return False
1✔
1188
            elif self.condition == "not in":
1✔
1189
                if instance_key in value:
1✔
1190
                    return False
1✔
1191
            else:
1192
                func = self.condition_to_func[self.condition]
1✔
1193
                if func is None:
1✔
1194
                    raise ValueError(
×
1195
                        f"Function not defined for condition '{self.condition}'"
1196
                    )
1197
                if not func(instance_key, value):
1✔
1198
                    return False
1✔
1199
        return True
1✔
1200

1201

1202
class FilterByConditionBasedOnFields(FilterByCondition):
1✔
1203
    """Filters a stream based on a condition between 2 fields values.
1204

1205
    Raises an error if either of the required fields names is missing from the input instance.
1206

1207
    Args:
1208
       values (Dict[str, str]): The fields names that the filter operation is based on.
1209
       condition: the name of the desired condition operator between the specified field's values.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1210
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1211

1212
    Examples:
1213
       FilterByCondition(values = {"a":"b}, condition = "gt") will yield only instances where field "a" contains a value greater then the value in field "b".
1214
       FilterByCondition(values = {"a":"b}, condition = "le") will yield only instances where "a"<="b"
1215
    """
1216

1217
    def _is_required(self, instance: dict) -> bool:
1✔
1218
        for key, value in self.values.items():
1✔
1219
            try:
1✔
1220
                instance_key = dict_get(instance, key)
1✔
1221
            except ValueError as ve:
×
1222
                raise ValueError(
×
1223
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1224
                ) from ve
1225
            try:
1✔
1226
                instance_value = dict_get(instance, value)
1✔
1227
            except ValueError as ve:
×
1228
                raise ValueError(
×
1229
                    f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1230
                ) from ve
1231
            if self.condition == "in":
1✔
1232
                if instance_key not in instance_value:
×
1233
                    return False
×
1234
            elif self.condition == "not in":
1✔
1235
                if instance_key in instance_value:
×
1236
                    return False
×
1237
            else:
1238
                func = self.condition_to_func[self.condition]
1✔
1239
                if func is None:
1✔
1240
                    raise ValueError(
×
1241
                        f"Function not defined for condition '{self.condition}'"
1242
                    )
1243
                if not func(instance_key, instance_value):
1✔
1244
                    return False
×
1245
        return True
1✔
1246

1247

1248
class ComputeExpressionMixin(Artifact):
1✔
1249
    """Computes an expression expressed over fields of an instance.
1250

1251
    Args:
1252
        expression (str): the expression, in terms of names of fields of an instance
1253
        imports_list (List[str]): list of names of imports needed for the evaluation of the expression
1254
    """
1255

1256
    expression: str
1✔
1257
    imports_list: List[str] = OptionalField(default_factory=list)
1✔
1258

1259
    def prepare(self):
1✔
1260
        # can not do the imports here, because object does not pickle with imports
1261
        self.globals = {
1✔
1262
            module_name: __import__(module_name) for module_name in self.imports_list
1263
        }
1264

1265
    def compute_expression(self, instance: dict) -> Any:
1✔
1266
        if settings.allow_unverified_code:
1✔
1267
            return eval(self.expression, {**self.globals, **instance})
1✔
1268

1269
        raise ValueError(
×
1270
            f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1271
            "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1272
        )
1273

1274

1275
class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1✔
1276
    """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1277

1278
    Raises an error if a field participating in the specified condition is missing from the instance
1279

1280
    Args:
1281
        expression (str):
1282
            a condition over fields of the instance, to be processed by python's eval()
1283
        imports_list (List[str]):
1284
            names of imports needed for the eval of the query (e.g. 're', 'json')
1285
        error_on_filtered_all (bool, optional):
1286
            If True, raises an error if all instances are filtered out. Defaults to True.
1287

1288
    Examples:
1289
        | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1290
        | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1291
        | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1292
        | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1293
        | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1294
    """
1295

1296
    error_on_filtered_all: bool = True
1✔
1297

1298
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1299
        yielded = False
1✔
1300
        for instance in stream:
1✔
1301
            if self.compute_expression(instance):
1✔
1302
                yielded = True
1✔
1303
                yield instance
1✔
1304

1305
        if not yielded and self.error_on_filtered_all:
1✔
1306
            raise RuntimeError(
1✔
1307
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1308
            )
1309

1310

1311
class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1✔
1312
    """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1313

1314
    Raises an error if a field mentioned in the query is missing from the instance.
1315

1316
    Args:
1317
       expression (str): an expression to be evaluated over the fields of the instance
1318
       to_field (str): the field where the result is to be stored into
1319
       imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1320

1321
    Examples:
1322
       When instance {"a": 2, "b": 3} is process-ed by operator
1323
       ExecuteExpression(expression="a+b", to_field = "c")
1324
       the result is {"a": 2, "b": 3, "c": 5}
1325

1326
       When instance {"a": "hello", "b": "world"} is process-ed by operator
1327
       ExecuteExpression(expression = "a+' '+b", to_field = "c")
1328
       the result is {"a": "hello", "b": "world", "c": "hello world"}
1329

1330
    """
1331

1332
    to_field: str
1✔
1333

1334
    def process(
1✔
1335
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1336
    ) -> Dict[str, Any]:
1337
        instance[self.to_field] = self.compute_expression(instance)
1✔
1338
        return instance
1✔
1339

1340

1341
class ExtractMostCommonFieldValues(MultiStreamOperator):
1✔
1342
    field: str
1✔
1343
    stream_name: str
1✔
1344
    overall_top_frequency_percent: Optional[int] = 100
1✔
1345
    min_frequency_percent: Optional[int] = 0
1✔
1346
    to_field: str
1✔
1347
    process_every_value: Optional[bool] = False
1✔
1348

1349
    """
1350
    Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
1351
    as a list in a new field ('to_field') in all streams.
1352

1353
    More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
1354
    When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
1355
    the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
1356
    When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
1357
    less than 'min_frequency_percent' of the total number of instances in the stream.
1358
    At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
1359

1360
    Examples:
1361

1362
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1363
    field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1364
    every instance in all streams.
1365

1366
    ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1367
    in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1368
    value members in these lists, and report the most frequent values.
1369
    if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1370
    'to_field' of each instance of all streams.
1371

1372
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1373
    extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1374
    and stores them in field 'classes' of each instance of all streams.
1375

1376
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1377
    extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1378
    Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1379
    """
1380

1381
    def verify(self):
1✔
1382
        assert (
1✔
1383
            self.overall_top_frequency_percent <= 100
1384
            and self.overall_top_frequency_percent >= 0
1385
        ), "'overall_top_frequency_percent' must be between 0 and 100"
1386
        assert (
1✔
1387
            self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
1388
        ), "'min_frequency_percent' must be between 0 and 100"
1389
        assert not (
1✔
1390
            self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
1391
        ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
1392
        super().verify()
1✔
1393

1394
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1395
        stream = multi_stream[self.stream_name]
1✔
1396
        counter = Counter()
1✔
1397
        for instance in stream:
1✔
1398
            if (not isinstance(instance[self.field], list)) and (
1✔
1399
                self.process_every_value is True
1400
            ):
1401
                raise ValueError(
1✔
1402
                    "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
1403
                )
1404
            if (not isinstance(instance[self.field], list)) or (
1✔
1405
                self.process_every_value is False
1406
            ):
1407
                # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
1408
                counter.update(
1✔
1409
                    [(*instance[self.field],)]
1410
                    if isinstance(instance[self.field], list)
1411
                    else [instance[self.field]]
1412
                )  # convert to a tuple if list, to enable the use of Counter which would not accept
1413
                # a list as an hashable entity to count its occurrences
1414
            else:
1415
                # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
1416
                counter.update(instance[self.field])
1✔
1417
        # here counter counts occurrences of individual values, or tuples.
1418
        values_and_counts = counter.most_common()
1✔
1419
        if self.overall_top_frequency_percent < 100:
1✔
1420
            top_frequency = (
1✔
1421
                sum(counter.values()) * self.overall_top_frequency_percent / 100.0
1422
            )
1423
            sum_counts = 0
1✔
1424
            for _i, p in enumerate(values_and_counts):
1✔
1425
                sum_counts += p[1]
1✔
1426
                if sum_counts >= top_frequency:
1✔
1427
                    break
1✔
1428
            values_and_counts = counter.most_common(_i + 1)
1✔
1429
        if self.min_frequency_percent > 0:
1✔
1430
            min_frequency = self.min_frequency_percent * sum(counter.values()) / 100.0
1✔
1431
            while values_and_counts[-1][1] < min_frequency:
1✔
1432
                values_and_counts.pop()
1✔
1433
        values_to_keep = [
1✔
1434
            [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1435
            for ele in values_and_counts
1436
        ]
1437

1438
        addmostcommons = Set(fields={self.to_field: values_to_keep})
1✔
1439
        return addmostcommons(multi_stream)
1✔
1440

1441

1442
class ExtractFieldValues(ExtractMostCommonFieldValues):
1✔
1443
    def verify(self):
1✔
1444
        super().verify()
1✔
1445

1446
    def prepare(self):
1✔
1447
        self.overall_top_frequency_percent = 100
1✔
1448
        self.min_frequency_percent = 0
1✔
1449

1450

1451
class Intersect(FieldOperator):
1✔
1452
    """Intersects the value of a field, which must be a list, with a given list.
1453

1454
    Args:
1455
        allowed_values (list) - list to intersect.
1456
    """
1457

1458
    allowed_values: List[Any]
1✔
1459

1460
    def verify(self):
1✔
1461
        super().verify()
1✔
1462
        if self.process_every_value:
1✔
1463
            raise ValueError(
1✔
1464
                "'process_every_value=True' is not supported in Intersect operator"
1465
            )
1466

1467
        if not isinstance(self.allowed_values, list):
1✔
1468
            raise ValueError(
1✔
1469
                f"The allowed_values is not a list but '{self.allowed_values}'"
1470
            )
1471

1472
    def process_value(self, value: Any) -> Any:
1✔
1473
        super().process_value(value)
1✔
1474
        if not isinstance(value, list):
1✔
1475
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1476
        return [e for e in value if e in self.allowed_values]
1✔
1477

1478

1479
class RemoveValues(FieldOperator):
1✔
1480
    """Removes elements in a field, which must be a list, using a given list of unallowed.
1481

1482
    Args:
1483
        unallowed_values (list) - values to be removed.
1484
    """
1485

1486
    unallowed_values: List[Any]
1✔
1487

1488
    def verify(self):
1✔
1489
        super().verify()
1✔
1490

1491
        if not isinstance(self.unallowed_values, list):
1✔
1492
            raise ValueError(
1✔
1493
                f"The unallowed_values is not a list but '{self.unallowed_values}'"
1494
            )
1495

1496
    def process_value(self, value: Any) -> Any:
1✔
1497
        if not isinstance(value, list):
1✔
1498
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1499
        return [e for e in value if e not in self.unallowed_values]
1✔
1500

1501

1502
class Unique(SingleStreamReducer):
1✔
1503
    """Reduces a stream to unique instances based on specified fields.
1504

1505
    Args:
1506
        fields (List[str]): The fields that should be unique in each instance.
1507
    """
1508

1509
    fields: List[str] = field(default_factory=list)
1✔
1510

1511
    @staticmethod
1✔
1512
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
1✔
1513
        result = []
1✔
1514
        for field_name in fields:
1✔
1515
            value = instance[field_name]
1✔
1516
            if isinstance(value, list):
1✔
1517
                value = tuple(value)
1✔
1518
            result.append(value)
1✔
1519
        return tuple(result)
1✔
1520

1521
    def process(self, stream: Stream) -> Stream:
1✔
1522
        seen = set()
1✔
1523
        for instance in stream:
1✔
1524
            values = self.to_tuple(instance, self.fields)
1✔
1525
            if values not in seen:
1✔
1526
                seen.add(values)
1✔
1527
        return list(seen)
1✔
1528

1529

1530
class SplitByValue(MultiStreamOperator):
1✔
1531
    """Splits a MultiStream into multiple streams based on unique values in specified fields.
1532

1533
    Args:
1534
        fields (List[str]): The fields to use when splitting the MultiStream.
1535
    """
1536

1537
    fields: List[str] = field(default_factory=list)
1✔
1538

1539
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1540
        uniques = Unique(fields=self.fields)(multi_stream)
1✔
1541

1542
        result = {}
1✔
1543

1544
        for stream_name, stream in multi_stream.items():
1✔
1545
            stream_unique_values = uniques[stream_name]
1✔
1546
            for unique_values in stream_unique_values:
1✔
1547
                filtering_values = dict(zip(self.fields, unique_values))
1✔
1548
                filtered_streams = FilterByCondition(
1✔
1549
                    values=filtering_values, condition="eq"
1550
                )._process_single_stream(stream)
1551
                filtered_stream_name = (
1✔
1552
                    stream_name + "_" + nested_tuple_to_string(unique_values)
1553
                )
1554
                result[filtered_stream_name] = filtered_streams
1✔
1555

1556
        return MultiStream(result)
1✔
1557

1558

1559
class SplitByNestedGroup(MultiStreamOperator):
1✔
1560
    """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1561

1562
    Args:
1563
        number_of_fusion_generations: int
1564

1565
    the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1566
    when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1567
    (See BaseFusion and its extensions)
1568
    number_of_fuaion_generations  specifies the length of the prefix by which to split the stream.
1569
    E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1570
    For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1571
    """
1572

1573
    field_name_of_group: str = "group"
1✔
1574
    number_of_fusion_generations: int = 1
1✔
1575

1576
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1577
        result = defaultdict(list)
×
1578

1579
        for stream_name, stream in multi_stream.items():
×
1580
            for instance in stream:
×
1581
                if self.field_name_of_group not in instance:
×
1582
                    raise ValueError(
×
1583
                        f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1584
                    )
1585
                signature = (
×
1586
                    stream_name
1587
                    + "~"  #  a sign that does not show within group values
1588
                    + (
1589
                        "/".join(
1590
                            instance[self.field_name_of_group].split("/")[
1591
                                : self.number_of_fusion_generations
1592
                            ]
1593
                        )
1594
                        if self.number_of_fusion_generations >= 0
1595
                        # for values with a smaller number of generations - take up to their last generation
1596
                        else instance[self.field_name_of_group]
1597
                        # for each instance - take all its generations
1598
                    )
1599
                )
1600
                result[signature].append(instance)
×
1601

1602
        return MultiStream.from_iterables(result)
×
1603

1604

1605
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1✔
1606
    """Applies stream operators to a stream based on specified fields in each instance.
1607

1608
    Args:
1609
        field (str): The field containing the operators to be applied.
1610
        reversed (bool): Whether to apply the operators in reverse order.
1611
    """
1612

1613
    field: str
1✔
1614
    reversed: bool = False
1✔
1615

1616
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1617
        first_instance = stream.peek()
1✔
1618

1619
        operators = first_instance.get(self.field, [])
1✔
1620
        if isinstance(operators, str):
1✔
1621
            operators = [operators]
1✔
1622

1623
        if self.reversed:
1✔
1624
            operators = list(reversed(operators))
1✔
1625

1626
        for operator_name in operators:
1✔
1627
            operator = self.get_artifact(operator_name)
1✔
1628
            assert isinstance(
1✔
1629
                operator, StreamingOperator
1630
            ), f"Operator {operator_name} must be a StreamOperator"
1631

1632
            stream = operator(MultiStream({stream_name: stream}))[stream_name]
1✔
1633

1634
        yield from stream
1✔
1635

1636

1637
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
1✔
1638
    for instance, score in zip(stream, scores):
1✔
1639
        instance["score"] = recursive_copy(score)
1✔
1640
        yield instance
1✔
1641

1642

1643
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1✔
1644
    """Applies metric operators to a stream based on a metric field specified in each instance.
1645

1646
    Args:
1647
        metric_field (str): The field containing the metrics to be applied.
1648
        calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1649
    """
1650

1651
    metric_field: str
1✔
1652
    calc_confidence_intervals: bool
1✔
1653

1654
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1655
        from .metrics import Metric, MetricsList
1✔
1656

1657
        # to be populated only when two or more metrics
1658
        accumulated_scores = []
1✔
1659

1660
        first_instance = stream.peek()
1✔
1661

1662
        metric_names = first_instance.get(self.metric_field, [])
1✔
1663
        if not metric_names:
1✔
1664
            raise RuntimeError(
1✔
1665
                f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1666
            )
1667

1668
        if isinstance(metric_names, str):
1✔
1669
            metric_names = [metric_names]
1✔
1670

1671
        metrics_list = []
1✔
1672
        for metric_name in metric_names:
1✔
1673
            metric = self.get_artifact(metric_name)
1✔
1674
            if isinstance(metric, MetricsList):
1✔
1675
                metrics_list.extend(list(metric.items))
1✔
1676
            elif isinstance(metric, Metric):
1✔
1677
                metrics_list.append(metric)
1✔
1678
            else:
1679
                raise ValueError(
×
1680
                    f"Operator {metric_name} must be a Metric or MetricsList"
1681
                )
1682

1683
        for metric in metrics_list:
1✔
1684
            if not self.calc_confidence_intervals:
1✔
1685
                metric.disable_confidence_interval_calculation()
1✔
1686
        # Each metric operator computes its score and then sets the main score, overwriting
1687
        # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1688
        # This will cause the first listed metric to run last, and the main score will be set
1689
        # by the first listed metric (as desired).
1690
        metrics_list = list(reversed(metrics_list))
1✔
1691

1692
        for i, metric in enumerate(metrics_list):
1✔
1693
            if i == 0:  # first metric
1✔
1694
                multi_stream = MultiStream({"tmp": stream})
1✔
1695
            else:  # metrics with previous scores
1696
                reusable_generator = ReusableGenerator(
1✔
1697
                    generator=update_scores_of_stream_instances,
1698
                    gen_kwargs={"stream": stream, "scores": accumulated_scores},
1699
                )
1700
                multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1✔
1701

1702
            multi_stream = metric(multi_stream)
1✔
1703

1704
            if i < len(metrics_list) - 1:  # last metric
1✔
1705
                accumulated_scores = []
1✔
1706
                for inst in multi_stream["tmp"]:
1✔
1707
                    accumulated_scores.append(recursive_copy(inst["score"]))
1✔
1708

1709
        yield from multi_stream["tmp"]
1✔
1710

1711

1712
class MergeStreams(MultiStreamOperator):
1✔
1713
    """Merges multiple streams into a single stream.
1714

1715
    Args:
1716
        new_stream_name (str): The name of the new stream resulting from the merge.
1717
        add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
1718
        origin_stream_name_field_name (str): The field name for the origin stream name.
1719
    """
1720

1721
    streams_to_merge: List[str] = None
1✔
1722
    new_stream_name: str = "all"
1✔
1723
    add_origin_stream_name: bool = True
1✔
1724
    origin_stream_name_field_name: str = "origin"
1✔
1725

1726
    def merge(self, multi_stream) -> Generator:
1✔
1727
        for stream_name, stream in multi_stream.items():
1✔
1728
            if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1✔
1729
                for instance in stream:
1✔
1730
                    if self.add_origin_stream_name:
1✔
1731
                        instance[self.origin_stream_name_field_name] = stream_name
1✔
1732
                    yield instance
1✔
1733

1734
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1735
        return MultiStream(
1✔
1736
            {
1737
                self.new_stream_name: DynamicStream(
1738
                    self.merge, gen_kwargs={"multi_stream": multi_stream}
1739
                )
1740
            }
1741
        )
1742

1743

1744
class Shuffle(PagedStreamOperator):
1✔
1745
    """Shuffles the order of instances in each page of a stream.
1746

1747
    Args (of superclass):
1748
        page_size (int): The size of each page in the stream. Defaults to 1000.
1749
    """
1750

1751
    random_generator: Random = None
1✔
1752

1753
    def before_process_multi_stream(self):
1✔
1754
        super().before_process_multi_stream()
1✔
1755
        self.random_generator = new_random_generator(sub_seed="shuffle")
1✔
1756

1757
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1758
        self.random_generator.shuffle(page)
1✔
1759
        yield from page
1✔
1760

1761

1762
class FeatureGroupedShuffle(Shuffle):
1✔
1763
    """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1764

1765
    Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1766
    All paraphrases have the same ID value as the original.
1767
    In this case, we may want to shuffle on grouping_features = ['question ID'],
1768
    to keep the paraphrases and original question together.
1769
    We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1770
    In this case, grouping_features = ['question ID', 'topic']
1771

1772
    Args:
1773
        grouping_features (list of strings): list of feature names to use to define the groups.
1774
            a group is defined by each unique observed combination of data values for features in grouping_features
1775
        shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1776

1777
    Args (of superclass):
1778
        page_size (int): The size of each page in the stream. Defaults to 1000.
1779
            Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1780
            separately by page (determined by page_size).  If a block of instances in the same group are split
1781
            into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1782
            grouping_features), these instances will be shuffled separately and thus the grouping may be
1783
            broken up by pages.  If the user wants to ensure the shuffle does the grouping and shuffling
1784
            across all pages, set the page_size to be larger than the dataset size.
1785
            See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1786
    """
1787

1788
    grouping_features: List[str] = None
1✔
1789
    shuffle_within_group: bool = False
1✔
1790

1791
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1792
        if self.grouping_features is None:
1✔
1793
            super().process(page, stream_name)
×
1794
        else:
1795
            yield from self.shuffle_by_grouping_features(page)
1✔
1796

1797
    def shuffle_by_grouping_features(self, page):
1✔
1798
        import itertools
1✔
1799
        from collections import defaultdict
1✔
1800

1801
        groups_to_instances = defaultdict(list)
1✔
1802
        for item in page:
1✔
1803
            groups_to_instances[
1✔
1804
                tuple(item[ff] for ff in self.grouping_features)
1805
            ].append(item)
1806
        # now extract the groups (i.e., lists of dicts with order preserved)
1807
        page_blocks = list(groups_to_instances.values())
1✔
1808
        # and now shuffle the blocks
1809
        self.random_generator.shuffle(page_blocks)
1✔
1810
        if self.shuffle_within_group:
1✔
1811
            blocks = []
1✔
1812
            # reshuffle the instances within each block, but keep the blocks in order
1813
            for block in page_blocks:
1✔
1814
                self.random_generator.shuffle(block)
1✔
1815
                blocks.append(block)
1✔
1816
            page_blocks = blocks
1✔
1817

1818
        # now flatten the list so it consists of individual dicts, but in (randomized) block order
1819
        return list(itertools.chain(*page_blocks))
1✔
1820

1821

1822
class EncodeLabels(InstanceOperator):
1✔
1823
    """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1824

1825
    Encoding is determined by a str->int map that is built on the go, as different values are
1826
    first encountered in the stream, either as list members or as values in single-value fields.
1827

1828
    Args:
1829
        fields (List[str]): The fields to encode together.
1830

1831
    Example:
1832
        applying ``EncodeLabels(fields = ["a", "b/*"])``
1833
        on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
1834
        {"a": "blue", "b": ["green"], "c":"water"}]``   will yield the
1835
        output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
1836

1837
        Note: dict_utils are applied here, and hence, fields that are lists, should be included in
1838
        input 'fields' with the appendix ``"/*"``  as in the above example.
1839

1840
    """
1841

1842
    fields: List[str]
1✔
1843

1844
    def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
1✔
1845
        self.encoder = {}
1✔
1846
        return super()._process_multi_stream(multi_stream)
1✔
1847

1848
    def process(
1✔
1849
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1850
    ) -> Dict[str, Any]:
1851
        for field_name in self.fields:
1✔
1852
            values = dict_get(instance, field_name)
1✔
1853
            values_was_a_list = isinstance(values, list)
1✔
1854
            if not isinstance(values, list):
1✔
1855
                values = [values]
1✔
1856
            for value in values:
1✔
1857
                if value not in self.encoder:
1✔
1858
                    self.encoder[value] = len(self.encoder)
1✔
1859
            new_values = [self.encoder[value] for value in values]
1✔
1860
            if not values_was_a_list:
1✔
1861
                new_values = new_values[0]
1✔
1862
            dict_set(
1✔
1863
                instance,
1864
                field_name,
1865
                new_values,
1866
                not_exist_ok=False,  # the values to encode where just taken from there
1867
                set_multiple="*" in field_name
1868
                and isinstance(new_values, list)
1869
                and len(new_values) > 0,
1870
            )
1871

1872
        return instance
1✔
1873

1874

1875
class StreamRefiner(StreamOperator):
1✔
1876
    """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1877

1878
    Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
1879
    input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1880
    of the leading 'max_instances' of the input stream.
1881

1882
    Args:
1883
        max_instances (int)
1884
        apply_to_streams (optional, list(str)):
1885
            names of streams to refine.
1886

1887
    Examples:
1888
        when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
1889
        ``StreamRefiner(max_instances=4)``
1890
        the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
1891
    """
1892

1893
    max_instances: int = None
1✔
1894
    apply_to_streams: Optional[List[str]] = None
1✔
1895

1896
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1897
        if self.max_instances is not None:
1✔
1898
            yield from stream.take(self.max_instances)
1✔
1899
        else:
1900
            yield from stream
1✔
1901

1902

1903
class Deduplicate(StreamOperator):
1✔
1904
    """Deduplicate the stream based on the given fields.
1905

1906
    Args:
1907
        by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness.
1908

1909
    Examples:
1910
        >>> dedup = Deduplicate(by=["field1", "field2"])
1911
    """
1912

1913
    by: List[str]
1✔
1914

1915
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1916
        seen = set()
1✔
1917

1918
        for instance in stream:
1✔
1919
            # Compute a lightweight hash for the signature
1920
            signature = hash(str(tuple(dict_get(instance, field) for field in self.by)))
1✔
1921

1922
            if signature not in seen:
1✔
1923
                seen.add(signature)
1✔
1924
                yield instance
1✔
1925

1926

1927
class Balance(StreamRefiner):
1✔
1928
    """A class used to balance streams deterministically.
1929

1930
    For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
1931
    By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
1932
    When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
1933
    'max_instances'. The total number of discarded instances is as few as possible.
1934

1935
    Args:
1936
        fields (List[str]):
1937
            A list of field names to be used in producing the instance's signature.
1938
        max_instances (Optional, int):
1939
            overall max.
1940

1941
    Usage:
1942
        ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
1943
        ``balanced_stream = balancer.process(stream)``
1944

1945
    Example:
1946
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
1947
        ``DeterministicBalancer(fields=["a"])``
1948
        the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
1949
    """
1950

1951
    fields: List[str]
1✔
1952

1953
    def signature(self, instance):
1✔
1954
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
1955

1956
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1957
        counter = Counter()
1✔
1958

1959
        for instance in stream:
1✔
1960
            counter[self.signature(instance)] += 1
1✔
1961

1962
        if len(counter) == 0:
1✔
1963
            return
1✔
1964

1965
        lowest_count = counter.most_common()[-1][-1]
1✔
1966

1967
        max_total_instances_per_sign = lowest_count
1✔
1968
        if self.max_instances is not None:
1✔
1969
            max_total_instances_per_sign = min(
1✔
1970
                lowest_count, self.max_instances // len(counter)
1971
            )
1972

1973
        counter = Counter()
1✔
1974

1975
        for instance in stream:
1✔
1976
            sign = self.signature(instance)
1✔
1977
            if counter[sign] < max_total_instances_per_sign:
1✔
1978
                counter[sign] += 1
1✔
1979
                yield instance
1✔
1980

1981

1982
class DeterministicBalancer(Balance):
1✔
1983
    pass
1✔
1984

1985

1986
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1✔
1987
    """A class used to return a specified number instances ensuring at least one example  per label.
1988

1989
    For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
1990
    ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit.  In general, the refiner takes the first elements in the stream that meet the required conditions.
1991
    ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
1992
    from each class first and then the rest . If max instance is not set, the original stream will be used
1993

1994
    Args:
1995
        fields (List[str]):
1996
            A list of field names to be used in producing the instance's signature.
1997
        max_instances (Optional, int):
1998
            Number of elements to select. Note that max_instances of StreamRefiners
1999
            that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
2000
            by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
2001

2002
    Usage:
2003
        | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
2004
        | ``balanced_stream = balancer.process(stream)``
2005

2006
    Example:
2007
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
2008
        ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
2009
        the resulting stream will be:
2010
        ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
2011
    """
2012

2013
    fields: List[str]
1✔
2014

2015
    def signature(self, instance):
1✔
2016
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2017

2018
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2019
        if self.max_instances is None:
1✔
2020
            for instance in stream:
×
2021
                yield instance
×
2022

2023
        counter = Counter()
1✔
2024
        for instance in stream:
1✔
2025
            counter[self.signature(instance)] += 1
1✔
2026
        all_keys = counter.keys()
1✔
2027
        if len(counter) == 0:
1✔
2028
            return
×
2029

2030
        if self.max_instances is not None and len(all_keys) > self.max_instances:
1✔
2031
            raise Exception(
×
2032
                f"Can not generate a stream with at least one example per label, because the max instances requested  {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2033
                f" ({len(all_keys)}"
2034
            )
2035

2036
        counter = Counter()
1✔
2037
        used_indices = set()
1✔
2038
        selected_elements = []
1✔
2039
        # select at least one per class
2040
        for idx, instance in enumerate(stream):
1✔
2041
            sign = self.signature(instance)
1✔
2042
            if counter[sign] == 0:
1✔
2043
                counter[sign] += 1
1✔
2044
                used_indices.add(idx)
1✔
2045
                selected_elements.append(
1✔
2046
                    instance
2047
                )  # collect all elements first to allow shuffling of both groups
2048

2049
        # select more to reach self.max_instances examples
2050
        for idx, instance in enumerate(stream):
1✔
2051
            if idx not in used_indices:
1✔
2052
                if self.max_instances is None or len(used_indices) < self.max_instances:
1✔
2053
                    used_indices.add(idx)
1✔
2054
                    selected_elements.append(
1✔
2055
                        instance
2056
                    )  # collect all elements first to allow shuffling of both groups
2057

2058
        # shuffle elements to avoid having one element from each class appear first
2059
        random_generator = new_random_generator(sub_seed=selected_elements)
1✔
2060
        random_generator.shuffle(selected_elements)
1✔
2061
        yield from selected_elements
1✔
2062

2063

2064
class LengthBalancer(DeterministicBalancer):
1✔
2065
    """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2066

2067
    Args:
2068
        segments_boundaries (List[int]):
2069
            distinct integers sorted in increasing order, that map a given total length
2070
            into the index of the least of them that exceeds the given total length.
2071
            (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2072
        fields (Optional, List[str]):
2073
            the total length of the values of these fields goes through the quantization described above
2074

2075

2076
    Example:
2077
        when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2078
        is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2079
        input instances will be counted and balanced against two categories:
2080
        empty total length (less than 1), and non-empty.
2081
    """
2082

2083
    segments_boundaries: List[int]
1✔
2084
    fields: Optional[List[str]]
1✔
2085

2086
    def signature(self, instance):
1✔
2087
        total_len = 0
1✔
2088
        for field_name in self.fields:
1✔
2089
            total_len += len(dict_get(instance, field_name))
1✔
2090
        for i, val in enumerate(self.segments_boundaries):
1✔
2091
            if total_len < val:
1✔
2092
                return i
1✔
2093
        return i + 1
1✔
2094

2095

2096
class DownloadError(Exception):
1✔
2097
    def __init__(
1✔
2098
        self,
2099
        message,
2100
    ):
2101
        self.__super__(message)
×
2102

2103

2104
class UnexpectedHttpCodeError(Exception):
1✔
2105
    def __init__(self, http_code):
1✔
2106
        self.__super__(f"unexpected http code {http_code}")
×
2107

2108

2109
class DownloadOperator(SideEffectOperator):
1✔
2110
    """Operator for downloading a file from a given URL to a specified local path.
2111

2112
    Args:
2113
        source (str):
2114
            URL of the file to be downloaded.
2115
        target (str):
2116
            Local path where the downloaded file should be saved.
2117
    """
2118

2119
    source: str
1✔
2120
    target: str
1✔
2121

2122
    def process(self):
1✔
2123
        try:
×
2124
            response = requests.get(self.source, allow_redirects=True)
×
2125
        except Exception as e:
×
2126
            raise DownloadError(f"Unabled to download {self.source}") from e
×
2127
        if response.status_code != 200:
×
2128
            raise UnexpectedHttpCodeError(response.status_code)
×
2129
        with open(self.target, "wb") as f:
×
2130
            f.write(response.content)
×
2131

2132

2133
class ExtractZipFile(SideEffectOperator):
1✔
2134
    """Operator for extracting files from a zip archive.
2135

2136
    Args:
2137
        zip_file (str):
2138
            Path of the zip file to be extracted.
2139
        target_dir (str):
2140
            Directory where the contents of the zip file will be extracted.
2141
    """
2142

2143
    zip_file: str
1✔
2144
    target_dir: str
1✔
2145

2146
    def process(self):
1✔
2147
        with zipfile.ZipFile(self.zip_file) as zf:
×
2148
            zf.extractall(self.target_dir)
×
2149

2150

2151
class DuplicateInstances(StreamOperator):
1✔
2152
    """Operator which duplicates each instance in stream a given number of times.
2153

2154
    Args:
2155
        num_duplications (int):
2156
            How many times each instance should be duplicated (1 means no duplication).
2157
        duplication_index_field (Optional[str]):
2158
            If given, then additional field with specified name is added to each duplicated instance,
2159
            which contains id of a given duplication. Defaults to None, so no field is added.
2160
    """
2161

2162
    num_duplications: int
1✔
2163
    duplication_index_field: Optional[str] = None
1✔
2164

2165
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2166
        for instance in stream:
1✔
2167
            for idx in range(self.num_duplications):
1✔
2168
                duplicate = recursive_shallow_copy(instance)
1✔
2169
                if self.duplication_index_field:
1✔
2170
                    duplicate.update({self.duplication_index_field: idx})
1✔
2171
                yield duplicate
1✔
2172

2173
    def verify(self):
1✔
2174
        if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1✔
2175
            raise ValueError(
×
2176
                f"num_duplications must be an integer equal to or greater than 1. "
2177
                f"Got: {self.num_duplications}."
2178
            )
2179

2180
        if self.duplication_index_field is not None and not isinstance(
1✔
2181
            self.duplication_index_field, str
2182
        ):
2183
            raise ValueError(
×
2184
                f"If given, duplication_index_field must be a string. "
2185
                f"Got: {self.duplication_index_field}"
2186
            )
2187

2188

2189
class CollateInstances(StreamOperator):
1✔
2190
    """Operator which collates values from multiple instances to a single instance.
2191

2192
    Each field becomes the list of values of corresponding field of collated `batch_size` of instances.
2193

2194
    Attributes:
2195
        batch_size (int)
2196

2197
    Example:
2198
        .. code-block:: text
2199

2200
            CollateInstances(batch_size=2)
2201

2202
            Given inputs = [
2203
                {"a": 1, "b": 2},
2204
                {"a": 2, "b": 2},
2205
                {"a": 3, "b": 2},
2206
                {"a": 4, "b": 2},
2207
                {"a": 5, "b": 2}
2208
            ]
2209

2210
            Returns targets = [
2211
                {"a": [1,2], "b": [2,2]},
2212
                {"a": [3,4], "b": [2,2]},
2213
                {"a": [5], "b": [2]},
2214
            ]
2215

2216

2217
    """
2218

2219
    batch_size: int
1✔
2220

2221
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2222
        stream = list(stream)
1✔
2223
        for i in range(0, len(stream), self.batch_size):
1✔
2224
            batch = stream[i : i + self.batch_size]
1✔
2225
            new_instance = {}
1✔
2226
            for a_field in batch[0]:
1✔
2227
                if a_field == "data_classification_policy":
1✔
2228
                    flattened_list = [
1✔
2229
                        classification
2230
                        for instance in batch
2231
                        for classification in instance[a_field]
2232
                    ]
2233
                    new_instance[a_field] = sorted(set(flattened_list))
1✔
2234
                else:
2235
                    new_instance[a_field] = [instance[a_field] for instance in batch]
1✔
2236
            yield new_instance
1✔
2237

2238
    def verify(self):
1✔
2239
        if not isinstance(self.batch_size, int) or self.batch_size < 1:
1✔
2240
            raise ValueError(
×
2241
                f"batch_size must be an integer equal to or greater than 1. "
2242
                f"Got: {self.batch_size}."
2243
            )
2244

2245

2246
class WikipediaFetcher(FieldOperator):
1✔
2247
    mode: Literal["summary", "text"] = "text"
1✔
2248
    _requirements_list = ["Wikipedia-API"]
1✔
2249

2250
    def prepare(self):
1✔
2251
        super().prepare()
×
2252
        import wikipediaapi
×
2253

2254
        self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
×
2255

2256
    def process_value(self, value: Any) -> Any:
1✔
2257
        title = value.split("/")[-1]
×
2258
        page = self.wikipedia.page(title)
×
2259

2260
        return {"title": page.title, "body": getattr(page, self.mode)}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc