• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15016414983

14 May 2025 08:50AM UTC coverage: 79.873%. Remained the same
15016414983

Pull #1790

github

web-flow
Merge a0a6799ed into 43086223c
Pull Request #1790: Documenation updates

1648 of 2053 branches covered (80.27%)

Branch coverage included in aggregate %.

10269 of 12867 relevant lines covered (79.81%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.41
src/unitxt/operators.py
1
"""This section describes unitxt operators.
2

3
Operators: Building Blocks of Unitxt Processing Pipelines
4
==============================================================
5

6
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
These operators are callable entities that receive a MultiStream as input.
9
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10

11
Creating Custom Operators
12
-------------------------------
13
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16

17
General or Specialized Operators
18
--------------------------------
19
Some operators are specialized in specific data or specific operations such as:
20

21
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
- :class:`splitters<unitxt.splitters>` for fixing data splits.
23
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
- :class:`string_operators<unitxt.string_operators>` for handling strings.
28
- :class:`span_labeling_operators<unitxt.span_lableing_operators>` for handling strings.
29
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30

31
Other specialized operators are used by unitxt internally:
32

33
- :class:`templates<unitxt.templates>` for verbalizing data examples.
34
- :class:`formats<unitxt.formats>` for preparing data for models.
35

36
The rest of this section is dedicated to general operators.
37

38
General Operators List:
39
------------------------
40
"""
41

42
import operator
1✔
43
import uuid
1✔
44
import warnings
1✔
45
import zipfile
1✔
46
from abc import abstractmethod
1✔
47
from collections import Counter, defaultdict
1✔
48
from dataclasses import field
1✔
49
from itertools import zip_longest
1✔
50
from random import Random
1✔
51
from typing import (
1✔
52
    Any,
53
    Callable,
54
    Dict,
55
    Generator,
56
    Iterable,
57
    List,
58
    Literal,
59
    Optional,
60
    Tuple,
61
    Union,
62
)
63

64
import requests
1✔
65

66
from .artifact import Artifact, fetch_artifact
1✔
67
from .dataclass import NonPositionalField, OptionalField
1✔
68
from .deprecation_utils import deprecation
1✔
69
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
1✔
70
from .error_utils import UnitxtError
1✔
71
from .generator_utils import ReusableGenerator
1✔
72
from .operator import (
1✔
73
    InstanceOperator,
74
    MultiStream,
75
    MultiStreamOperator,
76
    PagedStreamOperator,
77
    SequentialOperator,
78
    SideEffectOperator,
79
    SingleStreamReducer,
80
    SourceOperator,
81
    StreamingOperator,
82
    StreamInitializerOperator,
83
    StreamOperator,
84
)
85
from .random_utils import new_random_generator
1✔
86
from .settings_utils import get_settings
1✔
87
from .stream import DynamicStream, Stream
1✔
88
from .text_utils import nested_tuple_to_string, to_pretty_string
1✔
89
from .type_utils import isoftype
1✔
90
from .utils import (
1✔
91
    LRUCache,
92
    deep_copy,
93
    flatten_dict,
94
    recursive_copy,
95
    recursive_shallow_copy,
96
    shallow_copy,
97
)
98

99
settings = get_settings()
1✔
100

101

102
class FromIterables(StreamInitializerOperator):
1✔
103
    """Creates a MultiStream from a dict of named iterables.
104

105
    Example:
106
        operator = FromIterables()
107
        ms = operator.process(iterables)
108

109
    """
110

111
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
1✔
112
        return MultiStream.from_iterables(iterables)
1✔
113

114

115
class IterableSource(SourceOperator):
1✔
116
    """Creates a MultiStream from a dict of named iterables.
117

118
    It is a callable.
119

120
    Args:
121
        iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
122

123
    Example:
124
        operator =  IterableSource(input_dict)
125
        ms = operator()
126

127
    """
128

129
    iterables: Dict[str, Iterable]
1✔
130

131
    def process(self) -> MultiStream:
1✔
132
        return MultiStream.from_iterables(self.iterables)
1✔
133

134

135
class MapInstanceValues(InstanceOperator):
1✔
136
    """A class used to map instance values into other values.
137

138
    This class is a type of ``InstanceOperator``,
139
    it maps values of instances in a stream using predefined mappers.
140

141
    Args:
142
        mappers (Dict[str, Dict[str, Any]]):
143
            The mappers to use for mapping instance values.
144
            Keys are the names of the fields to undergo mapping, and values are dictionaries
145
            that define the mapping from old values to new values.
146
            Note that mapped values are defined by their string representation, so mapped values
147
            are converted to strings before being looked up in the mappers.
148
        strict (bool):
149
            If True, the mapping is applied strictly. That means if a value
150
            does not exist in the mapper, it will raise a KeyError. If False, values
151
            that are not present in the mapper are kept as they are.
152
        process_every_value (bool):
153
            If True, all fields to be mapped should be lists, and the mapping
154
            is to be applied to their individual elements.
155
            If False, mapping is only applied to a field containing a single value.
156

157
    Examples:
158
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
159
        replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
160
        instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
161
        since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
162
        up in the mapper of ``"a"``.
163

164
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
165
        Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
166
        each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
167
        instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
168

169
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
170
        To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
171
        Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
172
        because ``"3"`` is not a key in the mapper of ``"a"``.
173

174
        ``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
175
        replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
176

177
    """
178

179
    mappers: Dict[str, Dict[str, str]]
1✔
180
    strict: bool = True
1✔
181
    process_every_value: bool = False
1✔
182

183
    def verify(self):
1✔
184
        # make sure the mappers are valid
185
        for key, mapper in self.mappers.items():
1✔
186
            assert isinstance(
1✔
187
                mapper, dict
188
            ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
189
            for k in mapper.keys():
1✔
190
                assert isinstance(
1✔
191
                    k, str
192
                ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
193

194
    def process(
1✔
195
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
196
    ) -> Dict[str, Any]:
197
        for key, mapper in self.mappers.items():
1✔
198
            value = dict_get(instance, key)
1✔
199
            if value is not None:
1✔
200
                if (self.process_every_value is True) and (not isinstance(value, list)):
1✔
201
                    raise ValueError(
1✔
202
                        f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
203
                    )
204
                if isinstance(value, list) and self.process_every_value:
1✔
205
                    for i, val in enumerate(value):
1✔
206
                        value[i] = self.get_mapped_value(instance, key, mapper, val)
1✔
207
                else:
208
                    value = self.get_mapped_value(instance, key, mapper, value)
1✔
209
                dict_set(
1✔
210
                    instance,
211
                    key,
212
                    value,
213
                )
214

215
        return instance
1✔
216

217
    def get_mapped_value(self, instance, key, mapper, val):
1✔
218
        val_as_str = str(val)  # make sure the value is a string
1✔
219
        if val_as_str in mapper:
1✔
220
            return recursive_copy(mapper[val_as_str])
1✔
221
        if self.strict:
1✔
222
            raise KeyError(
1✔
223
                f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
224
            )
225
        return val
1✔
226

227

228
class FlattenInstances(InstanceOperator):
1✔
229
    """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
230

231
    Args:
232
        parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
233
        sep (str): The separator to use when concatenating nested keys. Defaults to "_".
234
    """
235

236
    parent_key: str = ""
1✔
237
    sep: str = "_"
1✔
238

239
    def process(
1✔
240
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
241
    ) -> Dict[str, Any]:
242
        return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
1✔
243

244

245
class Set(InstanceOperator):
1✔
246
    """Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
247

248
    Args:
249
        fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
250

251
        use_deepcopy (bool) : Deep copy the input value to avoid later modifications
252

253
    Examples:
254
        # Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
255
        ``Set(fields={"classes": ["positive","negatives"]})``
256

257
        # In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
258
        ``Set(fields={"span/start": 0}``
259

260
        # In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
261
        ``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
262

263
        # Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
264
        ``Set(fields={"classes": alist}), use_deepcopy=True)``  if now alist is modified, still the instances remain intact.
265
    """
266

267
    fields: Dict[str, object]
1✔
268
    use_query: Optional[bool] = None
1✔
269
    use_deepcopy: bool = False
1✔
270

271
    def verify(self):
1✔
272
        super().verify()
1✔
273
        if self.use_query is not None:
1✔
274
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
275
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
276

277
    def process(
1✔
278
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
279
    ) -> Dict[str, Any]:
280
        for key, value in self.fields.items():
1✔
281
            if self.use_deepcopy:
1✔
282
                value = deep_copy(value)
1✔
283
            dict_set(instance, key, value)
1✔
284
        return instance
1✔
285

286
def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
1✔
287
    """Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
288

289
    Args:
290
        data: The data structure (dict or list) to traverse.
291
        target_key: The specific key whose value needs to be checked and replaced or removed.
292
        value_map: A dictionary mapping old values to new values.
293
        value_remove: A list of values to completely remove if found as values of target_key.
294

295
    Returns:
296
        The modified data structure. Modification is done in-place.
297
    """
298
    if value_remove is None:
1✔
299
        value_remove = []
×
300

301
    if isinstance(data, dict):
1✔
302
        keys_to_delete = []
1✔
303
        for key, value in data.items():
1✔
304
            if key == target_key:
1✔
305
                if isinstance(value, list):
1✔
306
                    data[key] = [
×
307
                        value_map.get(item, item)
308
                        for item in value
309
                        if not isinstance(item, dict) and item not in value_remove
310
                    ]
311
                elif isinstance(value, dict):
1✔
312
                    pass  # Skip or handle dict values if needed
×
313
                elif value in value_remove:
1✔
314
                    keys_to_delete.append(key)
1✔
315
                elif value in value_map:
1✔
316
                    data[key] = value_map[value]
1✔
317
            else:
318
                recursive_key_value_replace(value, target_key, value_map, value_remove)
1✔
319
        for key in keys_to_delete:
1✔
320
            del data[key]
1✔
321
    elif isinstance(data, list):
1✔
322
        for item in data:
1✔
323
            recursive_key_value_replace(item, target_key, value_map, value_remove)
1✔
324
    return data
1✔
325

326
class RecursiveReplace(InstanceOperator):
1✔
327
    # Assisted by watsonx Code Assistant
328
    """An operator to recursively replace values in dictionary fields of instances based on a key and a mapping of values.
329

330
    Attributes:
331
        key (str): The key in the dictionary to start the replacement process.
332
        map_values (dict): A dictionary containing the key-value pairs to replace the original values.
333
        remove_values (Optional[list]): An optional list of values to remove from the dictionary. Defaults to None.
334

335
    Example:
336
    RecursiveReplace(key="a", map_values={"1": "hi", "2": "bye" }, remove_values=["3"])
337
        replaces the value of key "a" in all instances of all streams:
338
        instance ``{"field" : [{"a": "1", "b" : "2"}, {"a" : "3", "b:" "4"}}` becomes ``{"field" : [{"a": "hi", "b" : "2"}, {"b": "4"}}``
339

340
        Notice how the value of field ``"a"`` in the first instance is replaced with ``"hi"`` and the value of field ``"a"`` in the second instance is removed.
341
    """
342
    key: str
1✔
343
    map_values: dict
1✔
344
    remove_values: Optional[list] = None
1✔
345

346
    def process(self, instance: Dict[str, Any], stream_name: Optional[str] = None) -> Dict[str, Any]:
1✔
347
        return recursive_key_value_replace(instance, self.key, self.map_values, self.remove_values)
1✔
348

349
@deprecation(version="2.0.0", alternative=Set)
1✔
350
class AddFields(Set):
1✔
351
    pass
1✔
352

353

354
class RemoveFields(InstanceOperator):
1✔
355
    """Remove specified fields from each instance in a stream.
356

357
    Args:
358
        fields (List[str]): The fields to remove from each instance.
359
    """
360

361
    fields: List[str]
1✔
362

363
    def process(
1✔
364
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
365
    ) -> Dict[str, Any]:
366
        for field_name in self.fields:
1✔
367
            del instance[field_name]
1✔
368
        return instance
1✔
369

370

371
class SelectFields(InstanceOperator):
1✔
372
    """Keep only specified fields from each instance in a stream.
373

374
    Args:
375
        fields (List[str]): The fields to keep from each instance.
376
    """
377

378
    fields: List[str]
1✔
379

380
    def prepare(self):
1✔
381
        super().prepare()
1✔
382
        self.fields.extend(["data_classification_policy", "recipe_metadata"])
1✔
383

384
    def process(
1✔
385
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
386
    ) -> Dict[str, Any]:
387
        new_instance = {}
1✔
388
        for selected_field in self.fields:
1✔
389
            new_instance[selected_field] = instance[selected_field]
1✔
390
        return new_instance
1✔
391

392

393
class DefaultPlaceHolder:
1✔
394
    pass
1✔
395

396

397
default_place_holder = DefaultPlaceHolder()
1✔
398

399

400
class InstanceFieldOperator(InstanceOperator):
1✔
401
    """A general stream instance operator that processes the values of a field (or multiple ones).
402

403
    Args:
404
        field (Optional[str]):
405
            The field to process, if only a single one is passed. Defaults to None
406
        to_field (Optional[str]):
407
            Field name to save result into, if only one field is processed, if None is passed the
408
            operation would happen in-place and its result would replace the value of ``field``. Defaults to None
409
        field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
410
            Mapping from names of fields to process,
411
            to names of fields to save the results into. Inner List, if used, should be of length 2.
412
            A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
413
            is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
414
            in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
415
            order. The end result might depend on that order if either (1) two different fields are mapped to the same
416
            to_field, or (2) a field shows both as a key and as a value in different mappings.
417
            The operator throws an AssertionError in either of these cases. ``field_to_field``
418
            defaults to None.
419
        process_every_value (bool):
420
            Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
421

422
    Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
423
    prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
424

425
    """
426

427
    field: Optional[str] = None
1✔
428
    to_field: Optional[str] = None
1✔
429
    field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
1✔
430
    use_query: Optional[bool] = None
1✔
431
    process_every_value: bool = False
1✔
432
    get_default: Any = None
1✔
433
    not_exist_ok: bool = False
1✔
434
    not_exist_do_nothing: bool = False
1✔
435

436
    def verify(self):
1✔
437
        super().verify()
1✔
438
        if self.use_query is not None:
1✔
439
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
1✔
440
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
441

442
    def verify_field_definition(self):
1✔
443
        if hasattr(self, "_field_to_field") and self._field_to_field is not None:
1✔
444
            return
1✔
445
        assert (
1✔
446
            (self.field is None) != (self.field_to_field is None)
447
        ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
448
        assert (
1✔
449
            self.to_field is None or self.field_to_field is None
450
        ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
451

452
        if self.field_to_field is None:
1✔
453
            self._field_to_field = [
1✔
454
                (self.field, self.to_field if self.to_field is not None else self.field)
455
            ]
456
        else:
457
            self._field_to_field = (
1✔
458
                list(self.field_to_field.items())
459
                if isinstance(self.field_to_field, dict)
460
                else self.field_to_field
461
            )
462
        assert (
1✔
463
            self.field is not None or self.field_to_field is not None
464
        ), "Must supply a field to work on"
465
        assert (
1✔
466
            self.to_field is None or self.field_to_field is None
467
        ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
468
        assert (
1✔
469
            self.field is None or self.field_to_field is None
470
        ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
471
        assert (
1✔
472
            self._field_to_field is not None
473
        ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
474
        assert (
1✔
475
            len(self._field_to_field) > 0
476
        ), f"'input argument '{self.__class__.__name__}.field_to_field' should convey at least one field to process. Got {self.field_to_field}"
477
        # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
478
        if self.field_to_field is None:
1✔
479
            return
1✔
480
        # for backward compatibility also allow list of tuples of two strings
481
        if isoftype(self.field_to_field, List[List[str]]) or isoftype(
1✔
482
            self.field_to_field, List[Tuple[str, str]]
483
        ):
484
            for pair in self._field_to_field:
1✔
485
                assert (
1✔
486
                    len(pair) == 2
487
                ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
488
            # order of field processing is uniquely determined by the input field_to_field when a list
489
            return
1✔
490
        if isoftype(self.field_to_field, Dict[str, str]):
1✔
491
            if len(self.field_to_field) < 2:
1✔
492
                return
1✔
493
            for ff, tt in self.field_to_field.items():
1✔
494
                for f, t in self.field_to_field.items():
1✔
495
                    if f == ff:
1✔
496
                        continue
1✔
497
                    assert (
1✔
498
                        t != ff
499
                    ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
500
                    assert (
1✔
501
                        tt != t
502
                    ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
503
            return
1✔
504
        raise ValueError(
1✔
505
            "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
506
        )
507

508
    @abstractmethod
1✔
509
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
510
        pass
×
511

512
    def process(
1✔
513
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
514
    ) -> Dict[str, Any]:
515
        self.verify_field_definition()
1✔
516
        for from_field, to_field in self._field_to_field:
1✔
517
            try:
1✔
518
                old_value = dict_get(
1✔
519
                    instance,
520
                    from_field,
521
                    default=default_place_holder,
522
                    not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
523
                )
524
                if old_value is default_place_holder:
1✔
525
                    if self.not_exist_do_nothing:
1✔
526
                        continue
1✔
527
                    old_value = self.get_default
×
528
            except Exception as e:
1✔
529
                raise ValueError(
1✔
530
                    f"Failed to get '{from_field}' from instance due to the exception above."
531
                ) from e
532
            try:
1✔
533
                if self.process_every_value:
1✔
534
                    new_value = [
1✔
535
                        self.process_instance_value(value, instance)
536
                        for value in old_value
537
                    ]
538
                else:
539
                    new_value = self.process_instance_value(old_value, instance)
1✔
540
            except Exception as e:
1✔
541
                raise ValueError(
1✔
542
                    f"Failed to process field '{from_field}' from instance due to the exception above."
543
                ) from e
544
            dict_set(
1✔
545
                instance,
546
                to_field,
547
                new_value,
548
                not_exist_ok=True,
549
            )
550
        return instance
1✔
551

552

553
class FieldOperator(InstanceFieldOperator):
1✔
554
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
555
        return self.process_value(value)
1✔
556

557
    @abstractmethod
1✔
558
    def process_value(self, value: Any) -> Any:
1✔
559
        pass
1✔
560

561

562
class MapValues(FieldOperator):
1✔
563
    mapping: Dict[str, str]
1✔
564

565
    def process_value(self, value: Any) -> Any:
1✔
566
        return self.mapping[str(value)]
×
567

568

569
class Rename(FieldOperator):
1✔
570
    """Renames fields.
571

572
    Move value from one field to another, potentially, if field name contains a /, from one branch into another.
573
    Remove the from field, potentially part of it in case of / in from_field.
574

575
    Examples:
576
        Rename(field_to_field={"b": "c"})
577
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
578

579
        Rename(field_to_field={"b": "c/d"})
580
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
581

582
        Rename(field_to_field={"b": "b/d"})
583
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
584

585
        Rename(field_to_field={"b/c/e": "b/d"})
586
        will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
587

588
    """
589

590
    def process_value(self, value: Any) -> Any:
1✔
591
        return value
1✔
592

593
    def process(
1✔
594
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
595
    ) -> Dict[str, Any]:
596
        res = super().process(instance=instance, stream_name=stream_name)
1✔
597
        for from_field, to_field in self._field_to_field:
1✔
598
            if (not is_subpath(from_field, to_field)) and (
1✔
599
                not is_subpath(to_field, from_field)
600
            ):
601
                dict_delete(res, from_field, remove_empty_ancestors=True)
1✔
602

603
        return res
1✔
604

605

606
@deprecation(version="2.0.0", alternative=Rename)
1✔
607
class RenameFields(Rename):
1✔
608
    pass
1✔
609

610

611
class AddConstant(FieldOperator):
1✔
612
    """Adds a constant, being argument 'add', to the processed value.
613

614
    Args:
615
        add: the constant to add.
616
    """
617

618
    add: Any
1✔
619

620
    def process_value(self, value: Any) -> Any:
1✔
621
        return self.add + value
1✔
622

623
class ShuffleFieldValues(FieldOperator):
1✔
624
    # Assisted by watsonx Code Assistant
625
    """An operator that shuffles the values of a list field.
626

627
    the seed for shuffling in the is determined by the elements of the input field,
628
    ensuring that the shuffling operation produces different results for different input lists,
629
    but also that it is deterministic and reproducible.
630

631
    Attributes:
632
        None
633

634
    Methods:
635
        process_value(value: Any) -> Any:
636
            Shuffles the elements of the input list and returns the shuffled list.
637

638
            Parameters:
639
                value (Any): The input list to be shuffled.
640

641
    Returns:
642
                Any: The shuffled list.
643
    """
644
    def process_value(self, value: Any) -> Any:
1✔
645
        res = list(value)
1✔
646
        random_generator = new_random_generator(sub_seed=res)
1✔
647
        random_generator.shuffle(res)
1✔
648
        return res
1✔
649

650

651
class JoinStr(FieldOperator):
1✔
652
    """Joins a list of strings (contents of a field), similar to str.join().
653

654
    Args:
655
        separator (str): text to put between values
656
    """
657

658
    separator: str = ","
1✔
659

660
    def process_value(self, value: Any) -> Any:
1✔
661
        return self.separator.join(str(x) for x in value)
1✔
662

663

664
class Apply(InstanceOperator):
1✔
665
    """A class used to apply a python function and store the result in a field.
666

667
    Args:
668
        function (str): name of function.
669
        to_field (str): the field to store the result
670

671
    any additional arguments are field names whose values will be passed directly to the function specified
672

673
    Examples:
674
    Store in field  "b" the uppercase string of the value in field "a":
675
    ``Apply("a", function=str.upper, to_field="b")``
676

677
    Dump the json representation of field "t" and store back in the same field:
678
    ``Apply("t", function=json.dumps, to_field="t")``
679

680
    Set the time in a field 'b':
681
    ``Apply(function=time.time, to_field="b")``
682

683
    """
684

685
    __allow_unexpected_arguments__ = True
1✔
686
    function: Callable = NonPositionalField(required=True)
1✔
687
    to_field: str = NonPositionalField(required=True)
1✔
688

689
    def function_to_str(self, function: Callable) -> str:
1✔
690
        parts = []
1✔
691

692
        if hasattr(function, "__module__"):
1✔
693
            parts.append(function.__module__)
1✔
694
        if hasattr(function, "__qualname__"):
1✔
695
            parts.append(function.__qualname__)
1✔
696
        else:
697
            parts.append(function.__name__)
×
698

699
        return ".".join(parts)
1✔
700

701
    def str_to_function(self, function_str: str) -> Callable:
1✔
702
        parts = function_str.split(".", 1)
1✔
703
        if len(parts) == 1:
1✔
704
            return __builtins__[parts[0]]
1✔
705

706
        module_name, function_name = parts
1✔
707
        if module_name in __builtins__:
1✔
708
            obj = __builtins__[module_name]
1✔
709
        elif module_name in globals():
1✔
710
            obj = globals()[module_name]
×
711
        else:
712
            obj = __import__(module_name)
1✔
713
        for part in function_name.split("."):
1✔
714
            obj = getattr(obj, part)
1✔
715
        return obj
1✔
716

717
    def prepare(self):
1✔
718
        super().prepare()
1✔
719
        if isinstance(self.function, str):
1✔
720
            self.function = self.str_to_function(self.function)
1✔
721
        self._init_dict["function"] = self.function_to_str(self.function)
1✔
722

723
    def process(
1✔
724
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
725
    ) -> Dict[str, Any]:
726
        argv = [instance[arg] for arg in self._argv]
1✔
727
        kwargs = {key: instance[val] for key, val in self._kwargs}
1✔
728

729
        result = self.function(*argv, **kwargs)
1✔
730

731
        instance[self.to_field] = result
1✔
732
        return instance
1✔
733

734

735
class ListFieldValues(InstanceOperator):
1✔
736
    """Concatenates values of multiple fields into a list, and assigns it to a new field."""
737

738
    fields: List[str]
1✔
739
    to_field: str
1✔
740
    use_query: Optional[bool] = None
1✔
741

742
    def verify(self):
1✔
743
        super().verify()
1✔
744
        if self.use_query is not None:
1✔
745
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
746
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
747

748
    def process(
1✔
749
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
750
    ) -> Dict[str, Any]:
751
        values = []
1✔
752
        for field_name in self.fields:
1✔
753
            values.append(dict_get(instance, field_name))
1✔
754

755
        dict_set(instance, self.to_field, values)
1✔
756

757
        return instance
1✔
758

759

760
class ZipFieldValues(InstanceOperator):
1✔
761
    """Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
762

763
    The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
764
    are zipped, and stored into 'to_field'.
765

766
    | If 'longest'=False, the length of the zipped result is determined by the shortest input value.
767
    | If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
768

769
    """
770

771
    fields: List[str]
1✔
772
    to_field: str
1✔
773
    longest: bool = False
1✔
774
    use_query: Optional[bool] = None
1✔
775

776
    def verify(self):
1✔
777
        super().verify()
1✔
778
        if self.use_query is not None:
1✔
779
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
780
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
781

782
    def process(
1✔
783
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
784
    ) -> Dict[str, Any]:
785
        values = []
1✔
786
        for field_name in self.fields:
1✔
787
            values.append(dict_get(instance, field_name))
1✔
788
        if self.longest:
1✔
789
            zipped = zip_longest(*values)
1✔
790
        else:
791
            zipped = zip(*values)
1✔
792
        dict_set(instance, self.to_field, list(zipped))
1✔
793
        return instance
1✔
794

795

796
class InterleaveListsToDialogOperator(InstanceOperator):
1✔
797
    """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
798

799
    The list of tuples if of format (role, turn_content), where the role label is specified by
800
    the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
801

802
    The user turns and assistant turns field are specified in the arguments.
803
    The value of each of the 'fields' is assumed to be a list.
804

805
    """
806

807
    user_turns_field: str
1✔
808
    assistant_turns_field: str
1✔
809
    user_role_label: str = "user"
1✔
810
    assistant_role_label: str = "assistant"
1✔
811
    to_field: str
1✔
812

813
    def process(
1✔
814
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
815
    ) -> Dict[str, Any]:
816
        user_turns = instance[self.user_turns_field]
×
817
        assistant_turns = instance[self.assistant_turns_field]
×
818

819
        assert (
×
820
            len(user_turns) == len(assistant_turns)
821
            or (len(user_turns) - len(assistant_turns) == 1)
822
        ), "user_turns must have either the same length as assistant_turns or one more turn."
823

824
        interleaved_dialog = []
×
825
        i, j = 0, 0  # Indices for the user and assistant lists
×
826
        # While either list has elements left, continue interleaving
827
        while i < len(user_turns) or j < len(assistant_turns):
×
828
            if i < len(user_turns):
×
829
                interleaved_dialog.append((self.user_role_label, user_turns[i]))
×
830
                i += 1
×
831
            if j < len(assistant_turns):
×
832
                interleaved_dialog.append(
×
833
                    (self.assistant_role_label, assistant_turns[j])
834
                )
835
                j += 1
×
836

837
        instance[self.to_field] = interleaved_dialog
×
838
        return instance
×
839

840

841
class IndexOf(InstanceOperator):
1✔
842
    """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
843

844
    search_in: str
1✔
845
    index_of: str
1✔
846
    to_field: str
1✔
847
    use_query: Optional[bool] = None
1✔
848

849
    def verify(self):
1✔
850
        super().verify()
1✔
851
        if self.use_query is not None:
1✔
852
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
853
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
854

855
    def process(
1✔
856
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
857
    ) -> Dict[str, Any]:
858
        lst = dict_get(instance, self.search_in)
1✔
859
        item = dict_get(instance, self.index_of)
1✔
860
        instance[self.to_field] = lst.index(item)
1✔
861
        return instance
1✔
862

863

864
class TakeByField(InstanceOperator):
1✔
865
    """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
866

867
    field: str
1✔
868
    index: str
1✔
869
    to_field: str = None
1✔
870
    use_query: Optional[bool] = None
1✔
871

872
    def verify(self):
1✔
873
        super().verify()
1✔
874
        if self.use_query is not None:
1✔
875
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
876
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
877

878
    def prepare(self):
1✔
879
        if self.to_field is None:
1✔
880
            self.to_field = self.field
1✔
881

882
    def process(
1✔
883
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
884
    ) -> Dict[str, Any]:
885
        value = dict_get(instance, self.field)
1✔
886
        index_value = dict_get(instance, self.index)
1✔
887
        instance[self.to_field] = value[index_value]
1✔
888
        return instance
1✔
889

890

891
class Perturb(FieldOperator):
1✔
892
    """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
893

894
    When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
895
    relevant perturbation
896

897
    Args:
898
        percentage_to_perturb (int):
899
            the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
900
        select_from: List[Any]:
901
            a list of values to select from, as a perturbation of the field's value. Defaults to [].
902
    """
903

904
    select_from: List[Any] = []
1✔
905
    percentage_to_perturb: int = 1  # 1 percent
1✔
906

907
    def verify(self):
1✔
908
        assert (
1✔
909
            0 <= self.percentage_to_perturb and self.percentage_to_perturb <= 100
910
        ), f"'percentage_to_perturb' should be in the range 0..100. Received {self.percentage_to_perturb}"
911

912
    def prepare(self):
1✔
913
        super().prepare()
1✔
914
        self.random_generator = new_random_generator(sub_seed="CopyWithPerturbation")
1✔
915

916
    def process_value(self, value: Any) -> Any:
1✔
917
        perturb = self.random_generator.randint(1, 100) <= self.percentage_to_perturb
1✔
918
        if not perturb:
1✔
919
            return value
1✔
920

921
        if value in self.select_from:
1✔
922
            # 80% of cases, return a decent class, otherwise, perturb the value itself as follows
923
            if self.random_generator.random() < 0.8:
1✔
924
                return self.random_generator.choice(self.select_from)
1✔
925

926
        if isinstance(value, float):
1✔
927
            return value * (0.5 + self.random_generator.random())
1✔
928

929
        if isinstance(value, int):
1✔
930
            perturb = 1 if self.random_generator.random() < 0.5 else -1
1✔
931
            return value + perturb
1✔
932

933
        if isinstance(value, str):
1✔
934
            if len(value) < 2:
1✔
935
                # give up perturbation
936
                return value
1✔
937
            # throw one char out
938
            prefix_len = self.random_generator.randint(1, len(value) - 1)
1✔
939
            return value[:prefix_len] + value[prefix_len + 1 :]
1✔
940

941
        # and in any other case:
942
        return value
×
943

944

945
class Copy(FieldOperator):
1✔
946
    """Copies values from specified fields to specified fields.
947

948
    Args (of parent class):
949
        field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
950

951
    Examples:
952
        An input instance {"a": 2, "b": 3}, when processed by
953
        ``Copy(field_to_field={"a": "b"})``
954
        would yield {"a": 2, "b": 2}, and when processed by
955
        ``Copy(field_to_field={"a": "c"})`` would yield
956
        {"a": 2, "b": 3, "c": 2}
957

958
        with field names containing / , we can also copy inside the field:
959
        ``Copy(field="a/0",to_field="a")``
960
        would process instance {"a": [1, 3]} into {"a": 1}
961

962

963
    """
964

965
    def process_value(self, value: Any) -> Any:
1✔
966
        return value
1✔
967

968

969
class RecursiveCopy(FieldOperator):
1✔
970
    def process_value(self, value: Any) -> Any:
1✔
971
        return recursive_copy(value)
1✔
972

973

974
@deprecation(version="2.0.0", alternative=Copy)
1✔
975
class CopyFields(Copy):
1✔
976
    pass
1✔
977

978

979
class GetItemByIndex(FieldOperator):
1✔
980
    """Get the element from the fixed list by the index in the given field and store in another field.
981

982
    Example:
983
        GetItemByIndex(items_list=["dog",cat"],field="animal_index",to_field="animal")
984

985
    on instance {"animal_index" : 1}  will change the instance to {"animal_index" : 1, "animal" : "cat"}
986

987
    """
988

989
    items_list: List[Any]
1✔
990

991
    def process_value(self, value: Any) -> Any:
1✔
992
        return self.items_list[value]
×
993

994

995
class AddID(InstanceOperator):
1✔
996
    """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
997

998
    id_field_name: str = "id"
1✔
999

1000
    def process(
1✔
1001
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1002
    ) -> Dict[str, Any]:
1003
        instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
1✔
1004
        return instance
1✔
1005

1006

1007
class Cast(FieldOperator):
1✔
1008
    """Casts specified fields to specified types.
1009

1010
    Args:
1011
        default (object): A dictionary mapping field names to default values for cases of casting failure.
1012
        process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
1013
    """
1014

1015
    to: str
1✔
1016
    failure_default: Optional[Any] = "__UNDEFINED__"
1✔
1017

1018
    def prepare(self):
1✔
1019
        self.types = {"int": int, "float": float, "str": str, "bool": bool, "tuple": tuple}
1✔
1020

1021
    def process_value(self, value):
1✔
1022
        try:
1✔
1023
            return self.types[self.to](value)
1✔
1024
        except ValueError as e:
1✔
1025
            if self.failure_default == "__UNDEFINED__":
1✔
1026
                raise ValueError(
×
1027
                    f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
1028
                ) from e
1029
            return self.failure_default
1✔
1030

1031

1032
class CastFields(InstanceOperator):
1✔
1033
    """Casts specified fields to specified types.
1034

1035
    Args:
1036
        fields (Dict[str, str]):
1037
            A dictionary mapping field names to the names of the types to cast the fields to.
1038
            e.g: "int", "str", "float", "bool". Basic names of types
1039
        defaults (Dict[str, object]):
1040
            A dictionary mapping field names to default values for cases of casting failure.
1041
        process_every_value (bool):
1042
            If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
1043

1044
    Example:
1045
        .. code-block:: python
1046

1047
                CastFields(
1048
                    fields={"a/d": "float", "b": "int"},
1049
                    failure_defaults={"a/d": 0.0, "b": 0},
1050
                    process_every_value=True,
1051
                )
1052

1053
    would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
1054
    into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
1055

1056
    """
1057

1058
    fields: Dict[str, str] = field(default_factory=dict)
1✔
1059
    failure_defaults: Dict[str, object] = field(default_factory=dict)
1✔
1060
    use_nested_query: bool = None  # deprecated field
1✔
1061
    process_every_value: bool = False
1✔
1062

1063
    def prepare(self):
1✔
1064
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
1065

1066
    def verify(self):
1✔
1067
        super().verify()
1✔
1068
        if self.use_nested_query is not None:
1✔
1069
            depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
1✔
1070
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
1071

1072
    def _cast_single(self, value, type, field):
1✔
1073
        try:
1✔
1074
            return self.types[type](value)
1✔
1075
        except Exception as e:
1✔
1076
            if field not in self.failure_defaults:
1✔
1077
                raise ValueError(
1✔
1078
                    f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
1079
                ) from e
1080
            return self.failure_defaults[field]
1✔
1081

1082
    def _cast_multiple(self, values, type, field):
1✔
1083
        return [self._cast_single(value, type, field) for value in values]
1✔
1084

1085
    def process(
1✔
1086
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1087
    ) -> Dict[str, Any]:
1088
        for field_name, type in self.fields.items():
1✔
1089
            value = dict_get(instance, field_name)
1✔
1090
            if self.process_every_value:
1✔
1091
                assert isinstance(
1✔
1092
                    value, list
1093
                ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
1094
                casted_value = self._cast_multiple(value, type, field_name)
1✔
1095
            else:
1096
                casted_value = self._cast_single(value, type, field_name)
1✔
1097

1098
            dict_set(instance, field_name, casted_value)
1✔
1099
        return instance
1✔
1100

1101

1102
class DivideAllFieldsBy(InstanceOperator):
1✔
1103
    """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1104

1105
    The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
1106
    the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
1107
    if input flag 'strict' is True, or -- left alone, if 'strict' is False.
1108

1109
    Args:
1110
        divisor (float) the value to divide by
1111
        strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
1112

1113
    Example:
1114
        when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
1115
        operator = DivideAllFieldsBy(divisor=2.0)
1116
        the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
1117
        If the operator were defined with strict=True, through:
1118
        operator = DivideAllFieldsBy(divisor=2.0, strict=True),
1119
        the processing of the above instance would raise a ValueError, for the integer at "c".
1120
    """
1121

1122
    divisor: float = 1.0
1✔
1123
    strict: bool = False
1✔
1124

1125
    def _recursive_divide(self, instance, divisor):
1✔
1126
        if isinstance(instance, dict):
1✔
1127
            for key, value in instance.items():
1✔
1128
                instance[key] = self._recursive_divide(value, divisor)
1✔
1129
        elif isinstance(instance, list):
1✔
1130
            for i, value in enumerate(instance):
1✔
1131
                instance[i] = self._recursive_divide(value, divisor)
1✔
1132
        elif isinstance(instance, float):
1✔
1133
            instance /= divisor
1✔
1134
        elif self.strict:
1✔
1135
            raise ValueError(f"Cannot divide instance of type {type(instance)}")
1✔
1136
        return instance
1✔
1137

1138
    def process(
1✔
1139
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1140
    ) -> Dict[str, Any]:
1141
        return self._recursive_divide(instance, self.divisor)
1✔
1142

1143

1144
class ArtifactFetcherMixin:
1✔
1145
    """Provides a way to fetch and cache artifacts in the system.
1146

1147
    Args:
1148
        cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
1149
    """
1150

1151
    _artifacts_cache = LRUCache(max_size=1000)
1✔
1152

1153
    @classmethod
1✔
1154
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
1✔
1155
        if str(artifact_identifier) not in cls._artifacts_cache:
1✔
1156
            artifact, catalog = fetch_artifact(artifact_identifier)
1✔
1157
            cls._artifacts_cache[str(artifact_identifier)] = artifact
1✔
1158
        return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])
1✔
1159

1160

1161
class ApplyOperatorsField(InstanceOperator):
1✔
1162
    """Applies value operators to each instance in a stream based on specified fields.
1163

1164
    Args:
1165
        operators_field (str): name of the field that contains a single name, or a list of names, of the operators to be applied,
1166
            one after the other, for the processing of the instance. Each operator is equipped with 'process_instance()'
1167
            method.
1168

1169
        default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
1170

1171
    Example:
1172
        when instance {"prediction": 111, "references": [222, 333] , "c": ["processors.to_string", "processors.first_character"]}
1173
        is processed by operator (please look up the catalog that these operators, they are tuned to process fields "prediction" and
1174
        "references"):
1175
        operator = ApplyOperatorsField(operators_field="c"),
1176
        the resulting instance is: {"prediction": "1", "references": ["2", "3"], "c": ["processors.to_string", "processors.first_character"]}
1177

1178
    """
1179

1180
    operators_field: str
1✔
1181
    default_operators: List[str] = None
1✔
1182

1183
    def process(
1✔
1184
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1185
    ) -> Dict[str, Any]:
1186
        operator_names = instance.get(self.operators_field)
1✔
1187
        if operator_names is None:
1✔
1188
            assert (
1✔
1189
                self.default_operators is not None
1190
            ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1191
            operator_names = self.default_operators
1✔
1192

1193
        if isinstance(operator_names, str):
1✔
1194
            operator_names = [operator_names]
1✔
1195
        # otherwise , operator_names is already a list
1196

1197
        # we now have a list of nanes of operators, each is equipped with process_instance method.
1198
        operator = SequentialOperator(steps=operator_names)
1✔
1199
        return operator.process_instance(instance, stream_name=stream_name)
1✔
1200

1201

1202
class FilterByCondition(StreamOperator):
1✔
1203
    """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1204

1205
    Raises an error if a required field name is missing from the input instance.
1206

1207
    Args:
1208
       values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
1209

1210
       condition: the name of the desired condition operator between the specified (sub) field's value  and the provided constant value.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1211

1212
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1213

1214
    Examples:
1215
       | ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
1216
       | ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
1217
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
1218
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
1219
       | ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
1220
       | ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
1221

1222

1223
    """
1224

1225
    values: Dict[str, Any]
1✔
1226
    condition: str
1✔
1227
    condition_to_func = {
1✔
1228
        "gt": operator.gt,
1229
        "ge": operator.ge,
1230
        "lt": operator.lt,
1231
        "le": operator.le,
1232
        "eq": operator.eq,
1233
        "ne": operator.ne,
1234
        "in": None,  # Handled as special case
1235
        "not in": None,  # Handled as special case
1236
    }
1237
    error_on_filtered_all: bool = True
1✔
1238

1239
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1240
        yielded = False
1✔
1241
        for instance in stream:
1✔
1242
            if self._is_required(instance):
1✔
1243
                yielded = True
1✔
1244
                yield instance
1✔
1245

1246
        if not yielded and self.error_on_filtered_all:
1✔
1247
            raise RuntimeError(
1✔
1248
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1249
            )
1250

1251
    def verify(self):
1✔
1252
        if self.condition not in self.condition_to_func:
1✔
1253
            raise ValueError(
1✔
1254
                f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1255
            )
1256

1257
        for key, value in self.values.items():
1✔
1258
            if self.condition in ["in", "not it"] and not isinstance(value, list):
1✔
1259
                raise ValueError(
1✔
1260
                    f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1261
                )
1262
        return super().verify()
1✔
1263

1264
    def _is_required(self, instance: dict) -> bool:
1✔
1265
        for key, value in self.values.items():
1✔
1266
            try:
1✔
1267
                instance_key = dict_get(instance, key)
1✔
1268
            except ValueError as ve:
1✔
1269
                raise ValueError(
1✔
1270
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1271
                ) from ve
1272
            if self.condition == "in":
1✔
1273
                if instance_key not in value:
1✔
1274
                    return False
1✔
1275
            elif self.condition == "not in":
1✔
1276
                if instance_key in value:
1✔
1277
                    return False
1✔
1278
            else:
1279
                func = self.condition_to_func[self.condition]
1✔
1280
                if func is None:
1✔
1281
                    raise ValueError(
×
1282
                        f"Function not defined for condition '{self.condition}'"
1283
                    )
1284
                if not func(instance_key, value):
1✔
1285
                    return False
1✔
1286
        return True
1✔
1287

1288

1289
class FilterByConditionBasedOnFields(FilterByCondition):
1✔
1290
    """Filters a stream based on a condition between 2 fields values.
1291

1292
    Raises an error if either of the required fields names is missing from the input instance.
1293

1294
    Args:
1295
       values (Dict[str, str]): The fields names that the filter operation is based on.
1296
       condition: the name of the desired condition operator between the specified field's values.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1297
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1298

1299
    Examples:
1300
       FilterByCondition(values = {"a":"b}, condition = "gt") will yield only instances where field "a" contains a value greater then the value in field "b".
1301
       FilterByCondition(values = {"a":"b}, condition = "le") will yield only instances where "a"<="b"
1302
    """
1303

1304
    def _is_required(self, instance: dict) -> bool:
1✔
1305
        for key, value in self.values.items():
1✔
1306
            try:
1✔
1307
                instance_key = dict_get(instance, key)
1✔
1308
            except ValueError as ve:
×
1309
                raise ValueError(
×
1310
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1311
                ) from ve
1312
            try:
1✔
1313
                instance_value = dict_get(instance, value)
1✔
1314
            except ValueError as ve:
×
1315
                raise ValueError(
×
1316
                    f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1317
                ) from ve
1318
            if self.condition == "in":
1✔
1319
                if instance_key not in instance_value:
×
1320
                    return False
×
1321
            elif self.condition == "not in":
1✔
1322
                if instance_key in instance_value:
×
1323
                    return False
×
1324
            else:
1325
                func = self.condition_to_func[self.condition]
1✔
1326
                if func is None:
1✔
1327
                    raise ValueError(
×
1328
                        f"Function not defined for condition '{self.condition}'"
1329
                    )
1330
                if not func(instance_key, instance_value):
1✔
1331
                    return False
×
1332
        return True
1✔
1333

1334

1335
class ComputeExpressionMixin(Artifact):
1✔
1336
    """Computes an expression expressed over fields of an instance.
1337

1338
    Args:
1339
        expression (str): the expression, in terms of names of fields of an instance
1340
        imports_list (List[str]): list of names of imports needed for the evaluation of the expression
1341
    """
1342

1343
    expression: str
1✔
1344
    imports_list: List[str] = OptionalField(default_factory=list)
1✔
1345

1346
    def prepare(self):
1✔
1347
        # can not do the imports here, because object does not pickle with imports
1348
        self.globals = {
1✔
1349
            module_name: __import__(module_name) for module_name in self.imports_list
1350
        }
1351

1352
    def compute_expression(self, instance: dict) -> Any:
1✔
1353
        if settings.allow_unverified_code:
1✔
1354
            return eval(self.expression, {**self.globals, **instance})
1✔
1355

1356
        raise ValueError(
×
1357
            f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1358
            "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1359
        )
1360

1361

1362
class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1✔
1363
    """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1364

1365
    Raises an error if a field participating in the specified condition is missing from the instance
1366

1367
    Args:
1368
        expression (str):
1369
            a condition over fields of the instance, to be processed by python's eval()
1370
        imports_list (List[str]):
1371
            names of imports needed for the eval of the query (e.g. 're', 'json')
1372
        error_on_filtered_all (bool, optional):
1373
            If True, raises an error if all instances are filtered out. Defaults to True.
1374

1375
    Examples:
1376
        | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1377
        | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1378
        | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1379
        | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1380
        | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1381
    """
1382

1383
    error_on_filtered_all: bool = True
1✔
1384

1385
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1386
        yielded = False
1✔
1387
        for instance in stream:
1✔
1388
            if self.compute_expression(instance):
1✔
1389
                yielded = True
1✔
1390
                yield instance
1✔
1391

1392
        if not yielded and self.error_on_filtered_all:
1✔
1393
            raise RuntimeError(
1✔
1394
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1395
            )
1396

1397

1398
class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1✔
1399
    """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1400

1401
    Raises an error if a field mentioned in the query is missing from the instance.
1402

1403
    Args:
1404
       expression (str): an expression to be evaluated over the fields of the instance
1405
       to_field (str): the field where the result is to be stored into
1406
       imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1407

1408
    Examples:
1409
       When instance {"a": 2, "b": 3} is process-ed by operator
1410
       ExecuteExpression(expression="a+b", to_field = "c")
1411
       the result is {"a": 2, "b": 3, "c": 5}
1412

1413
       When instance {"a": "hello", "b": "world"} is process-ed by operator
1414
       ExecuteExpression(expression = "a+' '+b", to_field = "c")
1415
       the result is {"a": "hello", "b": "world", "c": "hello world"}
1416

1417
    """
1418

1419
    to_field: str
1✔
1420

1421
    def process(
1✔
1422
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1423
    ) -> Dict[str, Any]:
1424
        instance[self.to_field] = self.compute_expression(instance)
1✔
1425
        return instance
1✔
1426

1427

1428
class ExtractMostCommonFieldValues(MultiStreamOperator):
1✔
1429
    field: str
1✔
1430
    stream_name: str
1✔
1431
    overall_top_frequency_percent: Optional[int] = 100
1✔
1432
    min_frequency_percent: Optional[int] = 0
1✔
1433
    to_field: str
1✔
1434
    process_every_value: Optional[bool] = False
1✔
1435

1436
    """
1437
    Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
1438
    as a list in a new field ('to_field') in all streams.
1439

1440
    More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
1441
    When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
1442
    the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
1443
    When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
1444
    less than 'min_frequency_percent' of the total number of instances in the stream.
1445
    At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
1446

1447
    Examples:
1448

1449
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1450
    field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1451
    every instance in all streams.
1452

1453
    ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1454
    in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1455
    value members in these lists, and report the most frequent values.
1456
    if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1457
    'to_field' of each instance of all streams.
1458

1459
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1460
    extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1461
    and stores them in field 'classes' of each instance of all streams.
1462

1463
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1464
    extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1465
    Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1466
    """
1467

1468
    def verify(self):
1✔
1469
        assert (
1✔
1470
            self.overall_top_frequency_percent <= 100
1471
            and self.overall_top_frequency_percent >= 0
1472
        ), "'overall_top_frequency_percent' must be between 0 and 100"
1473
        assert (
1✔
1474
            self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
1475
        ), "'min_frequency_percent' must be between 0 and 100"
1476
        assert not (
1✔
1477
            self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
1478
        ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
1479
        super().verify()
1✔
1480

1481
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1482
        stream = multi_stream[self.stream_name]
1✔
1483
        counter = Counter()
1✔
1484
        for instance in stream:
1✔
1485
            if (not isinstance(instance[self.field], list)) and (
1✔
1486
                self.process_every_value is True
1487
            ):
1488
                raise ValueError(
1✔
1489
                    "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
1490
                )
1491
            if (not isinstance(instance[self.field], list)) or (
1✔
1492
                self.process_every_value is False
1493
            ):
1494
                # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
1495
                counter.update(
1✔
1496
                    [(*instance[self.field],)]
1497
                    if isinstance(instance[self.field], list)
1498
                    else [instance[self.field]]
1499
                )  # convert to a tuple if list, to enable the use of Counter which would not accept
1500
                # a list as an hashable entity to count its occurrences
1501
            else:
1502
                # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
1503
                counter.update(instance[self.field])
1✔
1504
        # here counter counts occurrences of individual values, or tuples.
1505
        values_and_counts = counter.most_common()
1✔
1506
        if self.overall_top_frequency_percent < 100:
1✔
1507
            top_frequency = (
1✔
1508
                sum(counter.values()) * self.overall_top_frequency_percent / 100.0
1509
            )
1510
            sum_counts = 0
1✔
1511
            for _i, p in enumerate(values_and_counts):
1✔
1512
                sum_counts += p[1]
1✔
1513
                if sum_counts >= top_frequency:
1✔
1514
                    break
1✔
1515
            values_and_counts = counter.most_common(_i + 1)
1✔
1516
        if self.min_frequency_percent > 0:
1✔
1517
            min_frequency = self.min_frequency_percent * sum(counter.values()) / 100.0
1✔
1518
            while values_and_counts[-1][1] < min_frequency:
1✔
1519
                values_and_counts.pop()
1✔
1520
        values_to_keep = [
1✔
1521
            [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1522
            for ele in values_and_counts
1523
        ]
1524

1525
        addmostcommons = Set(fields={self.to_field: values_to_keep})
1✔
1526
        return addmostcommons(multi_stream)
1✔
1527

1528

1529
class ExtractFieldValues(ExtractMostCommonFieldValues):
1✔
1530
    def verify(self):
1✔
1531
        super().verify()
1✔
1532

1533
    def prepare(self):
1✔
1534
        self.overall_top_frequency_percent = 100
1✔
1535
        self.min_frequency_percent = 0
1✔
1536

1537

1538
class Intersect(FieldOperator):
1✔
1539
    """Intersects the value of a field, which must be a list, with a given list.
1540

1541
    Args:
1542
        allowed_values (list) - list to intersect.
1543
    """
1544

1545
    allowed_values: List[Any]
1✔
1546

1547
    def verify(self):
1✔
1548
        super().verify()
1✔
1549
        if self.process_every_value:
1✔
1550
            raise ValueError(
1✔
1551
                "'process_every_value=True' is not supported in Intersect operator"
1552
            )
1553

1554
        if not isinstance(self.allowed_values, list):
1✔
1555
            raise ValueError(
1✔
1556
                f"The allowed_values is not a list but '{self.allowed_values}'"
1557
            )
1558

1559
    def process_value(self, value: Any) -> Any:
1✔
1560
        super().process_value(value)
1✔
1561
        if not isinstance(value, list):
1✔
1562
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1563
        return [e for e in value if e in self.allowed_values]
1✔
1564

1565

1566
class IntersectCorrespondingFields(InstanceOperator):
1✔
1567
    """Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
1568

1569
    For example:
1570

1571
    Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
1572

1573
    .. code-block:: text
1574

1575
        IntersectCorrespondingFields(field="label",
1576
                                    allowed_values=["b", "f"],
1577
                                    corresponding_fields_to_intersect=["position"])
1578

1579
    would keep only "b" and "f" values in 'labels' field and
1580
    their respective values in the 'position' field.
1581
    (All other fields are not effected)
1582

1583
    .. code-block:: text
1584

1585
        Given this input:
1586

1587
        [
1588
            {"label": ["a", "b"],"position": [0,1],"other" : "not"},
1589
            {"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
1590
            {"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
1591
        ]
1592

1593
        So the output would be:
1594
        [
1595
                {"label": ["b"], "position":[1],"other" : "not"},
1596
                {"label": [], "position": [], "other" : "relevant"},
1597
                {"label": ["b", "f"],"position": [1,2], "other" : "field"},
1598
        ]
1599

1600
    Args:
1601
        field - the field to intersected (must contain list values)
1602
        allowed_values (list) - list of values to keep
1603
        corresponding_fields_to_intersect (list) - additional list fields from which values
1604
        are removed based the corresponding indices of values removed from the 'field'
1605
    """
1606

1607
    field: str
1✔
1608
    allowed_values: List[str]
1✔
1609
    corresponding_fields_to_intersect: List[str]
1✔
1610

1611
    def verify(self):
1✔
1612
        super().verify()
1✔
1613

1614
        if not isinstance(self.allowed_values, list):
1✔
1615
            raise ValueError(
×
1616
                f"The allowed_values is not a type list but '{type(self.allowed_values)}'"
1617
            )
1618

1619
    def process(
1✔
1620
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1621
    ) -> Dict[str, Any]:
1622
        if self.field not in instance:
1✔
1623
            raise ValueError(
1✔
1624
                f"Field '{self.field}' is not in provided instance.\n"
1625
                + to_pretty_string(instance)
1626
            )
1627

1628
        for corresponding_field in self.corresponding_fields_to_intersect:
1✔
1629
            if corresponding_field not in instance:
1✔
1630
                raise ValueError(
1✔
1631
                    f"Field '{corresponding_field}' is not in provided instance.\n"
1632
                    + to_pretty_string(instance)
1633
                )
1634

1635
        if not isinstance(instance[self.field], list):
1✔
1636
            raise ValueError(
1✔
1637
                f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
1638
                + to_pretty_string(instance, keys=[self.field])
1639
            )
1640

1641
        num_values_in_field = len(instance[self.field])
1✔
1642

1643
        if set(self.allowed_values) == set(instance[self.field]):
1✔
1644
            return instance
×
1645

1646
        indices_to_keep = [
1✔
1647
            i
1648
            for i, value in enumerate(instance[self.field])
1649
            if value in set(self.allowed_values)
1650
        ]
1651

1652
        result_instance = {}
1✔
1653
        for field_name, field_value in instance.items():
1✔
1654
            if (
1✔
1655
                field_name in self.corresponding_fields_to_intersect
1656
                or field_name == self.field
1657
            ):
1658
                if not isinstance(field_value, list):
1✔
1659
                    raise ValueError(
×
1660
                        f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
1661
                    )
1662
                if len(field_value) != num_values_in_field:
1✔
1663
                    raise ValueError(
1✔
1664
                        f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
1665
                        + to_pretty_string(instance, keys=[self.field, field_name])
1666
                    )
1667
                result_instance[field_name] = [
1✔
1668
                    value
1669
                    for index, value in enumerate(field_value)
1670
                    if index in indices_to_keep
1671
                ]
1672
            else:
1673
                result_instance[field_name] = field_value
1✔
1674
        return result_instance
1✔
1675

1676

1677
class RemoveValues(FieldOperator):
1✔
1678
    """Removes elements in a field, which must be a list, using a given list of unallowed.
1679

1680
    Args:
1681
        unallowed_values (list) - values to be removed.
1682
    """
1683

1684
    unallowed_values: List[Any]
1✔
1685

1686
    def verify(self):
1✔
1687
        super().verify()
1✔
1688

1689
        if not isinstance(self.unallowed_values, list):
1✔
1690
            raise ValueError(
1✔
1691
                f"The unallowed_values is not a list but '{self.unallowed_values}'"
1692
            )
1693

1694
    def process_value(self, value: Any) -> Any:
1✔
1695
        if not isinstance(value, list):
1✔
1696
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1697
        return [e for e in value if e not in self.unallowed_values]
1✔
1698

1699

1700
class Unique(SingleStreamReducer):
1✔
1701
    """Reduces a stream to unique instances based on specified fields.
1702

1703
    Args:
1704
        fields (List[str]): The fields that should be unique in each instance.
1705
    """
1706

1707
    fields: List[str] = field(default_factory=list)
1✔
1708

1709
    @staticmethod
1✔
1710
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
1✔
1711
        result = []
1✔
1712
        for field_name in fields:
1✔
1713
            value = instance[field_name]
1✔
1714
            if isinstance(value, list):
1✔
1715
                value = tuple(value)
1✔
1716
            result.append(value)
1✔
1717
        return tuple(result)
1✔
1718

1719
    def process(self, stream: Stream) -> Stream:
1✔
1720
        seen = set()
1✔
1721
        for instance in stream:
1✔
1722
            values = self.to_tuple(instance, self.fields)
1✔
1723
            if values not in seen:
1✔
1724
                seen.add(values)
1✔
1725
        return list(seen)
1✔
1726

1727

1728
class SplitByValue(MultiStreamOperator):
1✔
1729
    """Splits a MultiStream into multiple streams based on unique values in specified fields.
1730

1731
    Args:
1732
        fields (List[str]): The fields to use when splitting the MultiStream.
1733
    """
1734

1735
    fields: List[str] = field(default_factory=list)
1✔
1736

1737
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1738
        uniques = Unique(fields=self.fields)(multi_stream)
1✔
1739

1740
        result = {}
1✔
1741

1742
        for stream_name, stream in multi_stream.items():
1✔
1743
            stream_unique_values = uniques[stream_name]
1✔
1744
            for unique_values in stream_unique_values:
1✔
1745
                filtering_values = dict(zip(self.fields, unique_values))
1✔
1746
                filtered_streams = FilterByCondition(
1✔
1747
                    values=filtering_values, condition="eq"
1748
                )._process_single_stream(stream)
1749
                filtered_stream_name = (
1✔
1750
                    stream_name + "_" + nested_tuple_to_string(unique_values)
1751
                )
1752
                result[filtered_stream_name] = filtered_streams
1✔
1753

1754
        return MultiStream(result)
1✔
1755

1756

1757
class SplitByNestedGroup(MultiStreamOperator):
1✔
1758
    """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1759

1760
    Args:
1761
        number_of_fusion_generations: int
1762

1763
    the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1764
    when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1765
    (See BaseFusion and its extensions)
1766
    number_of_fuaion_generations  specifies the length of the prefix by which to split the stream.
1767
    E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1768
    For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1769
    """
1770

1771
    field_name_of_group: str = "group"
1✔
1772
    number_of_fusion_generations: int = 1
1✔
1773

1774
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1775
        result = defaultdict(list)
×
1776

1777
        for stream_name, stream in multi_stream.items():
×
1778
            for instance in stream:
×
1779
                if self.field_name_of_group not in instance:
×
1780
                    raise ValueError(
×
1781
                        f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1782
                    )
1783
                signature = (
×
1784
                    stream_name
1785
                    + "~"  #  a sign that does not show within group values
1786
                    + (
1787
                        "/".join(
1788
                            instance[self.field_name_of_group].split("/")[
1789
                                : self.number_of_fusion_generations
1790
                            ]
1791
                        )
1792
                        if self.number_of_fusion_generations >= 0
1793
                        # for values with a smaller number of generations - take up to their last generation
1794
                        else instance[self.field_name_of_group]
1795
                        # for each instance - take all its generations
1796
                    )
1797
                )
1798
                result[signature].append(instance)
×
1799

1800
        return MultiStream.from_iterables(result)
×
1801

1802

1803
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1✔
1804
    """Applies stream operators to a stream based on specified fields in each instance.
1805

1806
    Args:
1807
        field (str): The field containing the operators to be applied.
1808
        reversed (bool): Whether to apply the operators in reverse order.
1809
    """
1810

1811
    field: str
1✔
1812
    reversed: bool = False
1✔
1813

1814
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1815
        first_instance = stream.peek()
1✔
1816

1817
        operators = first_instance.get(self.field, [])
1✔
1818
        if isinstance(operators, str):
1✔
1819
            operators = [operators]
1✔
1820

1821
        if self.reversed:
1✔
1822
            operators = list(reversed(operators))
1✔
1823

1824
        for operator_name in operators:
1✔
1825
            operator = self.get_artifact(operator_name)
1✔
1826
            assert isinstance(
1✔
1827
                operator, StreamingOperator
1828
            ), f"Operator {operator_name} must be a StreamOperator"
1829

1830
            stream = operator(MultiStream({stream_name: stream}))[stream_name]
1✔
1831

1832
        yield from stream
1✔
1833

1834

1835
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
1✔
1836
    for instance, score in zip(stream, scores):
1✔
1837
        instance["score"] = recursive_copy(score)
1✔
1838
        yield instance
1✔
1839

1840

1841
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1✔
1842
    """Applies metric operators to a stream based on a metric field specified in each instance.
1843

1844
    Args:
1845
        metric_field (str): The field containing the metrics to be applied.
1846
        calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1847
    """
1848

1849
    metric_field: str
1✔
1850
    calc_confidence_intervals: bool
1✔
1851

1852
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1853
        from .metrics import Metric, MetricsList
1✔
1854

1855
        # to be populated only when two or more metrics
1856
        accumulated_scores = []
1✔
1857

1858
        first_instance = stream.peek()
1✔
1859

1860
        metric_names = first_instance.get(self.metric_field, [])
1✔
1861
        if not metric_names:
1✔
1862
            raise RuntimeError(
1✔
1863
                f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1864
            )
1865

1866
        if isinstance(metric_names, str):
1✔
1867
            metric_names = [metric_names]
1✔
1868

1869
        metrics_list = []
1✔
1870
        for metric_name in metric_names:
1✔
1871
            metric = self.get_artifact(metric_name)
1✔
1872
            if isinstance(metric, MetricsList):
1✔
1873
                metrics_list.extend(list(metric.items))
1✔
1874
            elif isinstance(metric, Metric):
1✔
1875
                metrics_list.append(metric)
1✔
1876
            else:
1877
                raise ValueError(
×
1878
                    f"Operator {metric_name} must be a Metric or MetricsList"
1879
                )
1880

1881
        for metric in metrics_list:
1✔
1882
            if not self.calc_confidence_intervals:
1✔
1883
                metric.disable_confidence_interval_calculation()
1✔
1884
        # Each metric operator computes its score and then sets the main score, overwriting
1885
        # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1886
        # This will cause the first listed metric to run last, and the main score will be set
1887
        # by the first listed metric (as desired).
1888
        metrics_list = list(reversed(metrics_list))
1✔
1889

1890
        for i, metric in enumerate(metrics_list):
1✔
1891
            if i == 0:  # first metric
1✔
1892
                multi_stream = MultiStream({"tmp": stream})
1✔
1893
            else:  # metrics with previous scores
1894
                reusable_generator = ReusableGenerator(
1✔
1895
                    generator=update_scores_of_stream_instances,
1896
                    gen_kwargs={"stream": stream, "scores": accumulated_scores},
1897
                )
1898
                multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1✔
1899

1900
            multi_stream = metric(multi_stream)
1✔
1901

1902
            if i < len(metrics_list) - 1:  # last metric
1✔
1903
                accumulated_scores = []
1✔
1904
                for inst in multi_stream["tmp"]:
1✔
1905
                    accumulated_scores.append(recursive_copy(inst["score"]))
1✔
1906

1907
        yield from multi_stream["tmp"]
1✔
1908

1909

1910
class MergeStreams(MultiStreamOperator):
1✔
1911
    """Merges multiple streams into a single stream.
1912

1913
    Args:
1914
        new_stream_name (str): The name of the new stream resulting from the merge.
1915
        add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
1916
        origin_stream_name_field_name (str): The field name for the origin stream name.
1917
    """
1918

1919
    streams_to_merge: List[str] = None
1✔
1920
    new_stream_name: str = "all"
1✔
1921
    add_origin_stream_name: bool = True
1✔
1922
    origin_stream_name_field_name: str = "origin"
1✔
1923

1924
    def merge(self, multi_stream) -> Generator:
1✔
1925
        for stream_name, stream in multi_stream.items():
1✔
1926
            if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1✔
1927
                for instance in stream:
1✔
1928
                    if self.add_origin_stream_name:
1✔
1929
                        instance[self.origin_stream_name_field_name] = stream_name
1✔
1930
                    yield instance
1✔
1931

1932
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1933
        return MultiStream(
1✔
1934
            {
1935
                self.new_stream_name: DynamicStream(
1936
                    self.merge, gen_kwargs={"multi_stream": multi_stream}
1937
                )
1938
            }
1939
        )
1940

1941

1942
class Shuffle(PagedStreamOperator):
1✔
1943
    """Shuffles the order of instances in each page of a stream.
1944

1945
    Args (of superclass):
1946
        page_size (int): The size of each page in the stream. Defaults to 1000.
1947
    """
1948

1949
    random_generator: Random = None
1✔
1950

1951
    def before_process_multi_stream(self):
1✔
1952
        super().before_process_multi_stream()
1✔
1953
        self.random_generator = new_random_generator(sub_seed="shuffle")
1✔
1954

1955
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1956
        self.random_generator.shuffle(page)
1✔
1957
        yield from page
1✔
1958

1959

1960
class FeatureGroupedShuffle(Shuffle):
1✔
1961
    """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1962

1963
    Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1964
    All paraphrases have the same ID value as the original.
1965
    In this case, we may want to shuffle on grouping_features = ['question ID'],
1966
    to keep the paraphrases and original question together.
1967
    We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1968
    In this case, grouping_features = ['question ID', 'topic']
1969

1970
    Args:
1971
        grouping_features (list of strings): list of feature names to use to define the groups.
1972
            a group is defined by each unique observed combination of data values for features in grouping_features
1973
        shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1974

1975
    Args (of superclass):
1976
        page_size (int): The size of each page in the stream. Defaults to 1000.
1977
            Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1978
            separately by page (determined by page_size).  If a block of instances in the same group are split
1979
            into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1980
            grouping_features), these instances will be shuffled separately and thus the grouping may be
1981
            broken up by pages.  If the user wants to ensure the shuffle does the grouping and shuffling
1982
            across all pages, set the page_size to be larger than the dataset size.
1983
            See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1984
    """
1985

1986
    grouping_features: List[str] = None
1✔
1987
    shuffle_within_group: bool = False
1✔
1988

1989
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1990
        if self.grouping_features is None:
1✔
1991
            super().process(page, stream_name)
×
1992
        else:
1993
            yield from self.shuffle_by_grouping_features(page)
1✔
1994

1995
    def shuffle_by_grouping_features(self, page):
1✔
1996
        import itertools
1✔
1997
        from collections import defaultdict
1✔
1998

1999
        groups_to_instances = defaultdict(list)
1✔
2000
        for item in page:
1✔
2001
            groups_to_instances[
1✔
2002
                tuple(item[ff] for ff in self.grouping_features)
2003
            ].append(item)
2004
        # now extract the groups (i.e., lists of dicts with order preserved)
2005
        page_blocks = list(groups_to_instances.values())
1✔
2006
        # and now shuffle the blocks
2007
        self.random_generator.shuffle(page_blocks)
1✔
2008
        if self.shuffle_within_group:
1✔
2009
            blocks = []
1✔
2010
            # reshuffle the instances within each block, but keep the blocks in order
2011
            for block in page_blocks:
1✔
2012
                self.random_generator.shuffle(block)
1✔
2013
                blocks.append(block)
1✔
2014
            page_blocks = blocks
1✔
2015

2016
        # now flatten the list so it consists of individual dicts, but in (randomized) block order
2017
        return list(itertools.chain(*page_blocks))
1✔
2018

2019

2020
class EncodeLabels(InstanceOperator):
1✔
2021
    """Encode each value encountered in any field in 'fields' into the integers 0,1,...
2022

2023
    Encoding is determined by a str->int map that is built on the go, as different values are
2024
    first encountered in the stream, either as list members or as values in single-value fields.
2025

2026
    Args:
2027
        fields (List[str]): The fields to encode together.
2028

2029
    Example:
2030
        applying ``EncodeLabels(fields = ["a", "b/*"])``
2031
        on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
2032
        {"a": "blue", "b": ["green"], "c":"water"}]``   will yield the
2033
        output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
2034

2035
        Note: dict_utils are applied here, and hence, fields that are lists, should be included in
2036
        input 'fields' with the appendix ``"/*"``  as in the above example.
2037

2038
    """
2039

2040
    fields: List[str]
1✔
2041

2042
    def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
1✔
2043
        self.encoder = {}
1✔
2044
        return super()._process_multi_stream(multi_stream)
1✔
2045

2046
    def process(
1✔
2047
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
2048
    ) -> Dict[str, Any]:
2049
        for field_name in self.fields:
1✔
2050
            values = dict_get(instance, field_name)
1✔
2051
            values_was_a_list = isinstance(values, list)
1✔
2052
            if not isinstance(values, list):
1✔
2053
                values = [values]
1✔
2054
            for value in values:
1✔
2055
                if value not in self.encoder:
1✔
2056
                    self.encoder[value] = len(self.encoder)
1✔
2057
            new_values = [self.encoder[value] for value in values]
1✔
2058
            if not values_was_a_list:
1✔
2059
                new_values = new_values[0]
1✔
2060
            dict_set(
1✔
2061
                instance,
2062
                field_name,
2063
                new_values,
2064
                not_exist_ok=False,  # the values to encode where just taken from there
2065
                set_multiple="*" in field_name
2066
                and isinstance(new_values, list)
2067
                and len(new_values) > 0,
2068
            )
2069

2070
        return instance
1✔
2071

2072

2073
class StreamRefiner(StreamOperator):
1✔
2074
    """Discard from the input stream all instances beyond the leading 'max_instances' instances.
2075

2076
    Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
2077
    input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
2078
    of the leading 'max_instances' of the input stream.
2079

2080
    Args:
2081
        max_instances (int)
2082
        apply_to_streams (optional, list(str)):
2083
            names of streams to refine.
2084

2085
    Examples:
2086
        when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
2087
        ``StreamRefiner(max_instances=4)``
2088
        the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
2089
    """
2090

2091
    max_instances: int = None
1✔
2092
    apply_to_streams: Optional[List[str]] = None
1✔
2093

2094
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2095
        if self.max_instances is not None:
1✔
2096
            yield from stream.take(self.max_instances)
1✔
2097
        else:
2098
            yield from stream
1✔
2099

2100

2101
class Deduplicate(StreamOperator):
1✔
2102
    """Deduplicate the stream based on the given fields.
2103

2104
    Args:
2105
        by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness.
2106

2107
    Examples:
2108
        >>> dedup = Deduplicate(by=["field1", "field2"])
2109
    """
2110

2111
    by: List[str]
1✔
2112

2113
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2114
        seen = set()
1✔
2115

2116
        for instance in stream:
1✔
2117
            # Compute a lightweight hash for the signature
2118
            signature = hash(str(tuple(dict_get(instance, field) for field in self.by)))
1✔
2119

2120
            if signature not in seen:
1✔
2121
                seen.add(signature)
1✔
2122
                yield instance
1✔
2123

2124

2125
class Balance(StreamRefiner):
1✔
2126
    """A class used to balance streams deterministically.
2127

2128
    For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
2129
    By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
2130
    When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
2131
    'max_instances'. The total number of discarded instances is as few as possible.
2132

2133
    Args:
2134
        fields (List[str]):
2135
            A list of field names to be used in producing the instance's signature.
2136
        max_instances (Optional, int):
2137
            overall max.
2138

2139
    Usage:
2140
        ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
2141
        ``balanced_stream = balancer.process(stream)``
2142

2143
    Example:
2144
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
2145
        ``DeterministicBalancer(fields=["a"])``
2146
        the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
2147
    """
2148

2149
    fields: List[str]
1✔
2150

2151
    def signature(self, instance):
1✔
2152
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2153

2154
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2155
        counter = Counter()
1✔
2156

2157
        for instance in stream:
1✔
2158
            counter[self.signature(instance)] += 1
1✔
2159

2160
        if len(counter) == 0:
1✔
2161
            return
1✔
2162

2163
        lowest_count = counter.most_common()[-1][-1]
1✔
2164

2165
        max_total_instances_per_sign = lowest_count
1✔
2166
        if self.max_instances is not None:
1✔
2167
            max_total_instances_per_sign = min(
1✔
2168
                lowest_count, self.max_instances // len(counter)
2169
            )
2170

2171
        counter = Counter()
1✔
2172

2173
        for instance in stream:
1✔
2174
            sign = self.signature(instance)
1✔
2175
            if counter[sign] < max_total_instances_per_sign:
1✔
2176
                counter[sign] += 1
1✔
2177
                yield instance
1✔
2178

2179

2180
class DeterministicBalancer(Balance):
1✔
2181
    pass
1✔
2182

2183

2184
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1✔
2185
    """A class used to return a specified number instances ensuring at least one example  per label.
2186

2187
    For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
2188
    ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit.  In general, the refiner takes the first elements in the stream that meet the required conditions.
2189
    ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
2190
    from each class first and then the rest . If max instance is not set, the original stream will be used
2191

2192
    Args:
2193
        fields (List[str]):
2194
            A list of field names to be used in producing the instance's signature.
2195
        max_instances (Optional, int):
2196
            Number of elements to select. Note that max_instances of StreamRefiners
2197
            that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
2198
            by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
2199

2200
    Usage:
2201
        | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
2202
        | ``balanced_stream = balancer.process(stream)``
2203

2204
    Example:
2205
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
2206
        ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
2207
        the resulting stream will be:
2208
        ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
2209
    """
2210

2211
    fields: List[str]
1✔
2212

2213
    def signature(self, instance):
1✔
2214
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2215

2216
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2217
        if self.max_instances is None:
1✔
2218
            for instance in stream:
×
2219
                yield instance
×
2220

2221
        counter = Counter()
1✔
2222
        for instance in stream:
1✔
2223
            counter[self.signature(instance)] += 1
1✔
2224
        all_keys = counter.keys()
1✔
2225
        if len(counter) == 0:
1✔
2226
            return
×
2227

2228
        if self.max_instances is not None and len(all_keys) > self.max_instances:
1✔
2229
            raise Exception(
×
2230
                f"Can not generate a stream with at least one example per label, because the max instances requested  {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2231
                f" ({len(all_keys)}"
2232
            )
2233

2234
        counter = Counter()
1✔
2235
        used_indices = set()
1✔
2236
        selected_elements = []
1✔
2237
        # select at least one per class
2238
        for idx, instance in enumerate(stream):
1✔
2239
            sign = self.signature(instance)
1✔
2240
            if counter[sign] == 0:
1✔
2241
                counter[sign] += 1
1✔
2242
                used_indices.add(idx)
1✔
2243
                selected_elements.append(
1✔
2244
                    instance
2245
                )  # collect all elements first to allow shuffling of both groups
2246

2247
        # select more to reach self.max_instances examples
2248
        for idx, instance in enumerate(stream):
1✔
2249
            if idx not in used_indices:
1✔
2250
                if self.max_instances is None or len(used_indices) < self.max_instances:
1✔
2251
                    used_indices.add(idx)
1✔
2252
                    selected_elements.append(
1✔
2253
                        instance
2254
                    )  # collect all elements first to allow shuffling of both groups
2255

2256
        # shuffle elements to avoid having one element from each class appear first
2257
        random_generator = new_random_generator(sub_seed=selected_elements)
1✔
2258
        random_generator.shuffle(selected_elements)
1✔
2259
        yield from selected_elements
1✔
2260

2261

2262
class LengthBalancer(DeterministicBalancer):
1✔
2263
    """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2264

2265
    Args:
2266
        segments_boundaries (List[int]):
2267
            distinct integers sorted in increasing order, that map a given total length
2268
            into the index of the least of them that exceeds the given total length.
2269
            (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2270
        fields (Optional, List[str]):
2271
            the total length of the values of these fields goes through the quantization described above
2272

2273

2274
    Example:
2275
        when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2276
        is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2277
        input instances will be counted and balanced against two categories:
2278
        empty total length (less than 1), and non-empty.
2279
    """
2280

2281
    segments_boundaries: List[int]
1✔
2282
    fields: Optional[List[str]]
1✔
2283

2284
    def signature(self, instance):
1✔
2285
        total_len = 0
1✔
2286
        for field_name in self.fields:
1✔
2287
            total_len += len(dict_get(instance, field_name))
1✔
2288
        for i, val in enumerate(self.segments_boundaries):
1✔
2289
            if total_len < val:
1✔
2290
                return i
1✔
2291
        return i + 1
1✔
2292

2293

2294
class DownloadError(Exception):
1✔
2295
    def __init__(
1✔
2296
        self,
2297
        message,
2298
    ):
2299
        self.__super__(message)
×
2300

2301

2302
class UnexpectedHttpCodeError(Exception):
1✔
2303
    def __init__(self, http_code):
1✔
2304
        self.__super__(f"unexpected http code {http_code}")
×
2305

2306

2307
class DownloadOperator(SideEffectOperator):
1✔
2308
    """Operator for downloading a file from a given URL to a specified local path.
2309

2310
    Args:
2311
        source (str):
2312
            URL of the file to be downloaded.
2313
        target (str):
2314
            Local path where the downloaded file should be saved.
2315
    """
2316

2317
    source: str
1✔
2318
    target: str
1✔
2319

2320
    def process(self):
1✔
2321
        try:
×
2322
            response = requests.get(self.source, allow_redirects=True)
×
2323
        except Exception as e:
×
2324
            raise DownloadError(f"Unabled to download {self.source}") from e
×
2325
        if response.status_code != 200:
×
2326
            raise UnexpectedHttpCodeError(response.status_code)
×
2327
        with open(self.target, "wb") as f:
×
2328
            f.write(response.content)
×
2329

2330

2331
class ExtractZipFile(SideEffectOperator):
1✔
2332
    """Operator for extracting files from a zip archive.
2333

2334
    Args:
2335
        zip_file (str):
2336
            Path of the zip file to be extracted.
2337
        target_dir (str):
2338
            Directory where the contents of the zip file will be extracted.
2339
    """
2340

2341
    zip_file: str
1✔
2342
    target_dir: str
1✔
2343

2344
    def process(self):
1✔
2345
        with zipfile.ZipFile(self.zip_file) as zf:
×
2346
            zf.extractall(self.target_dir)
×
2347

2348

2349
class DuplicateInstances(StreamOperator):
1✔
2350
    """Operator which duplicates each instance in stream a given number of times.
2351

2352
    Args:
2353
        num_duplications (int):
2354
            How many times each instance should be duplicated (1 means no duplication).
2355
        duplication_index_field (Optional[str]):
2356
            If given, then additional field with specified name is added to each duplicated instance,
2357
            which contains id of a given duplication. Defaults to None, so no field is added.
2358
    """
2359

2360
    num_duplications: int
1✔
2361
    duplication_index_field: Optional[str] = None
1✔
2362

2363
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2364
        for instance in stream:
1✔
2365
            for idx in range(self.num_duplications):
1✔
2366
                duplicate = recursive_shallow_copy(instance)
1✔
2367
                if self.duplication_index_field:
1✔
2368
                    duplicate.update({self.duplication_index_field: idx})
1✔
2369
                yield duplicate
1✔
2370

2371
    def verify(self):
1✔
2372
        if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1✔
2373
            raise ValueError(
×
2374
                f"num_duplications must be an integer equal to or greater than 1. "
2375
                f"Got: {self.num_duplications}."
2376
            )
2377

2378
        if self.duplication_index_field is not None and not isinstance(
1✔
2379
            self.duplication_index_field, str
2380
        ):
2381
            raise ValueError(
×
2382
                f"If given, duplication_index_field must be a string. "
2383
                f"Got: {self.duplication_index_field}"
2384
            )
2385

2386

2387
class CollateInstances(StreamOperator):
1✔
2388
    """Operator which collates values from multiple instances to a single instance.
2389

2390
    Each field becomes the list of values of corresponding field of collated `batch_size` of instances.
2391

2392
    Attributes:
2393
        batch_size (int)
2394

2395
    Example:
2396
        .. code-block:: text
2397

2398
            CollateInstances(batch_size=2)
2399

2400
            Given inputs = [
2401
                {"a": 1, "b": 2},
2402
                {"a": 2, "b": 2},
2403
                {"a": 3, "b": 2},
2404
                {"a": 4, "b": 2},
2405
                {"a": 5, "b": 2}
2406
            ]
2407

2408
            Returns targets = [
2409
                {"a": [1,2], "b": [2,2]},
2410
                {"a": [3,4], "b": [2,2]},
2411
                {"a": [5], "b": [2]},
2412
            ]
2413

2414

2415
    """
2416

2417
    batch_size: int
1✔
2418

2419
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2420
        stream = list(stream)
1✔
2421
        for i in range(0, len(stream), self.batch_size):
1✔
2422
            batch = stream[i : i + self.batch_size]
1✔
2423
            new_instance = {}
1✔
2424
            for a_field in batch[0]:
1✔
2425
                if a_field == "data_classification_policy":
1✔
2426
                    flattened_list = [
1✔
2427
                        classification
2428
                        for instance in batch
2429
                        for classification in instance[a_field]
2430
                    ]
2431
                    new_instance[a_field] = sorted(set(flattened_list))
1✔
2432
                else:
2433
                    new_instance[a_field] = [instance[a_field] for instance in batch]
1✔
2434
            yield new_instance
1✔
2435

2436
    def verify(self):
1✔
2437
        if not isinstance(self.batch_size, int) or self.batch_size < 1:
1✔
2438
            raise ValueError(
×
2439
                f"batch_size must be an integer equal to or greater than 1. "
2440
                f"Got: {self.batch_size}."
2441
            )
2442

2443

2444
class CollateInstancesByField(StreamOperator):
1✔
2445
    """Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
2446

2447
    Args:
2448
        by_field str: the name of the field to group data by.
2449
        aggregate_fields list(str): the field names to aggregate into lists.
2450

2451
    Returns:
2452
        A stream of instances grouped and aggregated by the specified field.
2453

2454
    Raises:
2455
        UnitxtError: If non-aggregate fields have inconsistent values.
2456

2457
    Example:
2458
        Collate the instances based on field "category" and aggregate fields "value" and "id".
2459

2460
        .. code-block:: text
2461

2462
            CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
2463

2464
            given input:
2465
            [
2466
                {"id": 1, "category": "A", "value": 10", "flag" : True},
2467
                {"id": 2, "category": "B", "value": 20", "flag" : False},
2468
                {"id": 3, "category": "A", "value": 30", "flag" : True},
2469
                {"id": 4, "category": "B", "value": 40", "flag" : False}
2470
            ]
2471

2472
            the output is:
2473
            [
2474
                {"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
2475
                {"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
2476
            ]
2477

2478
        Note that the "flag" field is not aggregated, and must be the same
2479
        in all instances in the same category, or an error is raised.
2480
    """
2481

2482
    by_field: str = NonPositionalField(required=True)
1✔
2483
    aggregate_fields: List[str] = NonPositionalField(required=True)
1✔
2484

2485
    def prepare(self):
1✔
2486
        super().prepare()
1✔
2487

2488
    def verify(self):
1✔
2489
        super().verify()
1✔
2490
        if not isinstance(self.by_field, str):
1✔
2491
            raise UnitxtError(
×
2492
                f"The 'by_field' value is not a string but '{type(self.by_field)}'"
2493
            )
2494

2495
        if not isinstance(self.aggregate_fields, list):
1✔
2496
            raise UnitxtError(
×
2497
                f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
2498
            )
2499

2500
    def process(self, stream: Stream, stream_name: Optional[str] = None):
1✔
2501
        grouped_data = {}
1✔
2502

2503
        for instance in stream:
1✔
2504
            if self.by_field not in instance:
1✔
2505
                raise UnitxtError(
1✔
2506
                    f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
2507
                )
2508
            for k in self.aggregate_fields:
1✔
2509
                if k not in instance:
1✔
2510
                    raise UnitxtError(
1✔
2511
                        f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
2512
                    )
2513
            key = instance[self.by_field]
1✔
2514

2515
            if key not in grouped_data:
1✔
2516
                grouped_data[key] = {
1✔
2517
                    k: v for k, v in instance.items() if k not in self.aggregate_fields
2518
                }
2519
                # Add empty lists for fields to aggregate
2520
                for agg_field in self.aggregate_fields:
1✔
2521
                    if agg_field in instance:
1✔
2522
                        grouped_data[key][agg_field] = []
1✔
2523

2524
            for k, v in instance.items():
1✔
2525
                # Merge classification policy list across instance with same key
2526
                if k == "data_classification_policy" and instance[k]:
1✔
2527
                    grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
1✔
2528
                # Check consistency for all non-aggregate fields
2529
                elif k != self.by_field and k not in self.aggregate_fields:
1✔
2530
                    if k in grouped_data[key] and grouped_data[key][k] != v:
1✔
2531
                        raise ValueError(
1✔
2532
                            f"Inconsistent value for field '{k}' in group '{key}': "
2533
                            f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
2534
                        )
2535
                # Aggregate fields
2536
                elif k in self.aggregate_fields:
1✔
2537
                    grouped_data[key][k].append(instance[k])
1✔
2538

2539
        yield from grouped_data.values()
1✔
2540

2541

2542
class WikipediaFetcher(FieldOperator):
1✔
2543
    mode: Literal["summary", "text"] = "text"
1✔
2544
    _requirements_list = ["Wikipedia-API"]
1✔
2545

2546
    def prepare(self):
1✔
2547
        super().prepare()
×
2548
        import wikipediaapi
×
2549

2550
        self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
×
2551

2552
    def process_value(self, value: Any) -> Any:
1✔
2553
        title = value.split("/")[-1]
×
2554
        page = self.wikipedia.page(title)
×
2555

2556
        return {"title": page.title, "body": getattr(page, self.mode)}
×
2557

2558
class Fillna(FieldOperator):
1✔
2559
    value: Any
1✔
2560
    def process_value(self, value: Any) -> Any:
1✔
2561
        import numpy as np
×
2562
        try:
×
2563
            if np.isnan(value):
×
2564
                return self.value
×
2565
        except TypeError:
×
2566
            return value
×
2567
        return value
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc