• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13122603213

03 Feb 2025 08:51PM UTC coverage: 79.304%. Remained the same
13122603213

Pull #1572

github

web-flow
Merge 561c8cdc2 into b457d1f4e
Pull Request #1572: remove inference time and evaluation time from performance report

1451 of 1823 branches covered (79.59%)

Branch coverage included in aggregate %.

9163 of 11561 relevant lines covered (79.26%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.99
src/unitxt/operators.py
1
"""This section describes unitxt operators.
2

3
Operators: Building Blocks of Unitxt Processing Pipelines
4
==============================================================
5

6
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
7
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
8
These operators are callable entities that receive a MultiStream as input.
9
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
10

11
Creating Custom Operators
12
-------------------------------
13
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
14
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
15
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
16

17
General or Specialized Operators
18
--------------------------------
19
Some operators are specialized in specific data or specific operations such as:
20

21
- :class:`loaders<unitxt.loaders>` for accessing data from various sources.
22
- :class:`splitters<unitxt.splitters>` for fixing data splits.
23
- :class:`stream_operators<unitxt.stream_operators>` for changing joining and mixing streams.
24
- :class:`struct_data_operators<unitxt.struct_data_operators>` for structured data operators.
25
- :class:`collections_operators<unitxt.collections_operators>` for handling collections such as lists and dictionaries.
26
- :class:`dialog_operators<unitxt.dialog_operators>` for handling dialogs.
27
- :class:`string_operators<unitxt.string_operators>` for handling strings.
28
- :class:`span_labeling_operators<unitxt.span_labeling_operators>` for handling strings.
29
- :class:`fusion<unitxt.fusion>` for fusing and mixing datasets.
30

31
Other specialized operators are used by unitxt internally:
32

33
- :class:`templates<unitxt.templates>` for verbalizing data examples.
34
- :class:`formats<unitxt.formats>` for preparing data for models.
35

36
The rest of this section is dedicated to general operators.
37

38
General Operators List:
39
------------------------
40
"""
41

42
import operator
1✔
43
import uuid
1✔
44
import warnings
1✔
45
import zipfile
1✔
46
from abc import abstractmethod
1✔
47
from collections import Counter, defaultdict
1✔
48
from dataclasses import field
1✔
49
from itertools import zip_longest
1✔
50
from random import Random
1✔
51
from typing import (
1✔
52
    Any,
53
    Callable,
54
    Dict,
55
    Generator,
56
    Iterable,
57
    List,
58
    Literal,
59
    Optional,
60
    Tuple,
61
    Union,
62
)
63

64
import requests
1✔
65

66
from .artifact import Artifact, fetch_artifact
1✔
67
from .dataclass import NonPositionalField, OptionalField
1✔
68
from .deprecation_utils import deprecation
1✔
69
from .dict_utils import dict_delete, dict_get, dict_set, is_subpath
1✔
70
from .error_utils import UnitxtError
1✔
71
from .generator_utils import ReusableGenerator
1✔
72
from .operator import (
1✔
73
    InstanceOperator,
74
    MultiStream,
75
    MultiStreamOperator,
76
    PagedStreamOperator,
77
    SequentialOperator,
78
    SideEffectOperator,
79
    SingleStreamReducer,
80
    SourceOperator,
81
    StreamingOperator,
82
    StreamInitializerOperator,
83
    StreamOperator,
84
)
85
from .random_utils import new_random_generator
1✔
86
from .settings_utils import get_settings
1✔
87
from .stream import DynamicStream, Stream
1✔
88
from .text_utils import nested_tuple_to_string, to_pretty_string
1✔
89
from .type_utils import isoftype
1✔
90
from .utils import (
1✔
91
    LRUCache,
92
    deep_copy,
93
    flatten_dict,
94
    recursive_copy,
95
    recursive_shallow_copy,
96
    shallow_copy,
97
)
98

99
settings = get_settings()
1✔
100

101

102
class FromIterables(StreamInitializerOperator):
1✔
103
    """Creates a MultiStream from a dict of named iterables.
104

105
    Example:
106
        operator = FromIterables()
107
        ms = operator.process(iterables)
108

109
    """
110

111
    def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
1✔
112
        return MultiStream.from_iterables(iterables)
1✔
113

114

115
class IterableSource(SourceOperator):
1✔
116
    """Creates a MultiStream from a dict of named iterables.
117

118
    It is a callable.
119

120
    Args:
121
        iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
122

123
    Example:
124
        operator =  IterableSource(input_dict)
125
        ms = operator()
126

127
    """
128

129
    iterables: Dict[str, Iterable]
1✔
130

131
    def process(self) -> MultiStream:
1✔
132
        return MultiStream.from_iterables(self.iterables)
1✔
133

134

135
class MapInstanceValues(InstanceOperator):
1✔
136
    """A class used to map instance values into other values.
137

138
    This class is a type of ``InstanceOperator``,
139
    it maps values of instances in a stream using predefined mappers.
140

141
    Args:
142
        mappers (Dict[str, Dict[str, Any]]):
143
            The mappers to use for mapping instance values.
144
            Keys are the names of the fields to undergo mapping, and values are dictionaries
145
            that define the mapping from old values to new values.
146
            Note that mapped values are defined by their string representation, so mapped values
147
            are converted to strings before being looked up in the mappers.
148
        strict (bool):
149
            If True, the mapping is applied strictly. That means if a value
150
            does not exist in the mapper, it will raise a KeyError. If False, values
151
            that are not present in the mapper are kept as they are.
152
        process_every_value (bool):
153
            If True, all fields to be mapped should be lists, and the mapping
154
            is to be applied to their individual elements.
155
            If False, mapping is only applied to a field containing a single value.
156

157
    Examples:
158
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}})``
159
        replaces ``"1"`` with ``"hi"`` and ``"2"`` with ``"bye"`` in field ``"a"`` in all instances of all streams:
160
        instance ``{"a": 1, "b": 2}`` becomes ``{"a": "hi", "b": 2}``. Note that the value of ``"b"`` remained intact,
161
        since field-name ``"b"`` does not participate in the mappers, and that ``1`` was casted to ``"1"`` before looked
162
        up in the mapper of ``"a"``.
163

164
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, process_every_value=True)``:
165
        Assuming field ``"a"`` is a list of values, potentially including ``"1"``-s and ``"2"``-s, this replaces
166
        each such ``"1"`` with ``"hi"`` and ``"2"`` -- with ``"bye"`` in all instances of all streams:
167
        instance ``{"a": ["1", "2"], "b": 2}`` becomes ``{"a": ["hi", "bye"], "b": 2}``.
168

169
        ``MapInstanceValues(mappers={"a": {"1": "hi", "2": "bye"}}, strict=True)``:
170
        To ensure that all values of field ``"a"`` are mapped in every instance, use ``strict=True``.
171
        Input instance ``{"a":"3", "b": 2}`` will raise an exception per the above call,
172
        because ``"3"`` is not a key in the mapper of ``"a"``.
173

174
        ``MapInstanceValues(mappers={"a": {str([1,2,3,4]): "All", str([]): "None"}}, strict=True)``
175
        replaces a list ``[1,2,3,4]`` with the string ``"All"`` and an empty list by string ``"None"``.
176

177
    """
178

179
    mappers: Dict[str, Dict[str, str]]
1✔
180
    strict: bool = True
1✔
181
    process_every_value: bool = False
1✔
182

183
    def verify(self):
1✔
184
        # make sure the mappers are valid
185
        for key, mapper in self.mappers.items():
1✔
186
            assert isinstance(
1✔
187
                mapper, dict
188
            ), f"Mapper for given field {key} should be a dict, got {type(mapper)}"
189
            for k in mapper.keys():
1✔
190
                assert isinstance(
1✔
191
                    k, str
192
                ), f'Key "{k}" in mapper for field "{key}" should be a string, got {type(k)}'
193

194
    def process(
1✔
195
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
196
    ) -> Dict[str, Any]:
197
        for key, mapper in self.mappers.items():
1✔
198
            value = dict_get(instance, key)
1✔
199
            if value is not None:
1✔
200
                if (self.process_every_value is True) and (not isinstance(value, list)):
1✔
201
                    raise ValueError(
1✔
202
                        f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{key}' is '{value}'"
203
                    )
204
                if isinstance(value, list) and self.process_every_value:
1✔
205
                    for i, val in enumerate(value):
1✔
206
                        value[i] = self.get_mapped_value(instance, key, mapper, val)
1✔
207
                else:
208
                    value = self.get_mapped_value(instance, key, mapper, value)
1✔
209
                dict_set(
1✔
210
                    instance,
211
                    key,
212
                    value,
213
                )
214

215
        return instance
1✔
216

217
    def get_mapped_value(self, instance, key, mapper, val):
1✔
218
        val_as_str = str(val)  # make sure the value is a string
1✔
219
        if val_as_str in mapper:
1✔
220
            return recursive_copy(mapper[val_as_str])
1✔
221
        if self.strict:
1✔
222
            raise KeyError(
1✔
223
                f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
224
            )
225
        return val
1✔
226

227

228
class FlattenInstances(InstanceOperator):
1✔
229
    """Flattens each instance in a stream, making nested dictionary entries into top-level entries.
230

231
    Args:
232
        parent_key (str): A prefix to use for the flattened keys. Defaults to an empty string.
233
        sep (str): The separator to use when concatenating nested keys. Defaults to "_".
234
    """
235

236
    parent_key: str = ""
1✔
237
    sep: str = "_"
1✔
238

239
    def process(
1✔
240
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
241
    ) -> Dict[str, Any]:
242
        return flatten_dict(instance, parent_key=self.parent_key, sep=self.sep)
1✔
243

244

245
class Set(InstanceOperator):
1✔
246
    """Sets specified fields in each instance, in a given stream or all streams (default), with specified values. If fields exist, updates them, if do not exist -- adds them.
247

248
    Args:
249
        fields (Dict[str, object]): The fields to add to each instance. Use '/' to access inner fields
250

251
        use_deepcopy (bool) : Deep copy the input value to avoid later modifications
252

253
    Examples:
254
        # Set a value of a list consisting of "positive" and "negative" do field "classes" to each and every instance of all streams
255
        ``Set(fields={"classes": ["positive","negatives"]})``
256

257
        # In each and every instance of all streams, field "span" is to become a dictionary containing a field "start", in which the value 0 is to be set
258
        ``Set(fields={"span/start": 0}``
259

260
        # In all instances of stream "train" only, Set field "classes" to have the value of a list consisting of "positive" and "negative"
261
        ``Set(fields={"classes": ["positive","negatives"], apply_to_stream=["train"]})``
262

263
        # Set field "classes" to have the value of a given list, preventing modification of original list from changing the instance.
264
        ``Set(fields={"classes": alist}), use_deepcopy=True)``  if now alist is modified, still the instances remain intact.
265
    """
266

267
    fields: Dict[str, object]
1✔
268
    use_query: Optional[bool] = None
1✔
269
    use_deepcopy: bool = False
1✔
270

271
    def verify(self):
1✔
272
        super().verify()
1✔
273
        if self.use_query is not None:
1✔
274
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
275
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
276

277
    def process(
1✔
278
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
279
    ) -> Dict[str, Any]:
280
        for key, value in self.fields.items():
1✔
281
            if self.use_deepcopy:
1✔
282
                value = deep_copy(value)
1✔
283
            dict_set(instance, key, value)
1✔
284
        return instance
1✔
285

286

287
@deprecation(version="2.0.0", alternative=Set)
1✔
288
class AddFields(Set):
1✔
289
    pass
1✔
290

291

292
class RemoveFields(InstanceOperator):
1✔
293
    """Remove specified fields from each instance in a stream.
294

295
    Args:
296
        fields (List[str]): The fields to remove from each instance.
297
    """
298

299
    fields: List[str]
1✔
300

301
    def process(
1✔
302
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
303
    ) -> Dict[str, Any]:
304
        for field_name in self.fields:
1✔
305
            del instance[field_name]
1✔
306
        return instance
1✔
307

308

309
class SelectFields(InstanceOperator):
1✔
310
    """Keep only specified fields from each instance in a stream.
311

312
    Args:
313
        fields (List[str]): The fields to keep from each instance.
314
    """
315

316
    fields: List[str]
1✔
317

318
    def prepare(self):
1✔
319
        super().prepare()
1✔
320
        self.fields.extend(["data_classification_policy", "recipe_metadata"])
1✔
321

322
    def process(
1✔
323
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
324
    ) -> Dict[str, Any]:
325
        new_instance = {}
1✔
326
        for selected_field in self.fields:
1✔
327
            new_instance[selected_field] = instance[selected_field]
1✔
328
        return new_instance
1✔
329

330

331
class DefaultPlaceHolder:
1✔
332
    pass
1✔
333

334

335
default_place_holder = DefaultPlaceHolder()
1✔
336

337

338
class InstanceFieldOperator(InstanceOperator):
1✔
339
    """A general stream instance operator that processes the values of a field (or multiple ones).
340

341
    Args:
342
        field (Optional[str]):
343
            The field to process, if only a single one is passed. Defaults to None
344
        to_field (Optional[str]):
345
            Field name to save result into, if only one field is processed, if None is passed the
346
            operation would happen in-place and its result would replace the value of ``field``. Defaults to None
347
        field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]):
348
            Mapping from names of fields to process,
349
            to names of fields to save the results into. Inner List, if used, should be of length 2.
350
            A field is processed by feeding its value into method ``process_value`` and storing the result in ``to_field`` that
351
            is mapped to the field. When the type of argument ``field_to_field`` is List, the order by which the fields are processed is their order
352
            in the (outer) List. But when the type of argument ``field_to_field`` is Dict, there is no uniquely determined
353
            order. The end result might depend on that order if either (1) two different fields are mapped to the same
354
            to_field, or (2) a field shows both as a key and as a value in different mappings.
355
            The operator throws an AssertionError in either of these cases. ``field_to_field``
356
            defaults to None.
357
        process_every_value (bool):
358
            Processes the values in a list instead of the list as a value, similar to python's ``*var``. Defaults to False
359

360
    Note: if ``field`` and ``to_field`` (or both members of a pair in ``field_to_field`` ) are equal (or share a common
361
    prefix if ``field`` and ``to_field`` contain a / ), then the result of the operation is saved within ``field`` .
362

363
    """
364

365
    field: Optional[str] = None
1✔
366
    to_field: Optional[str] = None
1✔
367
    field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
1✔
368
    use_query: Optional[bool] = None
1✔
369
    process_every_value: bool = False
1✔
370
    get_default: Any = None
1✔
371
    not_exist_ok: bool = False
1✔
372
    not_exist_do_nothing: bool = False
1✔
373

374
    def verify(self):
1✔
375
        super().verify()
1✔
376
        if self.use_query is not None:
1✔
377
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
1✔
378
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
379

380
    def verify_field_definition(self):
1✔
381
        if hasattr(self, "_field_to_field") and self._field_to_field is not None:
1✔
382
            return
1✔
383
        assert (
1✔
384
            (self.field is None) != (self.field_to_field is None)
385
        ), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
386
        assert (
1✔
387
            self.to_field is None or self.field_to_field is None
388
        ), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
389

390
        if self.field_to_field is None:
1✔
391
            self._field_to_field = [
1✔
392
                (self.field, self.to_field if self.to_field is not None else self.field)
393
            ]
394
        else:
395
            self._field_to_field = (
1✔
396
                list(self.field_to_field.items())
397
                if isinstance(self.field_to_field, dict)
398
                else self.field_to_field
399
            )
400
        assert (
1✔
401
            self.field is not None or self.field_to_field is not None
402
        ), "Must supply a field to work on"
403
        assert (
1✔
404
            self.to_field is None or self.field_to_field is None
405
        ), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
406
        assert (
1✔
407
            self.field is None or self.field_to_field is None
408
        ), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
409
        assert (
1✔
410
            self._field_to_field is not None
411
        ), f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
412
        assert (
1✔
413
            len(self._field_to_field) > 0
414
        ), f"'input argument '{self.__class__.__name__}.field_to_field' should convey at least one field to process. Got {self.field_to_field}"
415
        # self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
416
        if self.field_to_field is None:
1✔
417
            return
1✔
418
        # for backward compatibility also allow list of tuples of two strings
419
        if isoftype(self.field_to_field, List[List[str]]) or isoftype(
1✔
420
            self.field_to_field, List[Tuple[str, str]]
421
        ):
422
            for pair in self._field_to_field:
1✔
423
                assert (
1✔
424
                    len(pair) == 2
425
                ), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
426
            # order of field processing is uniquely determined by the input field_to_field when a list
427
            return
1✔
428
        if isoftype(self.field_to_field, Dict[str, str]):
1✔
429
            if len(self.field_to_field) < 2:
1✔
430
                return
1✔
431
            for ff, tt in self.field_to_field.items():
1✔
432
                for f, t in self.field_to_field.items():
1✔
433
                    if f == ff:
1✔
434
                        continue
1✔
435
                    assert (
1✔
436
                        t != ff
437
                    ), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
438
                    assert (
1✔
439
                        tt != t
440
                    ), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
441
            return
1✔
442
        raise ValueError(
1✔
443
            "Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
444
        )
445

446
    @abstractmethod
1✔
447
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
448
        pass
×
449

450
    def process(
1✔
451
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
452
    ) -> Dict[str, Any]:
453
        self.verify_field_definition()
1✔
454
        for from_field, to_field in self._field_to_field:
1✔
455
            try:
1✔
456
                old_value = dict_get(
1✔
457
                    instance,
458
                    from_field,
459
                    default=default_place_holder,
460
                    not_exist_ok=self.not_exist_ok or self.not_exist_do_nothing,
461
                )
462
                if old_value is default_place_holder:
1✔
463
                    if self.not_exist_do_nothing:
1✔
464
                        continue
1✔
465
                    old_value = self.get_default
×
466
            except Exception as e:
1✔
467
                raise ValueError(
1✔
468
                    f"Failed to get '{from_field}' from instance due to the exception above."
469
                ) from e
470
            try:
1✔
471
                if self.process_every_value:
1✔
472
                    new_value = [
1✔
473
                        self.process_instance_value(value, instance)
474
                        for value in old_value
475
                    ]
476
                else:
477
                    new_value = self.process_instance_value(old_value, instance)
1✔
478
            except Exception as e:
1✔
479
                raise ValueError(
1✔
480
                    f"Failed to process field '{from_field}' from instance due to the exception above."
481
                ) from e
482
            dict_set(
1✔
483
                instance,
484
                to_field,
485
                new_value,
486
                not_exist_ok=True,
487
            )
488
        return instance
1✔
489

490

491
class FieldOperator(InstanceFieldOperator):
1✔
492
    def process_instance_value(self, value: Any, instance: Dict[str, Any]):
1✔
493
        return self.process_value(value)
1✔
494

495
    @abstractmethod
1✔
496
    def process_value(self, value: Any) -> Any:
1✔
497
        pass
1✔
498

499

500
class MapValues(FieldOperator):
1✔
501
    mapping: Dict[str, str]
1✔
502

503
    def process_value(self, value: Any) -> Any:
1✔
504
        return self.mapping[str(value)]
×
505

506

507
class Rename(FieldOperator):
1✔
508
    """Renames fields.
509

510
    Move value from one field to another, potentially, if field name contains a /, from one branch into another.
511
    Remove the from field, potentially part of it in case of / in from_field.
512

513
    Examples:
514
        Rename(field_to_field={"b": "c"})
515
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
516

517
        Rename(field_to_field={"b": "c/d"})
518
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
519

520
        Rename(field_to_field={"b": "b/d"})
521
        will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
522

523
        Rename(field_to_field={"b/c/e": "b/d"})
524
        will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
525

526
    """
527

528
    def process_value(self, value: Any) -> Any:
1✔
529
        return value
1✔
530

531
    def process(
1✔
532
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
533
    ) -> Dict[str, Any]:
534
        res = super().process(instance=instance, stream_name=stream_name)
1✔
535
        for from_field, to_field in self._field_to_field:
1✔
536
            if (not is_subpath(from_field, to_field)) and (
1✔
537
                not is_subpath(to_field, from_field)
538
            ):
539
                dict_delete(res, from_field, remove_empty_ancestors=True)
1✔
540

541
        return res
1✔
542

543

544
@deprecation(version="2.0.0", alternative=Rename)
1✔
545
class RenameFields(Rename):
1✔
546
    pass
1✔
547

548

549
class AddConstant(FieldOperator):
1✔
550
    """Adds a constant, being argument 'add', to the processed value.
551

552
    Args:
553
        add: the constant to add.
554
    """
555

556
    add: Any
1✔
557

558
    def process_value(self, value: Any) -> Any:
1✔
559
        return self.add + value
1✔
560

561

562
class ShuffleFieldValues(FieldOperator):
1✔
563
    """Shuffles a list of values found in a field."""
564

565
    def process_value(self, value: Any) -> Any:
1✔
566
        res = list(value)
1✔
567
        random_generator = new_random_generator(sub_seed=res)
1✔
568
        random_generator.shuffle(res)
1✔
569
        return res
1✔
570

571

572
class JoinStr(FieldOperator):
1✔
573
    """Joins a list of strings (contents of a field), similar to str.join().
574

575
    Args:
576
        separator (str): text to put between values
577
    """
578

579
    separator: str = ","
1✔
580

581
    def process_value(self, value: Any) -> Any:
1✔
582
        return self.separator.join(str(x) for x in value)
1✔
583

584

585
class Apply(InstanceOperator):
1✔
586
    """A class used to apply a python function and store the result in a field.
587

588
    Args:
589
        function (str): name of function.
590
        to_field (str): the field to store the result
591

592
    any additional arguments are field names whose values will be passed directly to the function specified
593

594
    Examples:
595
    Store in field  "b" the uppercase string of the value in field "a":
596
    ``Apply("a", function=str.upper, to_field="b")``
597

598
    Dump the json representation of field "t" and store back in the same field:
599
    ``Apply("t", function=json.dumps, to_field="t")``
600

601
    Set the time in a field 'b':
602
    ``Apply(function=time.time, to_field="b")``
603

604
    """
605

606
    __allow_unexpected_arguments__ = True
1✔
607
    function: Callable = NonPositionalField(required=True)
1✔
608
    to_field: str = NonPositionalField(required=True)
1✔
609

610
    def function_to_str(self, function: Callable) -> str:
1✔
611
        parts = []
1✔
612

613
        if hasattr(function, "__module__"):
1✔
614
            parts.append(function.__module__)
1✔
615
        if hasattr(function, "__qualname__"):
1✔
616
            parts.append(function.__qualname__)
1✔
617
        else:
618
            parts.append(function.__name__)
×
619

620
        return ".".join(parts)
1✔
621

622
    def str_to_function(self, function_str: str) -> Callable:
1✔
623
        parts = function_str.split(".", 1)
1✔
624
        if len(parts) == 1:
1✔
625
            return __builtins__[parts[0]]
1✔
626

627
        module_name, function_name = parts
1✔
628
        if module_name in __builtins__:
1✔
629
            obj = __builtins__[module_name]
1✔
630
        elif module_name in globals():
1✔
631
            obj = globals()[module_name]
×
632
        else:
633
            obj = __import__(module_name)
1✔
634
        for part in function_name.split("."):
1✔
635
            obj = getattr(obj, part)
1✔
636
        return obj
1✔
637

638
    def prepare(self):
1✔
639
        super().prepare()
1✔
640
        if isinstance(self.function, str):
1✔
641
            self.function = self.str_to_function(self.function)
1✔
642
        self._init_dict["function"] = self.function_to_str(self.function)
1✔
643

644
    def process(
1✔
645
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
646
    ) -> Dict[str, Any]:
647
        argv = [instance[arg] for arg in self._argv]
1✔
648
        kwargs = {key: instance[val] for key, val in self._kwargs}
1✔
649

650
        result = self.function(*argv, **kwargs)
1✔
651

652
        instance[self.to_field] = result
1✔
653
        return instance
1✔
654

655

656
class ListFieldValues(InstanceOperator):
1✔
657
    """Concatenates values of multiple fields into a list, and assigns it to a new field."""
658

659
    fields: List[str]
1✔
660
    to_field: str
1✔
661
    use_query: Optional[bool] = None
1✔
662

663
    def verify(self):
1✔
664
        super().verify()
1✔
665
        if self.use_query is not None:
1✔
666
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
667
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
668

669
    def process(
1✔
670
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
671
    ) -> Dict[str, Any]:
672
        values = []
1✔
673
        for field_name in self.fields:
1✔
674
            values.append(dict_get(instance, field_name))
1✔
675

676
        dict_set(instance, self.to_field, values)
1✔
677

678
        return instance
1✔
679

680

681
class ZipFieldValues(InstanceOperator):
1✔
682
    """Zips values of multiple fields in a given instance, similar to ``list(zip(*fields))``.
683

684
    The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
685
    are zipped, and stored into 'to_field'.
686

687
    | If 'longest'=False, the length of the zipped result is determined by the shortest input value.
688
    | If 'longest'=True, the length of the zipped result is determined by the longest input, padding shorter inputs with None-s.
689

690
    """
691

692
    fields: List[str]
1✔
693
    to_field: str
1✔
694
    longest: bool = False
1✔
695
    use_query: Optional[bool] = None
1✔
696

697
    def verify(self):
1✔
698
        super().verify()
1✔
699
        if self.use_query is not None:
1✔
700
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
701
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
702

703
    def process(
1✔
704
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
705
    ) -> Dict[str, Any]:
706
        values = []
1✔
707
        for field_name in self.fields:
1✔
708
            values.append(dict_get(instance, field_name))
1✔
709
        if self.longest:
1✔
710
            zipped = zip_longest(*values)
1✔
711
        else:
712
            zipped = zip(*values)
1✔
713
        dict_set(instance, self.to_field, list(zipped))
1✔
714
        return instance
1✔
715

716

717
class InterleaveListsToDialogOperator(InstanceOperator):
1✔
718
    """Interleaves two lists, one of user dialog turns and one of assistant dialog turns, into a single list of tuples, alternating between "user" and "assistant".
719

720
    The list of tuples if of format (role, turn_content), where the role label is specified by
721
    the 'user_role_label' and 'assistant_role_label' fields (default to "user" and "assistant").
722

723
    The user turns and assistant turns field are specified in the arguments.
724
    The value of each of the 'fields' is assumed to be a list.
725

726
    """
727

728
    user_turns_field: str
1✔
729
    assistant_turns_field: str
1✔
730
    user_role_label: str = "user"
1✔
731
    assistant_role_label: str = "assistant"
1✔
732
    to_field: str
1✔
733

734
    def process(
1✔
735
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
736
    ) -> Dict[str, Any]:
737
        user_turns = instance[self.user_turns_field]
×
738
        assistant_turns = instance[self.assistant_turns_field]
×
739

740
        assert (
×
741
            len(user_turns) == len(assistant_turns)
742
            or (len(user_turns) - len(assistant_turns) == 1)
743
        ), "user_turns must have either the same length as assistant_turns or one more turn."
744

745
        interleaved_dialog = []
×
746
        i, j = 0, 0  # Indices for the user and assistant lists
×
747
        # While either list has elements left, continue interleaving
748
        while i < len(user_turns) or j < len(assistant_turns):
×
749
            if i < len(user_turns):
×
750
                interleaved_dialog.append((self.user_role_label, user_turns[i]))
×
751
                i += 1
×
752
            if j < len(assistant_turns):
×
753
                interleaved_dialog.append(
×
754
                    (self.assistant_role_label, assistant_turns[j])
755
                )
756
                j += 1
×
757

758
        instance[self.to_field] = interleaved_dialog
×
759
        return instance
×
760

761

762
class IndexOf(InstanceOperator):
1✔
763
    """For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
764

765
    search_in: str
1✔
766
    index_of: str
1✔
767
    to_field: str
1✔
768
    use_query: Optional[bool] = None
1✔
769

770
    def verify(self):
1✔
771
        super().verify()
1✔
772
        if self.use_query is not None:
1✔
773
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
774
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
775

776
    def process(
1✔
777
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
778
    ) -> Dict[str, Any]:
779
        lst = dict_get(instance, self.search_in)
1✔
780
        item = dict_get(instance, self.index_of)
1✔
781
        instance[self.to_field] = lst.index(item)
1✔
782
        return instance
1✔
783

784

785
class TakeByField(InstanceOperator):
1✔
786
    """From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
787

788
    field: str
1✔
789
    index: str
1✔
790
    to_field: str = None
1✔
791
    use_query: Optional[bool] = None
1✔
792

793
    def verify(self):
1✔
794
        super().verify()
1✔
795
        if self.use_query is not None:
1✔
796
            depr_message = "Field 'use_query' is deprecated. From now on, default behavior is compatible to use_query=True. Please remove this field from your code."
×
797
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
×
798

799
    def prepare(self):
1✔
800
        if self.to_field is None:
1✔
801
            self.to_field = self.field
1✔
802

803
    def process(
1✔
804
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
805
    ) -> Dict[str, Any]:
806
        value = dict_get(instance, self.field)
1✔
807
        index_value = dict_get(instance, self.index)
1✔
808
        instance[self.to_field] = value[index_value]
1✔
809
        return instance
1✔
810

811

812
class Perturb(FieldOperator):
1✔
813
    """Slightly perturbs the contents of ``field``. Could be Handy for imitating prediction from given target.
814

815
    When task was classification, argument ``select_from`` can be used to list the other potential classes, as a
816
    relevant perturbation
817

818
    Args:
819
        percentage_to_perturb (int):
820
            the percentage of the instances for which to apply this perturbation. Defaults to 1 (1 percent)
821
        select_from: List[Any]:
822
            a list of values to select from, as a perturbation of the field's value. Defaults to [].
823
    """
824

825
    select_from: List[Any] = []
1✔
826
    percentage_to_perturb: int = 1  # 1 percent
1✔
827

828
    def verify(self):
1✔
829
        assert (
1✔
830
            0 <= self.percentage_to_perturb and self.percentage_to_perturb <= 100
831
        ), f"'percentage_to_perturb' should be in the range 0..100. Received {self.percentage_to_perturb}"
832

833
    def prepare(self):
1✔
834
        super().prepare()
1✔
835
        self.random_generator = new_random_generator(sub_seed="CopyWithPerturbation")
1✔
836

837
    def process_value(self, value: Any) -> Any:
1✔
838
        perturb = self.random_generator.randint(1, 100) <= self.percentage_to_perturb
1✔
839
        if not perturb:
1✔
840
            return value
1✔
841

842
        if value in self.select_from:
1✔
843
            # 80% of cases, return a decent class, otherwise, perturb the value itself as follows
844
            if self.random_generator.random() < 0.8:
1✔
845
                return self.random_generator.choice(self.select_from)
1✔
846

847
        if isinstance(value, float):
1✔
848
            return value * (0.5 + self.random_generator.random())
1✔
849

850
        if isinstance(value, int):
1✔
851
            perturb = 1 if self.random_generator.random() < 0.5 else -1
1✔
852
            return value + perturb
1✔
853

854
        if isinstance(value, str):
1✔
855
            if len(value) < 2:
1✔
856
                # give up perturbation
857
                return value
1✔
858
            # throw one char out
859
            prefix_len = self.random_generator.randint(1, len(value) - 1)
1✔
860
            return value[:prefix_len] + value[prefix_len + 1 :]
1✔
861

862
        # and in any other case:
863
        return value
×
864

865

866
class Copy(FieldOperator):
1✔
867
    """Copies values from specified fields to specified fields.
868

869
    Args (of parent class):
870
        field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
871

872
    Examples:
873
        An input instance {"a": 2, "b": 3}, when processed by
874
        ``Copy(field_to_field={"a": "b"})``
875
        would yield {"a": 2, "b": 2}, and when processed by
876
        ``Copy(field_to_field={"a": "c"})`` would yield
877
        {"a": 2, "b": 3, "c": 2}
878

879
        with field names containing / , we can also copy inside the field:
880
        ``Copy(field="a/0",to_field="a")``
881
        would process instance {"a": [1, 3]} into {"a": 1}
882

883

884
    """
885

886
    def process_value(self, value: Any) -> Any:
1✔
887
        return value
1✔
888

889

890
class RecursiveCopy(FieldOperator):
1✔
891
    def process_value(self, value: Any) -> Any:
1✔
892
        return recursive_copy(value)
1✔
893

894

895
@deprecation(version="2.0.0", alternative=Copy)
1✔
896
class CopyFields(Copy):
1✔
897
    pass
1✔
898

899

900
class GetItemByIndex(FieldOperator):
1✔
901
    """Get from the item list by the index in the field."""
902

903
    items_list: List[Any]
1✔
904

905
    def process_value(self, value: Any) -> Any:
1✔
906
        return self.items_list[value]
×
907

908

909
class AddID(InstanceOperator):
1✔
910
    """Stores a unique id value in the designated 'id_field_name' field of the given instance."""
911

912
    id_field_name: str = "id"
1✔
913

914
    def process(
1✔
915
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
916
    ) -> Dict[str, Any]:
917
        instance[self.id_field_name] = str(uuid.uuid4()).replace("-", "")
1✔
918
        return instance
1✔
919

920

921
class Cast(FieldOperator):
1✔
922
    """Casts specified fields to specified types.
923

924
    Args:
925
        default (object): A dictionary mapping field names to default values for cases of casting failure.
926
        process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
927
    """
928

929
    to: str
1✔
930
    failure_default: Optional[Any] = "__UNDEFINED__"
1✔
931

932
    def prepare(self):
1✔
933
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
934

935
    def process_value(self, value):
1✔
936
        try:
1✔
937
            return self.types[self.to](value)
1✔
938
        except ValueError as e:
1✔
939
            if self.failure_default == "__UNDEFINED__":
1✔
940
                raise ValueError(
×
941
                    f'Failed to cast value {value} to type "{self.to}", and no default value is provided.'
942
                ) from e
943
            return self.failure_default
1✔
944

945

946
class CastFields(InstanceOperator):
1✔
947
    """Casts specified fields to specified types.
948

949
    Args:
950
        fields (Dict[str, str]):
951
            A dictionary mapping field names to the names of the types to cast the fields to.
952
            e.g: "int", "str", "float", "bool". Basic names of types
953
        defaults (Dict[str, object]):
954
            A dictionary mapping field names to default values for cases of casting failure.
955
        process_every_value (bool):
956
            If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
957

958
    Example:
959
        .. code-block:: python
960

961
                CastFields(
962
                    fields={"a/d": "float", "b": "int"},
963
                    failure_defaults={"a/d": 0.0, "b": 0},
964
                    process_every_value=True,
965
                )
966

967
    would process the input instance: ``{"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}``
968
    into ``{"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}``.
969

970
    """
971

972
    fields: Dict[str, str] = field(default_factory=dict)
1✔
973
    failure_defaults: Dict[str, object] = field(default_factory=dict)
1✔
974
    use_nested_query: bool = None  # deprecated field
1✔
975
    process_every_value: bool = False
1✔
976

977
    def prepare(self):
1✔
978
        self.types = {"int": int, "float": float, "str": str, "bool": bool}
1✔
979

980
    def verify(self):
1✔
981
        super().verify()
1✔
982
        if self.use_nested_query is not None:
1✔
983
            depr_message = "Field 'use_nested_query' is deprecated. From now on, default behavior is compatible to use_nested_query=True. Please remove this field from your code."
1✔
984
            warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
1✔
985

986
    def _cast_single(self, value, type, field):
1✔
987
        try:
1✔
988
            return self.types[type](value)
1✔
989
        except Exception as e:
1✔
990
            if field not in self.failure_defaults:
1✔
991
                raise ValueError(
1✔
992
                    f'Failed to cast field "{field}" with value {value} to type "{type}", and no default value is provided.'
993
                ) from e
994
            return self.failure_defaults[field]
1✔
995

996
    def _cast_multiple(self, values, type, field):
1✔
997
        return [self._cast_single(value, type, field) for value in values]
1✔
998

999
    def process(
1✔
1000
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1001
    ) -> Dict[str, Any]:
1002
        for field_name, type in self.fields.items():
1✔
1003
            value = dict_get(instance, field_name)
1✔
1004
            if self.process_every_value:
1✔
1005
                assert isinstance(
1✔
1006
                    value, list
1007
                ), f"'process_every_field' == True is allowed only for fields whose values are lists, but value of field '{field_name}' is '{value}'"
1008
                casted_value = self._cast_multiple(value, type, field_name)
1✔
1009
            else:
1010
                casted_value = self._cast_single(value, type, field_name)
1✔
1011

1012
            dict_set(instance, field_name, casted_value)
1✔
1013
        return instance
1✔
1014

1015

1016
class DivideAllFieldsBy(InstanceOperator):
1✔
1017
    """Recursively reach down to all fields that are float, and divide each by 'divisor'.
1018

1019
    The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
1020
    the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
1021
    if input flag 'strict' is True, or -- left alone, if 'strict' is False.
1022

1023
    Args:
1024
        divisor (float) the value to divide by
1025
        strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
1026

1027
    Example:
1028
        when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
1029
        operator = DivideAllFieldsBy(divisor=2.0)
1030
        the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
1031
        If the operator were defined with strict=True, through:
1032
        operator = DivideAllFieldsBy(divisor=2.0, strict=True),
1033
        the processing of the above instance would raise a ValueError, for the integer at "c".
1034
    """
1035

1036
    divisor: float = 1.0
1✔
1037
    strict: bool = False
1✔
1038

1039
    def _recursive_divide(self, instance, divisor):
1✔
1040
        if isinstance(instance, dict):
1✔
1041
            for key, value in instance.items():
1✔
1042
                instance[key] = self._recursive_divide(value, divisor)
1✔
1043
        elif isinstance(instance, list):
1✔
1044
            for i, value in enumerate(instance):
1✔
1045
                instance[i] = self._recursive_divide(value, divisor)
1✔
1046
        elif isinstance(instance, float):
1✔
1047
            instance /= divisor
1✔
1048
        elif self.strict:
1✔
1049
            raise ValueError(f"Cannot divide instance of type {type(instance)}")
1✔
1050
        return instance
1✔
1051

1052
    def process(
1✔
1053
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1054
    ) -> Dict[str, Any]:
1055
        return self._recursive_divide(instance, self.divisor)
1✔
1056

1057

1058
class ArtifactFetcherMixin:
1✔
1059
    """Provides a way to fetch and cache artifacts in the system.
1060

1061
    Args:
1062
        cache (Dict[str, Artifact]): A cache for storing fetched artifacts.
1063
    """
1064

1065
    _artifacts_cache = LRUCache(max_size=1000)
1✔
1066

1067
    @classmethod
1✔
1068
    def get_artifact(cls, artifact_identifier: str) -> Artifact:
1✔
1069
        if str(artifact_identifier) not in cls._artifacts_cache:
1✔
1070
            artifact, catalog = fetch_artifact(artifact_identifier)
1✔
1071
            cls._artifacts_cache[str(artifact_identifier)] = artifact
1✔
1072
        return shallow_copy(cls._artifacts_cache[str(artifact_identifier)])
1✔
1073

1074

1075
class ApplyOperatorsField(InstanceOperator):
1✔
1076
    """Applies value operators to each instance in a stream based on specified fields.
1077

1078
    Args:
1079
        operators_field (str): name of the field that contains a single name, or a list of names, of the operators to be applied,
1080
            one after the other, for the processing of the instance. Each operator is equipped with 'process_instance()'
1081
            method.
1082

1083
        default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
1084

1085
    Example:
1086
        when instance {"prediction": 111, "references": [222, 333] , "c": ["processors.to_string", "processors.first_character"]}
1087
        is processed by operator (please look up the catalog that these operators, they are tuned to process fields "prediction" and
1088
        "references"):
1089
        operator = ApplyOperatorsField(operators_field="c"),
1090
        the resulting instance is: {"prediction": "1", "references": ["2", "3"], "c": ["processors.to_string", "processors.first_character"]}
1091

1092
    """
1093

1094
    operators_field: str
1✔
1095
    default_operators: List[str] = None
1✔
1096

1097
    def process(
1✔
1098
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1099
    ) -> Dict[str, Any]:
1100
        operator_names = instance.get(self.operators_field)
1✔
1101
        if operator_names is None:
1✔
1102
            assert (
1✔
1103
                self.default_operators is not None
1104
            ), f"No operators found in field '{self.operators_field}', and no default operators provided."
1105
            operator_names = self.default_operators
1✔
1106

1107
        if isinstance(operator_names, str):
1✔
1108
            operator_names = [operator_names]
1✔
1109
        # otherwise , operator_names is already a list
1110

1111
        # we now have a list of nanes of operators, each is equipped with process_instance method.
1112
        operator = SequentialOperator(steps=operator_names)
1✔
1113
        return operator.process_instance(instance, stream_name=stream_name)
1✔
1114

1115

1116
class FilterByCondition(StreamOperator):
1✔
1117
    """Filters a stream, yielding only instances in which the values in required fields follow the required condition operator.
1118

1119
    Raises an error if a required field name is missing from the input instance.
1120

1121
    Args:
1122
       values (Dict[str, Any]): Field names and respective Values that instances must match according the condition, to be included in the output.
1123

1124
       condition: the name of the desired condition operator between the specified (sub) field's value  and the provided constant value.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1125

1126
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1127

1128
    Examples:
1129
       | ``FilterByCondition(values = {"a":4}, condition = "gt")`` will yield only instances where field ``"a"`` contains a value ``> 4``
1130
       | ``FilterByCondition(values = {"a":4}, condition = "le")`` will yield only instances where ``"a"<=4``
1131
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "in")`` will yield only instances where ``"a"`` is ``4`` or ``8``
1132
       | ``FilterByCondition(values = {"a":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is different from ``4`` or ``8``
1133
       | ``FilterByCondition(values = {"a/b":[4,8]}, condition = "not in")`` will yield only instances where ``"a"`` is a dict in which key ``"b"`` is mapped to a value that is neither ``4`` nor ``8``
1134
       | ``FilterByCondition(values = {"a[2]":4}, condition = "le")`` will yield only instances where "a" is a list whose 3-rd element is ``<= 4``
1135

1136

1137
    """
1138

1139
    values: Dict[str, Any]
1✔
1140
    condition: str
1✔
1141
    condition_to_func = {
1✔
1142
        "gt": operator.gt,
1143
        "ge": operator.ge,
1144
        "lt": operator.lt,
1145
        "le": operator.le,
1146
        "eq": operator.eq,
1147
        "ne": operator.ne,
1148
        "in": None,  # Handled as special case
1149
        "not in": None,  # Handled as special case
1150
    }
1151
    error_on_filtered_all: bool = True
1✔
1152

1153
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1154
        yielded = False
1✔
1155
        for instance in stream:
1✔
1156
            if self._is_required(instance):
1✔
1157
                yielded = True
1✔
1158
                yield instance
1✔
1159

1160
        if not yielded and self.error_on_filtered_all:
1✔
1161
            raise RuntimeError(
1✔
1162
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1163
            )
1164

1165
    def verify(self):
1✔
1166
        if self.condition not in self.condition_to_func:
1✔
1167
            raise ValueError(
1✔
1168
                f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
1169
            )
1170

1171
        for key, value in self.values.items():
1✔
1172
            if self.condition in ["in", "not it"] and not isinstance(value, list):
1✔
1173
                raise ValueError(
1✔
1174
                    f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
1175
                )
1176
        return super().verify()
1✔
1177

1178
    def _is_required(self, instance: dict) -> bool:
1✔
1179
        for key, value in self.values.items():
1✔
1180
            try:
1✔
1181
                instance_key = dict_get(instance, key)
1✔
1182
            except ValueError as ve:
1✔
1183
                raise ValueError(
1✔
1184
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance."
1185
                ) from ve
1186
            if self.condition == "in":
1✔
1187
                if instance_key not in value:
1✔
1188
                    return False
1✔
1189
            elif self.condition == "not in":
1✔
1190
                if instance_key in value:
1✔
1191
                    return False
1✔
1192
            else:
1193
                func = self.condition_to_func[self.condition]
1✔
1194
                if func is None:
1✔
1195
                    raise ValueError(
×
1196
                        f"Function not defined for condition '{self.condition}'"
1197
                    )
1198
                if not func(instance_key, value):
1✔
1199
                    return False
1✔
1200
        return True
1✔
1201

1202

1203
class FilterByConditionBasedOnFields(FilterByCondition):
1✔
1204
    """Filters a stream based on a condition between 2 fields values.
1205

1206
    Raises an error if either of the required fields names is missing from the input instance.
1207

1208
    Args:
1209
       values (Dict[str, str]): The fields names that the filter operation is based on.
1210
       condition: the name of the desired condition operator between the specified field's values.  Supported conditions are  ("gt", "ge", "lt", "le", "ne", "eq", "in","not in")
1211
       error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
1212

1213
    Examples:
1214
       FilterByCondition(values = {"a":"b}, condition = "gt") will yield only instances where field "a" contains a value greater then the value in field "b".
1215
       FilterByCondition(values = {"a":"b}, condition = "le") will yield only instances where "a"<="b"
1216
    """
1217

1218
    def _is_required(self, instance: dict) -> bool:
1✔
1219
        for key, value in self.values.items():
1✔
1220
            try:
1✔
1221
                instance_key = dict_get(instance, key)
1✔
1222
            except ValueError as ve:
×
1223
                raise ValueError(
×
1224
                    f"Required filter field ('{key}') in FilterByCondition is not found in instance"
1225
                ) from ve
1226
            try:
1✔
1227
                instance_value = dict_get(instance, value)
1✔
1228
            except ValueError as ve:
×
1229
                raise ValueError(
×
1230
                    f"Required filter field ('{value}') in FilterByCondition is not found in instance"
1231
                ) from ve
1232
            if self.condition == "in":
1✔
1233
                if instance_key not in instance_value:
×
1234
                    return False
×
1235
            elif self.condition == "not in":
1✔
1236
                if instance_key in instance_value:
×
1237
                    return False
×
1238
            else:
1239
                func = self.condition_to_func[self.condition]
1✔
1240
                if func is None:
1✔
1241
                    raise ValueError(
×
1242
                        f"Function not defined for condition '{self.condition}'"
1243
                    )
1244
                if not func(instance_key, instance_value):
1✔
1245
                    return False
×
1246
        return True
1✔
1247

1248

1249
class ComputeExpressionMixin(Artifact):
1✔
1250
    """Computes an expression expressed over fields of an instance.
1251

1252
    Args:
1253
        expression (str): the expression, in terms of names of fields of an instance
1254
        imports_list (List[str]): list of names of imports needed for the evaluation of the expression
1255
    """
1256

1257
    expression: str
1✔
1258
    imports_list: List[str] = OptionalField(default_factory=list)
1✔
1259

1260
    def prepare(self):
1✔
1261
        # can not do the imports here, because object does not pickle with imports
1262
        self.globals = {
1✔
1263
            module_name: __import__(module_name) for module_name in self.imports_list
1264
        }
1265

1266
    def compute_expression(self, instance: dict) -> Any:
1✔
1267
        if settings.allow_unverified_code:
1✔
1268
            return eval(self.expression, {**self.globals, **instance})
1✔
1269

1270
        raise ValueError(
×
1271
            f"Cannot evaluate expression in {self} when unitxt.settings.allow_unverified_code=False - either set it to True or set {settings.allow_unverified_code_key} environment variable."
1272
            "\nNote: If using test_card() with the default setting, increase loader_limit to avoid missing conditions due to limited data sampling."
1273
        )
1274

1275

1276
class FilterByExpression(StreamOperator, ComputeExpressionMixin):
1✔
1277
    """Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
1278

1279
    Raises an error if a field participating in the specified condition is missing from the instance
1280

1281
    Args:
1282
        expression (str):
1283
            a condition over fields of the instance, to be processed by python's eval()
1284
        imports_list (List[str]):
1285
            names of imports needed for the eval of the query (e.g. 're', 'json')
1286
        error_on_filtered_all (bool, optional):
1287
            If True, raises an error if all instances are filtered out. Defaults to True.
1288

1289
    Examples:
1290
        | ``FilterByExpression(expression = "a > 4")`` will yield only instances where "a">4
1291
        | ``FilterByExpression(expression = "a <= 4 and b > 5")`` will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
1292
        | ``FilterByExpression(expression = "a in [4, 8]")`` will yield only instances where "a" is 4 or 8
1293
        | ``FilterByExpression(expression = "a not in [4, 8]")`` will yield only instances where "a" is neither 4 nor 8
1294
        | ``FilterByExpression(expression = "a['b'] not in [4, 8]")`` will yield only instances where "a" is a dict in which key 'b' is mapped to a value that is neither 4 nor 8
1295
    """
1296

1297
    error_on_filtered_all: bool = True
1✔
1298

1299
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1300
        yielded = False
1✔
1301
        for instance in stream:
1✔
1302
            if self.compute_expression(instance):
1✔
1303
                yielded = True
1✔
1304
                yield instance
1✔
1305

1306
        if not yielded and self.error_on_filtered_all:
1✔
1307
            raise RuntimeError(
1✔
1308
                f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
1309
            )
1310

1311

1312
class ExecuteExpression(InstanceOperator, ComputeExpressionMixin):
1✔
1313
    """Compute an expression, specified as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
1314

1315
    Raises an error if a field mentioned in the query is missing from the instance.
1316

1317
    Args:
1318
       expression (str): an expression to be evaluated over the fields of the instance
1319
       to_field (str): the field where the result is to be stored into
1320
       imports_list (List[str]): names of imports needed for the eval of the query (e.g. 're', 'json')
1321

1322
    Examples:
1323
       When instance {"a": 2, "b": 3} is process-ed by operator
1324
       ExecuteExpression(expression="a+b", to_field = "c")
1325
       the result is {"a": 2, "b": 3, "c": 5}
1326

1327
       When instance {"a": "hello", "b": "world"} is process-ed by operator
1328
       ExecuteExpression(expression = "a+' '+b", to_field = "c")
1329
       the result is {"a": "hello", "b": "world", "c": "hello world"}
1330

1331
    """
1332

1333
    to_field: str
1✔
1334

1335
    def process(
1✔
1336
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1337
    ) -> Dict[str, Any]:
1338
        instance[self.to_field] = self.compute_expression(instance)
1✔
1339
        return instance
1✔
1340

1341

1342
class ExtractMostCommonFieldValues(MultiStreamOperator):
1✔
1343
    field: str
1✔
1344
    stream_name: str
1✔
1345
    overall_top_frequency_percent: Optional[int] = 100
1✔
1346
    min_frequency_percent: Optional[int] = 0
1✔
1347
    to_field: str
1✔
1348
    process_every_value: Optional[bool] = False
1✔
1349

1350
    """
1351
    Extract the unique values of a field ('field') of a given stream ('stream_name') and store (the most frequent of) them
1352
    as a list in a new field ('to_field') in all streams.
1353

1354
    More specifically, sort all the unique values encountered in field 'field' by decreasing order of frequency.
1355
    When 'overall_top_frequency_percent' is smaller than 100, trim the list from bottom, so that the total frequency of
1356
    the remaining values makes 'overall_top_frequency_percent' of the total number of instances in the stream.
1357
    When 'min_frequency_percent' is larger than 0, remove from the list any value whose relative frequency makes
1358
    less than 'min_frequency_percent' of the total number of instances in the stream.
1359
    At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default values.
1360

1361
    Examples:
1362

1363
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
1364
    field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
1365
    every instance in all streams.
1366

1367
    ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
1368
    in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
1369
    value members in these lists, and report the most frequent values.
1370
    if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
1371
    'to_field' of each instance of all streams.
1372

1373
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
1374
    extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
1375
    and stores them in field 'classes' of each instance of all streams.
1376

1377
    ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
1378
    extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
1379
    Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
1380
    """
1381

1382
    def verify(self):
1✔
1383
        assert (
1✔
1384
            self.overall_top_frequency_percent <= 100
1385
            and self.overall_top_frequency_percent >= 0
1386
        ), "'overall_top_frequency_percent' must be between 0 and 100"
1387
        assert (
1✔
1388
            self.min_frequency_percent <= 100 and self.min_frequency_percent >= 0
1389
        ), "'min_frequency_percent' must be between 0 and 100"
1390
        assert not (
1✔
1391
            self.overall_top_frequency_percent < 100 and self.min_frequency_percent > 0
1392
        ), "At most one of 'overall_top_frequency_percent' and 'min_frequency_percent' is allowed to move from their default value"
1393
        super().verify()
1✔
1394

1395
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1396
        stream = multi_stream[self.stream_name]
1✔
1397
        counter = Counter()
1✔
1398
        for instance in stream:
1✔
1399
            if (not isinstance(instance[self.field], list)) and (
1✔
1400
                self.process_every_value is True
1401
            ):
1402
                raise ValueError(
1✔
1403
                    "'process_every_field' is allowed to change to 'True' only for fields whose contents are lists"
1404
                )
1405
            if (not isinstance(instance[self.field], list)) or (
1✔
1406
                self.process_every_value is False
1407
            ):
1408
                # either not a list, or is a list but process_every_value == False : view contetns of 'field' as one entity whose occurrences are counted.
1409
                counter.update(
1✔
1410
                    [(*instance[self.field],)]
1411
                    if isinstance(instance[self.field], list)
1412
                    else [instance[self.field]]
1413
                )  # convert to a tuple if list, to enable the use of Counter which would not accept
1414
                # a list as an hashable entity to count its occurrences
1415
            else:
1416
                # content of 'field' is a list and process_every_value == True: add one occurrence on behalf of each individual value
1417
                counter.update(instance[self.field])
1✔
1418
        # here counter counts occurrences of individual values, or tuples.
1419
        values_and_counts = counter.most_common()
1✔
1420
        if self.overall_top_frequency_percent < 100:
1✔
1421
            top_frequency = (
1✔
1422
                sum(counter.values()) * self.overall_top_frequency_percent / 100.0
1423
            )
1424
            sum_counts = 0
1✔
1425
            for _i, p in enumerate(values_and_counts):
1✔
1426
                sum_counts += p[1]
1✔
1427
                if sum_counts >= top_frequency:
1✔
1428
                    break
1✔
1429
            values_and_counts = counter.most_common(_i + 1)
1✔
1430
        if self.min_frequency_percent > 0:
1✔
1431
            min_frequency = self.min_frequency_percent * sum(counter.values()) / 100.0
1✔
1432
            while values_and_counts[-1][1] < min_frequency:
1✔
1433
                values_and_counts.pop()
1✔
1434
        values_to_keep = [
1✔
1435
            [*ele[0]] if isinstance(ele[0], tuple) else ele[0]
1436
            for ele in values_and_counts
1437
        ]
1438

1439
        addmostcommons = Set(fields={self.to_field: values_to_keep})
1✔
1440
        return addmostcommons(multi_stream)
1✔
1441

1442

1443
class ExtractFieldValues(ExtractMostCommonFieldValues):
1✔
1444
    def verify(self):
1✔
1445
        super().verify()
1✔
1446

1447
    def prepare(self):
1✔
1448
        self.overall_top_frequency_percent = 100
1✔
1449
        self.min_frequency_percent = 0
1✔
1450

1451

1452
class Intersect(FieldOperator):
1✔
1453
    """Intersects the value of a field, which must be a list, with a given list.
1454

1455
    Args:
1456
        allowed_values (list) - list to intersect.
1457
    """
1458

1459
    allowed_values: List[Any]
1✔
1460

1461
    def verify(self):
1✔
1462
        super().verify()
1✔
1463
        if self.process_every_value:
1✔
1464
            raise ValueError(
1✔
1465
                "'process_every_value=True' is not supported in Intersect operator"
1466
            )
1467

1468
        if not isinstance(self.allowed_values, list):
1✔
1469
            raise ValueError(
1✔
1470
                f"The allowed_values is not a list but '{self.allowed_values}'"
1471
            )
1472

1473
    def process_value(self, value: Any) -> Any:
1✔
1474
        super().process_value(value)
1✔
1475
        if not isinstance(value, list):
1✔
1476
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1477
        return [e for e in value if e in self.allowed_values]
1✔
1478

1479

1480
class IntersectCorrespondingFields(InstanceOperator):
1✔
1481
    """Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
1482

1483
    For example:
1484

1485
    Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
1486

1487
    .. code-block:: text
1488

1489
        IntersectCorrespondingFields(field="label",
1490
                                    allowed_values=["b", "f"],
1491
                                    corresponding_fields_to_intersect=["position"])
1492

1493
    would keep only "b" and "f" values in 'labels' field and
1494
    their respective values in the 'position' field.
1495
    (All other fields are not effected)
1496

1497
    .. code-block:: text
1498

1499
        Given this input:
1500

1501
        [
1502
            {"label": ["a", "b"],"position": [0,1],"other" : "not"},
1503
            {"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
1504
            {"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
1505
        ]
1506

1507
        So the output would be:
1508
        [
1509
                {"label": ["b"], "position":[1],"other" : "not"},
1510
                {"label": [], "position": [], "other" : "relevant"},
1511
                {"label": ["b", "f"],"position": [1,2], "other" : "field"},
1512
        ]
1513

1514
    Args:
1515
        field - the field to intersected (must contain list values)
1516
        allowed_values (list) - list of values to keep
1517
        corresponding_fields_to_intersect (list) - additional list fields from which values
1518
        are removed based the corresponding indices of values removed from the 'field'
1519
    """
1520

1521
    field: str
1✔
1522
    allowed_values: List[str]
1✔
1523
    corresponding_fields_to_intersect: List[str]
1✔
1524

1525
    def verify(self):
1✔
1526
        super().verify()
1✔
1527

1528
        if not isinstance(self.allowed_values, list):
1✔
1529
            raise ValueError(
×
1530
                f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
1531
            )
1532

1533
    def process(
1✔
1534
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1535
    ) -> Dict[str, Any]:
1536
        if self.field not in instance:
1✔
1537
            raise ValueError(
1✔
1538
                f"Field '{self.field}' is not in provided instance.\n"
1539
                + to_pretty_string(instance)
1540
            )
1541

1542
        for corresponding_field in self.corresponding_fields_to_intersect:
1✔
1543
            if corresponding_field not in instance:
1✔
1544
                raise ValueError(
1✔
1545
                    f"Field '{corresponding_field}' is not in provided instance.\n"
1546
                    + to_pretty_string(instance)
1547
                )
1548

1549
        if not isinstance(instance[self.field], list):
1✔
1550
            raise ValueError(
1✔
1551
                f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
1552
                + to_pretty_string(instance, keys=[self.field])
1553
            )
1554

1555
        num_values_in_field = len(instance[self.field])
1✔
1556

1557
        if set(self.allowed_values) == set(instance[self.field]):
1✔
1558
            return instance
×
1559

1560
        indices_to_keep = [
1✔
1561
            i
1562
            for i, value in enumerate(instance[self.field])
1563
            if value in set(self.allowed_values)
1564
        ]
1565

1566
        result_instance = {}
1✔
1567
        for field_name, field_value in instance.items():
1✔
1568
            if (
1✔
1569
                field_name in self.corresponding_fields_to_intersect
1570
                or field_name == self.field
1571
            ):
1572
                if not isinstance(field_value, list):
1✔
1573
                    raise ValueError(
×
1574
                        f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
1575
                    )
1576
                if len(field_value) != num_values_in_field:
1✔
1577
                    raise ValueError(
1✔
1578
                        f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
1579
                        + to_pretty_string(instance, keys=[self.field, field_name])
1580
                    )
1581
                result_instance[field_name] = [
1✔
1582
                    value
1583
                    for index, value in enumerate(field_value)
1584
                    if index in indices_to_keep
1585
                ]
1586
            else:
1587
                result_instance[field_name] = field_value
1✔
1588
        return result_instance
1✔
1589

1590

1591
class RemoveValues(FieldOperator):
1✔
1592
    """Removes elements in a field, which must be a list, using a given list of unallowed.
1593

1594
    Args:
1595
        unallowed_values (list) - values to be removed.
1596
    """
1597

1598
    unallowed_values: List[Any]
1✔
1599

1600
    def verify(self):
1✔
1601
        super().verify()
1✔
1602

1603
        if not isinstance(self.unallowed_values, list):
1✔
1604
            raise ValueError(
1✔
1605
                f"The unallowed_values is not a list but '{self.unallowed_values}'"
1606
            )
1607

1608
    def process_value(self, value: Any) -> Any:
1✔
1609
        if not isinstance(value, list):
1✔
1610
            raise ValueError(f"The value in field is not a list but '{value}'")
1✔
1611
        return [e for e in value if e not in self.unallowed_values]
1✔
1612

1613

1614
class Unique(SingleStreamReducer):
1✔
1615
    """Reduces a stream to unique instances based on specified fields.
1616

1617
    Args:
1618
        fields (List[str]): The fields that should be unique in each instance.
1619
    """
1620

1621
    fields: List[str] = field(default_factory=list)
1✔
1622

1623
    @staticmethod
1✔
1624
    def to_tuple(instance: dict, fields: List[str]) -> tuple:
1✔
1625
        result = []
1✔
1626
        for field_name in fields:
1✔
1627
            value = instance[field_name]
1✔
1628
            if isinstance(value, list):
1✔
1629
                value = tuple(value)
1✔
1630
            result.append(value)
1✔
1631
        return tuple(result)
1✔
1632

1633
    def process(self, stream: Stream) -> Stream:
1✔
1634
        seen = set()
1✔
1635
        for instance in stream:
1✔
1636
            values = self.to_tuple(instance, self.fields)
1✔
1637
            if values not in seen:
1✔
1638
                seen.add(values)
1✔
1639
        return list(seen)
1✔
1640

1641

1642
class SplitByValue(MultiStreamOperator):
1✔
1643
    """Splits a MultiStream into multiple streams based on unique values in specified fields.
1644

1645
    Args:
1646
        fields (List[str]): The fields to use when splitting the MultiStream.
1647
    """
1648

1649
    fields: List[str] = field(default_factory=list)
1✔
1650

1651
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1652
        uniques = Unique(fields=self.fields)(multi_stream)
1✔
1653

1654
        result = {}
1✔
1655

1656
        for stream_name, stream in multi_stream.items():
1✔
1657
            stream_unique_values = uniques[stream_name]
1✔
1658
            for unique_values in stream_unique_values:
1✔
1659
                filtering_values = dict(zip(self.fields, unique_values))
1✔
1660
                filtered_streams = FilterByCondition(
1✔
1661
                    values=filtering_values, condition="eq"
1662
                )._process_single_stream(stream)
1663
                filtered_stream_name = (
1✔
1664
                    stream_name + "_" + nested_tuple_to_string(unique_values)
1665
                )
1666
                result[filtered_stream_name] = filtered_streams
1✔
1667

1668
        return MultiStream(result)
1✔
1669

1670

1671
class SplitByNestedGroup(MultiStreamOperator):
1✔
1672
    """Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
1673

1674
    Args:
1675
        number_of_fusion_generations: int
1676

1677
    the value in field group is of the form "sourcen/sourcenminus1/..." describing the sources in which the instance sat
1678
    when these were fused, potentially several phases of fusion. the name of the most recent source sits first in this value.
1679
    (See BaseFusion and its extensions)
1680
    number_of_fuaion_generations  specifies the length of the prefix by which to split the stream.
1681
    E.g. for number_of_fusion_generations = 1, only the most recent fusion in creating this multi_stream, affects the splitting.
1682
    For number_of_fusion_generations = -1, take the whole history written in this field, ignoring number of generations.
1683
    """
1684

1685
    field_name_of_group: str = "group"
1✔
1686
    number_of_fusion_generations: int = 1
1✔
1687

1688
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1689
        result = defaultdict(list)
×
1690

1691
        for stream_name, stream in multi_stream.items():
×
1692
            for instance in stream:
×
1693
                if self.field_name_of_group not in instance:
×
1694
                    raise ValueError(
×
1695
                        f"Field {self.field_name_of_group} is missing from instance. Available fields: {instance.keys()}"
1696
                    )
1697
                signature = (
×
1698
                    stream_name
1699
                    + "~"  #  a sign that does not show within group values
1700
                    + (
1701
                        "/".join(
1702
                            instance[self.field_name_of_group].split("/")[
1703
                                : self.number_of_fusion_generations
1704
                            ]
1705
                        )
1706
                        if self.number_of_fusion_generations >= 0
1707
                        # for values with a smaller number of generations - take up to their last generation
1708
                        else instance[self.field_name_of_group]
1709
                        # for each instance - take all its generations
1710
                    )
1711
                )
1712
                result[signature].append(instance)
×
1713

1714
        return MultiStream.from_iterables(result)
×
1715

1716

1717
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
1✔
1718
    """Applies stream operators to a stream based on specified fields in each instance.
1719

1720
    Args:
1721
        field (str): The field containing the operators to be applied.
1722
        reversed (bool): Whether to apply the operators in reverse order.
1723
    """
1724

1725
    field: str
1✔
1726
    reversed: bool = False
1✔
1727

1728
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1729
        first_instance = stream.peek()
1✔
1730

1731
        operators = first_instance.get(self.field, [])
1✔
1732
        if isinstance(operators, str):
1✔
1733
            operators = [operators]
1✔
1734

1735
        if self.reversed:
1✔
1736
            operators = list(reversed(operators))
1✔
1737

1738
        for operator_name in operators:
1✔
1739
            operator = self.get_artifact(operator_name)
1✔
1740
            assert isinstance(
1✔
1741
                operator, StreamingOperator
1742
            ), f"Operator {operator_name} must be a StreamOperator"
1743

1744
            stream = operator(MultiStream({stream_name: stream}))[stream_name]
1✔
1745

1746
        yield from stream
1✔
1747

1748

1749
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
1✔
1750
    for instance, score in zip(stream, scores):
1✔
1751
        instance["score"] = recursive_copy(score)
1✔
1752
        yield instance
1✔
1753

1754

1755
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
1✔
1756
    """Applies metric operators to a stream based on a metric field specified in each instance.
1757

1758
    Args:
1759
        metric_field (str): The field containing the metrics to be applied.
1760
        calc_confidence_intervals (bool): Whether the applied metric should calculate confidence intervals or not.
1761
    """
1762

1763
    metric_field: str
1✔
1764
    calc_confidence_intervals: bool
1✔
1765

1766
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
1767
        from .metrics import Metric, MetricsList
1✔
1768

1769
        # to be populated only when two or more metrics
1770
        accumulated_scores = []
1✔
1771

1772
        first_instance = stream.peek()
1✔
1773

1774
        metric_names = first_instance.get(self.metric_field, [])
1✔
1775
        if not metric_names:
1✔
1776
            raise RuntimeError(
1✔
1777
                f"Missing metric names in field '{self.metric_field}' and instance '{first_instance}'."
1778
            )
1779

1780
        if isinstance(metric_names, str):
1✔
1781
            metric_names = [metric_names]
1✔
1782

1783
        metrics_list = []
1✔
1784
        for metric_name in metric_names:
1✔
1785
            metric = self.get_artifact(metric_name)
1✔
1786
            if isinstance(metric, MetricsList):
1✔
1787
                metrics_list.extend(list(metric.items))
1✔
1788
            elif isinstance(metric, Metric):
1✔
1789
                metrics_list.append(metric)
1✔
1790
            else:
1791
                raise ValueError(
×
1792
                    f"Operator {metric_name} must be a Metric or MetricsList"
1793
                )
1794

1795
        for metric in metrics_list:
1✔
1796
            if not self.calc_confidence_intervals:
1✔
1797
                metric.disable_confidence_interval_calculation()
1✔
1798
        # Each metric operator computes its score and then sets the main score, overwriting
1799
        # the previous main score value (if any). So, we need to reverse the order of the listed metrics.
1800
        # This will cause the first listed metric to run last, and the main score will be set
1801
        # by the first listed metric (as desired).
1802
        metrics_list = list(reversed(metrics_list))
1✔
1803

1804
        for i, metric in enumerate(metrics_list):
1✔
1805
            if i == 0:  # first metric
1✔
1806
                multi_stream = MultiStream({"tmp": stream})
1✔
1807
            else:  # metrics with previous scores
1808
                reusable_generator = ReusableGenerator(
1✔
1809
                    generator=update_scores_of_stream_instances,
1810
                    gen_kwargs={"stream": stream, "scores": accumulated_scores},
1811
                )
1812
                multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
1✔
1813

1814
            multi_stream = metric(multi_stream)
1✔
1815

1816
            if i < len(metrics_list) - 1:  # last metric
1✔
1817
                accumulated_scores = []
1✔
1818
                for inst in multi_stream["tmp"]:
1✔
1819
                    accumulated_scores.append(recursive_copy(inst["score"]))
1✔
1820

1821
        yield from multi_stream["tmp"]
1✔
1822

1823

1824
class MergeStreams(MultiStreamOperator):
1✔
1825
    """Merges multiple streams into a single stream.
1826

1827
    Args:
1828
        new_stream_name (str): The name of the new stream resulting from the merge.
1829
        add_origin_stream_name (bool): Whether to add the origin stream name to each instance.
1830
        origin_stream_name_field_name (str): The field name for the origin stream name.
1831
    """
1832

1833
    streams_to_merge: List[str] = None
1✔
1834
    new_stream_name: str = "all"
1✔
1835
    add_origin_stream_name: bool = True
1✔
1836
    origin_stream_name_field_name: str = "origin"
1✔
1837

1838
    def merge(self, multi_stream) -> Generator:
1✔
1839
        for stream_name, stream in multi_stream.items():
1✔
1840
            if self.streams_to_merge is None or stream_name in self.streams_to_merge:
1✔
1841
                for instance in stream:
1✔
1842
                    if self.add_origin_stream_name:
1✔
1843
                        instance[self.origin_stream_name_field_name] = stream_name
1✔
1844
                    yield instance
1✔
1845

1846
    def process(self, multi_stream: MultiStream) -> MultiStream:
1✔
1847
        return MultiStream(
1✔
1848
            {
1849
                self.new_stream_name: DynamicStream(
1850
                    self.merge, gen_kwargs={"multi_stream": multi_stream}
1851
                )
1852
            }
1853
        )
1854

1855

1856
class Shuffle(PagedStreamOperator):
1✔
1857
    """Shuffles the order of instances in each page of a stream.
1858

1859
    Args (of superclass):
1860
        page_size (int): The size of each page in the stream. Defaults to 1000.
1861
    """
1862

1863
    random_generator: Random = None
1✔
1864

1865
    def before_process_multi_stream(self):
1✔
1866
        super().before_process_multi_stream()
1✔
1867
        self.random_generator = new_random_generator(sub_seed="shuffle")
1✔
1868

1869
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1870
        self.random_generator.shuffle(page)
1✔
1871
        yield from page
1✔
1872

1873

1874
class FeatureGroupedShuffle(Shuffle):
1✔
1875
    """Class for shuffling an input dataset by instance 'blocks', not on the individual instance level.
1876

1877
    Example is if the dataset consists of questions with paraphrases of it, and each question falls into a topic.
1878
    All paraphrases have the same ID value as the original.
1879
    In this case, we may want to shuffle on grouping_features = ['question ID'],
1880
    to keep the paraphrases and original question together.
1881
    We may also want to group by both 'question ID' and 'topic', if the question IDs are repeated between topics.
1882
    In this case, grouping_features = ['question ID', 'topic']
1883

1884
    Args:
1885
        grouping_features (list of strings): list of feature names to use to define the groups.
1886
            a group is defined by each unique observed combination of data values for features in grouping_features
1887
        shuffle_within_group (bool): whether to further shuffle the instances within each group block, keeping the block order
1888

1889
    Args (of superclass):
1890
        page_size (int): The size of each page in the stream. Defaults to 1000.
1891
            Note: shuffle_by_grouping_features determines the unique groups (unique combinations of values of grouping_features)
1892
            separately by page (determined by page_size).  If a block of instances in the same group are split
1893
            into separate pages (either by a page break falling in the group, or the dataset was not sorted by
1894
            grouping_features), these instances will be shuffled separately and thus the grouping may be
1895
            broken up by pages.  If the user wants to ensure the shuffle does the grouping and shuffling
1896
            across all pages, set the page_size to be larger than the dataset size.
1897
            See outputs_2features_bigpage and outputs_2features_smallpage in test_grouped_shuffle.
1898
    """
1899

1900
    grouping_features: List[str] = None
1✔
1901
    shuffle_within_group: bool = False
1✔
1902

1903
    def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
1✔
1904
        if self.grouping_features is None:
1✔
1905
            super().process(page, stream_name)
×
1906
        else:
1907
            yield from self.shuffle_by_grouping_features(page)
1✔
1908

1909
    def shuffle_by_grouping_features(self, page):
1✔
1910
        import itertools
1✔
1911
        from collections import defaultdict
1✔
1912

1913
        groups_to_instances = defaultdict(list)
1✔
1914
        for item in page:
1✔
1915
            groups_to_instances[
1✔
1916
                tuple(item[ff] for ff in self.grouping_features)
1917
            ].append(item)
1918
        # now extract the groups (i.e., lists of dicts with order preserved)
1919
        page_blocks = list(groups_to_instances.values())
1✔
1920
        # and now shuffle the blocks
1921
        self.random_generator.shuffle(page_blocks)
1✔
1922
        if self.shuffle_within_group:
1✔
1923
            blocks = []
1✔
1924
            # reshuffle the instances within each block, but keep the blocks in order
1925
            for block in page_blocks:
1✔
1926
                self.random_generator.shuffle(block)
1✔
1927
                blocks.append(block)
1✔
1928
            page_blocks = blocks
1✔
1929

1930
        # now flatten the list so it consists of individual dicts, but in (randomized) block order
1931
        return list(itertools.chain(*page_blocks))
1✔
1932

1933

1934
class EncodeLabels(InstanceOperator):
1✔
1935
    """Encode each value encountered in any field in 'fields' into the integers 0,1,...
1936

1937
    Encoding is determined by a str->int map that is built on the go, as different values are
1938
    first encountered in the stream, either as list members or as values in single-value fields.
1939

1940
    Args:
1941
        fields (List[str]): The fields to encode together.
1942

1943
    Example:
1944
        applying ``EncodeLabels(fields = ["a", "b/*"])``
1945
        on input stream = ``[{"a": "red", "b": ["red", "blue"], "c":"bread"},
1946
        {"a": "blue", "b": ["green"], "c":"water"}]``   will yield the
1947
        output stream = ``[{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]``
1948

1949
        Note: dict_utils are applied here, and hence, fields that are lists, should be included in
1950
        input 'fields' with the appendix ``"/*"``  as in the above example.
1951

1952
    """
1953

1954
    fields: List[str]
1✔
1955

1956
    def _process_multi_stream(self, multi_stream: MultiStream) -> MultiStream:
1✔
1957
        self.encoder = {}
1✔
1958
        return super()._process_multi_stream(multi_stream)
1✔
1959

1960
    def process(
1✔
1961
        self, instance: Dict[str, Any], stream_name: Optional[str] = None
1962
    ) -> Dict[str, Any]:
1963
        for field_name in self.fields:
1✔
1964
            values = dict_get(instance, field_name)
1✔
1965
            values_was_a_list = isinstance(values, list)
1✔
1966
            if not isinstance(values, list):
1✔
1967
                values = [values]
1✔
1968
            for value in values:
1✔
1969
                if value not in self.encoder:
1✔
1970
                    self.encoder[value] = len(self.encoder)
1✔
1971
            new_values = [self.encoder[value] for value in values]
1✔
1972
            if not values_was_a_list:
1✔
1973
                new_values = new_values[0]
1✔
1974
            dict_set(
1✔
1975
                instance,
1976
                field_name,
1977
                new_values,
1978
                not_exist_ok=False,  # the values to encode where just taken from there
1979
                set_multiple="*" in field_name
1980
                and isinstance(new_values, list)
1981
                and len(new_values) > 0,
1982
            )
1983

1984
        return instance
1✔
1985

1986

1987
class StreamRefiner(StreamOperator):
1✔
1988
    """Discard from the input stream all instances beyond the leading 'max_instances' instances.
1989

1990
    Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
1991
    input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
1992
    of the leading 'max_instances' of the input stream.
1993

1994
    Args:
1995
        max_instances (int)
1996
        apply_to_streams (optional, list(str)):
1997
            names of streams to refine.
1998

1999
    Examples:
2000
        when input = ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}]`` is fed into
2001
        ``StreamRefiner(max_instances=4)``
2002
        the resulting stream is ``[{"a": 1},{"a": 2},{"a": 3},{"a": 4}]``
2003
    """
2004

2005
    max_instances: int = None
1✔
2006
    apply_to_streams: Optional[List[str]] = None
1✔
2007

2008
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2009
        if self.max_instances is not None:
1✔
2010
            yield from stream.take(self.max_instances)
1✔
2011
        else:
2012
            yield from stream
1✔
2013

2014

2015
class Deduplicate(StreamOperator):
1✔
2016
    """Deduplicate the stream based on the given fields.
2017

2018
    Args:
2019
        by (List[str]): A list of field names to deduplicate by. The combination of these fields' values will be used to determine uniqueness.
2020

2021
    Examples:
2022
        >>> dedup = Deduplicate(by=["field1", "field2"])
2023
    """
2024

2025
    by: List[str]
1✔
2026

2027
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2028
        seen = set()
1✔
2029

2030
        for instance in stream:
1✔
2031
            # Compute a lightweight hash for the signature
2032
            signature = hash(str(tuple(dict_get(instance, field) for field in self.by)))
1✔
2033

2034
            if signature not in seen:
1✔
2035
                seen.add(signature)
1✔
2036
                yield instance
1✔
2037

2038

2039
class Balance(StreamRefiner):
1✔
2040
    """A class used to balance streams deterministically.
2041

2042
    For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
2043
    By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
2044
    When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
2045
    'max_instances'. The total number of discarded instances is as few as possible.
2046

2047
    Args:
2048
        fields (List[str]):
2049
            A list of field names to be used in producing the instance's signature.
2050
        max_instances (Optional, int):
2051
            overall max.
2052

2053
    Usage:
2054
        ``balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)``
2055
        ``balanced_stream = balancer.process(stream)``
2056

2057
    Example:
2058
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}]`` is fed into
2059
        ``DeterministicBalancer(fields=["a"])``
2060
        the resulting stream will be: ``[{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]``
2061
    """
2062

2063
    fields: List[str]
1✔
2064

2065
    def signature(self, instance):
1✔
2066
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2067

2068
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2069
        counter = Counter()
1✔
2070

2071
        for instance in stream:
1✔
2072
            counter[self.signature(instance)] += 1
1✔
2073

2074
        if len(counter) == 0:
1✔
2075
            return
1✔
2076

2077
        lowest_count = counter.most_common()[-1][-1]
1✔
2078

2079
        max_total_instances_per_sign = lowest_count
1✔
2080
        if self.max_instances is not None:
1✔
2081
            max_total_instances_per_sign = min(
1✔
2082
                lowest_count, self.max_instances // len(counter)
2083
            )
2084

2085
        counter = Counter()
1✔
2086

2087
        for instance in stream:
1✔
2088
            sign = self.signature(instance)
1✔
2089
            if counter[sign] < max_total_instances_per_sign:
1✔
2090
                counter[sign] += 1
1✔
2091
                yield instance
1✔
2092

2093

2094
class DeterministicBalancer(Balance):
1✔
2095
    pass
1✔
2096

2097

2098
class MinimumOneExamplePerLabelRefiner(StreamRefiner):
1✔
2099
    """A class used to return a specified number instances ensuring at least one example  per label.
2100

2101
    For each instance, a signature value is constructed from the values of the instance in specified input ``fields``.
2102
    ``MinimumOneExamplePerLabelRefiner`` takes first instance that appears from each label (each unique signature), and then adds more elements up to the max_instances limit.  In general, the refiner takes the first elements in the stream that meet the required conditions.
2103
    ``MinimumOneExamplePerLabelRefiner`` then shuffles the results to avoid having one instance
2104
    from each class first and then the rest . If max instance is not set, the original stream will be used
2105

2106
    Args:
2107
        fields (List[str]):
2108
            A list of field names to be used in producing the instance's signature.
2109
        max_instances (Optional, int):
2110
            Number of elements to select. Note that max_instances of StreamRefiners
2111
            that are passed to the recipe (e.g. ``train_refiner``. ``test_refiner``) are overridden
2112
            by the recipe parameters ( ``max_train_instances``, ``max_test_instances``)
2113

2114
    Usage:
2115
        | ``balancer = MinimumOneExamplePerLabelRefiner(fields=["field1", "field2"], max_instances=200)``
2116
        | ``balanced_stream = balancer.process(stream)``
2117

2118
    Example:
2119
        When input ``[{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 1, "b": 3},{"a": 1, "b": 4},{"a": 2, "b": 5}]`` is fed into
2120
        ``MinimumOneExamplePerLabelRefiner(fields=["a"], max_instances=3)``
2121
        the resulting stream will be:
2122
        ``[{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 5}]`` (order may be different)
2123
    """
2124

2125
    fields: List[str]
1✔
2126

2127
    def signature(self, instance):
1✔
2128
        return str(tuple(dict_get(instance, field) for field in self.fields))
1✔
2129

2130
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2131
        if self.max_instances is None:
1✔
2132
            for instance in stream:
×
2133
                yield instance
×
2134

2135
        counter = Counter()
1✔
2136
        for instance in stream:
1✔
2137
            counter[self.signature(instance)] += 1
1✔
2138
        all_keys = counter.keys()
1✔
2139
        if len(counter) == 0:
1✔
2140
            return
×
2141

2142
        if self.max_instances is not None and len(all_keys) > self.max_instances:
1✔
2143
            raise Exception(
×
2144
                f"Can not generate a stream with at least one example per label, because the max instances requested  {self.max_instances} is smaller than the number of different labels {len(all_keys)}"
2145
                f" ({len(all_keys)}"
2146
            )
2147

2148
        counter = Counter()
1✔
2149
        used_indices = set()
1✔
2150
        selected_elements = []
1✔
2151
        # select at least one per class
2152
        for idx, instance in enumerate(stream):
1✔
2153
            sign = self.signature(instance)
1✔
2154
            if counter[sign] == 0:
1✔
2155
                counter[sign] += 1
1✔
2156
                used_indices.add(idx)
1✔
2157
                selected_elements.append(
1✔
2158
                    instance
2159
                )  # collect all elements first to allow shuffling of both groups
2160

2161
        # select more to reach self.max_instances examples
2162
        for idx, instance in enumerate(stream):
1✔
2163
            if idx not in used_indices:
1✔
2164
                if self.max_instances is None or len(used_indices) < self.max_instances:
1✔
2165
                    used_indices.add(idx)
1✔
2166
                    selected_elements.append(
1✔
2167
                        instance
2168
                    )  # collect all elements first to allow shuffling of both groups
2169

2170
        # shuffle elements to avoid having one element from each class appear first
2171
        random_generator = new_random_generator(sub_seed=selected_elements)
1✔
2172
        random_generator.shuffle(selected_elements)
1✔
2173
        yield from selected_elements
1✔
2174

2175

2176
class LengthBalancer(DeterministicBalancer):
1✔
2177
    """Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
2178

2179
    Args:
2180
        segments_boundaries (List[int]):
2181
            distinct integers sorted in increasing order, that map a given total length
2182
            into the index of the least of them that exceeds the given total length.
2183
            (If none exceeds -- into one index beyond, namely, the length of segments_boundaries)
2184
        fields (Optional, List[str]):
2185
            the total length of the values of these fields goes through the quantization described above
2186

2187

2188
    Example:
2189
        when input ``[{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}]``
2190
        is fed into ``LengthBalancer(fields=["a"], segments_boundaries=[1])``,
2191
        input instances will be counted and balanced against two categories:
2192
        empty total length (less than 1), and non-empty.
2193
    """
2194

2195
    segments_boundaries: List[int]
1✔
2196
    fields: Optional[List[str]]
1✔
2197

2198
    def signature(self, instance):
1✔
2199
        total_len = 0
1✔
2200
        for field_name in self.fields:
1✔
2201
            total_len += len(dict_get(instance, field_name))
1✔
2202
        for i, val in enumerate(self.segments_boundaries):
1✔
2203
            if total_len < val:
1✔
2204
                return i
1✔
2205
        return i + 1
1✔
2206

2207

2208
class DownloadError(Exception):
1✔
2209
    def __init__(
1✔
2210
        self,
2211
        message,
2212
    ):
2213
        self.__super__(message)
×
2214

2215

2216
class UnexpectedHttpCodeError(Exception):
1✔
2217
    def __init__(self, http_code):
1✔
2218
        self.__super__(f"unexpected http code {http_code}")
×
2219

2220

2221
class DownloadOperator(SideEffectOperator):
1✔
2222
    """Operator for downloading a file from a given URL to a specified local path.
2223

2224
    Args:
2225
        source (str):
2226
            URL of the file to be downloaded.
2227
        target (str):
2228
            Local path where the downloaded file should be saved.
2229
    """
2230

2231
    source: str
1✔
2232
    target: str
1✔
2233

2234
    def process(self):
1✔
2235
        try:
×
2236
            response = requests.get(self.source, allow_redirects=True)
×
2237
        except Exception as e:
×
2238
            raise DownloadError(f"Unabled to download {self.source}") from e
×
2239
        if response.status_code != 200:
×
2240
            raise UnexpectedHttpCodeError(response.status_code)
×
2241
        with open(self.target, "wb") as f:
×
2242
            f.write(response.content)
×
2243

2244

2245
class ExtractZipFile(SideEffectOperator):
1✔
2246
    """Operator for extracting files from a zip archive.
2247

2248
    Args:
2249
        zip_file (str):
2250
            Path of the zip file to be extracted.
2251
        target_dir (str):
2252
            Directory where the contents of the zip file will be extracted.
2253
    """
2254

2255
    zip_file: str
1✔
2256
    target_dir: str
1✔
2257

2258
    def process(self):
1✔
2259
        with zipfile.ZipFile(self.zip_file) as zf:
×
2260
            zf.extractall(self.target_dir)
×
2261

2262

2263
class DuplicateInstances(StreamOperator):
1✔
2264
    """Operator which duplicates each instance in stream a given number of times.
2265

2266
    Args:
2267
        num_duplications (int):
2268
            How many times each instance should be duplicated (1 means no duplication).
2269
        duplication_index_field (Optional[str]):
2270
            If given, then additional field with specified name is added to each duplicated instance,
2271
            which contains id of a given duplication. Defaults to None, so no field is added.
2272
    """
2273

2274
    num_duplications: int
1✔
2275
    duplication_index_field: Optional[str] = None
1✔
2276

2277
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2278
        for instance in stream:
1✔
2279
            for idx in range(self.num_duplications):
1✔
2280
                duplicate = recursive_shallow_copy(instance)
1✔
2281
                if self.duplication_index_field:
1✔
2282
                    duplicate.update({self.duplication_index_field: idx})
1✔
2283
                yield duplicate
1✔
2284

2285
    def verify(self):
1✔
2286
        if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
1✔
2287
            raise ValueError(
×
2288
                f"num_duplications must be an integer equal to or greater than 1. "
2289
                f"Got: {self.num_duplications}."
2290
            )
2291

2292
        if self.duplication_index_field is not None and not isinstance(
1✔
2293
            self.duplication_index_field, str
2294
        ):
2295
            raise ValueError(
×
2296
                f"If given, duplication_index_field must be a string. "
2297
                f"Got: {self.duplication_index_field}"
2298
            )
2299

2300

2301
class CollateInstances(StreamOperator):
1✔
2302
    """Operator which collates values from multiple instances to a single instance.
2303

2304
    Each field becomes the list of values of corresponding field of collated `batch_size` of instances.
2305

2306
    Attributes:
2307
        batch_size (int)
2308

2309
    Example:
2310
        .. code-block:: text
2311

2312
            CollateInstances(batch_size=2)
2313

2314
            Given inputs = [
2315
                {"a": 1, "b": 2},
2316
                {"a": 2, "b": 2},
2317
                {"a": 3, "b": 2},
2318
                {"a": 4, "b": 2},
2319
                {"a": 5, "b": 2}
2320
            ]
2321

2322
            Returns targets = [
2323
                {"a": [1,2], "b": [2,2]},
2324
                {"a": [3,4], "b": [2,2]},
2325
                {"a": [5], "b": [2]},
2326
            ]
2327

2328

2329
    """
2330

2331
    batch_size: int
1✔
2332

2333
    def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
1✔
2334
        stream = list(stream)
1✔
2335
        for i in range(0, len(stream), self.batch_size):
1✔
2336
            batch = stream[i : i + self.batch_size]
1✔
2337
            new_instance = {}
1✔
2338
            for a_field in batch[0]:
1✔
2339
                if a_field == "data_classification_policy":
1✔
2340
                    flattened_list = [
1✔
2341
                        classification
2342
                        for instance in batch
2343
                        for classification in instance[a_field]
2344
                    ]
2345
                    new_instance[a_field] = sorted(set(flattened_list))
1✔
2346
                else:
2347
                    new_instance[a_field] = [instance[a_field] for instance in batch]
1✔
2348
            yield new_instance
1✔
2349

2350
    def verify(self):
1✔
2351
        if not isinstance(self.batch_size, int) or self.batch_size < 1:
1✔
2352
            raise ValueError(
×
2353
                f"batch_size must be an integer equal to or greater than 1. "
2354
                f"Got: {self.batch_size}."
2355
            )
2356

2357

2358
class CollateInstancesByField(StreamOperator):
1✔
2359
    """Groups a list of instances by a specified field, aggregates specified fields into lists, and ensures consistency for all other non-aggregated fields.
2360

2361
    Args:
2362
        by_field str: the name of the field to group data by.
2363
        aggregate_fields list(str): the field names to aggregate into lists.
2364

2365
    Returns:
2366
        A stream of instances grouped and aggregated by the specified field.
2367

2368
    Raises:
2369
        UnitxtError: If non-aggregate fields have inconsistent values.
2370

2371
    Example:
2372
        Collate the instances based on field "category" and aggregate fields "value" and "id".
2373

2374
        .. code-block:: text
2375

2376
            CollateInstancesByField(by_field="category", aggregate_fields=["value", "id"])
2377

2378
            given input:
2379
            [
2380
                {"id": 1, "category": "A", "value": 10", "flag" : True},
2381
                {"id": 2, "category": "B", "value": 20", "flag" : False},
2382
                {"id": 3, "category": "A", "value": 30", "flag" : True},
2383
                {"id": 4, "category": "B", "value": 40", "flag" : False}
2384
            ]
2385

2386
            the output is:
2387
            [
2388
                {"category": "A", "id": [1, 3], "value": [10, 30], "info": True},
2389
                {"category": "B", "id": [2, 4], "value": [20, 40], "info": False}
2390
            ]
2391

2392
        Note that the "flag" field is not aggregated, and must be the same
2393
        in all instances in the same category, or an error is raised.
2394
    """
2395

2396
    by_field: str = NonPositionalField(required=True)
1✔
2397
    aggregate_fields: List[str] = NonPositionalField(required=True)
1✔
2398

2399
    def prepare(self):
1✔
2400
        super().prepare()
1✔
2401

2402
    def verify(self):
1✔
2403
        super().verify()
1✔
2404
        if not isinstance(self.by_field, str):
1✔
2405
            raise UnitxtError(
×
2406
                f"The 'by_field' value is not a string but '{type(self.by_field)}'"
2407
            )
2408

2409
        if not isinstance(self.aggregate_fields, list):
1✔
2410
            raise UnitxtError(
×
2411
                f"The 'allowed_field_values' is not a list but '{type(self.aggregate_fields)}'"
2412
            )
2413

2414
    def process(self, stream: Stream, stream_name: Optional[str] = None):
1✔
2415
        grouped_data = {}
1✔
2416

2417
        for instance in stream:
1✔
2418
            if self.by_field not in instance:
1✔
2419
                raise UnitxtError(
1✔
2420
                    f"The field '{self.by_field}' specified by CollateInstancesByField's 'by_field' argument is not found in instance."
2421
                )
2422
            for k in self.aggregate_fields:
1✔
2423
                if k not in instance:
1✔
2424
                    raise UnitxtError(
1✔
2425
                        f"The field '{k}' specified in CollateInstancesByField's 'aggregate_fields' argument is not found in instance."
2426
                    )
2427
            key = instance[self.by_field]
1✔
2428

2429
            if key not in grouped_data:
1✔
2430
                grouped_data[key] = {
1✔
2431
                    k: v for k, v in instance.items() if k not in self.aggregate_fields
2432
                }
2433
                # Add empty lists for fields to aggregate
2434
                for agg_field in self.aggregate_fields:
1✔
2435
                    if agg_field in instance:
1✔
2436
                        grouped_data[key][agg_field] = []
1✔
2437

2438
            for k, v in instance.items():
1✔
2439
                # Merge classification policy list across instance with same key
2440
                if k == "data_classification_policy" and instance[k]:
1✔
2441
                    grouped_data[key][k] = sorted(set(grouped_data[key][k] + v))
1✔
2442
                # Check consistency for all non-aggregate fields
2443
                elif k != self.by_field and k not in self.aggregate_fields:
1✔
2444
                    if k in grouped_data[key] and grouped_data[key][k] != v:
1✔
2445
                        raise ValueError(
1✔
2446
                            f"Inconsistent value for field '{k}' in group '{key}': "
2447
                            f"'{grouped_data[key][k]}' vs '{v}'. Ensure that all non-aggregated fields in CollateInstancesByField are consistent across all instances."
2448
                        )
2449
                # Aggregate fields
2450
                elif k in self.aggregate_fields:
1✔
2451
                    grouped_data[key][k].append(instance[k])
1✔
2452

2453
        yield from grouped_data.values()
1✔
2454

2455

2456
class WikipediaFetcher(FieldOperator):
1✔
2457
    mode: Literal["summary", "text"] = "text"
1✔
2458
    _requirements_list = ["Wikipedia-API"]
1✔
2459

2460
    def prepare(self):
1✔
2461
        super().prepare()
×
2462
        import wikipediaapi
×
2463

2464
        self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
×
2465

2466
    def process_value(self, value: Any) -> Any:
1✔
2467
        title = value.split("/")[-1]
×
2468
        page = self.wikipedia.page(title)
×
2469

2470
        return {"title": page.title, "body": getattr(page, self.mode)}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc