• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13055168089

30 Jan 2025 03:08PM UTC coverage: 79.264% (+0.04%) from 79.225%
13055168089

push

github

web-flow
Create new IntersectCorrespondingFields operator  (#1531)

* filter for entity types intro

* code update

* optimalisation

* improv

* remove filter and add its functionality to intersect

* typo

* Created a new type of intersect operator

Signed-off-by: Yoav Katz <katz@il.ibm.com>

* Updated documentation

Signed-off-by: Yoav Katz <katz@il.ibm.com>

---------

Signed-off-by: Yoav Katz <katz@il.ibm.com>
Co-authored-by: Przemysław Klocek <przemyslaw.klocek@ibm.com>
Co-authored-by: Yoav Katz <katz@il.ibm.com>
Co-authored-by: Yoav Katz <68273864+yoavkatz@users.noreply.github.com>

1449 of 1821 branches covered (79.57%)

Branch coverage included in aggregate %.

9132 of 11528 relevant lines covered (79.22%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.99
src/unitxt/operators.py
1
"""This section describes unitxt operators.
2

3
Operators: Building Blocks of Unitxt Processing Pipelines
4
==============================================================
5

6
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
These operators are callable entities that receive a MultiStream as input.
9
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10

11
Creating Custom Operators
12
-------------------------------
13
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16

17
General or Specialized Operators
18
--------------------------------
19
Some operators are specialized in specific data or specific operations such as:
20

21
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
- :class:`splitters<unitxt.splitters>` for fixing data splits.
23
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
- :class:`string_operators<unitxt.string_operators>` for handling strings.
28
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30

31
Other specialized operators are used by unitxt internally:
32

33
- :class:`templates<unitxt.templates>` for verbalizing data examples.
34
- :class:`formats<unitxt.formats>` for preparing data for models.
35

36
The rest of this section is dedicated to general operators.
37

38
General Operators List:
39
------------------------
40
"""
41

42
import operator
1✔
43
import uuid
1✔
44
import warnings
1✔
45
import zipfile
1✔
46
from abc import abstractmethod
1✔
47
from collections import Counter, defaultdict
1✔
48
from dataclasses import field
1✔
49
from itertools import zip_longest
1✔
50
from random import Random
1✔
51
from typing import (
1✔
52
    Any,
53
    Callable,
54
    Dict,
55
    Generator,
56
    Iterable,
57
    List,
58
    Literal,
59
    Optional,
60
    Tuple,
61
    Union,
62
)
63

64
import requests
1✔
65

66
from .artifact import Artifact, fetch_artifact
1✔
67
from .dataclass import NonPositionalField, OptionalField
1✔
68
from .deprecation_utils import deprecation
1✔
69
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
1✔
70
from .error_utils import UnitxtError
1✔
71
from .generator_utils import ReusableGenerator
1✔
72
from .operator import (
1✔
73
    InstanceOperator,
74
    MultiStream,
75
    MultiStreamOperator,
76
    PagedStreamOperator,
77
    SequentialOperator,
78
    SideEffectOperator,
79
    SingleStreamReducer,
80
    SourceOperator,
81
    StreamingOperator,
82
    StreamInitializerOperator,
83
    StreamOperator,
84
)
85
from .random_utils import new_random_generator
1✔
86
from .settings_utils import get_settings
1✔
87
from .stream import DynamicStream, Stream
1✔
88
from .text_utils import nested_tuple_to_string, to_pretty_string
1✔
89
from .type_utils import isoftype
1✔
90
from .utils import (
1✔
91
    LRUCache,
92
    deep_copy,
93
    flatten_dict,
94
    recursive_copy,
95
    recursive_shallow_copy,
96
    shallow_copy,
97
)
98

99
settings = get_settings()
1✔
100

101

102
class FromIterables(StreamInitializerOperator):
1✔
103
    """Creates a MultiStream from a dict of named iterables.
104

105
    Example:
106
        operator = FromIterables()
107
        ms = operator.process(iterables)
108

109
    """
110

111
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
1✔
112
        return MultiStream.from_iterables(iterables)
1✔
113

114

115
class IterableSource(SourceOperator):
1✔
116
    """Creates a MultiStream from a dict of named iterables.
117

118
    It is a callable.
119

120
    Args:
121
        iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
122

123
    Example:
124
        operator =  IterableSource(input_dict)
125
        ms = operator()
126

127
    """
128

129
    iterables: Dict[str, Iterable]
1✔
130

131
    def process(self) -> MultiStream:
1✔
132
        return MultiStream.from_iterables(self.iterables)
1✔
133

134

135
class MapInstanceValues(InstanceOperator):
1✔
136
    """A class used to map instance values into other values.
137

138
    This class is a type of ``InstanceOperator``,
139
    it maps values of instances in a stream using predefined mappers.
140

141
    Args:
142
        mappers (Dict[str, Dict[str, Any]]):
143
            The mappers to use for mapping instance values.
144
            Keys are the names of the fields to undergo mapping, and values are dictionaries
145
            that define the mapping from old values to new values.
146
            Note that mapped values are defined by their string representation, so mapped values
147
            are converted to strings before being looked up in the mappers.
148
        strict (bool):
149
            If True, the mapping is applied strictly. That means if a value
150
            does not exist in the mapper, it will raise a KeyError. If False, values
151
            that are not present in the mapper are kept as they are.
152
        process_every_value (bool):
153
            If True, all fields to be mapped should be lists, and the mapping
154
            is to be applied to their individual elements.
155
            If False, mapping is only applied to a field containing a single value.
156

157
    Examples:
158
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
159
        replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
160
        instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
161
        since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
162
        up in the mapper of ``"a"``.
163

164
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
165
        Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
166
        each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
167
        instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
168

169
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
170
        To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
171
        Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
172
        because ``"3"`` is not a key in the mapper of ``"a"``.
173

174
        ``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
175
        replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
176

177
    """
178

179
    mappers: Dict[str, Dict[str, str]]
1✔
180
    strict: bool = True
1✔
181
    process_every_value: bool = False
1✔
182

183
    def verify(self):
1✔
184
        # make sure the mappers are valid
185
        for key, mapper in self.mappers.items():
1✔
186
            assert isinstance(
1✔
187
                mapper, dict
188
            ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
189
            for k in mapper.keys():
1✔
190
                assert isinstance(
1✔
191
                    k, str
192
                ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
193

194
    def process(
1✔
195
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
196
    ) -> Dict[str, Any]:
197
        for key, mapper in self.mappers.items():
1✔
198
            value = dict_get(instance, key)
1✔
199
            if value is not None:
1✔
200
                if (self.process_every_value is True) and (not isinstance(value, list)):
1✔
201
                    raise ValueError(
1✔
202
                        f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
203
                    )
204
                if isinstance(value, list) and self.process_every_value:
1✔
205
                    for i, val in enumerate(value):
1✔
206
                        value[i] = self.get_mapped_value(instance, key, mapper, val)
1✔
207
                else:
208
                    value = self.get_mapped_value(instance, key, mapper, value)
1✔
209
                dict_set(
1✔
210
                    instance,
211
                    key,
212
                    value,
213
                )
214

215
        return instance
1✔
216

217
    def get_mapped_value(self, instance, key, mapper, val):
1✔
218
        val_as_str = str(val)  # make sure the value is a string
1✔
219
        if val_as_str in mapper:
1✔
220
            return recursive_copy(mapper[val_as_str])
1✔
221
        if self.strict:
1✔
222
            raise KeyError(
1✔
223
                f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
224
            )
225
        return val
1✔
226

227

228
class FlattenInstances(InstanceOperator):
1✔
229
    """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
230

231
    Args:
232
        parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
233
        sep (str): The separator to use when concatenating nested keys. Defaults to "_".
234
    """
235

236
    parent_key: str = ""
1✔
237
    sep: str = "_"
1✔
238

239
    def process(
1✔
240
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
241
    ) -> Dict[str, Any]:
242
        return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
1✔
243

244

245
class Set(InstanceOperator):
1✔
246
    """Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
247

248
    Args:
249
        fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
250

251
        use_deepcopy (bool) : Deep copy the input value to avoid later modifications
252

253
    Examples:
254
        # Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
255
        ``Set(fields={"classes": ["positive","negatives"]})``
256

257
        # In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
258
        ``Set(fields={"span/start": 0}``
259

260
        # In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
261
        ``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
262

263
        # Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
264
        ``Set(fields={"classes": alist}), use_deepcopy=True)``  if now alist is modified, still the instances remain intact.
265
    """
266

267
    fields: Dict[str, object]
1✔
268
    use_query: Optional[bool] = None
1✔
269
    use_deepcopy: bool = False
1✔
270

271
    def verify(self):
1✔
272
        super().verify()
1✔
273
        if self.use_query is not None:
1✔
274
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
275
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
276

277
    def process(
1✔
278
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
279
    ) -> Dict[str, Any]:
280
        for key, value in self.fields.items():
1✔
281
            if self.use_deepcopy:
1✔
282
                value = deep_copy(value)
1✔
283
            dict_set(instance, key, value)
1✔
284
        return instance
1✔
285

286

287
@deprecation(version="2.0.0", alternative=Set)
1✔
288
class AddFields(Set):
1✔
289
    pass
1✔
290

291

292
class RemoveFields(InstanceOperator):
1✔
293
    """Remove specified fields from each instance in a stream.
294

295
    Args:
296
        fields (List[str]): The fields to remove from each instance.
297
    """
298

299
    fields: List[str]
1✔
300

301
    def process(
1✔
302
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
303
    ) -> Dict[str, Any]:
304
        for field_name in self.fields:
1✔
305
            del instance[field_name]
1✔
306
        return instance
1✔
307

308

309
class SelectFields(InstanceOperator):
1✔
310
    """Keep only specified fields from each instance in a stream.
311

312
    Args:
313
        fields (List[str]): The fields to keep from each instance.
314
    """
315

316
    fields: List[str]
1✔
317

318
    def prepare(self):
1✔
319
        super().prepare()
1✔
320
        self.fields.extend(["data_classification_policy", "recipe_metadata"])
1✔
321

322
    def process(
1✔
323
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
324
    ) -> Dict[str, Any]:
325
        new_instance = {}
1✔
326
        for selected_field in self.fields:
1✔
327
            new_instance[selected_field] = instance[selected_field]
1✔
328
        return new_instance
1✔
329

330

331
class DefaultPlaceHolder:
1✔
332
    pass
1✔
333

334

335
default_place_holder = DefaultPlaceHolder()
1✔
336

337

338
class InstanceFieldOperator(InstanceOperator):
1✔
339
    """A general stream instance operator that processes the values of a field (or multiple ones).
340

341
    Args:
342
        field (Optional[str]):
343
            The field to process, if only a single one is passed. Defaults to None
344
        to_field (Optional[str]):
345
            Field name to save result into, if only one field is processed, if None is passed the
346
            operation would happen in-place and its result would replace the value of ``field``. Defaults to None
347
        field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
348
            Mapping from names of fields to process,
349
            to names of fields to save the results into. Inner List, if used, should be of length 2.
350
            A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
351
            is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
352
            in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
353
            order. The end result might depend on that order if either (1) two different fields are mapped to the same
354
            to_field, or (2) a field shows both as a key and as a value in different mappings.
355
            The operator throws an AssertionError in either of these cases. ``field_to_field``
356
            defaults to None.
357
        process_every_value (bool):
358
            Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
359

360
    Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
361
    prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
362

363
    """
364

365
    field: Optional[str] = None
1✔
366
    to_field: Optional[str] = None
1✔
367
    field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
1✔
368
    use_query: Optional[bool] = None
1✔
369
    process_every_value: bool = False
1✔
370
    get_default: Any = None
1✔
371
    not_exist_ok: bool = False
1✔
372
    not_exist_do_nothing: bool = False
1✔
373

374
    def verify(self):
1✔
375
        super().verify()
1✔
376
        if self.use_query is not None:
1✔
377
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
1✔
378
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
379

380
    def verify_field_definition(self):
1✔
381
        if hasattr(self, "_field_to_field") and self._field_to_field is not None:
1✔
382
            return
1✔
383
        assert (
1✔
384
            (self.field is None) != (self.field_to_field is None)
385
        ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
386
        assert (
1✔
387
            self.to_field is None or self.field_to_field is None
388
        ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
389

390
        if self.field_to_field is None:
1✔
391
            self._field_to_field = [
1✔
392
                (self.field, self.to_field if self.to_field is not None else self.field)
393
            ]
394
        else:
395
            self._field_to_field = (
1✔
396
                list(self.field_to_field.items())
397
                if isinstance(self.field_to_field, dict)
398
                else self.field_to_field
399
            )
400
        assert (
1✔
401
            self.field is not None or self.field_to_field is not None
402
        ), "Must supply a field to work on"
403
        assert (
1✔
404
            self.to_field is None or self.field_to_field is None
405
        ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
406
        assert (
1✔
407
            self.field is None or self.field_to_field is None
408
        ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
409
        assert (
1✔
410
            self._field_to_field is not None
411
        ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
412
        assert (
1✔
413
            len(self._field_to_field) > 0
414
        ), f"'input argument '{self.__class__.__name__}.field_to_field' should convey at least one field to process. Got {self.field_to_field}"
415
        # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
416
        if self.field_to_field is None:
1✔
417
            return
1✔
418
        # for backward compatibility also allow list of tuples of two strings
419
        if isoftype(self.field_to_field, List[List[str]]) or isoftype(
1✔
420
            self.field_to_field, List[Tuple[str, str]]
421
        ):
422
            for pair in self._field_to_field:
1✔
423
                assert (
1✔
424
                    len(pair) == 2
425
                ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
426
            # order of field processing is uniquely determined by the input field_to_field when a list
427
            return
1✔
428
        if isoftype(self.field_to_field, Dict[str, str]):
1✔
429
            if len(self.field_to_field) < 2:
1✔
430
                return
1✔
431
            for ff, tt in self.field_to_field.items():
1✔
432
                for f, t in self.field_to_field.items():
1✔
433
                    if f == ff:
1✔
434
                        continue
1✔
435
                    assert (
1✔
436
                        t != ff
437
                    ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
438
                    assert (
1✔
439
                        tt != t
440
                    ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
441
            return
1✔
442
        raise ValueError(
1✔
443
            "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
444
        )
445

446
    @abstractmethod
1✔
447
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
448
        pass
×
449

450
    def process(
1✔
451
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
452
    ) -> Dict[str, Any]:
453
        self.verify_field_definition()
1✔
454
        for from_field, to_field in self._field_to_field:
1✔
455
            try:
1✔
456
                old_value = dict_get(
1✔
457
                    instance,
458
                    from_field,
459
                    default=default_place_holder,
460
                    not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
461
                )
462
                if old_value is default_place_holder:
1✔
463
                    if self.not_exist_do_nothing:
1✔
464
                        continue
1✔
465
                    old_value = self.get_default
×
466
            except Exception as e:
1✔
467
                raise ValueError(
1✔
468
                    f"Failed to get '{from_field}' from instance due to the exception above."
469
                ) from e
470
            try:
1✔
471
                if self.process_every_value:
1✔
472
                    new_value = [
1✔
473
                        self.process_instance_value(value, instance)
474
                        for value in old_value
475
                    ]
476
                else:
477
                    new_value = self.process_instance_value(old_value, instance)
1✔
478
            except Exception as e:
1✔
479
                raise ValueError(
1✔
480
                    f"Failed to process field '{from_field}' from instance due to the exception above."
481
                ) from e
482
            dict_set(
1✔
483
                instance,
484
                to_field,
485
                new_value,
486
                not_exist_ok=True,
487
            )
488
        return instance
1✔
489

490

491
class FieldOperator(InstanceFieldOperator):
1✔
492
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
493
        return self.process_value(value)
1✔
494

495
    @abstractmethod
1✔
496
    def process_value(self, value: Any) -> Any:
1✔
497
        pass
1✔
498

499

500
class MapValues(FieldOperator):
1✔
501
    mapping: Dict[str, str]
1✔
502

503
    def process_value(self, value: Any) -> Any:
1✔
504
        return self.mapping[str(value)]
×
505

506

507
class Rename(FieldOperator):
1✔
508
    """Renames fields.
509

510
    Move value from one field to another, potentially, if field name contains a /, from one branch into another.
511
    Remove the from field, potentially part of it in case of / in from_field.
512

513
    Examples:
514
        Rename(field_to_field={"b": "c"})
515
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
516

517
        Rename(field_to_field={"b": "c/d"})
518
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
519

520
        Rename(field_to_field={"b": "b/d"})
521
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
522

523
        Rename(field_to_field={"b/c/e": "b/d"})
524
        will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
525

526
    """
527

528
    def process_value(self, value: Any) -> Any:
1✔
529
        return value
1✔
530

531
    def process(
1✔
532
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
533
    ) -> Dict[str, Any]:
534
        res = super().process(instance=instance, stream_name=stream_name)
1✔
535
        for from_field, to_field in self._field_to_field:
1✔
536
            if (not is_subpath(from_field, to_field)) and (
1✔
537
                not is_subpath(to_field, from_field)
538
            ):
539
                dict_delete(res, from_field, remove_empty_ancestors=True)
1✔
540

541
        return res
1✔
542

543

544
@deprecation(version="2.0.0", alternative=Rename)
1✔
545
class RenameFields(Rename):
1✔
546
    pass
1✔
547

548

549
class AddConstant(FieldOperator):
1✔
550
    """Adds a constant, being argument 'add', to the processed value.
551

552
    Args:
553
        add: the constant to add.
554
    """
555

556
    add: Any
1✔
557

558
    def process_value(self, value: Any) -> Any:
1✔
559
        return self.add + value
1✔
560

561

562
class ShuffleFieldValues(FieldOperator):
1✔
563
    """Shuffles a list of values found in a field."""
564

565
    def process_value(self, value: Any) -> Any:
1✔
566
        res = list(value)
1✔
567
        random_generator = new_random_generator(sub_seed=res)
1✔
568
        random_generator.shuffle(res)
1✔
569
        return res
1✔
570

571

572
class JoinStr(FieldOperator):
1✔
573
    """Joins a list of strings (contents of a field), similar to str.join().
574

575
    Args:
576
        separator (str): text to put between values
577
    """
578

579
    separator: str = ","
1✔
580

581
    def process_value(self, value: Any) -> Any:
1✔
582
        return self.separator.join(str(x) for x in value)
1✔
583

584

585
class Apply(InstanceOperator):
1✔
586
    """A class used to apply a python function and store the result in a field.
587

588
    Args:
589
        function (str): name of function.
590
        to_field (str): the field to store the result
591

592
    any additional arguments are field names whose values will be passed directly to the function specified
593

594
    Examples:
595
    Store in field  "b" the uppercase string of the value in field "a":
596
    ``Apply("a", function=str.upper, to_field="b")``
597

598
    Dump the json representation of field "t" and store back in the same field:
599
    ``Apply("t", function=json.dumps, to_field="t")``
600

601
    Set the time in a field 'b':
602
    ``Apply(function=time.time, to_field="b")``
603

604
    """
605

606
    __allow_unexpected_arguments__ = True
1✔
607
    function: Callable = NonPositionalField(required=True)
1✔
608
    to_field: str = NonPositionalField(required=True)
1✔
609

610
    def function_to_str(self, function: Callable) -> str:
1✔
611
        parts = []
1✔
612

613
        if hasattr(function, "__module__"):
1✔
614
            parts.append(function.__module__)
1✔
615
        if hasattr(function, "__qualname__"):
1✔
616
            parts.append(function.__qualname__)
1✔
617
        else:
618
            parts.append(function.__name__)
×
619

620
        return ".".join(parts)
1✔
621

622
    def str_to_function(self, function_str: str) -> Callable:
1✔
623
        parts = function_str.split(".", 1)
1✔
624
        if len(parts) == 1:
1✔
625
            return __builtins__[parts[0]]
1✔
626

627
        module_name, function_name = parts
1✔
628
        if module_name in __builtins__:
1✔
629
            obj = __builtins__[module_name]
1✔
630
        elif module_name in globals():
1✔
631
            obj = globals()[module_name]
×
632
        else:
633
            obj = __import__(module_name)
1✔
634
        for part in function_name.split("."):
1✔
635
            obj = getattr(obj, part)
1✔
636
        return obj
1✔
637

638
    def prepare(self):
1✔
639
        super().prepare()
1✔
640
        if isinstance(self.function, str):
1✔
641
            self.function = self.str_to_function(self.function)
1✔
642
        self._init_dict["function"] = self.function_to_str(self.function)
1✔
643

644
    def process(
1✔
645
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
646
    ) -> Dict[str, Any]:
647
        argv = [instance[arg] for arg in self._argv]
1✔
648
        kwargs = {key: instance[val] for key, val in self._kwargs}
1✔
649

650
        result = self.function(*argv, **kwargs)
1✔
651

652
        instance[self.to_field] = result
1✔
653
        return instance
1✔
654

655

656
class ListFieldValues(InstanceOperator):
1✔
657
    """Concatenates values of multiple fields into a list, and assigns it to a new field."""
658

659
    fields: List[str]
1✔
660
    to_field: str
1✔
661
    use_query: Optional[bool] = None
1✔
662

663
    def verify(self):
1✔
664
        super().verify()
1✔
665
        if self.use_query is not None:
1✔
666
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
667
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
668

669
    def process(
1✔
670
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
671
    ) -> Dict[str, Any]:
672
        values = []
1✔
673
        for field_name in self.fields:
1✔
674
            values.append(dict_get(instance, field_name))
1✔
675

676
        dict_set(instance, self.to_field, values)
1✔
677

678
        return instance
1✔
679

680

681
class ZipFieldValues(InstanceOperator):
1✔
682
    """Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
683

684
    The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
685
    are zipped, and stored into 'to_field'.
686

687
    | If 'longest'=False, the length of the zipped result is determined by the shortest input value.
688
    | If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
689

690
    """
691

692
    fields: List[str]
1✔
693
    to_field: str
1✔
694
    longest: bool = False
1✔
695
    use_query: Optional[bool] = None
1✔
696

697
    def verify(self):
1✔
698
        super().verify()
1✔
699
        if self.use_query is not None:
1✔
700
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
701
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
702

703
    def process(
1✔
704
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
705
    ) -> Dict[str, Any]:
706
        values = []
1✔
707
        for field_name in self.fields:
1✔
708
            values.append(dict_get(instance, field_name))
1✔
709
        if self.longest:
1✔
710
            zipped = zip_longest(*values)
1✔
711
        else:
712
            zipped = zip(*values)
1✔
713
        dict_set(instance, self.to_field, list(zipped))
1✔
714
        return instance
1✔
715

716

717
class InterleaveListsToDialogOperator(InstanceOperator):
1✔
718
    """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
719

720
    The list of tuples if of format (role, turn_content), where the role label is specified by
721
    the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
722

723
    The user turns and assistant turns field are specified in the arguments.
724
    The value of each of the 'fields' is assumed to be a list.
725

726
    """
727

728
    user_turns_field: str
1✔
729
    assistant_turns_field: str
1✔
730
    user_role_label: str = "user"
1✔
731
    assistant_role_label: str = "assistant"
1✔
732
    to_field: str
1✔
733

734
    def process(
1✔
735
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
736
    ) -> Dict[str, Any]:
737
        user_turns = instance[self.user_turns_field]
×
738
        assistant_turns = instance[self.assistant_turns_field]
×
739

740
        assert (
×
741
            len(user_turns) == len(assistant_turns)
742
            or (len(user_turns) - len(assistant_turns) == 1)
743
        ), "user_turns must have either the same length as assistant_turns or one more turn."
744

745
        interleaved_dialog = []
×
746
        i, j = 0, 0  # Indices for the user and assistant lists
×
747
        # While either list has elements left, continue interleaving
748
        while i < len(user_turns) or j < len(assistant_turns):
×
749
            if i < len(user_turns):
×
750
                interleaved_dialog.append((self.user_role_label, user_turns[i]))
×
751
                i += 1
×
752
            if j < len(assistant_turns):
×
753
                interleaved_dialog.append(
×
754
                    (self.assistant_role_label, assistant_turns[j])
755
                )
756
                j += 1
×
757

758
        instance[self.to_field] = interleaved_dialog
×
759
        return instance
×
760

761

762
class IndexOf(InstanceOperator):
1✔
763
    """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
764

765
    search_in: str
1✔
766
    index_of: str
1✔
767
    to_field: str
1✔
768
    use_query: Optional[bool] = None
1✔
769

770
    def verify(self):
1✔
771
        super().verify()
1✔
772
        if self.use_query is not None:
1✔
773
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
774
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
775

776
    def process(
1✔
777
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
778
    ) -> Dict[str, Any]:
779
        lst = dict_get(instance, self.search_in)
1✔
780
        item = dict_get(instance, self.index_of)
1✔
781
        instance[self.to_field] = lst.index(item)
1✔
782
        return instance
1✔
783

784

785
class TakeByField(InstanceOperator):
1✔
786
    """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
787

788
    field: str
1✔
789
    index: str
1✔
790
    to_field: str = None
1✔
791
    use_query: Optional[bool] = None
1✔
792

793
    def verify(self):
1✔
794
        super().verify()
1✔
795
        if self.use_query is not None:
1✔
796
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
797
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
798

799
    def prepare(self):
1✔
800
        if self.to_field is None:
1✔
801
            self.to_field = self.field
1✔
802

803
    def process(
1✔
804
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
805
    ) -> Dict[str, Any]:
806
        value = dict_get(instance, self.field)
1✔
807
        index_value = dict_get(instance, self.index)
1✔
808
        instance[self.to_field] = value[index_value]
1✔
809
        return instance
1✔
810

811

812
class Perturb(FieldOperator):
1✔
813
    """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
814

815
    When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
816
    relevant perturbation
817

818
    Args:
819
        percentage_to_perturb (int):
820
            the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
821
        select_from: List[Any]:
822
            a list of values to select from, as a perturbation of the field's value. Defaults to [].
823
    """
824

825
    select_from: List[Any] = []
1✔
826
    percentage_to_perturb: int = 1  # 1 percent
1✔
827

828
    def verify(self):
1✔
829
        assert (
1✔
830
            0 <= self.percentage_to_perturb and self.percentage_to_perturb <= 100
831
        ), f"'percentage_to_perturb' should be in the range 0..100. Received {self.percentage_to_perturb}"
832

833
    def prepare(self):
1✔
834
        super().prepare()
1✔
835
        self.random_generator = new_random_generator(sub_seed="CopyWithPerturbation")
1✔
836

837
    def process_value(self, value: Any) -> Any:
1✔
838
        perturb = self.random_generator.randint(1, 100) <= self.percentage_to_perturb
1✔
839
        if not perturb:
1✔
840
            return value
1✔
841

842
        if value in self.select_from:
1✔
843
            # 80% of cases, return a decent class, otherwise, perturb the value itself as follows
844
            if self.random_generator.random() < 0.8:
1✔
845
                return self.random_generator.choice(self.select_from)
1✔
846

847
        if isinstance(value, float):
1✔
848
            return value * (0.5 + self.random_generator.random())
1✔
849

850
        if isinstance(value, int):
1✔
851
            perturb = 1 if self.random_generator.random() < 0.5 else -1
1✔
852
            return value + perturb
1✔
853

854
        if isinstance(value, str):
1✔
855
            if len(value) < 2:
1✔
856
                # give up perturbation
857
                return value
1✔
858
            # throw one char out
859
            prefix_len = self.random_generator.randint(1, len(value) - 1)
1✔
860
            return value[:prefix_len] + value[prefix_len + 1 :]
1✔
861

862
        # and in any other case:
863
        return value
×
864

865

866
class Copy(FieldOperator):
1✔
867
    """Copies values from specified fields to specified fields.
868

869
    Args (of parent class):
870
        field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
871

872
    Examples:
873
        An input instance {"a": 2, "b": 3}, when processed by
874
        ``Copy(field_to_field={"a": "b"})``
875
        would yield {"a": 2, "b": 2}, and when processed by
876
        ``Copy(field_to_field={"a": "c"})`` would yield
877
        {"a": 2, "b": 3, "c": 2}
878

879
        with field names containing / , we can also copy inside the field:
880
        ``Copy(field="a/0",to_field="a")``
881
        would process instance {"a": [1, 3]} into {"a": 1}
882

883

884
    """
885

886
    def process_value(self, value: Any) -> Any:
1✔
887
        return value
1✔
888

889

890
class RecursiveCopy(FieldOperator):
1✔
891
    def process_value(self, value: Any) -> Any:
1✔
892
        return recursive_copy(value)
1✔
893

894

895
@deprecation(version="2.0.0", alternative=Copy)
1✔
896
class CopyFields(Copy):
1✔
897
    pass
1✔
898

899

900
class GetItemByIndex(FieldOperator):
1✔
901
    """Get from the item list by the index in the field."""
902

903
    items_list: List[Any]
1✔
904

905
    def process_value(self, value: Any) -> Any:
1✔
906
        return self.items_list[value]
×
907

908

909
class AddID(InstanceOperator):
1✔
910
    """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
911

912
    id_field_name: str = "id"
1✔
913

914
    def process(
1✔
915
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
916
    ) -> Dict[str, Any]:
917
        instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
1✔
918
        return instance
1✔
919

920

921
class Cast(FieldOperator):
1✔
922
    """Casts specified fields to specified types.
923

924
    Args:
925
        default (object): A dictionary mapping field names to default values for cases of casting failure.
926
        process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
927
    """
928

929
    to: str
1✔
930
    failure_default: Optional[Any] = "__UNDEFINED__"
1✔
931

932
    def prepare(self):
1✔
933
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
934

935
    def process_value(self, value):
1✔
936
        try:
1✔
937
            return self.types[self.to](value)
1✔
938
        except ValueError as e:
1✔
939
            if self.failure_default == "__UNDEFINED__":
1✔
940
                raise ValueError(
×
941
                    f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
942
                ) from e
943
            return self.failure_default
1✔
944

945

946
class CastFields(InstanceOperator):
1✔
947
    """Casts specified fields to specified types.
948

949
    Args:
950
        fields (Dict[str, str]):
951
            A dictionary mapping field names to the names of the types to cast the fields to.
952
            e.g: "int", "str", "float", "bool". Basic names of types
953
        defaults (Dict[str, object]):
954
            A dictionary mapping field names to default values for cases of casting failure.
955
        process_every_value (bool):
956
            If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
957

958
    Example:
959
        .. code-block:: python
960

961
                CastFields(
962
                    fields={"a/d": "float", "b": "int"},
963
                    failure_defaults={"a/d": 0.0, "b": 0},
964
                    process_every_value=True,
965
                )
966

967
    would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
968
    into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
969

970
    """
971

972
    fields: Dict[str, str] = field(default_factory=dict)
1✔
973
    failure_defaults: Dict[str, object] = field(default_factory=dict)
1✔
974
    use_nested_query: bool = None  # deprecated field
1✔
975
    process_every_value: bool = False
1✔
976

977
    def prepare(self):
1✔
978
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
979

980
    def verify(self):
1✔
981
        super().verify()
1✔
982
        if self.use_nested_query is not None:
1✔
983
            depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
1✔
984
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
985

986
    def _cast_single(self, value, type, field):
1✔
987
        try:
1✔
988
            return self.types[type](value)
1✔
989
        except Exception as e:
1✔
990
            if field not in self.failure_defaults:
1✔
991
                raise ValueError(
1✔
992
                    f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
993
                ) from e
994
            return self.failure_defaults[field]
1✔
995

996
    def _cast_multiple(self, values, type, field):
1✔
997
        return [self._cast_single(value, type, field) for value in values]
1✔
998

999
    def process(
1✔
1000
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1001
    ) -> Dict[str, Any]:
1002
        for field_name, type in self.fields.items():
1✔
1003
            value = dict_get(instance, field_name)
1✔
1004
            if self.process_every_value:
1✔
1005
                assert isinstance(
1✔
1006
                    value, list
1007
                ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
1008
                casted_value = self._cast_multiple(value, type, field_name)
1✔
1009
            else:
1010
                casted_value = self._cast_single(value, type, field_name)
1✔
1011

1012
            dict_set(instance, field_name, casted_value)
1✔
1013
        return instance
1✔
1014

1015

1016
class DivideAllFieldsBy(InstanceOperator):
1✔
1017
    """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1018

1019
    The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
1020
    the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
1021
    if input flag 'strict' is True, or -- left alone, if 'strict' is False.
1022

1023
    Args:
1024
        divisor (float) the value to divide by
1025
        strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
1026

1027
    Example:
1028
        when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
1029
        operator = DivideAllFieldsBy(divisor=2.0)
1030
        the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
1031
        If the operator were defined with strict=True, through:
1032
        operator = DivideAllFieldsBy(divisor=2.0, strict=True),
1033
        the processing of the above instance would raise a ValueError, for the integer at "c".
1034
    """
1035

1036
    divisor: float = 1.0
1✔
1037
    strict: bool = False
1✔
1038

1039
    def _recursive_divide(self, instance, divisor):
1✔
1040
        if isinstance(instance, dict):
1✔
1041
            for key, value in instance.items():
1✔
1042
                instance[key] = self._recursive_divide(value, divisor)
1✔
1043
        elif isinstance(instance, list):
1✔
1044
            for i, value in enumerate(instance):
1✔
1045
                instance[i] = self._recursive_divide(value, divisor)
1✔
1046
        elif isinstance(instance, float):
1✔
1047
            instance /= divisor
1✔
1048
        elif self.strict:
1✔
1049
            raise ValueError(f"Cannot divide instance of type {type(instance)}")
1✔
1050
        return instance
1✔
1051

1052
    def process(
1✔
1053
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1054
    ) -> Dict[str, Any]:
1055
        return self._recursive_divide(instance, self.divisor)
1✔
1056

1057

1058
class ArtifactFetcherMixin:
1✔
1059
    """Provides a way to fetch and cache artifacts in the system.
1060

1061
    Args:
1062
        cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
1063
    """
1064

1065
    _artifacts_cache = LRUCache(max_size=1000)
1✔
1066

1067
    @classmethod
1✔
1068
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
1✔
1069
        if str(artifact_identifier) not in cls._artifacts_cache:
1✔
1070
            artifact, catalog = fetch_artifact(artifact_identifier)
1✔
1071
            cls._artifacts_cache[str(artifact_identifier)] = artifact
1✔
1072
        return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])
1✔
1073

1074

1075
class ApplyOperatorsField(InstanceOperator):
1✔
1076
    """Applies value operators to each instance in a stream based on specified fields.
1077

1078
    Args:
1079
        operators_field (str): name of the field that contains a single name, or a list of names, of the operators to be applied,
1080
            one after the other, for the processing of the instance. Each operator is equipped with 'process_instance()'
1081
            method.
1082

1083
        default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
1084

1085
    Example:
1086
        when instance {"prediction": 111, "references": [222, 333] , "c": ["processors.to_string", "processors.first_character"]}
1087
        is processed by operator (please look up the catalog that these operators, they are tuned to process fields "prediction" and
1088
        "references"):
1089
        operator = ApplyOperatorsField(operators_field="c"),
1090
        the resulting instance is: {"prediction": "1", "references": ["2", "3"], "c": ["processors.to_string", "processors.first_character"]}
1091

1092
    """
1093

1094
    operators_field: str
1✔
1095
    default_operators: List[str] = None
1✔
1096

1097
    def process(
1✔
1098
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1099
    ) -> Dict[str, Any]:
1100
        operator_names = instance.get(self.operators_field)
1✔
1101
        if operator_names is None:
1✔
1102
            assert (
1✔
1103
                self.default_operators is not None
1104
            ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1105
            operator_names = self.default_operators
1✔
1106

1107
        if isinstance(operator_names, str):
1✔
1108
            operator_names = [operator_names]
1✔
1109
        # otherwise , operator_names is already a list
1110

1111
        # we now have a list of nanes of operators, each is equipped with process_instance method.
1112
        operator = SequentialOperator(steps=operator_names)
1✔
1113
        return operator.process_instance(instance, stream_name=stream_name)
1✔
1114

1115

1116
class FilterByCondition(StreamOperator):
1✔
1117
    """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1118

1119
    Raises an error if a required field name is missing from the input instance.
1120

1121
    Args:
1122
       values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
1123

1124
       condition: the name of the desired condition operator between the specified (sub) field's value  and the provided constant value.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1125

1126
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1127

1128
    Examples:
1129
       | ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
1130
       | ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
1131
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
1132
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
1133
       | ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
1134
       | ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
1135

1136

1137
    """
1138

1139
    values: Dict[str, Any]
1✔
1140
    condition: str
1✔
1141
    condition_to_func = {
1✔
1142
        "gt": operator.gt,
1143
        "ge": operator.ge,
1144
        "lt": operator.lt,
1145
        "le": operator.le,
1146
        "eq": operator.eq,
1147
        "ne": operator.ne,
1148
        "in": None,  # Handled as special case
1149
        "not in": None,  # Handled as special case
1150
    }
1151
    error_on_filtered_all: bool = True
1✔
1152

1153
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1154
        yielded = False
1✔
1155
        for instance in stream:
1✔
1156
            if self._is_required(instance):
1✔
1157
                yielded = True
1✔
1158
                yield instance
1✔
1159

1160
        if not yielded and self.error_on_filtered_all:
1✔
1161
            raise RuntimeError(
1✔
1162
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1163
            )
1164

1165
    def verify(self):
1✔
1166
        if self.condition not in self.condition_to_func:
1✔
1167
            raise ValueError(
1✔
1168
                f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1169
            )
1170

1171
        for key, value in self.values.items():
1✔
1172
            if self.condition in ["in", "not it"] and not isinstance(value, list):
1✔
1173
                raise ValueError(
1✔
1174
                    f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1175
                )
1176
        return super().verify()
1✔
1177

1178
    def _is_required(self, instance: dict) -> bool:
1✔
1179
        for key, value in self.values.items():
1✔
1180
            try:
1✔
1181
                instance_key = dict_get(instance, key)
1✔
1182
            except ValueError as ve:
1✔
1183
                raise ValueError(
1✔
1184
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1185
                ) from ve
1186
            if self.condition == "in":
1✔
1187
                if instance_key not in value:
1✔
1188
                    return False
1✔
1189
            elif self.condition == "not in":
1✔
1190
                if instance_key in value:
1✔
1191
                    return False
1✔
1192
            else:
1193
                func = self.condition_to_func[self.condition]
1✔
1194
                if func is None:
1✔
1195
                    raise ValueError(
×
1196
                        f"Function not defined for condition '{self.condition}'"
1197
                    )
1198
                if not func(instance_key, value):
1✔
1199
                    return False
1✔
1200
        return True
1✔
1201

1202

1203
class FilterByConditionBasedOnFields(FilterByCondition):
1✔
1204
    """Filters a stream based on a condition between 2 fields values.
1205

1206
    Raises an error if either of the required fields names is missing from the input instance.
1207

1208
    Args:
1209
       values (Dict[str, str]): The fields names that the filter operation is based on.
1210
       condition: the name of the desired condition operator between the specified field's values.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1211
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1212

1213
    Examples:
1214
       FilterByCondition(values = {"a":"b}, condition = "gt") will yield only instances where field "a" contains a value greater then the value in field "b".
1215
       FilterByCondition(values = {"a":"b}, condition = "le") will yield only instances where "a"<="b"
1216
    """
1217

1218
    def _is_required(self, instance: dict) -> bool:
1✔
1219
        for key, value in self.values.items():
1✔
1220
            try:
1✔
1221
                instance_key = dict_get(instance, key)
1✔
1222
            except ValueError as ve:
×
1223
                raise ValueError(
×
1224
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1225
                ) from ve
1226
            try:
1✔
1227
                instance_value = dict_get(instance, value)
1✔
1228
            except ValueError as ve:
×
1229
                raise ValueError(
×
1230
                    f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1231
                ) from ve
1232
            if self.condition == "in":
1✔
1233
                if instance_key not in instance_value:
×
1234
                    return False
×
1235
            elif self.condition == "not in":
1✔
1236
                if instance_key in instance_value:
×
1237
                    return False
×
1238
            else:
1239
                func = self.condition_to_func[self.condition]
1✔
1240
                if func is None:
1✔
1241
                    raise ValueError(
×
1242
                        f"Function not defined for condition '{self.condition}'"
1243
                    )
1244
                if not func(instance_key, instance_value):
1✔
1245
                    return False
×
1246
        return True
1✔
1247

1248

1249
class ComputeExpressionMixin(Artifact):
1✔
1250
    """Computes an expression expressed over fields of an instance.
1251

1252
    Args:
1253
        expression (str): the expression, in terms of names of fields of an instance
1254
        imports_list (List[str]): list of names of imports needed for the evaluation of the expression
1255
    """
1256

1257
    expression: str
1✔
1258
    imports_list: List[str] = OptionalField(default_factory=list)
1✔
1259

1260
    def prepare(self):
1✔
1261
        # can not do the imports here, because object does not pickle with imports
1262
        self.globals = {
1✔
1263
            module_name: __import__(module_name) for module_name in self.imports_list
1264
        }
1265

1266
    def compute_expression(self, instance: dict) -> Any:
1✔
1267
        if settings.allow_unverified_code:
1✔
1268
            return eval(self.expression, {**self.globals, **instance})
1✔
1269

1270
        raise ValueError(
×
1271
            f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1272
            "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1273
        )
1274

1275

1276
class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1✔
1277
    """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1278

1279
    Raises an error if a field participating in the specified condition is missing from the instance
1280

1281
    Args:
1282
        expression (str):
1283
            a condition over fields of the instance, to be processed by python's eval()
1284
        imports_list (List[str]):
1285
            names of imports needed for the eval of the query (e.g. 're', 'json')
1286
        error_on_filtered_all (bool, optional):
1287
            If True, raises an error if all instances are filtered out. Defaults to True.
1288

1289
    Examples:
1290
        | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1291
        | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1292
        | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1293
        | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1294
        | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1295
    """
1296

1297
    error_on_filtered_all: bool = True
1✔
1298

1299
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1300
        yielded = False
1✔
1301
        for instance in stream:
1✔
1302
            if self.compute_expression(instance):
1✔
1303
                yielded = True
1✔
1304
                yield instance
1✔
1305

1306
        if not yielded and self.error_on_filtered_all:
1✔
1307
            raise RuntimeError(
1✔
1308
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1309
            )
1310

1311

1312
class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1✔
1313
    """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1314

1315
    Raises an error if a field mentioned in the query is missing from the instance.
1316

1317
    Args:
1318
       expression (str): an expression to be evaluated over the fields of the instance
1319
       to_field (str): the field where the result is to be stored into
1320
       imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1321

1322
    Examples:
1323
       When instance {"a": 2, "b": 3} is process-ed by operator
1324
       ExecuteExpression(expression="a+b", to_field = "c")
1325
       the result is {"a": 2, "b": 3, "c": 5}
1326

1327
       When instance {"a": "hello", "b": "world"} is process-ed by operator
1328
       ExecuteExpression(expression = "a+' '+b", to_field = "c")
1329
       the result is {"a": "hello", "b": "world", "c": "hello world"}
1330

1331
    """
1332

1333
    to_field: str
1✔
1334

1335
    def process(
1✔
1336
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1337
    ) -> Dict[str, Any]:
1338
        instance[self.to_field] = self.compute_expression(instance)
1✔
1339
        return instance
1✔
1340

1341

1342
class ExtractMostCommonFieldValues(MultiStreamOperator):
1✔
1343
    field: str
1✔
1344
    stream_name: str
1✔
1345
    overall_top_frequency_percent: Optional[int] = 100
1✔
1346
    min_frequency_percent: Optional[int] = 0
1✔
1347
    to_field: str
1✔
1348
    process_every_value: Optional[bool] = False
1✔
1349

1350
    """
1351
    Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
1352
    as a list in a new field ('to_field') in all streams.
1353

1354
    More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
1355
    When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
1356
    the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
1357
    When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
1358
    less than 'min_frequency_percent' of the total number of instances in the stream.
1359
    At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
1360

1361
    Examples:
1362

1363
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1364
    field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1365
    every instance in all streams.
1366

1367
    ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1368
    in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1369
    value members in these lists, and report the most frequent values.
1370
    if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1371
    'to_field' of each instance of all streams.
1372

1373
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1374
    extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1375
    and stores them in field 'classes' of each instance of all streams.
1376

1377
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1378
    extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1379
    Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1380
    """
1381

1382
    def verify(self):
1✔
1383
        assert (
1✔
1384
            self.overall_top_frequency_percent <= 100
1385
            and self.overall_top_frequency_percent >= 0
1386
        ), "'overall_top_frequency_percent' must be between 0 and 100"
1387
        assert (
1✔
1388
            self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
1389
        ), "'min_frequency_percent' must be between 0 and 100"
1390
        assert not (
1✔
1391
            self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
1392
        ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
1393
        super().verify()
1✔
1394

1395
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1396
        stream = multi_stream[self.stream_name]
1✔
1397
        counter = Counter()
1✔
1398
        for instance in stream:
1✔
1399
            if (not isinstance(instance[self.field], list)) and (
1✔
1400
                self.process_every_value is True
1401
            ):
1402
                raise ValueError(
1✔
1403
                    "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
1404
                )
1405
            if (not isinstance(instance[self.field], list)) or (
1✔
1406
                self.process_every_value is False
1407
            ):
1408
                # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
1409
                counter.update(
1✔
1410
                    [(*instance[self.field],)]
1411
                    if isinstance(instance[self.field], list)
1412
                    else [instance[self.field]]
1413
                )  # convert to a tuple if list, to enable the use of Counter which would not accept
1414
                # a list as an hashable entity to count its occurrences
1415
            else:
1416
                # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
1417
                counter.update(instance[self.field])
1✔
1418
        # here counter counts occurrences of individual values, or tuples.
1419
        values_and_counts = counter.most_common()
1✔
1420
        if self.overall_top_frequency_percent < 100:
1✔
1421
            top_frequency = (
1✔
1422
                sum(counter.values()) * self.overall_top_frequency_percent / 100.0
1423
            )
1424
            sum_counts = 0
1✔
1425
            for _i, p in enumerate(values_and_counts):
1✔
1426
                sum_counts += p[1]
1✔
1427
                if sum_counts >= top_frequency:
1✔
1428
                    break
1✔
1429
            values_and_counts = counter.most_common(_i + 1)
1✔
1430
        if self.min_frequency_percent > 0:
1✔
1431
            min_frequency = self.min_frequency_percent * sum(counter.values()) / 100.0
1✔
1432
            while values_and_counts[-1][1] < min_frequency:
1✔
1433
                values_and_counts.pop()
1✔
1434
        values_to_keep = [
1✔
1435
            [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1436
            for ele in values_and_counts
1437
        ]
1438

1439
        addmostcommons = Set(fields={self.to_field: values_to_keep})
1✔
1440
        return addmostcommons(multi_stream)
1✔
1441

1442

1443
class ExtractFieldValues(ExtractMostCommonFieldValues):
1✔
1444
    def verify(self):
1✔
1445
        super().verify()
1✔
1446

1447
    def prepare(self):
1✔
1448
        self.overall_top_frequency_percent = 100
1✔
1449
        self.min_frequency_percent = 0
1✔
1450

1451

1452
class Intersect(FieldOperator):
1✔
1453
    """Intersects the value of a field, which must be a list, with a given list.
1454

1455
    Args:
1456
        allowed_values (list) - list to intersect.
1457
    """
1458

1459
    allowed_values: List[Any]
1✔
1460

1461
    def verify(self):
1✔
1462
        super().verify()
1✔
1463
        if self.process_every_value:
1✔
1464
            raise ValueError(
1✔
1465
                "'process_every_value=True' is not supported in Intersect operator"
1466
            )
1467

1468
        if not isinstance(self.allowed_values, list):
1✔
1469
            raise ValueError(
1✔
1470
                f"The allowed_values is not a list but '{self.allowed_values}'"
1471
            )
1472

1473
    def process_value(self, value: Any) -> Any:
1✔
1474
        super().process_value(value)
1✔
1475
        if not isinstance(value, list):
1✔
1476
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1477
        return [e for e in value if e in self.allowed_values]
1✔
1478

1479

1480
class IntersectCorrespondingFields(InstanceOperator):
1✔
1481
    """Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
1482

1483
    For example:
1484

1485
    Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
1486

1487
    IntersectCorrespondingFields(field="label",
1488
                                 allowed_values=["b", "f"],
1489
                                 corresponding_fields_to_intersect=["position"])
1490

1491
    would keep only "b" and "f" values in 'labels' field and
1492
    their respective values in the 'position' field.
1493
    (All other fields are not effected)
1494

1495
    Given this input:
1496

1497
    [
1498
        {"label": ["a", "b"],"position": [0,1],"other" : "not"},
1499
        {"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
1500
        {"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
1501
    ]
1502

1503
    So the output would be:
1504
    [
1505
            {"label": ["b"], "position":[1],"other" : "not"},
1506
            {"label": [], "position": [], "other" : "relevant"},
1507
            {"label": ["b", "f"],"position": [1,2], "other" : "field"},
1508
    ]
1509

1510
    Args:
1511
        field - the field to intersected (must contain list values)
1512
        allowed_values (list) - list of values to keep
1513
        corresponding_fields_to_intersect (list) - additional list fields from which values
1514
        are removed based the corresponding indices of values removed from the 'field'
1515
    """
1516

1517
    field: str
1✔
1518
    allowed_values: List[str]
1✔
1519
    corresponding_fields_to_intersect: List[str]
1✔
1520

1521
    def verify(self):
1✔
1522
        super().verify()
1✔
1523

1524
        if not isinstance(self.allowed_values, list):
1✔
1525
            raise ValueError(
×
1526
                f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
1527
            )
1528

1529
    def process(
1✔
1530
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1531
    ) -> Dict[str, Any]:
1532
        if self.field not in instance:
1✔
1533
            raise ValueError(
1✔
1534
                f"Field '{self.field}' is not in provided instance.\n"
1535
                + to_pretty_string(instance)
1536
            )
1537

1538
        for corresponding_field in self.corresponding_fields_to_intersect:
1✔
1539
            if corresponding_field not in instance:
1✔
1540
                raise ValueError(
1✔
1541
                    f"Field '{corresponding_field}' is not in provided instance.\n"
1542
                    + to_pretty_string(instance)
1543
                )
1544

1545
        if not isinstance(instance[self.field], list):
1✔
1546
            raise ValueError(
1✔
1547
                f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
1548
                + to_pretty_string(instance, keys=[self.field])
1549
            )
1550

1551
        num_values_in_field = len(instance[self.field])
1✔
1552

1553
        if set(self.allowed_values) == set(instance[self.field]):
1✔
1554
            return instance
×
1555

1556
        indices_to_keep = [
1✔
1557
            i
1558
            for i, value in enumerate(instance[self.field])
1559
            if value in set(self.allowed_values)
1560
        ]
1561

1562
        result_instance = {}
1✔
1563
        for field_name, field_value in instance.items():
1✔
1564
            if (
1✔
1565
                field_name in self.corresponding_fields_to_intersect
1566
                or field_name == self.field
1567
            ):
1568
                if not isinstance(field_value, list):
1✔
1569
                    raise ValueError(
×
1570
                        f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
1571
                    )
1572
                if len(field_value) != num_values_in_field:
1✔
1573
                    raise ValueError(
1✔
1574
                        f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
1575
                        + to_pretty_string(instance, keys=[self.field, field_name])
1576
                    )
1577
                result_instance[field_name] = [
1✔
1578
                    value
1579
                    for index, value in enumerate(field_value)
1580
                    if index in indices_to_keep
1581
                ]
1582
            else:
1583
                result_instance[field_name] = field_value
1✔
1584
        return result_instance
1✔
1585

1586

1587
class RemoveValues(FieldOperator):
1✔
1588
    """Removes elements in a field, which must be a list, using a given list of unallowed.
1589

1590
    Args:
1591
        unallowed_values (list) - values to be removed.
1592
    """
1593

1594
    unallowed_values: List[Any]
1✔
1595

1596
    def verify(self):
1✔
1597
        super().verify()
1✔
1598

1599
        if not isinstance(self.unallowed_values, list):
1✔
1600
            raise ValueError(
1✔
1601
                f"The unallowed_values is not a list but '{self.unallowed_values}'"
1602
            )
1603

1604
    def process_value(self, value: Any) -> Any:
1✔
1605
        if not isinstance(value, list):
1✔
1606
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1607
        return [e for e in value if e not in self.unallowed_values]
1✔
1608

1609

1610
class Unique(SingleStreamReducer):
1✔
1611
    """Reduces a stream to unique instances based on specified fields.
1612

1613
    Args:
1614
        fields (List[str]): The fields that should be unique in each instance.
1615
    """
1616

1617
    fields: List[str] = field(default_factory=list)
1✔
1618

1619
    @staticmethod
1✔
1620
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
1✔
1621
        result = []
1✔
1622
        for field_name in fields:
1✔
1623
            value = instance[field_name]
1✔
1624
            if isinstance(value, list):
1✔
1625
                value = tuple(value)
1✔
1626
            result.append(value)
1✔
1627
        return tuple(result)
1✔
1628

1629
    def process(self, stream: Stream) -> Stream:
1✔
1630
        seen = set()
1✔
1631
        for instance in stream:
1✔
1632
            values = self.to_tuple(instance, self.fields)
1✔
1633
            if values not in seen:
1✔
1634
                seen.add(values)
1✔
1635
        return list(seen)
1✔
1636

1637

1638
class SplitByValue(MultiStreamOperator):
1✔
1639
    """Splits a MultiStream into multiple streams based on unique values in specified fields.
1640

1641
    Args:
1642
        fields (List[str]): The fields to use when splitting the MultiStream.
1643
    """
1644

1645
    fields: List[str] = field(default_factory=list)
1✔
1646

1647
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1648
        uniques = Unique(fields=self.fields)(multi_stream)
1✔
1649

1650
        result = {}
1✔
1651

1652
        for stream_name, stream in multi_stream.items():
1✔
1653
            stream_unique_values = uniques[stream_name]
1✔
1654
            for unique_values in stream_unique_values:
1✔
1655
                filtering_values = dict(zip(self.fields, unique_values))
1✔
1656
                filtered_streams = FilterByCondition(
1✔
1657
                    values=filtering_values, condition="eq"
1658
                )._process_single_stream(stream)
1659
                filtered_stream_name = (
1✔
1660
                    stream_name + "_" + nested_tuple_to_string(unique_values)
1661
                )
1662
                result[filtered_stream_name] = filtered_streams
1✔
1663

1664
        return MultiStream(result)
1✔
1665

1666

1667
class SplitByNestedGroup(MultiStreamOperator):
1✔
1668
    """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1669

1670
    Args:
1671
        number_of_fusion_generations: int
1672

1673
    the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1674
    when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1675
    (See BaseFusion and its extensions)
1676
    number_of_fuaion_generations  specifies the length of the prefix by which to split the stream.
1677
    E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1678
    For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1679
    """
1680

1681
    field_name_of_group: str = "group"
1✔
1682
    number_of_fusion_generations: int = 1
1✔
1683

1684
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1685
        result = defaultdict(list)
×
1686

1687
        for stream_name, stream in multi_stream.items():
×
1688
            for instance in stream:
×
1689
                if self.field_name_of_group not in instance:
×
1690
                    raise ValueError(
×
1691
                        f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1692
                    )
1693
                signature = (
×
1694
                    stream_name
1695
                    + "~"  #  a sign that does not show within group values
1696
                    + (
1697
                        "/".join(
1698
                            instance[self.field_name_of_group].split("/")[
1699
                                : self.number_of_fusion_generations
1700
                            ]
1701
                        )
1702
                        if self.number_of_fusion_generations >= 0
1703
                        # for values with a smaller number of generations - take up to their last generation
1704
                        else instance[self.field_name_of_group]
1705
                        # for each instance - take all its generations
1706
                    )
1707
                )
1708
                result[signature].append(instance)
×
1709

1710
        return MultiStream.from_iterables(result)
×
1711

1712

1713
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1✔
1714
    """Applies stream operators to a stream based on specified fields in each instance.
1715

1716
    Args:
1717
        field (str): The field containing the operators to be applied.
1718
        reversed (bool): Whether to apply the operators in reverse order.
1719
    """
1720

1721
    field: str
1✔
1722
    reversed: bool = False
1✔
1723

1724
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1725
        first_instance = stream.peek()
1✔
1726

1727
        operators = first_instance.get(self.field, [])
1✔
1728
        if isinstance(operators, str):
1✔
1729
            operators = [operators]
1✔
1730

1731
        if self.reversed:
1✔
1732
            operators = list(reversed(operators))
1✔
1733

1734
        for operator_name in operators:
1✔
1735
            operator = self.get_artifact(operator_name)
1✔
1736
            assert isinstance(
1✔
1737
                operator, StreamingOperator
1738
            ), f"Operator {operator_name} must be a StreamOperator"
1739

1740
            stream = operator(MultiStream({stream_name: stream}))[stream_name]
1✔
1741

1742
        yield from stream
1✔
1743

1744

1745
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
1✔
1746
    for instance, score in zip(stream, scores):
1✔
1747
        instance["score"] = recursive_copy(score)
1✔
1748
        yield instance
1✔
1749

1750

1751
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1✔
1752
    """Applies metric operators to a stream based on a metric field specified in each instance.
1753

1754
    Args:
1755
        metric_field (str): The field containing the metrics to be applied.
1756
        calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1757
    """
1758

1759
    metric_field: str
1✔
1760
    calc_confidence_intervals: bool
1✔
1761

1762
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1763
        from .metrics import Metric, MetricsList
1✔
1764

1765
        # to be populated only when two or more metrics
1766
        accumulated_scores = []
1✔
1767

1768
        first_instance = stream.peek()
1✔
1769

1770
        metric_names = first_instance.get(self.metric_field, [])
1✔
1771
        if not metric_names:
1✔
1772
            raise RuntimeError(
1✔
1773
                f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1774
            )
1775

1776
        if isinstance(metric_names, str):
1✔
1777
            metric_names = [metric_names]
1✔
1778

1779
        metrics_list = []
1✔
1780
        for metric_name in metric_names:
1✔
1781
            metric = self.get_artifact(metric_name)
1✔
1782
            if isinstance(metric, MetricsList):
1✔
1783
                metrics_list.extend(list(metric.items))
1✔
1784
            elif isinstance(metric, Metric):
1✔
1785
                metrics_list.append(metric)
1✔
1786
            else:
1787
                raise ValueError(
×
1788
                    f"Operator {metric_name} must be a Metric or MetricsList"
1789
                )
1790

1791
        for metric in metrics_list:
1✔
1792
            if not self.calc_confidence_intervals:
1✔
1793
                metric.disable_confidence_interval_calculation()
1✔
1794
        # Each metric operator computes its score and then sets the main score, overwriting
1795
        # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1796
        # This will cause the first listed metric to run last, and the main score will be set
1797
        # by the first listed metric (as desired).
1798
        metrics_list = list(reversed(metrics_list))
1✔
1799

1800
        for i, metric in enumerate(metrics_list):
1✔
1801
            if i == 0:  # first metric
1✔
1802
                multi_stream = MultiStream({"tmp": stream})
1✔
1803
            else:  # metrics with previous scores
1804
                reusable_generator = ReusableGenerator(
1✔
1805
                    generator=update_scores_of_stream_instances,
1806
                    gen_kwargs={"stream": stream, "scores": accumulated_scores},
1807
                )
1808
                multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1✔
1809

1810
            multi_stream = metric(multi_stream)
1✔
1811

1812
            if i < len(metrics_list) - 1:  # last metric
1✔
1813
                accumulated_scores = []
1✔
1814
                for inst in multi_stream["tmp"]:
1✔
1815
                    accumulated_scores.append(recursive_copy(inst["score"]))
1✔
1816

1817
        yield from multi_stream["tmp"]
1✔
1818

1819

1820
class MergeStreams(MultiStreamOperator):
1✔
1821
    """Merges multiple streams into a single stream.
1822

1823
    Args:
1824
        new_stream_name (str): The name of the new stream resulting from the merge.
1825
        add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
1826
        origin_stream_name_field_name (str): The field name for the origin stream name.
1827
    """
1828

1829
    streams_to_merge: List[str] = None
1✔
1830
    new_stream_name: str = "all"
1✔
1831
    add_origin_stream_name: bool = True
1✔
1832
    origin_stream_name_field_name: str = "origin"
1✔
1833

1834
    def merge(self, multi_stream) -> Generator:
1✔
1835
        for stream_name, stream in multi_stream.items():
1✔
1836
            if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1✔
1837
                for instance in stream:
1✔
1838
                    if self.add_origin_stream_name:
1✔
1839
                        instance[self.origin_stream_name_field_name] = stream_name
1✔
1840
                    yield instance
1✔
1841

1842
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1843
        return MultiStream(
1✔
1844
            {
1845
                self.new_stream_name: DynamicStream(
1846
                    self.merge, gen_kwargs={"multi_stream": multi_stream}
1847
                )
1848
            }
1849
        )
1850

1851

1852
class Shuffle(PagedStreamOperator):
1✔
1853
    """Shuffles the order of instances in each page of a stream.
1854

1855
    Args (of superclass):
1856
        page_size (int): The size of each page in the stream. Defaults to 1000.
1857
    """
1858

1859
    random_generator: Random = None
1✔
1860

1861
    def before_process_multi_stream(self):
1✔
1862
        super().before_process_multi_stream()
1✔
1863
        self.random_generator = new_random_generator(sub_seed="shuffle")
1✔
1864

1865
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1866
        self.random_generator.shuffle(page)
1✔
1867
        yield from page
1✔
1868

1869

1870
class FeatureGroupedShuffle(Shuffle):
1✔
1871
    """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1872

1873
    Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1874
    All paraphrases have the same ID value as the original.
1875
    In this case, we may want to shuffle on grouping_features = ['question ID'],
1876
    to keep the paraphrases and original question together.
1877
    We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1878
    In this case, grouping_features = ['question ID', 'topic']
1879

1880
    Args:
1881
        grouping_features (list of strings): list of feature names to use to define the groups.
1882
            a group is defined by each unique observed combination of data values for features in grouping_features
1883
        shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1884

1885
    Args (of superclass):
1886
        page_size (int): The size of each page in the stream. Defaults to 1000.
1887
            Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1888
            separately by page (determined by page_size).  If a block of instances in the same group are split
1889
            into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1890
            grouping_features), these instances will be shuffled separately and thus the grouping may be
1891
            broken up by pages.  If the user wants to ensure the shuffle does the grouping and shuffling
1892
            across all pages, set the page_size to be larger than the dataset size.
1893
            See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1894
    """
1895

1896
    grouping_features: List[str] = None
1✔
1897
    shuffle_within_group: bool = False
1✔
1898

1899
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1900
        if self.grouping_features is None:
1✔
1901
            super().process(page, stream_name)
×
1902
        else:
1903
            yield from self.shuffle_by_grouping_features(page)
1✔
1904

1905
    def shuffle_by_grouping_features(self, page):
1✔
1906
        import itertools
1✔
1907
        from collections import defaultdict
1✔
1908

1909
        groups_to_instances = defaultdict(list)
1✔
1910
        for item in page:
1✔
1911
            groups_to_instances[
1✔
1912
                tuple(item[ff] for ff in self.grouping_features)
1913
            ].append(item)
1914
        # now extract the groups (i.e., lists of dicts with order preserved)
1915
        page_blocks = list(groups_to_instances.values())
1✔
1916
        # and now shuffle the blocks
1917
        self.random_generator.shuffle(page_blocks)
1✔
1918
        if self.shuffle_within_group:
1✔
1919
            blocks = []
1✔
1920
            # reshuffle the instances within each block, but keep the blocks in order
1921
            for block in page_blocks:
1✔
1922
                self.random_generator.shuffle(block)
1✔
1923
                blocks.append(block)
1✔
1924
            page_blocks = blocks
1✔
1925

1926
        # now flatten the list so it consists of individual dicts, but in (randomized) block order
1927
        return list(itertools.chain(*page_blocks))
1✔
1928

1929

1930
class EncodeLabels(InstanceOperator):
1✔
1931
    """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1932

1933
    Encoding is determined by a str->int map that is built on the go, as different values are
1934
    first encountered in the stream, either as list members or as values in single-value fields.
1935

1936
    Args:
1937
        fields (List[str]): The fields to encode together.
1938

1939
    Example:
1940
        applying ``EncodeLabels(fields = ["a", "b/*"])``
1941
        on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
1942
        {"a": "blue", "b": ["green"], "c":"water"}]``   will yield the
1943
        output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
1944

1945
        Note: dict_utils are applied here, and hence, fields that are lists, should be included in
1946
        input 'fields' with the appendix ``"/*"``  as in the above example.
1947

1948
    """
1949

1950
    fields: List[str]
1✔
1951

1952
    def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
1✔
1953
        self.encoder = {}
1✔
1954
        return super()._process_multi_stream(multi_stream)
1✔
1955

1956
    def process(
1✔
1957
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1958
    ) -> Dict[str, Any]:
1959
        for field_name in self.fields:
1✔
1960
            values = dict_get(instance, field_name)
1✔
1961
            values_was_a_list = isinstance(values, list)
1✔
1962
            if not isinstance(values, list):
1✔
1963
                values = [values]
1✔
1964
            for value in values:
1✔
1965
                if value not in self.encoder:
1✔
1966
                    self.encoder[value] = len(self.encoder)
1✔
1967
            new_values = [self.encoder[value] for value in values]
1✔
1968
            if not values_was_a_list:
1✔
1969
                new_values = new_values[0]
1✔
1970
            dict_set(
1✔
1971
                instance,
1972
                field_name,
1973
                new_values,
1974
                not_exist_ok=False,  # the values to encode where just taken from there
1975
                set_multiple="*" in field_name
1976
                and isinstance(new_values, list)
1977
                and len(new_values) > 0,
1978
            )
1979

1980
        return instance
1✔
1981

1982

1983
class StreamRefiner(StreamOperator):
1✔
1984
    """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1985

1986
    Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
1987
    input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1988
    of the leading 'max_instances' of the input stream.
1989

1990
    Args:
1991
        max_instances (int)
1992
        apply_to_streams (optional, list(str)):
1993
            names of streams to refine.
1994

1995
    Examples:
1996
        when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
1997
        ``StreamRefiner(max_instances=4)``
1998
        the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
1999
    """
2000

2001
    max_instances: int = None
1✔
2002
    apply_to_streams: Optional[List[str]] = None
1✔
2003

2004
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2005
        if self.max_instances is not None:
1✔
2006
            yield from stream.take(self.max_instances)
1✔
2007
        else:
2008
            yield from stream
1✔
2009

2010

2011
class Deduplicate(StreamOperator):
1✔
2012
    """Deduplicate the stream based on the given fields.
2013

2014
    Args:
2015
        by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness.
2016

2017
    Examples:
2018
        >>> dedup = Deduplicate(by=["field1", "field2"])
2019
    """
2020

2021
    by: List[str]
1✔
2022

2023
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2024
        seen = set()
1✔
2025

2026
        for instance in stream:
1✔
2027
            # Compute a lightweight hash for the signature
2028
            signature = hash(str(tuple(dict_get(instance, field) for field in self.by)))
1✔
2029

2030
            if signature not in seen:
1✔
2031
                seen.add(signature)
1✔
2032
                yield instance
1✔
2033

2034

2035
class Balance(StreamRefiner):
1✔
2036
    """A class used to balance streams deterministically.
2037

2038
    For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
2039
    By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
2040
    When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
2041
    'max_instances'. The total number of discarded instances is as few as possible.
2042

2043
    Args:
2044
        fields (List[str]):
2045
            A list of field names to be used in producing the instance's signature.
2046
        max_instances (Optional, int):
2047
            overall max.
2048

2049
    Usage:
2050
        ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
2051
        ``balanced_stream = balancer.process(stream)``
2052

2053
    Example:
2054
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
2055
        ``DeterministicBalancer(fields=["a"])``
2056
        the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
2057
    """
2058

2059
    fields: List[str]
1✔
2060

2061
    def signature(self, instance):
1✔
2062
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2063

2064
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2065
        counter = Counter()
1✔
2066

2067
        for instance in stream:
1✔
2068
            counter[self.signature(instance)] += 1
1✔
2069

2070
        if len(counter) == 0:
1✔
2071
            return
1✔
2072

2073
        lowest_count = counter.most_common()[-1][-1]
1✔
2074

2075
        max_total_instances_per_sign = lowest_count
1✔
2076
        if self.max_instances is not None:
1✔
2077
            max_total_instances_per_sign = min(
1✔
2078
                lowest_count, self.max_instances // len(counter)
2079
            )
2080

2081
        counter = Counter()
1✔
2082

2083
        for instance in stream:
1✔
2084
            sign = self.signature(instance)
1✔
2085
            if counter[sign] < max_total_instances_per_sign:
1✔
2086
                counter[sign] += 1
1✔
2087
                yield instance
1✔
2088

2089

2090
class DeterministicBalancer(Balance):
1✔
2091
    pass
1✔
2092

2093

2094
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1✔
2095
    """A class used to return a specified number instances ensuring at least one example  per label.
2096

2097
    For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
2098
    ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit.  In general, the refiner takes the first elements in the stream that meet the required conditions.
2099
    ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
2100
    from each class first and then the rest . If max instance is not set, the original stream will be used
2101

2102
    Args:
2103
        fields (List[str]):
2104
            A list of field names to be used in producing the instance's signature.
2105
        max_instances (Optional, int):
2106
            Number of elements to select. Note that max_instances of StreamRefiners
2107
            that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
2108
            by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
2109

2110
    Usage:
2111
        | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
2112
        | ``balanced_stream = balancer.process(stream)``
2113

2114
    Example:
2115
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
2116
        ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
2117
        the resulting stream will be:
2118
        ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
2119
    """
2120

2121
    fields: List[str]
1✔
2122

2123
    def signature(self, instance):
1✔
2124
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2125

2126
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2127
        if self.max_instances is None:
1✔
2128
            for instance in stream:
×
2129
                yield instance
×
2130

2131
        counter = Counter()
1✔
2132
        for instance in stream:
1✔
2133
            counter[self.signature(instance)] += 1
1✔
2134
        all_keys = counter.keys()
1✔
2135
        if len(counter) == 0:
1✔
2136
            return
×
2137

2138
        if self.max_instances is not None and len(all_keys) > self.max_instances:
1✔
2139
            raise Exception(
×
2140
                f"Can not generate a stream with at least one example per label, because the max instances requested  {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2141
                f" ({len(all_keys)}"
2142
            )
2143

2144
        counter = Counter()
1✔
2145
        used_indices = set()
1✔
2146
        selected_elements = []
1✔
2147
        # select at least one per class
2148
        for idx, instance in enumerate(stream):
1✔
2149
            sign = self.signature(instance)
1✔
2150
            if counter[sign] == 0:
1✔
2151
                counter[sign] += 1
1✔
2152
                used_indices.add(idx)
1✔
2153
                selected_elements.append(
1✔
2154
                    instance
2155
                )  # collect all elements first to allow shuffling of both groups
2156

2157
        # select more to reach self.max_instances examples
2158
        for idx, instance in enumerate(stream):
1✔
2159
            if idx not in used_indices:
1✔
2160
                if self.max_instances is None or len(used_indices) < self.max_instances:
1✔
2161
                    used_indices.add(idx)
1✔
2162
                    selected_elements.append(
1✔
2163
                        instance
2164
                    )  # collect all elements first to allow shuffling of both groups
2165

2166
        # shuffle elements to avoid having one element from each class appear first
2167
        random_generator = new_random_generator(sub_seed=selected_elements)
1✔
2168
        random_generator.shuffle(selected_elements)
1✔
2169
        yield from selected_elements
1✔
2170

2171

2172
class LengthBalancer(DeterministicBalancer):
1✔
2173
    """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2174

2175
    Args:
2176
        segments_boundaries (List[int]):
2177
            distinct integers sorted in increasing order, that map a given total length
2178
            into the index of the least of them that exceeds the given total length.
2179
            (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2180
        fields (Optional, List[str]):
2181
            the total length of the values of these fields goes through the quantization described above
2182

2183

2184
    Example:
2185
        when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2186
        is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2187
        input instances will be counted and balanced against two categories:
2188
        empty total length (less than 1), and non-empty.
2189
    """
2190

2191
    segments_boundaries: List[int]
1✔
2192
    fields: Optional[List[str]]
1✔
2193

2194
    def signature(self, instance):
1✔
2195
        total_len = 0
1✔
2196
        for field_name in self.fields:
1✔
2197
            total_len += len(dict_get(instance, field_name))
1✔
2198
        for i, val in enumerate(self.segments_boundaries):
1✔
2199
            if total_len < val:
1✔
2200
                return i
1✔
2201
        return i + 1
1✔
2202

2203

2204
class DownloadError(Exception):
1✔
2205
    def __init__(
1✔
2206
        self,
2207
        message,
2208
    ):
2209
        self.__super__(message)
×
2210

2211

2212
class UnexpectedHttpCodeError(Exception):
1✔
2213
    def __init__(self, http_code):
1✔
2214
        self.__super__(f"unexpected http code {http_code}")
×
2215

2216

2217
class DownloadOperator(SideEffectOperator):
1✔
2218
    """Operator for downloading a file from a given URL to a specified local path.
2219

2220
    Args:
2221
        source (str):
2222
            URL of the file to be downloaded.
2223
        target (str):
2224
            Local path where the downloaded file should be saved.
2225
    """
2226

2227
    source: str
1✔
2228
    target: str
1✔
2229

2230
    def process(self):
1✔
2231
        try:
×
2232
            response = requests.get(self.source, allow_redirects=True)
×
2233
        except Exception as e:
×
2234
            raise DownloadError(f"Unabled to download {self.source}") from e
×
2235
        if response.status_code != 200:
×
2236
            raise UnexpectedHttpCodeError(response.status_code)
×
2237
        with open(self.target, "wb") as f:
×
2238
            f.write(response.content)
×
2239

2240

2241
class ExtractZipFile(SideEffectOperator):
1✔
2242
    """Operator for extracting files from a zip archive.
2243

2244
    Args:
2245
        zip_file (str):
2246
            Path of the zip file to be extracted.
2247
        target_dir (str):
2248
            Directory where the contents of the zip file will be extracted.
2249
    """
2250

2251
    zip_file: str
1✔
2252
    target_dir: str
1✔
2253

2254
    def process(self):
1✔
2255
        with zipfile.ZipFile(self.zip_file) as zf:
×
2256
            zf.extractall(self.target_dir)
×
2257

2258

2259
class DuplicateInstances(StreamOperator):
1✔
2260
    """Operator which duplicates each instance in stream a given number of times.
2261

2262
    Args:
2263
        num_duplications (int):
2264
            How many times each instance should be duplicated (1 means no duplication).
2265
        duplication_index_field (Optional[str]):
2266
            If given, then additional field with specified name is added to each duplicated instance,
2267
            which contains id of a given duplication. Defaults to None, so no field is added.
2268
    """
2269

2270
    num_duplications: int
1✔
2271
    duplication_index_field: Optional[str] = None
1✔
2272

2273
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2274
        for instance in stream:
1✔
2275
            for idx in range(self.num_duplications):
1✔
2276
                duplicate = recursive_shallow_copy(instance)
1✔
2277
                if self.duplication_index_field:
1✔
2278
                    duplicate.update({self.duplication_index_field: idx})
1✔
2279
                yield duplicate
1✔
2280

2281
    def verify(self):
1✔
2282
        if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1✔
2283
            raise ValueError(
×
2284
                f"num_duplications must be an integer equal to or greater than 1. "
2285
                f"Got: {self.num_duplications}."
2286
            )
2287

2288
        if self.duplication_index_field is not None and not isinstance(
1✔
2289
            self.duplication_index_field, str
2290
        ):
2291
            raise ValueError(
×
2292
                f"If given, duplication_index_field must be a string. "
2293
                f"Got: {self.duplication_index_field}"
2294
            )
2295

2296

2297
class CollateInstances(StreamOperator):
1✔
2298
    """Operator which collates values from multiple instances to a single instance.
2299

2300
    Each field becomes the list of values of corresponding field of collated `batch_size` of instances.
2301

2302
    Attributes:
2303
        batch_size (int)
2304

2305
    Example:
2306
        .. code-block:: text
2307

2308
            CollateInstances(batch_size=2)
2309

2310
            Given inputs = [
2311
                {"a": 1, "b": 2},
2312
                {"a": 2, "b": 2},
2313
                {"a": 3, "b": 2},
2314
                {"a": 4, "b": 2},
2315
                {"a": 5, "b": 2}
2316
            ]
2317

2318
            Returns targets = [
2319
                {"a": [1,2], "b": [2,2]},
2320
                {"a": [3,4], "b": [2,2]},
2321
                {"a": [5], "b": [2]},
2322
            ]
2323

2324

2325
    """
2326

2327
    batch_size: int
1✔
2328

2329
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2330
        stream = list(stream)
1✔
2331
        for i in range(0, len(stream), self.batch_size):
1✔
2332
            batch = stream[i : i + self.batch_size]
1✔
2333
            new_instance = {}
1✔
2334
            for a_field in batch[0]:
1✔
2335
                if a_field == "data_classification_policy":
1✔
2336
                    flattened_list = [
1✔
2337
                        classification
2338
                        for instance in batch
2339
                        for classification in instance[a_field]
2340
                    ]
2341
                    new_instance[a_field] = sorted(set(flattened_list))
1✔
2342
                else:
2343
                    new_instance[a_field] = [instance[a_field] for instance in batch]
1✔
2344
            yield new_instance
1✔
2345

2346
    def verify(self):
1✔
2347
        if not isinstance(self.batch_size, int) or self.batch_size < 1:
1✔
2348
            raise ValueError(
×
2349
                f"batch_size must be an integer equal to or greater than 1. "
2350
                f"Got: {self.batch_size}."
2351
            )
2352

2353

2354
class CollateInstancesByField(StreamOperator):
1✔
2355
    """Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
2356

2357
    Args:
2358
        by_field str: the name of the field to group data by.
2359
        aggregate_fields list(str): the field names to aggregate into lists.
2360

2361
    Returns:
2362
        A stream of instances grouped and aggregated by the specified field.
2363

2364
    Raises:
2365
        UnitxtError: If non-aggregate fields have inconsistent values.
2366

2367
    Example:
2368
        Collate the instances based on field "category" and aggregate fields "value" and "id".
2369

2370
        CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
2371

2372
        given input:
2373
        [
2374
            {"id": 1, "category": "A", "value": 10", "flag" : True},
2375
            {"id": 2, "category": "B", "value": 20", "flag" : False},
2376
            {"id": 3, "category": "A", "value": 30", "flag" : True},
2377
            {"id": 4, "category": "B", "value": 40", "flag" : False}
2378
        ]
2379

2380
        the output is:
2381
        [
2382
            {"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
2383
            {"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
2384
        ]
2385

2386
        Note that the "flag" field is not aggregated, and must be the same
2387
        in all instances in the same category, or an error is raised.
2388
    """
2389

2390
    by_field: str = NonPositionalField(required=True)
1✔
2391
    aggregate_fields: List[str] = NonPositionalField(required=True)
1✔
2392

2393
    def prepare(self):
1✔
2394
        super().prepare()
1✔
2395

2396
    def verify(self):
1✔
2397
        super().verify()
1✔
2398
        if not isinstance(self.by_field, str):
1✔
2399
            raise UnitxtError(
×
2400
                f"The 'by_field' value is not a string but '{type(self.by_field)}'"
2401
            )
2402

2403
        if not isinstance(self.aggregate_fields, list):
1✔
2404
            raise UnitxtError(
×
2405
                f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
2406
            )
2407

2408
    def process(self, stream: Stream, stream_name: Optional[str] = None):
1✔
2409
        grouped_data = {}
1✔
2410

2411
        for instance in stream:
1✔
2412
            if self.by_field not in instance:
1✔
2413
                raise UnitxtError(
1✔
2414
                    f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
2415
                )
2416
            for k in self.aggregate_fields:
1✔
2417
                if k not in instance:
1✔
2418
                    raise UnitxtError(
1✔
2419
                        f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
2420
                    )
2421
            key = instance[self.by_field]
1✔
2422

2423
            if key not in grouped_data:
1✔
2424
                grouped_data[key] = {
1✔
2425
                    k: v for k, v in instance.items() if k not in self.aggregate_fields
2426
                }
2427
                # Add empty lists for fields to aggregate
2428
                for agg_field in self.aggregate_fields:
1✔
2429
                    if agg_field in instance:
1✔
2430
                        grouped_data[key][agg_field] = []
1✔
2431

2432
            for k, v in instance.items():
1✔
2433
                # Merge classification policy list across instance with same key
2434
                if k == "data_classification_policy" and instance[k]:
1✔
2435
                    grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
1✔
2436
                # Check consistency for all non-aggregate fields
2437
                elif k != self.by_field and k not in self.aggregate_fields:
1✔
2438
                    if k in grouped_data[key] and grouped_data[key][k] != v:
1✔
2439
                        raise ValueError(
1✔
2440
                            f"Inconsistent value for field '{k}' in group '{key}': "
2441
                            f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
2442
                        )
2443
                # Aggregate fields
2444
                elif k in self.aggregate_fields:
1✔
2445
                    grouped_data[key][k].append(instance[k])
1✔
2446

2447
        yield from grouped_data.values()
1✔
2448

2449

2450
class WikipediaFetcher(FieldOperator):
1✔
2451
    mode: Literal["summary", "text"] = "text"
1✔
2452
    _requirements_list = ["Wikipedia-API"]
1✔
2453

2454
    def prepare(self):
1✔
2455
        super().prepare()
×
2456
        import wikipediaapi
×
2457

2458
        self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
×
2459

2460
    def process_value(self, value: Any) -> Any:
1✔
2461
        title = value.split("/")[-1]
×
2462
        page = self.wikipedia.page(title)
×
2463

2464
        return {"title": page.title, "body": getattr(page, self.mode)}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc