• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14911405902

08 May 2025 04:31PM UTC coverage: 80.074% (-0.07%) from 80.14%
14911405902

Pull #1773

github

web-flow
Merge e96fbbe15 into 2d15f20af
Pull Request #1773: Simplify tool calling base types

1645 of 2037 branches covered (80.76%)

Branch coverage included in aggregate %.

10250 of 12818 relevant lines covered (79.97%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.41
src/unitxt/operators.py
1
"""This section describes unitxt operators.
2

3
Operators: Building Blocks of Unitxt Processing Pipelines
4
==============================================================
5

6
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
These operators are callable entities that receive a MultiStream as input.
9
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10

11
Creating Custom Operators
12
-------------------------------
13
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16

17
General or Specialized Operators
18
--------------------------------
19
Some operators are specialized in specific data or specific operations such as:
20

21
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
- :class:`splitters<unitxt.splitters>` for fixing data splits.
23
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
- :class:`string_operators<unitxt.string_operators>` for handling strings.
28
- :class:`span_labeling_operators<unitxt.span_lableing_operators>` for handling strings.
29
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30

31
Other specialized operators are used by unitxt internally:
32

33
- :class:`templates<unitxt.templates>` for verbalizing data examples.
34
- :class:`formats<unitxt.formats>` for preparing data for models.
35

36
The rest of this section is dedicated to general operators.
37

38
General Operators List:
39
------------------------
40
"""
41

42
import operator
1✔
43
import uuid
1✔
44
import warnings
1✔
45
import zipfile
1✔
46
from abc import abstractmethod
1✔
47
from collections import Counter, defaultdict
1✔
48
from dataclasses import field
1✔
49
from itertools import zip_longest
1✔
50
from random import Random
1✔
51
from typing import (
1✔
52
    Any,
53
    Callable,
54
    Dict,
55
    Generator,
56
    Iterable,
57
    List,
58
    Literal,
59
    Optional,
60
    Tuple,
61
    Union,
62
)
63

64
import requests
1✔
65

66
from .artifact import Artifact, fetch_artifact
1✔
67
from .dataclass import NonPositionalField, OptionalField
1✔
68
from .deprecation_utils import deprecation
1✔
69
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
1✔
70
from .error_utils import UnitxtError
1✔
71
from .generator_utils import ReusableGenerator
1✔
72
from .operator import (
1✔
73
    InstanceOperator,
74
    MultiStream,
75
    MultiStreamOperator,
76
    PagedStreamOperator,
77
    SequentialOperator,
78
    SideEffectOperator,
79
    SingleStreamReducer,
80
    SourceOperator,
81
    StreamingOperator,
82
    StreamInitializerOperator,
83
    StreamOperator,
84
)
85
from .random_utils import new_random_generator
1✔
86
from .settings_utils import get_settings
1✔
87
from .stream import DynamicStream, Stream
1✔
88
from .text_utils import nested_tuple_to_string, to_pretty_string
1✔
89
from .type_utils import isoftype
1✔
90
from .utils import (
1✔
91
    LRUCache,
92
    deep_copy,
93
    flatten_dict,
94
    recursive_copy,
95
    recursive_shallow_copy,
96
    shallow_copy,
97
)
98

99
settings = get_settings()
1✔
100

101

102
class FromIterables(StreamInitializerOperator):
1✔
103
    """Creates a MultiStream from a dict of named iterables.
104

105
    Example:
106
        operator = FromIterables()
107
        ms = operator.process(iterables)
108

109
    """
110

111
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
1✔
112
        return MultiStream.from_iterables(iterables)
1✔
113

114

115
class IterableSource(SourceOperator):
1✔
116
    """Creates a MultiStream from a dict of named iterables.
117

118
    It is a callable.
119

120
    Args:
121
        iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
122

123
    Example:
124
        operator =  IterableSource(input_dict)
125
        ms = operator()
126

127
    """
128

129
    iterables: Dict[str, Iterable]
1✔
130

131
    def process(self) -> MultiStream:
1✔
132
        return MultiStream.from_iterables(self.iterables)
1✔
133

134

135
class MapInstanceValues(InstanceOperator):
1✔
136
    """A class used to map instance values into other values.
137

138
    This class is a type of ``InstanceOperator``,
139
    it maps values of instances in a stream using predefined mappers.
140

141
    Args:
142
        mappers (Dict[str, Dict[str, Any]]):
143
            The mappers to use for mapping instance values.
144
            Keys are the names of the fields to undergo mapping, and values are dictionaries
145
            that define the mapping from old values to new values.
146
            Note that mapped values are defined by their string representation, so mapped values
147
            are converted to strings before being looked up in the mappers.
148
        strict (bool):
149
            If True, the mapping is applied strictly. That means if a value
150
            does not exist in the mapper, it will raise a KeyError. If False, values
151
            that are not present in the mapper are kept as they are.
152
        process_every_value (bool):
153
            If True, all fields to be mapped should be lists, and the mapping
154
            is to be applied to their individual elements.
155
            If False, mapping is only applied to a field containing a single value.
156

157
    Examples:
158
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
159
        replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
160
        instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
161
        since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
162
        up in the mapper of ``"a"``.
163

164
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
165
        Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
166
        each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
167
        instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
168

169
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
170
        To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
171
        Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
172
        because ``"3"`` is not a key in the mapper of ``"a"``.
173

174
        ``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
175
        replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
176

177
    """
178

179
    mappers: Dict[str, Dict[str, str]]
1✔
180
    strict: bool = True
1✔
181
    process_every_value: bool = False
1✔
182

183
    def verify(self):
1✔
184
        # make sure the mappers are valid
185
        for key, mapper in self.mappers.items():
1✔
186
            assert isinstance(
1✔
187
                mapper, dict
188
            ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
189
            for k in mapper.keys():
1✔
190
                assert isinstance(
1✔
191
                    k, str
192
                ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
193

194
    def process(
1✔
195
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
196
    ) -> Dict[str, Any]:
197
        for key, mapper in self.mappers.items():
1✔
198
            value = dict_get(instance, key)
1✔
199
            if value is not None:
1✔
200
                if (self.process_every_value is True) and (not isinstance(value, list)):
1✔
201
                    raise ValueError(
1✔
202
                        f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
203
                    )
204
                if isinstance(value, list) and self.process_every_value:
1✔
205
                    for i, val in enumerate(value):
1✔
206
                        value[i] = self.get_mapped_value(instance, key, mapper, val)
1✔
207
                else:
208
                    value = self.get_mapped_value(instance, key, mapper, value)
1✔
209
                dict_set(
1✔
210
                    instance,
211
                    key,
212
                    value,
213
                )
214

215
        return instance
1✔
216

217
    def get_mapped_value(self, instance, key, mapper, val):
1✔
218
        val_as_str = str(val)  # make sure the value is a string
1✔
219
        if val_as_str in mapper:
1✔
220
            return recursive_copy(mapper[val_as_str])
1✔
221
        if self.strict:
1✔
222
            raise KeyError(
1✔
223
                f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
224
            )
225
        return val
1✔
226

227

228
class FlattenInstances(InstanceOperator):
1✔
229
    """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
230

231
    Args:
232
        parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
233
        sep (str): The separator to use when concatenating nested keys. Defaults to "_".
234
    """
235

236
    parent_key: str = ""
1✔
237
    sep: str = "_"
1✔
238

239
    def process(
1✔
240
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
241
    ) -> Dict[str, Any]:
242
        return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
1✔
243

244

245
class Set(InstanceOperator):
1✔
246
    """Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
247

248
    Args:
249
        fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
250

251
        use_deepcopy (bool) : Deep copy the input value to avoid later modifications
252

253
    Examples:
254
        # Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
255
        ``Set(fields={"classes": ["positive","negatives"]})``
256

257
        # In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
258
        ``Set(fields={"span/start": 0}``
259

260
        # In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
261
        ``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
262

263
        # Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
264
        ``Set(fields={"classes": alist}), use_deepcopy=True)``  if now alist is modified, still the instances remain intact.
265
    """
266

267
    fields: Dict[str, object]
1✔
268
    use_query: Optional[bool] = None
1✔
269
    use_deepcopy: bool = False
1✔
270

271
    def verify(self):
1✔
272
        super().verify()
1✔
273
        if self.use_query is not None:
1✔
274
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
275
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
276

277
    def process(
1✔
278
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
279
    ) -> Dict[str, Any]:
280
        for key, value in self.fields.items():
1✔
281
            if self.use_deepcopy:
1✔
282
                value = deep_copy(value)
1✔
283
            dict_set(instance, key, value)
1✔
284
        return instance
1✔
285

286
def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
1✔
287
    """Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
288

289
    Args:
290
        data: The data structure (dict or list) to traverse.
291
        target_key: The specific key whose value needs to be checked and replaced or removed.
292
        value_map: A dictionary mapping old values to new values.
293
        value_remove: A list of values to completely remove if found as values of target_key.
294

295
    Returns:
296
        The modified data structure. Modification is done in-place.
297
    """
298
    if value_remove is None:
1✔
299
        value_remove = []
×
300

301
    if isinstance(data, dict):
1✔
302
        keys_to_delete = []
1✔
303
        for key, value in data.items():
1✔
304
            if key == target_key:
1✔
305
                if isinstance(value, list):
1✔
306
                    data[key] = [
×
307
                        value_map.get(item, item)
308
                        for item in value
309
                        if not isinstance(item, dict) and item not in value_remove
310
                    ]
311
                elif isinstance(value, dict):
1✔
312
                    pass  # Skip or handle dict values if needed
×
313
                elif value in value_remove:
1✔
314
                    keys_to_delete.append(key)
1✔
315
                elif value in value_map:
1✔
316
                    data[key] = value_map[value]
1✔
317
            else:
318
                recursive_key_value_replace(value, target_key, value_map, value_remove)
1✔
319
        for key in keys_to_delete:
1✔
320
            del data[key]
1✔
321
    elif isinstance(data, list):
1✔
322
        for item in data:
1✔
323
            recursive_key_value_replace(item, target_key, value_map, value_remove)
1✔
324
    return data
1✔
325

326
class RecursiveReplace(InstanceOperator):
1✔
327
    key: str
1✔
328
    map_values: dict
1✔
329
    remove_values: Optional[list] = None
1✔
330

331
    def process(self, instance: Dict[str, Any], stream_name: Optional[str] = None) -> Dict[str, Any]:
1✔
332
        return recursive_key_value_replace(instance, self.key, self.map_values, self.remove_values)
1✔
333

334
@deprecation(version="2.0.0", alternative=Set)
1✔
335
class AddFields(Set):
1✔
336
    pass
1✔
337

338

339
class RemoveFields(InstanceOperator):
1✔
340
    """Remove specified fields from each instance in a stream.
341

342
    Args:
343
        fields (List[str]): The fields to remove from each instance.
344
    """
345

346
    fields: List[str]
1✔
347

348
    def process(
1✔
349
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
350
    ) -> Dict[str, Any]:
351
        for field_name in self.fields:
1✔
352
            del instance[field_name]
1✔
353
        return instance
1✔
354

355

356
class SelectFields(InstanceOperator):
1✔
357
    """Keep only specified fields from each instance in a stream.
358

359
    Args:
360
        fields (List[str]): The fields to keep from each instance.
361
    """
362

363
    fields: List[str]
1✔
364

365
    def prepare(self):
1✔
366
        super().prepare()
1✔
367
        self.fields.extend(["data_classification_policy", "recipe_metadata"])
1✔
368

369
    def process(
1✔
370
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
371
    ) -> Dict[str, Any]:
372
        new_instance = {}
1✔
373
        for selected_field in self.fields:
1✔
374
            new_instance[selected_field] = instance[selected_field]
1✔
375
        return new_instance
1✔
376

377

378
class DefaultPlaceHolder:
1✔
379
    pass
1✔
380

381

382
default_place_holder = DefaultPlaceHolder()
1✔
383

384

385
class InstanceFieldOperator(InstanceOperator):
1✔
386
    """A general stream instance operator that processes the values of a field (or multiple ones).
387

388
    Args:
389
        field (Optional[str]):
390
            The field to process, if only a single one is passed. Defaults to None
391
        to_field (Optional[str]):
392
            Field name to save result into, if only one field is processed, if None is passed the
393
            operation would happen in-place and its result would replace the value of ``field``. Defaults to None
394
        field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
395
            Mapping from names of fields to process,
396
            to names of fields to save the results into. Inner List, if used, should be of length 2.
397
            A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
398
            is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
399
            in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
400
            order. The end result might depend on that order if either (1) two different fields are mapped to the same
401
            to_field, or (2) a field shows both as a key and as a value in different mappings.
402
            The operator throws an AssertionError in either of these cases. ``field_to_field``
403
            defaults to None.
404
        process_every_value (bool):
405
            Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
406

407
    Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
408
    prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
409

410
    """
411

412
    field: Optional[str] = None
1✔
413
    to_field: Optional[str] = None
1✔
414
    field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
1✔
415
    use_query: Optional[bool] = None
1✔
416
    process_every_value: bool = False
1✔
417
    get_default: Any = None
1✔
418
    not_exist_ok: bool = False
1✔
419
    not_exist_do_nothing: bool = False
1✔
420

421
    def verify(self):
1✔
422
        super().verify()
1✔
423
        if self.use_query is not None:
1✔
424
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
1✔
425
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
426

427
    def verify_field_definition(self):
1✔
428
        if hasattr(self, "_field_to_field") and self._field_to_field is not None:
1✔
429
            return
1✔
430
        assert (
1✔
431
            (self.field is None) != (self.field_to_field is None)
432
        ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
433
        assert (
1✔
434
            self.to_field is None or self.field_to_field is None
435
        ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
436

437
        if self.field_to_field is None:
1✔
438
            self._field_to_field = [
1✔
439
                (self.field, self.to_field if self.to_field is not None else self.field)
440
            ]
441
        else:
442
            self._field_to_field = (
1✔
443
                list(self.field_to_field.items())
444
                if isinstance(self.field_to_field, dict)
445
                else self.field_to_field
446
            )
447
        assert (
1✔
448
            self.field is not None or self.field_to_field is not None
449
        ), "Must supply a field to work on"
450
        assert (
1✔
451
            self.to_field is None or self.field_to_field is None
452
        ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
453
        assert (
1✔
454
            self.field is None or self.field_to_field is None
455
        ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
456
        assert (
1✔
457
            self._field_to_field is not None
458
        ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
459
        assert (
1✔
460
            len(self._field_to_field) > 0
461
        ), f"'input argument '{self.__class__.__name__}.field_to_field' should convey at least one field to process. Got {self.field_to_field}"
462
        # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
463
        if self.field_to_field is None:
1✔
464
            return
1✔
465
        # for backward compatibility also allow list of tuples of two strings
466
        if isoftype(self.field_to_field, List[List[str]]) or isoftype(
1✔
467
            self.field_to_field, List[Tuple[str, str]]
468
        ):
469
            for pair in self._field_to_field:
1✔
470
                assert (
1✔
471
                    len(pair) == 2
472
                ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
473
            # order of field processing is uniquely determined by the input field_to_field when a list
474
            return
1✔
475
        if isoftype(self.field_to_field, Dict[str, str]):
1✔
476
            if len(self.field_to_field) < 2:
1✔
477
                return
1✔
478
            for ff, tt in self.field_to_field.items():
1✔
479
                for f, t in self.field_to_field.items():
1✔
480
                    if f == ff:
1✔
481
                        continue
1✔
482
                    assert (
1✔
483
                        t != ff
484
                    ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
485
                    assert (
1✔
486
                        tt != t
487
                    ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
488
            return
1✔
489
        raise ValueError(
1✔
490
            "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
491
        )
492

493
    @abstractmethod
1✔
494
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
495
        pass
×
496

497
    def process(
1✔
498
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
499
    ) -> Dict[str, Any]:
500
        self.verify_field_definition()
1✔
501
        for from_field, to_field in self._field_to_field:
1✔
502
            try:
1✔
503
                old_value = dict_get(
1✔
504
                    instance,
505
                    from_field,
506
                    default=default_place_holder,
507
                    not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
508
                )
509
                if old_value is default_place_holder:
1✔
510
                    if self.not_exist_do_nothing:
1✔
511
                        continue
1✔
512
                    old_value = self.get_default
×
513
            except Exception as e:
1✔
514
                raise ValueError(
1✔
515
                    f"Failed to get '{from_field}' from instance due to the exception above."
516
                ) from e
517
            try:
1✔
518
                if self.process_every_value:
1✔
519
                    new_value = [
1✔
520
                        self.process_instance_value(value, instance)
521
                        for value in old_value
522
                    ]
523
                else:
524
                    new_value = self.process_instance_value(old_value, instance)
1✔
525
            except Exception as e:
1✔
526
                raise ValueError(
1✔
527
                    f"Failed to process field '{from_field}' from instance due to the exception above."
528
                ) from e
529
            dict_set(
1✔
530
                instance,
531
                to_field,
532
                new_value,
533
                not_exist_ok=True,
534
            )
535
        return instance
1✔
536

537

538
class FieldOperator(InstanceFieldOperator):
1✔
539
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
540
        return self.process_value(value)
1✔
541

542
    @abstractmethod
1✔
543
    def process_value(self, value: Any) -> Any:
1✔
544
        pass
1✔
545

546

547
class MapValues(FieldOperator):
1✔
548
    mapping: Dict[str, str]
1✔
549

550
    def process_value(self, value: Any) -> Any:
1✔
551
        return self.mapping[str(value)]
×
552

553

554
class Rename(FieldOperator):
1✔
555
    """Renames fields.
556

557
    Move value from one field to another, potentially, if field name contains a /, from one branch into another.
558
    Remove the from field, potentially part of it in case of / in from_field.
559

560
    Examples:
561
        Rename(field_to_field={"b": "c"})
562
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
563

564
        Rename(field_to_field={"b": "c/d"})
565
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
566

567
        Rename(field_to_field={"b": "b/d"})
568
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
569

570
        Rename(field_to_field={"b/c/e": "b/d"})
571
        will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
572

573
    """
574

575
    def process_value(self, value: Any) -> Any:
1✔
576
        return value
1✔
577

578
    def process(
1✔
579
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
580
    ) -> Dict[str, Any]:
581
        res = super().process(instance=instance, stream_name=stream_name)
1✔
582
        for from_field, to_field in self._field_to_field:
1✔
583
            if (not is_subpath(from_field, to_field)) and (
1✔
584
                not is_subpath(to_field, from_field)
585
            ):
586
                dict_delete(res, from_field, remove_empty_ancestors=True)
1✔
587

588
        return res
1✔
589

590

591
@deprecation(version="2.0.0", alternative=Rename)
1✔
592
class RenameFields(Rename):
1✔
593
    pass
1✔
594

595

596
class AddConstant(FieldOperator):
1✔
597
    """Adds a constant, being argument 'add', to the processed value.
598

599
    Args:
600
        add: the constant to add.
601
    """
602

603
    add: Any
1✔
604

605
    def process_value(self, value: Any) -> Any:
1✔
606
        return self.add + value
1✔
607

608

609
class ShuffleFieldValues(FieldOperator):
1✔
610
    """Shuffles a list of values found in a field."""
611

612
    def process_value(self, value: Any) -> Any:
1✔
613
        res = list(value)
1✔
614
        random_generator = new_random_generator(sub_seed=res)
1✔
615
        random_generator.shuffle(res)
1✔
616
        return res
1✔
617

618

619
class JoinStr(FieldOperator):
1✔
620
    """Joins a list of strings (contents of a field), similar to str.join().
621

622
    Args:
623
        separator (str): text to put between values
624
    """
625

626
    separator: str = ","
1✔
627

628
    def process_value(self, value: Any) -> Any:
1✔
629
        return self.separator.join(str(x) for x in value)
1✔
630

631

632
class Apply(InstanceOperator):
1✔
633
    """A class used to apply a python function and store the result in a field.
634

635
    Args:
636
        function (str): name of function.
637
        to_field (str): the field to store the result
638

639
    any additional arguments are field names whose values will be passed directly to the function specified
640

641
    Examples:
642
    Store in field  "b" the uppercase string of the value in field "a":
643
    ``Apply("a", function=str.upper, to_field="b")``
644

645
    Dump the json representation of field "t" and store back in the same field:
646
    ``Apply("t", function=json.dumps, to_field="t")``
647

648
    Set the time in a field 'b':
649
    ``Apply(function=time.time, to_field="b")``
650

651
    """
652

653
    __allow_unexpected_arguments__ = True
1✔
654
    function: Callable = NonPositionalField(required=True)
1✔
655
    to_field: str = NonPositionalField(required=True)
1✔
656

657
    def function_to_str(self, function: Callable) -> str:
1✔
658
        parts = []
1✔
659

660
        if hasattr(function, "__module__"):
1✔
661
            parts.append(function.__module__)
1✔
662
        if hasattr(function, "__qualname__"):
1✔
663
            parts.append(function.__qualname__)
1✔
664
        else:
665
            parts.append(function.__name__)
×
666

667
        return ".".join(parts)
1✔
668

669
    def str_to_function(self, function_str: str) -> Callable:
1✔
670
        parts = function_str.split(".", 1)
1✔
671
        if len(parts) == 1:
1✔
672
            return __builtins__[parts[0]]
1✔
673

674
        module_name, function_name = parts
1✔
675
        if module_name in __builtins__:
1✔
676
            obj = __builtins__[module_name]
1✔
677
        elif module_name in globals():
1✔
678
            obj = globals()[module_name]
×
679
        else:
680
            obj = __import__(module_name)
1✔
681
        for part in function_name.split("."):
1✔
682
            obj = getattr(obj, part)
1✔
683
        return obj
1✔
684

685
    def prepare(self):
1✔
686
        super().prepare()
1✔
687
        if isinstance(self.function, str):
1✔
688
            self.function = self.str_to_function(self.function)
1✔
689
        self._init_dict["function"] = self.function_to_str(self.function)
1✔
690

691
    def process(
1✔
692
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
693
    ) -> Dict[str, Any]:
694
        argv = [instance[arg] for arg in self._argv]
1✔
695
        kwargs = {key: instance[val] for key, val in self._kwargs}
1✔
696

697
        result = self.function(*argv, **kwargs)
1✔
698

699
        instance[self.to_field] = result
1✔
700
        return instance
1✔
701

702

703
class ListFieldValues(InstanceOperator):
1✔
704
    """Concatenates values of multiple fields into a list, and assigns it to a new field."""
705

706
    fields: List[str]
1✔
707
    to_field: str
1✔
708
    use_query: Optional[bool] = None
1✔
709

710
    def verify(self):
1✔
711
        super().verify()
1✔
712
        if self.use_query is not None:
1✔
713
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
714
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
715

716
    def process(
1✔
717
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
718
    ) -> Dict[str, Any]:
719
        values = []
1✔
720
        for field_name in self.fields:
1✔
721
            values.append(dict_get(instance, field_name))
1✔
722

723
        dict_set(instance, self.to_field, values)
1✔
724

725
        return instance
1✔
726

727

728
class ZipFieldValues(InstanceOperator):
1✔
729
    """Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
730

731
    The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
732
    are zipped, and stored into 'to_field'.
733

734
    | If 'longest'=False, the length of the zipped result is determined by the shortest input value.
735
    | If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
736

737
    """
738

739
    fields: List[str]
1✔
740
    to_field: str
1✔
741
    longest: bool = False
1✔
742
    use_query: Optional[bool] = None
1✔
743

744
    def verify(self):
1✔
745
        super().verify()
1✔
746
        if self.use_query is not None:
1✔
747
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
748
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
749

750
    def process(
1✔
751
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
752
    ) -> Dict[str, Any]:
753
        values = []
1✔
754
        for field_name in self.fields:
1✔
755
            values.append(dict_get(instance, field_name))
1✔
756
        if self.longest:
1✔
757
            zipped = zip_longest(*values)
1✔
758
        else:
759
            zipped = zip(*values)
1✔
760
        dict_set(instance, self.to_field, list(zipped))
1✔
761
        return instance
1✔
762

763

764
class InterleaveListsToDialogOperator(InstanceOperator):
1✔
765
    """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
766

767
    The list of tuples if of format (role, turn_content), where the role label is specified by
768
    the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
769

770
    The user turns and assistant turns field are specified in the arguments.
771
    The value of each of the 'fields' is assumed to be a list.
772

773
    """
774

775
    user_turns_field: str
1✔
776
    assistant_turns_field: str
1✔
777
    user_role_label: str = "user"
1✔
778
    assistant_role_label: str = "assistant"
1✔
779
    to_field: str
1✔
780

781
    def process(
1✔
782
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
783
    ) -> Dict[str, Any]:
784
        user_turns = instance[self.user_turns_field]
×
785
        assistant_turns = instance[self.assistant_turns_field]
×
786

787
        assert (
×
788
            len(user_turns) == len(assistant_turns)
789
            or (len(user_turns) - len(assistant_turns) == 1)
790
        ), "user_turns must have either the same length as assistant_turns or one more turn."
791

792
        interleaved_dialog = []
×
793
        i, j = 0, 0  # Indices for the user and assistant lists
×
794
        # While either list has elements left, continue interleaving
795
        while i < len(user_turns) or j < len(assistant_turns):
×
796
            if i < len(user_turns):
×
797
                interleaved_dialog.append((self.user_role_label, user_turns[i]))
×
798
                i += 1
×
799
            if j < len(assistant_turns):
×
800
                interleaved_dialog.append(
×
801
                    (self.assistant_role_label, assistant_turns[j])
802
                )
803
                j += 1
×
804

805
        instance[self.to_field] = interleaved_dialog
×
806
        return instance
×
807

808

809
class IndexOf(InstanceOperator):
1✔
810
    """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
811

812
    search_in: str
1✔
813
    index_of: str
1✔
814
    to_field: str
1✔
815
    use_query: Optional[bool] = None
1✔
816

817
    def verify(self):
1✔
818
        super().verify()
1✔
819
        if self.use_query is not None:
1✔
820
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
821
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
822

823
    def process(
1✔
824
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
825
    ) -> Dict[str, Any]:
826
        lst = dict_get(instance, self.search_in)
1✔
827
        item = dict_get(instance, self.index_of)
1✔
828
        instance[self.to_field] = lst.index(item)
1✔
829
        return instance
1✔
830

831

832
class TakeByField(InstanceOperator):
1✔
833
    """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
834

835
    field: str
1✔
836
    index: str
1✔
837
    to_field: str = None
1✔
838
    use_query: Optional[bool] = None
1✔
839

840
    def verify(self):
1✔
841
        super().verify()
1✔
842
        if self.use_query is not None:
1✔
843
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
844
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
845

846
    def prepare(self):
1✔
847
        if self.to_field is None:
1✔
848
            self.to_field = self.field
1✔
849

850
    def process(
1✔
851
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
852
    ) -> Dict[str, Any]:
853
        value = dict_get(instance, self.field)
1✔
854
        index_value = dict_get(instance, self.index)
1✔
855
        instance[self.to_field] = value[index_value]
1✔
856
        return instance
1✔
857

858

859
class Perturb(FieldOperator):
1✔
860
    """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
861

862
    When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
863
    relevant perturbation
864

865
    Args:
866
        percentage_to_perturb (int):
867
            the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
868
        select_from: List[Any]:
869
            a list of values to select from, as a perturbation of the field's value. Defaults to [].
870
    """
871

872
    select_from: List[Any] = []
1✔
873
    percentage_to_perturb: int = 1  # 1 percent
1✔
874

875
    def verify(self):
1✔
876
        assert (
1✔
877
            0 <= self.percentage_to_perturb and self.percentage_to_perturb <= 100
878
        ), f"'percentage_to_perturb' should be in the range 0..100. Received {self.percentage_to_perturb}"
879

880
    def prepare(self):
1✔
881
        super().prepare()
1✔
882
        self.random_generator = new_random_generator(sub_seed="CopyWithPerturbation")
1✔
883

884
    def process_value(self, value: Any) -> Any:
1✔
885
        perturb = self.random_generator.randint(1, 100) <= self.percentage_to_perturb
1✔
886
        if not perturb:
1✔
887
            return value
1✔
888

889
        if value in self.select_from:
1✔
890
            # 80% of cases, return a decent class, otherwise, perturb the value itself as follows
891
            if self.random_generator.random() < 0.8:
1✔
892
                return self.random_generator.choice(self.select_from)
1✔
893

894
        if isinstance(value, float):
1✔
895
            return value * (0.5 + self.random_generator.random())
1✔
896

897
        if isinstance(value, int):
1✔
898
            perturb = 1 if self.random_generator.random() < 0.5 else -1
1✔
899
            return value + perturb
1✔
900

901
        if isinstance(value, str):
1✔
902
            if len(value) < 2:
1✔
903
                # give up perturbation
904
                return value
1✔
905
            # throw one char out
906
            prefix_len = self.random_generator.randint(1, len(value) - 1)
1✔
907
            return value[:prefix_len] + value[prefix_len + 1 :]
1✔
908

909
        # and in any other case:
910
        return value
×
911

912

913
class Copy(FieldOperator):
1✔
914
    """Copies values from specified fields to specified fields.
915

916
    Args (of parent class):
917
        field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
918

919
    Examples:
920
        An input instance {"a": 2, "b": 3}, when processed by
921
        ``Copy(field_to_field={"a": "b"})``
922
        would yield {"a": 2, "b": 2}, and when processed by
923
        ``Copy(field_to_field={"a": "c"})`` would yield
924
        {"a": 2, "b": 3, "c": 2}
925

926
        with field names containing / , we can also copy inside the field:
927
        ``Copy(field="a/0",to_field="a")``
928
        would process instance {"a": [1, 3]} into {"a": 1}
929

930

931
    """
932

933
    def process_value(self, value: Any) -> Any:
1✔
934
        return value
1✔
935

936

937
class RecursiveCopy(FieldOperator):
1✔
938
    def process_value(self, value: Any) -> Any:
1✔
939
        return recursive_copy(value)
1✔
940

941

942
@deprecation(version="2.0.0", alternative=Copy)
1✔
943
class CopyFields(Copy):
1✔
944
    pass
1✔
945

946

947
class GetItemByIndex(FieldOperator):
1✔
948
    """Get from the item list by the index in the field."""
949

950
    items_list: List[Any]
1✔
951

952
    def process_value(self, value: Any) -> Any:
1✔
953
        return self.items_list[value]
×
954

955

956
class AddID(InstanceOperator):
1✔
957
    """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
958

959
    id_field_name: str = "id"
1✔
960

961
    def process(
1✔
962
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
963
    ) -> Dict[str, Any]:
964
        instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
1✔
965
        return instance
1✔
966

967

968
class Cast(FieldOperator):
1✔
969
    """Casts specified fields to specified types.
970

971
    Args:
972
        default (object): A dictionary mapping field names to default values for cases of casting failure.
973
        process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
974
    """
975

976
    to: str
1✔
977
    failure_default: Optional[Any] = "__UNDEFINED__"
1✔
978

979
    def prepare(self):
1✔
980
        self.types = {"int": int, "float": float, "str": str, "bool": bool, "tuple": tuple}
1✔
981

982
    def process_value(self, value):
1✔
983
        try:
1✔
984
            return self.types[self.to](value)
1✔
985
        except ValueError as e:
1✔
986
            if self.failure_default == "__UNDEFINED__":
1✔
987
                raise ValueError(
×
988
                    f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
989
                ) from e
990
            return self.failure_default
1✔
991

992

993
class CastFields(InstanceOperator):
1✔
994
    """Casts specified fields to specified types.
995

996
    Args:
997
        fields (Dict[str, str]):
998
            A dictionary mapping field names to the names of the types to cast the fields to.
999
            e.g: "int", "str", "float", "bool". Basic names of types
1000
        defaults (Dict[str, object]):
1001
            A dictionary mapping field names to default values for cases of casting failure.
1002
        process_every_value (bool):
1003
            If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
1004

1005
    Example:
1006
        .. code-block:: python
1007

1008
                CastFields(
1009
                    fields={"a/d": "float", "b": "int"},
1010
                    failure_defaults={"a/d": 0.0, "b": 0},
1011
                    process_every_value=True,
1012
                )
1013

1014
    would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
1015
    into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
1016

1017
    """
1018

1019
    fields: Dict[str, str] = field(default_factory=dict)
1✔
1020
    failure_defaults: Dict[str, object] = field(default_factory=dict)
1✔
1021
    use_nested_query: bool = None  # deprecated field
1✔
1022
    process_every_value: bool = False
1✔
1023

1024
    def prepare(self):
1✔
1025
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
1026

1027
    def verify(self):
1✔
1028
        super().verify()
1✔
1029
        if self.use_nested_query is not None:
1✔
1030
            depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
1✔
1031
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
1032

1033
    def _cast_single(self, value, type, field):
1✔
1034
        try:
1✔
1035
            return self.types[type](value)
1✔
1036
        except Exception as e:
1✔
1037
            if field not in self.failure_defaults:
1✔
1038
                raise ValueError(
1✔
1039
                    f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
1040
                ) from e
1041
            return self.failure_defaults[field]
1✔
1042

1043
    def _cast_multiple(self, values, type, field):
1✔
1044
        return [self._cast_single(value, type, field) for value in values]
1✔
1045

1046
    def process(
1✔
1047
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1048
    ) -> Dict[str, Any]:
1049
        for field_name, type in self.fields.items():
1✔
1050
            value = dict_get(instance, field_name)
1✔
1051
            if self.process_every_value:
1✔
1052
                assert isinstance(
1✔
1053
                    value, list
1054
                ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
1055
                casted_value = self._cast_multiple(value, type, field_name)
1✔
1056
            else:
1057
                casted_value = self._cast_single(value, type, field_name)
1✔
1058

1059
            dict_set(instance, field_name, casted_value)
1✔
1060
        return instance
1✔
1061

1062

1063
class DivideAllFieldsBy(InstanceOperator):
1✔
1064
    """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1065

1066
    The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
1067
    the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
1068
    if input flag 'strict' is True, or -- left alone, if 'strict' is False.
1069

1070
    Args:
1071
        divisor (float) the value to divide by
1072
        strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
1073

1074
    Example:
1075
        when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
1076
        operator = DivideAllFieldsBy(divisor=2.0)
1077
        the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
1078
        If the operator were defined with strict=True, through:
1079
        operator = DivideAllFieldsBy(divisor=2.0, strict=True),
1080
        the processing of the above instance would raise a ValueError, for the integer at "c".
1081
    """
1082

1083
    divisor: float = 1.0
1✔
1084
    strict: bool = False
1✔
1085

1086
    def _recursive_divide(self, instance, divisor):
1✔
1087
        if isinstance(instance, dict):
1✔
1088
            for key, value in instance.items():
1✔
1089
                instance[key] = self._recursive_divide(value, divisor)
1✔
1090
        elif isinstance(instance, list):
1✔
1091
            for i, value in enumerate(instance):
1✔
1092
                instance[i] = self._recursive_divide(value, divisor)
1✔
1093
        elif isinstance(instance, float):
1✔
1094
            instance /= divisor
1✔
1095
        elif self.strict:
1✔
1096
            raise ValueError(f"Cannot divide instance of type {type(instance)}")
1✔
1097
        return instance
1✔
1098

1099
    def process(
1✔
1100
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1101
    ) -> Dict[str, Any]:
1102
        return self._recursive_divide(instance, self.divisor)
1✔
1103

1104

1105
class ArtifactFetcherMixin:
1✔
1106
    """Provides a way to fetch and cache artifacts in the system.
1107

1108
    Args:
1109
        cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
1110
    """
1111

1112
    _artifacts_cache = LRUCache(max_size=1000)
1✔
1113

1114
    @classmethod
1✔
1115
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
1✔
1116
        if str(artifact_identifier) not in cls._artifacts_cache:
1✔
1117
            artifact, catalog = fetch_artifact(artifact_identifier)
1✔
1118
            cls._artifacts_cache[str(artifact_identifier)] = artifact
1✔
1119
        return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])
1✔
1120

1121

1122
class ApplyOperatorsField(InstanceOperator):
1✔
1123
    """Applies value operators to each instance in a stream based on specified fields.
1124

1125
    Args:
1126
        operators_field (str): name of the field that contains a single name, or a list of names, of the operators to be applied,
1127
            one after the other, for the processing of the instance. Each operator is equipped with 'process_instance()'
1128
            method.
1129

1130
        default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
1131

1132
    Example:
1133
        when instance {"prediction": 111, "references": [222, 333] , "c": ["processors.to_string", "processors.first_character"]}
1134
        is processed by operator (please look up the catalog that these operators, they are tuned to process fields "prediction" and
1135
        "references"):
1136
        operator = ApplyOperatorsField(operators_field="c"),
1137
        the resulting instance is: {"prediction": "1", "references": ["2", "3"], "c": ["processors.to_string", "processors.first_character"]}
1138

1139
    """
1140

1141
    operators_field: str
1✔
1142
    default_operators: List[str] = None
1✔
1143

1144
    def process(
1✔
1145
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1146
    ) -> Dict[str, Any]:
1147
        operator_names = instance.get(self.operators_field)
1✔
1148
        if operator_names is None:
1✔
1149
            assert (
1✔
1150
                self.default_operators is not None
1151
            ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1152
            operator_names = self.default_operators
1✔
1153

1154
        if isinstance(operator_names, str):
1✔
1155
            operator_names = [operator_names]
1✔
1156
        # otherwise , operator_names is already a list
1157

1158
        # we now have a list of nanes of operators, each is equipped with process_instance method.
1159
        operator = SequentialOperator(steps=operator_names)
1✔
1160
        return operator.process_instance(instance, stream_name=stream_name)
1✔
1161

1162

1163
class FilterByCondition(StreamOperator):
1✔
1164
    """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1165

1166
    Raises an error if a required field name is missing from the input instance.
1167

1168
    Args:
1169
       values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
1170

1171
       condition: the name of the desired condition operator between the specified (sub) field's value  and the provided constant value.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1172

1173
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1174

1175
    Examples:
1176
       | ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
1177
       | ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
1178
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
1179
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
1180
       | ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
1181
       | ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
1182

1183

1184
    """
1185

1186
    values: Dict[str, Any]
1✔
1187
    condition: str
1✔
1188
    condition_to_func = {
1✔
1189
        "gt": operator.gt,
1190
        "ge": operator.ge,
1191
        "lt": operator.lt,
1192
        "le": operator.le,
1193
        "eq": operator.eq,
1194
        "ne": operator.ne,
1195
        "in": None,  # Handled as special case
1196
        "not in": None,  # Handled as special case
1197
    }
1198
    error_on_filtered_all: bool = True
1✔
1199

1200
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1201
        yielded = False
1✔
1202
        for instance in stream:
1✔
1203
            if self._is_required(instance):
1✔
1204
                yielded = True
1✔
1205
                yield instance
1✔
1206

1207
        if not yielded and self.error_on_filtered_all:
1✔
1208
            raise RuntimeError(
1✔
1209
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1210
            )
1211

1212
    def verify(self):
1✔
1213
        if self.condition not in self.condition_to_func:
1✔
1214
            raise ValueError(
1✔
1215
                f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1216
            )
1217

1218
        for key, value in self.values.items():
1✔
1219
            if self.condition in ["in", "not it"] and not isinstance(value, list):
1✔
1220
                raise ValueError(
1✔
1221
                    f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1222
                )
1223
        return super().verify()
1✔
1224

1225
    def _is_required(self, instance: dict) -> bool:
1✔
1226
        for key, value in self.values.items():
1✔
1227
            try:
1✔
1228
                instance_key = dict_get(instance, key)
1✔
1229
            except ValueError as ve:
1✔
1230
                raise ValueError(
1✔
1231
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1232
                ) from ve
1233
            if self.condition == "in":
1✔
1234
                if instance_key not in value:
1✔
1235
                    return False
1✔
1236
            elif self.condition == "not in":
1✔
1237
                if instance_key in value:
1✔
1238
                    return False
1✔
1239
            else:
1240
                func = self.condition_to_func[self.condition]
1✔
1241
                if func is None:
1✔
1242
                    raise ValueError(
×
1243
                        f"Function not defined for condition '{self.condition}'"
1244
                    )
1245
                if not func(instance_key, value):
1✔
1246
                    return False
1✔
1247
        return True
1✔
1248

1249

1250
class FilterByConditionBasedOnFields(FilterByCondition):
1✔
1251
    """Filters a stream based on a condition between 2 fields values.
1252

1253
    Raises an error if either of the required fields names is missing from the input instance.
1254

1255
    Args:
1256
       values (Dict[str, str]): The fields names that the filter operation is based on.
1257
       condition: the name of the desired condition operator between the specified field's values.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1258
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1259

1260
    Examples:
1261
       FilterByCondition(values = {"a":"b}, condition = "gt") will yield only instances where field "a" contains a value greater then the value in field "b".
1262
       FilterByCondition(values = {"a":"b}, condition = "le") will yield only instances where "a"<="b"
1263
    """
1264

1265
    def _is_required(self, instance: dict) -> bool:
1✔
1266
        for key, value in self.values.items():
1✔
1267
            try:
1✔
1268
                instance_key = dict_get(instance, key)
1✔
1269
            except ValueError as ve:
×
1270
                raise ValueError(
×
1271
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1272
                ) from ve
1273
            try:
1✔
1274
                instance_value = dict_get(instance, value)
1✔
1275
            except ValueError as ve:
×
1276
                raise ValueError(
×
1277
                    f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1278
                ) from ve
1279
            if self.condition == "in":
1✔
1280
                if instance_key not in instance_value:
×
1281
                    return False
×
1282
            elif self.condition == "not in":
1✔
1283
                if instance_key in instance_value:
×
1284
                    return False
×
1285
            else:
1286
                func = self.condition_to_func[self.condition]
1✔
1287
                if func is None:
1✔
1288
                    raise ValueError(
×
1289
                        f"Function not defined for condition '{self.condition}'"
1290
                    )
1291
                if not func(instance_key, instance_value):
1✔
1292
                    return False
×
1293
        return True
1✔
1294

1295

1296
class ComputeExpressionMixin(Artifact):
1✔
1297
    """Computes an expression expressed over fields of an instance.
1298

1299
    Args:
1300
        expression (str): the expression, in terms of names of fields of an instance
1301
        imports_list (List[str]): list of names of imports needed for the evaluation of the expression
1302
    """
1303

1304
    expression: str
1✔
1305
    imports_list: List[str] = OptionalField(default_factory=list)
1✔
1306

1307
    def prepare(self):
1✔
1308
        # can not do the imports here, because object does not pickle with imports
1309
        self.globals = {
1✔
1310
            module_name: __import__(module_name) for module_name in self.imports_list
1311
        }
1312

1313
    def compute_expression(self, instance: dict) -> Any:
1✔
1314
        if settings.allow_unverified_code:
1✔
1315
            return eval(self.expression, {**self.globals, **instance})
1✔
1316

1317
        raise ValueError(
×
1318
            f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1319
            "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1320
        )
1321

1322

1323
class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1✔
1324
    """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1325

1326
    Raises an error if a field participating in the specified condition is missing from the instance
1327

1328
    Args:
1329
        expression (str):
1330
            a condition over fields of the instance, to be processed by python's eval()
1331
        imports_list (List[str]):
1332
            names of imports needed for the eval of the query (e.g. 're', 'json')
1333
        error_on_filtered_all (bool, optional):
1334
            If True, raises an error if all instances are filtered out. Defaults to True.
1335

1336
    Examples:
1337
        | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1338
        | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1339
        | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1340
        | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1341
        | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1342
    """
1343

1344
    error_on_filtered_all: bool = True
1✔
1345

1346
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1347
        yielded = False
1✔
1348
        for instance in stream:
1✔
1349
            if self.compute_expression(instance):
1✔
1350
                yielded = True
1✔
1351
                yield instance
1✔
1352

1353
        if not yielded and self.error_on_filtered_all:
1✔
1354
            raise RuntimeError(
1✔
1355
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1356
            )
1357

1358

1359
class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1✔
1360
    """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1361

1362
    Raises an error if a field mentioned in the query is missing from the instance.
1363

1364
    Args:
1365
       expression (str): an expression to be evaluated over the fields of the instance
1366
       to_field (str): the field where the result is to be stored into
1367
       imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1368

1369
    Examples:
1370
       When instance {"a": 2, "b": 3} is process-ed by operator
1371
       ExecuteExpression(expression="a+b", to_field = "c")
1372
       the result is {"a": 2, "b": 3, "c": 5}
1373

1374
       When instance {"a": "hello", "b": "world"} is process-ed by operator
1375
       ExecuteExpression(expression = "a+' '+b", to_field = "c")
1376
       the result is {"a": "hello", "b": "world", "c": "hello world"}
1377

1378
    """
1379

1380
    to_field: str
1✔
1381

1382
    def process(
1✔
1383
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1384
    ) -> Dict[str, Any]:
1385
        instance[self.to_field] = self.compute_expression(instance)
1✔
1386
        return instance
1✔
1387

1388

1389
class ExtractMostCommonFieldValues(MultiStreamOperator):
1✔
1390
    field: str
1✔
1391
    stream_name: str
1✔
1392
    overall_top_frequency_percent: Optional[int] = 100
1✔
1393
    min_frequency_percent: Optional[int] = 0
1✔
1394
    to_field: str
1✔
1395
    process_every_value: Optional[bool] = False
1✔
1396

1397
    """
1398
    Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
1399
    as a list in a new field ('to_field') in all streams.
1400

1401
    More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
1402
    When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
1403
    the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
1404
    When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
1405
    less than 'min_frequency_percent' of the total number of instances in the stream.
1406
    At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
1407

1408
    Examples:
1409

1410
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1411
    field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1412
    every instance in all streams.
1413

1414
    ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1415
    in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1416
    value members in these lists, and report the most frequent values.
1417
    if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1418
    'to_field' of each instance of all streams.
1419

1420
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1421
    extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1422
    and stores them in field 'classes' of each instance of all streams.
1423

1424
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1425
    extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1426
    Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1427
    """
1428

1429
    def verify(self):
1✔
1430
        assert (
1✔
1431
            self.overall_top_frequency_percent <= 100
1432
            and self.overall_top_frequency_percent >= 0
1433
        ), "'overall_top_frequency_percent' must be between 0 and 100"
1434
        assert (
1✔
1435
            self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
1436
        ), "'min_frequency_percent' must be between 0 and 100"
1437
        assert not (
1✔
1438
            self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
1439
        ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
1440
        super().verify()
1✔
1441

1442
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1443
        stream = multi_stream[self.stream_name]
1✔
1444
        counter = Counter()
1✔
1445
        for instance in stream:
1✔
1446
            if (not isinstance(instance[self.field], list)) and (
1✔
1447
                self.process_every_value is True
1448
            ):
1449
                raise ValueError(
1✔
1450
                    "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
1451
                )
1452
            if (not isinstance(instance[self.field], list)) or (
1✔
1453
                self.process_every_value is False
1454
            ):
1455
                # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
1456
                counter.update(
1✔
1457
                    [(*instance[self.field],)]
1458
                    if isinstance(instance[self.field], list)
1459
                    else [instance[self.field]]
1460
                )  # convert to a tuple if list, to enable the use of Counter which would not accept
1461
                # a list as an hashable entity to count its occurrences
1462
            else:
1463
                # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
1464
                counter.update(instance[self.field])
1✔
1465
        # here counter counts occurrences of individual values, or tuples.
1466
        values_and_counts = counter.most_common()
1✔
1467
        if self.overall_top_frequency_percent < 100:
1✔
1468
            top_frequency = (
1✔
1469
                sum(counter.values()) * self.overall_top_frequency_percent / 100.0
1470
            )
1471
            sum_counts = 0
1✔
1472
            for _i, p in enumerate(values_and_counts):
1✔
1473
                sum_counts += p[1]
1✔
1474
                if sum_counts >= top_frequency:
1✔
1475
                    break
1✔
1476
            values_and_counts = counter.most_common(_i + 1)
1✔
1477
        if self.min_frequency_percent > 0:
1✔
1478
            min_frequency = self.min_frequency_percent * sum(counter.values()) / 100.0
1✔
1479
            while values_and_counts[-1][1] < min_frequency:
1✔
1480
                values_and_counts.pop()
1✔
1481
        values_to_keep = [
1✔
1482
            [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1483
            for ele in values_and_counts
1484
        ]
1485

1486
        addmostcommons = Set(fields={self.to_field: values_to_keep})
1✔
1487
        return addmostcommons(multi_stream)
1✔
1488

1489

1490
class ExtractFieldValues(ExtractMostCommonFieldValues):
1✔
1491
    def verify(self):
1✔
1492
        super().verify()
1✔
1493

1494
    def prepare(self):
1✔
1495
        self.overall_top_frequency_percent = 100
1✔
1496
        self.min_frequency_percent = 0
1✔
1497

1498

1499
class Intersect(FieldOperator):
1✔
1500
    """Intersects the value of a field, which must be a list, with a given list.
1501

1502
    Args:
1503
        allowed_values (list) - list to intersect.
1504
    """
1505

1506
    allowed_values: List[Any]
1✔
1507

1508
    def verify(self):
1✔
1509
        super().verify()
1✔
1510
        if self.process_every_value:
1✔
1511
            raise ValueError(
1✔
1512
                "'process_every_value=True' is not supported in Intersect operator"
1513
            )
1514

1515
        if not isinstance(self.allowed_values, list):
1✔
1516
            raise ValueError(
1✔
1517
                f"The allowed_values is not a list but '{self.allowed_values}'"
1518
            )
1519

1520
    def process_value(self, value: Any) -> Any:
1✔
1521
        super().process_value(value)
1✔
1522
        if not isinstance(value, list):
1✔
1523
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1524
        return [e for e in value if e in self.allowed_values]
1✔
1525

1526

1527
class IntersectCorrespondingFields(InstanceOperator):
1✔
1528
    """Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
1529

1530
    For example:
1531

1532
    Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
1533

1534
    .. code-block:: text
1535

1536
        IntersectCorrespondingFields(field="label",
1537
                                    allowed_values=["b", "f"],
1538
                                    corresponding_fields_to_intersect=["position"])
1539

1540
    would keep only "b" and "f" values in 'labels' field and
1541
    their respective values in the 'position' field.
1542
    (All other fields are not effected)
1543

1544
    .. code-block:: text
1545

1546
        Given this input:
1547

1548
        [
1549
            {"label": ["a", "b"],"position": [0,1],"other" : "not"},
1550
            {"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
1551
            {"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
1552
        ]
1553

1554
        So the output would be:
1555
        [
1556
                {"label": ["b"], "position":[1],"other" : "not"},
1557
                {"label": [], "position": [], "other" : "relevant"},
1558
                {"label": ["b", "f"],"position": [1,2], "other" : "field"},
1559
        ]
1560

1561
    Args:
1562
        field - the field to intersected (must contain list values)
1563
        allowed_values (list) - list of values to keep
1564
        corresponding_fields_to_intersect (list) - additional list fields from which values
1565
        are removed based the corresponding indices of values removed from the 'field'
1566
    """
1567

1568
    field: str
1✔
1569
    allowed_values: List[str]
1✔
1570
    corresponding_fields_to_intersect: List[str]
1✔
1571

1572
    def verify(self):
1✔
1573
        super().verify()
1✔
1574

1575
        if not isinstance(self.allowed_values, list):
1✔
1576
            raise ValueError(
×
1577
                f"The allowed_values is not a type list but '{type(self.allowed_values)}'"
1578
            )
1579

1580
    def process(
1✔
1581
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1582
    ) -> Dict[str, Any]:
1583
        if self.field not in instance:
1✔
1584
            raise ValueError(
1✔
1585
                f"Field '{self.field}' is not in provided instance.\n"
1586
                + to_pretty_string(instance)
1587
            )
1588

1589
        for corresponding_field in self.corresponding_fields_to_intersect:
1✔
1590
            if corresponding_field not in instance:
1✔
1591
                raise ValueError(
1✔
1592
                    f"Field '{corresponding_field}' is not in provided instance.\n"
1593
                    + to_pretty_string(instance)
1594
                )
1595

1596
        if not isinstance(instance[self.field], list):
1✔
1597
            raise ValueError(
1✔
1598
                f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
1599
                + to_pretty_string(instance, keys=[self.field])
1600
            )
1601

1602
        num_values_in_field = len(instance[self.field])
1✔
1603

1604
        if set(self.allowed_values) == set(instance[self.field]):
1✔
1605
            return instance
×
1606

1607
        indices_to_keep = [
1✔
1608
            i
1609
            for i, value in enumerate(instance[self.field])
1610
            if value in set(self.allowed_values)
1611
        ]
1612

1613
        result_instance = {}
1✔
1614
        for field_name, field_value in instance.items():
1✔
1615
            if (
1✔
1616
                field_name in self.corresponding_fields_to_intersect
1617
                or field_name == self.field
1618
            ):
1619
                if not isinstance(field_value, list):
1✔
1620
                    raise ValueError(
×
1621
                        f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
1622
                    )
1623
                if len(field_value) != num_values_in_field:
1✔
1624
                    raise ValueError(
1✔
1625
                        f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
1626
                        + to_pretty_string(instance, keys=[self.field, field_name])
1627
                    )
1628
                result_instance[field_name] = [
1✔
1629
                    value
1630
                    for index, value in enumerate(field_value)
1631
                    if index in indices_to_keep
1632
                ]
1633
            else:
1634
                result_instance[field_name] = field_value
1✔
1635
        return result_instance
1✔
1636

1637

1638
class RemoveValues(FieldOperator):
1✔
1639
    """Removes elements in a field, which must be a list, using a given list of unallowed.
1640

1641
    Args:
1642
        unallowed_values (list) - values to be removed.
1643
    """
1644

1645
    unallowed_values: List[Any]
1✔
1646

1647
    def verify(self):
1✔
1648
        super().verify()
1✔
1649

1650
        if not isinstance(self.unallowed_values, list):
1✔
1651
            raise ValueError(
1✔
1652
                f"The unallowed_values is not a list but '{self.unallowed_values}'"
1653
            )
1654

1655
    def process_value(self, value: Any) -> Any:
1✔
1656
        if not isinstance(value, list):
1✔
1657
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1658
        return [e for e in value if e not in self.unallowed_values]
1✔
1659

1660

1661
class Unique(SingleStreamReducer):
1✔
1662
    """Reduces a stream to unique instances based on specified fields.
1663

1664
    Args:
1665
        fields (List[str]): The fields that should be unique in each instance.
1666
    """
1667

1668
    fields: List[str] = field(default_factory=list)
1✔
1669

1670
    @staticmethod
1✔
1671
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
1✔
1672
        result = []
1✔
1673
        for field_name in fields:
1✔
1674
            value = instance[field_name]
1✔
1675
            if isinstance(value, list):
1✔
1676
                value = tuple(value)
1✔
1677
            result.append(value)
1✔
1678
        return tuple(result)
1✔
1679

1680
    def process(self, stream: Stream) -> Stream:
1✔
1681
        seen = set()
1✔
1682
        for instance in stream:
1✔
1683
            values = self.to_tuple(instance, self.fields)
1✔
1684
            if values not in seen:
1✔
1685
                seen.add(values)
1✔
1686
        return list(seen)
1✔
1687

1688

1689
class SplitByValue(MultiStreamOperator):
1✔
1690
    """Splits a MultiStream into multiple streams based on unique values in specified fields.
1691

1692
    Args:
1693
        fields (List[str]): The fields to use when splitting the MultiStream.
1694
    """
1695

1696
    fields: List[str] = field(default_factory=list)
1✔
1697

1698
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1699
        uniques = Unique(fields=self.fields)(multi_stream)
1✔
1700

1701
        result = {}
1✔
1702

1703
        for stream_name, stream in multi_stream.items():
1✔
1704
            stream_unique_values = uniques[stream_name]
1✔
1705
            for unique_values in stream_unique_values:
1✔
1706
                filtering_values = dict(zip(self.fields, unique_values))
1✔
1707
                filtered_streams = FilterByCondition(
1✔
1708
                    values=filtering_values, condition="eq"
1709
                )._process_single_stream(stream)
1710
                filtered_stream_name = (
1✔
1711
                    stream_name + "_" + nested_tuple_to_string(unique_values)
1712
                )
1713
                result[filtered_stream_name] = filtered_streams
1✔
1714

1715
        return MultiStream(result)
1✔
1716

1717

1718
class SplitByNestedGroup(MultiStreamOperator):
1✔
1719
    """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1720

1721
    Args:
1722
        number_of_fusion_generations: int
1723

1724
    the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1725
    when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1726
    (See BaseFusion and its extensions)
1727
    number_of_fuaion_generations  specifies the length of the prefix by which to split the stream.
1728
    E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1729
    For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1730
    """
1731

1732
    field_name_of_group: str = "group"
1✔
1733
    number_of_fusion_generations: int = 1
1✔
1734

1735
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1736
        result = defaultdict(list)
×
1737

1738
        for stream_name, stream in multi_stream.items():
×
1739
            for instance in stream:
×
1740
                if self.field_name_of_group not in instance:
×
1741
                    raise ValueError(
×
1742
                        f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1743
                    )
1744
                signature = (
×
1745
                    stream_name
1746
                    + "~"  #  a sign that does not show within group values
1747
                    + (
1748
                        "/".join(
1749
                            instance[self.field_name_of_group].split("/")[
1750
                                : self.number_of_fusion_generations
1751
                            ]
1752
                        )
1753
                        if self.number_of_fusion_generations >= 0
1754
                        # for values with a smaller number of generations - take up to their last generation
1755
                        else instance[self.field_name_of_group]
1756
                        # for each instance - take all its generations
1757
                    )
1758
                )
1759
                result[signature].append(instance)
×
1760

1761
        return MultiStream.from_iterables(result)
×
1762

1763

1764
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1✔
1765
    """Applies stream operators to a stream based on specified fields in each instance.
1766

1767
    Args:
1768
        field (str): The field containing the operators to be applied.
1769
        reversed (bool): Whether to apply the operators in reverse order.
1770
    """
1771

1772
    field: str
1✔
1773
    reversed: bool = False
1✔
1774

1775
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1776
        first_instance = stream.peek()
1✔
1777

1778
        operators = first_instance.get(self.field, [])
1✔
1779
        if isinstance(operators, str):
1✔
1780
            operators = [operators]
1✔
1781

1782
        if self.reversed:
1✔
1783
            operators = list(reversed(operators))
1✔
1784

1785
        for operator_name in operators:
1✔
1786
            operator = self.get_artifact(operator_name)
1✔
1787
            assert isinstance(
1✔
1788
                operator, StreamingOperator
1789
            ), f"Operator {operator_name} must be a StreamOperator"
1790

1791
            stream = operator(MultiStream({stream_name: stream}))[stream_name]
1✔
1792

1793
        yield from stream
1✔
1794

1795

1796
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
1✔
1797
    for instance, score in zip(stream, scores):
1✔
1798
        instance["score"] = recursive_copy(score)
1✔
1799
        yield instance
1✔
1800

1801

1802
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1✔
1803
    """Applies metric operators to a stream based on a metric field specified in each instance.
1804

1805
    Args:
1806
        metric_field (str): The field containing the metrics to be applied.
1807
        calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1808
    """
1809

1810
    metric_field: str
1✔
1811
    calc_confidence_intervals: bool
1✔
1812

1813
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1814
        from .metrics import Metric, MetricsList
1✔
1815

1816
        # to be populated only when two or more metrics
1817
        accumulated_scores = []
1✔
1818

1819
        first_instance = stream.peek()
1✔
1820

1821
        metric_names = first_instance.get(self.metric_field, [])
1✔
1822
        if not metric_names:
1✔
1823
            raise RuntimeError(
1✔
1824
                f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1825
            )
1826

1827
        if isinstance(metric_names, str):
1✔
1828
            metric_names = [metric_names]
1✔
1829

1830
        metrics_list = []
1✔
1831
        for metric_name in metric_names:
1✔
1832
            metric = self.get_artifact(metric_name)
1✔
1833
            if isinstance(metric, MetricsList):
1✔
1834
                metrics_list.extend(list(metric.items))
1✔
1835
            elif isinstance(metric, Metric):
1✔
1836
                metrics_list.append(metric)
1✔
1837
            else:
1838
                raise ValueError(
×
1839
                    f"Operator {metric_name} must be a Metric or MetricsList"
1840
                )
1841

1842
        for metric in metrics_list:
1✔
1843
            if not self.calc_confidence_intervals:
1✔
1844
                metric.disable_confidence_interval_calculation()
1✔
1845
        # Each metric operator computes its score and then sets the main score, overwriting
1846
        # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1847
        # This will cause the first listed metric to run last, and the main score will be set
1848
        # by the first listed metric (as desired).
1849
        metrics_list = list(reversed(metrics_list))
1✔
1850

1851
        for i, metric in enumerate(metrics_list):
1✔
1852
            if i == 0:  # first metric
1✔
1853
                multi_stream = MultiStream({"tmp": stream})
1✔
1854
            else:  # metrics with previous scores
1855
                reusable_generator = ReusableGenerator(
1✔
1856
                    generator=update_scores_of_stream_instances,
1857
                    gen_kwargs={"stream": stream, "scores": accumulated_scores},
1858
                )
1859
                multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1✔
1860

1861
            multi_stream = metric(multi_stream)
1✔
1862

1863
            if i < len(metrics_list) - 1:  # last metric
1✔
1864
                accumulated_scores = []
1✔
1865
                for inst in multi_stream["tmp"]:
1✔
1866
                    accumulated_scores.append(recursive_copy(inst["score"]))
1✔
1867

1868
        yield from multi_stream["tmp"]
1✔
1869

1870

1871
class MergeStreams(MultiStreamOperator):
1✔
1872
    """Merges multiple streams into a single stream.
1873

1874
    Args:
1875
        new_stream_name (str): The name of the new stream resulting from the merge.
1876
        add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
1877
        origin_stream_name_field_name (str): The field name for the origin stream name.
1878
    """
1879

1880
    streams_to_merge: List[str] = None
1✔
1881
    new_stream_name: str = "all"
1✔
1882
    add_origin_stream_name: bool = True
1✔
1883
    origin_stream_name_field_name: str = "origin"
1✔
1884

1885
    def merge(self, multi_stream) -> Generator:
1✔
1886
        for stream_name, stream in multi_stream.items():
1✔
1887
            if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1✔
1888
                for instance in stream:
1✔
1889
                    if self.add_origin_stream_name:
1✔
1890
                        instance[self.origin_stream_name_field_name] = stream_name
1✔
1891
                    yield instance
1✔
1892

1893
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1894
        return MultiStream(
1✔
1895
            {
1896
                self.new_stream_name: DynamicStream(
1897
                    self.merge, gen_kwargs={"multi_stream": multi_stream}
1898
                )
1899
            }
1900
        )
1901

1902

1903
class Shuffle(PagedStreamOperator):
1✔
1904
    """Shuffles the order of instances in each page of a stream.
1905

1906
    Args (of superclass):
1907
        page_size (int): The size of each page in the stream. Defaults to 1000.
1908
    """
1909

1910
    random_generator: Random = None
1✔
1911

1912
    def before_process_multi_stream(self):
1✔
1913
        super().before_process_multi_stream()
1✔
1914
        self.random_generator = new_random_generator(sub_seed="shuffle")
1✔
1915

1916
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1917
        self.random_generator.shuffle(page)
1✔
1918
        yield from page
1✔
1919

1920

1921
class FeatureGroupedShuffle(Shuffle):
1✔
1922
    """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1923

1924
    Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1925
    All paraphrases have the same ID value as the original.
1926
    In this case, we may want to shuffle on grouping_features = ['question ID'],
1927
    to keep the paraphrases and original question together.
1928
    We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1929
    In this case, grouping_features = ['question ID', 'topic']
1930

1931
    Args:
1932
        grouping_features (list of strings): list of feature names to use to define the groups.
1933
            a group is defined by each unique observed combination of data values for features in grouping_features
1934
        shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1935

1936
    Args (of superclass):
1937
        page_size (int): The size of each page in the stream. Defaults to 1000.
1938
            Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1939
            separately by page (determined by page_size).  If a block of instances in the same group are split
1940
            into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1941
            grouping_features), these instances will be shuffled separately and thus the grouping may be
1942
            broken up by pages.  If the user wants to ensure the shuffle does the grouping and shuffling
1943
            across all pages, set the page_size to be larger than the dataset size.
1944
            See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1945
    """
1946

1947
    grouping_features: List[str] = None
1✔
1948
    shuffle_within_group: bool = False
1✔
1949

1950
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1951
        if self.grouping_features is None:
1✔
1952
            super().process(page, stream_name)
×
1953
        else:
1954
            yield from self.shuffle_by_grouping_features(page)
1✔
1955

1956
    def shuffle_by_grouping_features(self, page):
1✔
1957
        import itertools
1✔
1958
        from collections import defaultdict
1✔
1959

1960
        groups_to_instances = defaultdict(list)
1✔
1961
        for item in page:
1✔
1962
            groups_to_instances[
1✔
1963
                tuple(item[ff] for ff in self.grouping_features)
1964
            ].append(item)
1965
        # now extract the groups (i.e., lists of dicts with order preserved)
1966
        page_blocks = list(groups_to_instances.values())
1✔
1967
        # and now shuffle the blocks
1968
        self.random_generator.shuffle(page_blocks)
1✔
1969
        if self.shuffle_within_group:
1✔
1970
            blocks = []
1✔
1971
            # reshuffle the instances within each block, but keep the blocks in order
1972
            for block in page_blocks:
1✔
1973
                self.random_generator.shuffle(block)
1✔
1974
                blocks.append(block)
1✔
1975
            page_blocks = blocks
1✔
1976

1977
        # now flatten the list so it consists of individual dicts, but in (randomized) block order
1978
        return list(itertools.chain(*page_blocks))
1✔
1979

1980

1981
class EncodeLabels(InstanceOperator):
1✔
1982
    """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1983

1984
    Encoding is determined by a str->int map that is built on the go, as different values are
1985
    first encountered in the stream, either as list members or as values in single-value fields.
1986

1987
    Args:
1988
        fields (List[str]): The fields to encode together.
1989

1990
    Example:
1991
        applying ``EncodeLabels(fields = ["a", "b/*"])``
1992
        on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
1993
        {"a": "blue", "b": ["green"], "c":"water"}]``   will yield the
1994
        output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
1995

1996
        Note: dict_utils are applied here, and hence, fields that are lists, should be included in
1997
        input 'fields' with the appendix ``"/*"``  as in the above example.
1998

1999
    """
2000

2001
    fields: List[str]
1✔
2002

2003
    def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
1✔
2004
        self.encoder = {}
1✔
2005
        return super()._process_multi_stream(multi_stream)
1✔
2006

2007
    def process(
1✔
2008
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
2009
    ) -> Dict[str, Any]:
2010
        for field_name in self.fields:
1✔
2011
            values = dict_get(instance, field_name)
1✔
2012
            values_was_a_list = isinstance(values, list)
1✔
2013
            if not isinstance(values, list):
1✔
2014
                values = [values]
1✔
2015
            for value in values:
1✔
2016
                if value not in self.encoder:
1✔
2017
                    self.encoder[value] = len(self.encoder)
1✔
2018
            new_values = [self.encoder[value] for value in values]
1✔
2019
            if not values_was_a_list:
1✔
2020
                new_values = new_values[0]
1✔
2021
            dict_set(
1✔
2022
                instance,
2023
                field_name,
2024
                new_values,
2025
                not_exist_ok=False,  # the values to encode where just taken from there
2026
                set_multiple="*" in field_name
2027
                and isinstance(new_values, list)
2028
                and len(new_values) > 0,
2029
            )
2030

2031
        return instance
1✔
2032

2033

2034
class StreamRefiner(StreamOperator):
1✔
2035
    """Discard from the input stream all instances beyond the leading 'max_instances' instances.
2036

2037
    Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
2038
    input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
2039
    of the leading 'max_instances' of the input stream.
2040

2041
    Args:
2042
        max_instances (int)
2043
        apply_to_streams (optional, list(str)):
2044
            names of streams to refine.
2045

2046
    Examples:
2047
        when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
2048
        ``StreamRefiner(max_instances=4)``
2049
        the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
2050
    """
2051

2052
    max_instances: int = None
1✔
2053
    apply_to_streams: Optional[List[str]] = None
1✔
2054

2055
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2056
        if self.max_instances is not None:
1✔
2057
            yield from stream.take(self.max_instances)
1✔
2058
        else:
2059
            yield from stream
1✔
2060

2061

2062
class Deduplicate(StreamOperator):
1✔
2063
    """Deduplicate the stream based on the given fields.
2064

2065
    Args:
2066
        by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness.
2067

2068
    Examples:
2069
        >>> dedup = Deduplicate(by=["field1", "field2"])
2070
    """
2071

2072
    by: List[str]
1✔
2073

2074
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2075
        seen = set()
1✔
2076

2077
        for instance in stream:
1✔
2078
            # Compute a lightweight hash for the signature
2079
            signature = hash(str(tuple(dict_get(instance, field) for field in self.by)))
1✔
2080

2081
            if signature not in seen:
1✔
2082
                seen.add(signature)
1✔
2083
                yield instance
1✔
2084

2085

2086
class Balance(StreamRefiner):
1✔
2087
    """A class used to balance streams deterministically.
2088

2089
    For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
2090
    By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
2091
    When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
2092
    'max_instances'. The total number of discarded instances is as few as possible.
2093

2094
    Args:
2095
        fields (List[str]):
2096
            A list of field names to be used in producing the instance's signature.
2097
        max_instances (Optional, int):
2098
            overall max.
2099

2100
    Usage:
2101
        ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
2102
        ``balanced_stream = balancer.process(stream)``
2103

2104
    Example:
2105
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
2106
        ``DeterministicBalancer(fields=["a"])``
2107
        the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
2108
    """
2109

2110
    fields: List[str]
1✔
2111

2112
    def signature(self, instance):
1✔
2113
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2114

2115
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2116
        counter = Counter()
1✔
2117

2118
        for instance in stream:
1✔
2119
            counter[self.signature(instance)] += 1
1✔
2120

2121
        if len(counter) == 0:
1✔
2122
            return
1✔
2123

2124
        lowest_count = counter.most_common()[-1][-1]
1✔
2125

2126
        max_total_instances_per_sign = lowest_count
1✔
2127
        if self.max_instances is not None:
1✔
2128
            max_total_instances_per_sign = min(
1✔
2129
                lowest_count, self.max_instances // len(counter)
2130
            )
2131

2132
        counter = Counter()
1✔
2133

2134
        for instance in stream:
1✔
2135
            sign = self.signature(instance)
1✔
2136
            if counter[sign] < max_total_instances_per_sign:
1✔
2137
                counter[sign] += 1
1✔
2138
                yield instance
1✔
2139

2140

2141
class DeterministicBalancer(Balance):
1✔
2142
    pass
1✔
2143

2144

2145
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1✔
2146
    """A class used to return a specified number instances ensuring at least one example  per label.
2147

2148
    For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
2149
    ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit.  In general, the refiner takes the first elements in the stream that meet the required conditions.
2150
    ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
2151
    from each class first and then the rest . If max instance is not set, the original stream will be used
2152

2153
    Args:
2154
        fields (List[str]):
2155
            A list of field names to be used in producing the instance's signature.
2156
        max_instances (Optional, int):
2157
            Number of elements to select. Note that max_instances of StreamRefiners
2158
            that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
2159
            by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
2160

2161
    Usage:
2162
        | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
2163
        | ``balanced_stream = balancer.process(stream)``
2164

2165
    Example:
2166
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
2167
        ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
2168
        the resulting stream will be:
2169
        ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
2170
    """
2171

2172
    fields: List[str]
1✔
2173

2174
    def signature(self, instance):
1✔
2175
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2176

2177
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2178
        if self.max_instances is None:
1✔
2179
            for instance in stream:
×
2180
                yield instance
×
2181

2182
        counter = Counter()
1✔
2183
        for instance in stream:
1✔
2184
            counter[self.signature(instance)] += 1
1✔
2185
        all_keys = counter.keys()
1✔
2186
        if len(counter) == 0:
1✔
2187
            return
×
2188

2189
        if self.max_instances is not None and len(all_keys) > self.max_instances:
1✔
2190
            raise Exception(
×
2191
                f"Can not generate a stream with at least one example per label, because the max instances requested  {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2192
                f" ({len(all_keys)}"
2193
            )
2194

2195
        counter = Counter()
1✔
2196
        used_indices = set()
1✔
2197
        selected_elements = []
1✔
2198
        # select at least one per class
2199
        for idx, instance in enumerate(stream):
1✔
2200
            sign = self.signature(instance)
1✔
2201
            if counter[sign] == 0:
1✔
2202
                counter[sign] += 1
1✔
2203
                used_indices.add(idx)
1✔
2204
                selected_elements.append(
1✔
2205
                    instance
2206
                )  # collect all elements first to allow shuffling of both groups
2207

2208
        # select more to reach self.max_instances examples
2209
        for idx, instance in enumerate(stream):
1✔
2210
            if idx not in used_indices:
1✔
2211
                if self.max_instances is None or len(used_indices) < self.max_instances:
1✔
2212
                    used_indices.add(idx)
1✔
2213
                    selected_elements.append(
1✔
2214
                        instance
2215
                    )  # collect all elements first to allow shuffling of both groups
2216

2217
        # shuffle elements to avoid having one element from each class appear first
2218
        random_generator = new_random_generator(sub_seed=selected_elements)
1✔
2219
        random_generator.shuffle(selected_elements)
1✔
2220
        yield from selected_elements
1✔
2221

2222

2223
class LengthBalancer(DeterministicBalancer):
1✔
2224
    """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2225

2226
    Args:
2227
        segments_boundaries (List[int]):
2228
            distinct integers sorted in increasing order, that map a given total length
2229
            into the index of the least of them that exceeds the given total length.
2230
            (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2231
        fields (Optional, List[str]):
2232
            the total length of the values of these fields goes through the quantization described above
2233

2234

2235
    Example:
2236
        when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2237
        is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2238
        input instances will be counted and balanced against two categories:
2239
        empty total length (less than 1), and non-empty.
2240
    """
2241

2242
    segments_boundaries: List[int]
1✔
2243
    fields: Optional[List[str]]
1✔
2244

2245
    def signature(self, instance):
1✔
2246
        total_len = 0
1✔
2247
        for field_name in self.fields:
1✔
2248
            total_len += len(dict_get(instance, field_name))
1✔
2249
        for i, val in enumerate(self.segments_boundaries):
1✔
2250
            if total_len < val:
1✔
2251
                return i
1✔
2252
        return i + 1
1✔
2253

2254

2255
class DownloadError(Exception):
1✔
2256
    def __init__(
1✔
2257
        self,
2258
        message,
2259
    ):
2260
        self.__super__(message)
×
2261

2262

2263
class UnexpectedHttpCodeError(Exception):
1✔
2264
    def __init__(self, http_code):
1✔
2265
        self.__super__(f"unexpected http code {http_code}")
×
2266

2267

2268
class DownloadOperator(SideEffectOperator):
1✔
2269
    """Operator for downloading a file from a given URL to a specified local path.
2270

2271
    Args:
2272
        source (str):
2273
            URL of the file to be downloaded.
2274
        target (str):
2275
            Local path where the downloaded file should be saved.
2276
    """
2277

2278
    source: str
1✔
2279
    target: str
1✔
2280

2281
    def process(self):
1✔
2282
        try:
×
2283
            response = requests.get(self.source, allow_redirects=True)
×
2284
        except Exception as e:
×
2285
            raise DownloadError(f"Unabled to download {self.source}") from e
×
2286
        if response.status_code != 200:
×
2287
            raise UnexpectedHttpCodeError(response.status_code)
×
2288
        with open(self.target, "wb") as f:
×
2289
            f.write(response.content)
×
2290

2291

2292
class ExtractZipFile(SideEffectOperator):
1✔
2293
    """Operator for extracting files from a zip archive.
2294

2295
    Args:
2296
        zip_file (str):
2297
            Path of the zip file to be extracted.
2298
        target_dir (str):
2299
            Directory where the contents of the zip file will be extracted.
2300
    """
2301

2302
    zip_file: str
1✔
2303
    target_dir: str
1✔
2304

2305
    def process(self):
1✔
2306
        with zipfile.ZipFile(self.zip_file) as zf:
×
2307
            zf.extractall(self.target_dir)
×
2308

2309

2310
class DuplicateInstances(StreamOperator):
1✔
2311
    """Operator which duplicates each instance in stream a given number of times.
2312

2313
    Args:
2314
        num_duplications (int):
2315
            How many times each instance should be duplicated (1 means no duplication).
2316
        duplication_index_field (Optional[str]):
2317
            If given, then additional field with specified name is added to each duplicated instance,
2318
            which contains id of a given duplication. Defaults to None, so no field is added.
2319
    """
2320

2321
    num_duplications: int
1✔
2322
    duplication_index_field: Optional[str] = None
1✔
2323

2324
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2325
        for instance in stream:
1✔
2326
            for idx in range(self.num_duplications):
1✔
2327
                duplicate = recursive_shallow_copy(instance)
1✔
2328
                if self.duplication_index_field:
1✔
2329
                    duplicate.update({self.duplication_index_field: idx})
1✔
2330
                yield duplicate
1✔
2331

2332
    def verify(self):
1✔
2333
        if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1✔
2334
            raise ValueError(
×
2335
                f"num_duplications must be an integer equal to or greater than 1. "
2336
                f"Got: {self.num_duplications}."
2337
            )
2338

2339
        if self.duplication_index_field is not None and not isinstance(
1✔
2340
            self.duplication_index_field, str
2341
        ):
2342
            raise ValueError(
×
2343
                f"If given, duplication_index_field must be a string. "
2344
                f"Got: {self.duplication_index_field}"
2345
            )
2346

2347

2348
class CollateInstances(StreamOperator):
1✔
2349
    """Operator which collates values from multiple instances to a single instance.
2350

2351
    Each field becomes the list of values of corresponding field of collated `batch_size` of instances.
2352

2353
    Attributes:
2354
        batch_size (int)
2355

2356
    Example:
2357
        .. code-block:: text
2358

2359
            CollateInstances(batch_size=2)
2360

2361
            Given inputs = [
2362
                {"a": 1, "b": 2},
2363
                {"a": 2, "b": 2},
2364
                {"a": 3, "b": 2},
2365
                {"a": 4, "b": 2},
2366
                {"a": 5, "b": 2}
2367
            ]
2368

2369
            Returns targets = [
2370
                {"a": [1,2], "b": [2,2]},
2371
                {"a": [3,4], "b": [2,2]},
2372
                {"a": [5], "b": [2]},
2373
            ]
2374

2375

2376
    """
2377

2378
    batch_size: int
1✔
2379

2380
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2381
        stream = list(stream)
1✔
2382
        for i in range(0, len(stream), self.batch_size):
1✔
2383
            batch = stream[i : i + self.batch_size]
1✔
2384
            new_instance = {}
1✔
2385
            for a_field in batch[0]:
1✔
2386
                if a_field == "data_classification_policy":
1✔
2387
                    flattened_list = [
1✔
2388
                        classification
2389
                        for instance in batch
2390
                        for classification in instance[a_field]
2391
                    ]
2392
                    new_instance[a_field] = sorted(set(flattened_list))
1✔
2393
                else:
2394
                    new_instance[a_field] = [instance[a_field] for instance in batch]
1✔
2395
            yield new_instance
1✔
2396

2397
    def verify(self):
1✔
2398
        if not isinstance(self.batch_size, int) or self.batch_size < 1:
1✔
2399
            raise ValueError(
×
2400
                f"batch_size must be an integer equal to or greater than 1. "
2401
                f"Got: {self.batch_size}."
2402
            )
2403

2404

2405
class CollateInstancesByField(StreamOperator):
1✔
2406
    """Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
2407

2408
    Args:
2409
        by_field str: the name of the field to group data by.
2410
        aggregate_fields list(str): the field names to aggregate into lists.
2411

2412
    Returns:
2413
        A stream of instances grouped and aggregated by the specified field.
2414

2415
    Raises:
2416
        UnitxtError: If non-aggregate fields have inconsistent values.
2417

2418
    Example:
2419
        Collate the instances based on field "category" and aggregate fields "value" and "id".
2420

2421
        .. code-block:: text
2422

2423
            CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
2424

2425
            given input:
2426
            [
2427
                {"id": 1, "category": "A", "value": 10", "flag" : True},
2428
                {"id": 2, "category": "B", "value": 20", "flag" : False},
2429
                {"id": 3, "category": "A", "value": 30", "flag" : True},
2430
                {"id": 4, "category": "B", "value": 40", "flag" : False}
2431
            ]
2432

2433
            the output is:
2434
            [
2435
                {"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
2436
                {"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
2437
            ]
2438

2439
        Note that the "flag" field is not aggregated, and must be the same
2440
        in all instances in the same category, or an error is raised.
2441
    """
2442

2443
    by_field: str = NonPositionalField(required=True)
1✔
2444
    aggregate_fields: List[str] = NonPositionalField(required=True)
1✔
2445

2446
    def prepare(self):
1✔
2447
        super().prepare()
1✔
2448

2449
    def verify(self):
1✔
2450
        super().verify()
1✔
2451
        if not isinstance(self.by_field, str):
1✔
2452
            raise UnitxtError(
×
2453
                f"The 'by_field' value is not a string but '{type(self.by_field)}'"
2454
            )
2455

2456
        if not isinstance(self.aggregate_fields, list):
1✔
2457
            raise UnitxtError(
×
2458
                f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
2459
            )
2460

2461
    def process(self, stream: Stream, stream_name: Optional[str] = None):
1✔
2462
        grouped_data = {}
1✔
2463

2464
        for instance in stream:
1✔
2465
            if self.by_field not in instance:
1✔
2466
                raise UnitxtError(
1✔
2467
                    f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
2468
                )
2469
            for k in self.aggregate_fields:
1✔
2470
                if k not in instance:
1✔
2471
                    raise UnitxtError(
1✔
2472
                        f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
2473
                    )
2474
            key = instance[self.by_field]
1✔
2475

2476
            if key not in grouped_data:
1✔
2477
                grouped_data[key] = {
1✔
2478
                    k: v for k, v in instance.items() if k not in self.aggregate_fields
2479
                }
2480
                # Add empty lists for fields to aggregate
2481
                for agg_field in self.aggregate_fields:
1✔
2482
                    if agg_field in instance:
1✔
2483
                        grouped_data[key][agg_field] = []
1✔
2484

2485
            for k, v in instance.items():
1✔
2486
                # Merge classification policy list across instance with same key
2487
                if k == "data_classification_policy" and instance[k]:
1✔
2488
                    grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
1✔
2489
                # Check consistency for all non-aggregate fields
2490
                elif k != self.by_field and k not in self.aggregate_fields:
1✔
2491
                    if k in grouped_data[key] and grouped_data[key][k] != v:
1✔
2492
                        raise ValueError(
1✔
2493
                            f"Inconsistent value for field '{k}' in group '{key}': "
2494
                            f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
2495
                        )
2496
                # Aggregate fields
2497
                elif k in self.aggregate_fields:
1✔
2498
                    grouped_data[key][k].append(instance[k])
1✔
2499

2500
        yield from grouped_data.values()
1✔
2501

2502

2503
class WikipediaFetcher(FieldOperator):
1✔
2504
    mode: Literal["summary", "text"] = "text"
1✔
2505
    _requirements_list = ["Wikipedia-API"]
1✔
2506

2507
    def prepare(self):
1✔
2508
        super().prepare()
×
2509
        import wikipediaapi
×
2510

2511
        self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
×
2512

2513
    def process_value(self, value: Any) -> Any:
1✔
2514
        title = value.split("/")[-1]
×
2515
        page = self.wikipedia.page(title)
×
2516

2517
        return {"title": page.title, "body": getattr(page, self.mode)}
×
2518

2519
class Fillna(FieldOperator):
1✔
2520
    value: Any
1✔
2521
    def process_value(self, value: Any) -> Any:
1✔
2522
        import numpy as np
×
2523
        try:
×
2524
            if np.isnan(value):
×
2525
                return self.value
×
2526
        except TypeError:
×
2527
            return value
×
2528
        return value
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc