• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16704320175

03 Aug 2025 11:05AM UTC coverage: 80.829% (-0.4%) from 81.213%
16704320175

Pull #1845

github

web-flow
Merge 59428aa88 into 5372aa6df
Pull Request #1845: Allow using python functions instead of operators (e.g in pre-processing pipeline)

1576 of 1970 branches covered (80.0%)

Branch coverage included in aggregate %.

10685 of 13199 relevant lines covered (80.95%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.01
src/unitxt/dataclass.py
1
import copy
1✔
2
import dataclasses
1✔
3
import functools
1✔
4
import inspect
1✔
5
import json
1✔
6
from abc import ABCMeta
1✔
7
from inspect import Parameter, Signature
1✔
8
from typing import Any, Dict, List, Optional, final
1✔
9

10
_FIELDS = "__fields__"
1✔
11

12

13
class Undefined:
1✔
14
    pass
15

16

17
@dataclasses.dataclass
1✔
18
class Field:
1✔
19
    """An alternative to dataclasses.dataclass decorator for a more flexible field definition.
20

21
    Args:
22
        default (Any, optional):
23
            Default value for the field. Defaults to None.
24
        name (str, optional):
25
            Name of the field. Defaults to None.
26
        type (type, optional):
27
            Type of the field. Defaults to None.
28
        default_factory (Any, optional):
29
            A function that returns the default value. Defaults to None.
30
        final (bool, optional):
31
            A boolean indicating if the field is final (cannot be overridden). Defaults to False.
32
        abstract (bool, optional):
33
            A boolean indicating if the field is abstract (must be implemented by subclasses). Defaults to False.
34
        required (bool, optional):
35
            A boolean indicating if the field is required. Defaults to False.
36
        origin_cls (type, optional):
37
            The original class that defined the field. Defaults to None.
38
    """
39

40
    default: Any = Undefined
1✔
41
    name: str = None
1✔
42
    type: type = None
1✔
43
    init: bool = True
1✔
44
    also_positional: bool = True
1✔
45
    default_factory: Any = None
1✔
46
    final: bool = False
1✔
47
    abstract: bool = False
1✔
48
    required: bool = False
1✔
49
    internal: bool = False
1✔
50
    origin_cls: type = None
1✔
51
    metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
1✔
52

53
    def get_default(self):
1✔
54
        if self.default_factory is not None:
×
55
            return self.default_factory()
×
56
        return self.default
×
57

58

59
@dataclasses.dataclass
1✔
60
class FinalField(Field):
1✔
61
    def __post_init__(self):
1✔
62
        self.final = True
1✔
63

64

65
@dataclasses.dataclass
1✔
66
class RequiredField(Field):
1✔
67
    def __post_init__(self):
1✔
68
        self.required = True
1✔
69

70

71
class MissingDefaultError(TypeError):
1✔
72
    pass
73

74

75
@dataclasses.dataclass
1✔
76
class OptionalField(Field):
1✔
77
    def __post_init__(self):
1✔
78
        self.required = False
1✔
79
        if self.default is Undefined and self.default_factory is None:
1✔
80
            raise MissingDefaultError(
1✔
81
                "OptionalField must have default or default_factory"
82
            )
83

84

85
@dataclasses.dataclass
1✔
86
class AbstractField(Field):
1✔
87
    def __post_init__(self):
1✔
88
        self.abstract = True
1✔
89

90

91
@dataclasses.dataclass
1✔
92
class NonPositionalField(Field):
1✔
93
    def __post_init__(self):
1✔
94
        self.also_positional = False
1✔
95

96

97
@dataclasses.dataclass
1✔
98
class InternalField(Field):
1✔
99
    def __post_init__(self):
1✔
100
        self.internal = True
1✔
101
        self.init = False
1✔
102
        self.also_positional = False
1✔
103

104

105
class FinalFieldError(TypeError):
1✔
106
    pass
107

108

109
class RequiredFieldError(TypeError):
1✔
110
    pass
111

112

113
class AbstractFieldError(TypeError):
1✔
114
    pass
115

116

117
class TypeMismatchError(TypeError):
1✔
118
    pass
119

120

121
class UnexpectedArgumentError(TypeError):
1✔
122
    pass
123

124

125
standard_variables = dir(object)
1✔
126

127

128
def is_class_method(func):
1✔
129
    if inspect.ismethod(func):
1✔
130
        return True
1✔
131
    if inspect.isfunction(func):
1✔
132
        sig = inspect.signature(func)
1✔
133
        params = list(sig.parameters.values())
1✔
134
        if len(params) > 0 and params[0].name in ["self", "cls"]:
1✔
135
            return True
1✔
136
    return False
1✔
137

138

139
def is_possible_field(field_name, field_value):
1✔
140
    """Check if a name-value pair can potentially represent a field.
141

142
    Args:
143
        field_name (str): The name of the field.
144
        field_value: The value of the field.
145

146
    Returns:
147
        bool: True if the name-value pair can represent a field, False otherwise.
148
    """
149
    if field_name in standard_variables:
1✔
150
        return False
1✔
151
    if is_class_method(field_value):
1✔
152
        return False
1✔
153
    return True
1✔
154

155

156
def get_fields(cls, attrs):
1✔
157
    """Get the fields for a class based on its attributes.
158

159
    Args:
160
        cls (type): The class to get the fields for.
161
        attrs (dict): The attributes of the class.
162

163
    Returns:
164
        dict: A dictionary mapping field names to Field instances.
165
    """
166
    fields = {}
1✔
167
    for base in cls.__bases__:
1✔
168
        fields = {**getattr(base, _FIELDS, {}), **fields}
1✔
169
    annotations = {**attrs.get("__annotations__", {})}
1✔
170

171
    for attr_name, attr_value in attrs.items():
1✔
172
        if attr_name not in annotations and is_possible_field(attr_name, attr_value):
1✔
173
            if attr_name in fields:
1✔
174
                try:
1✔
175
                    if not isinstance(attr_value, fields[attr_name].type):
1✔
176
                        raise TypeMismatchError(
1✔
177
                            f"Type mismatch for field '{attr_name}' of class '{fields[attr_name].origin_cls}'. Expected {fields[attr_name].type}, got {type(attr_value)}"
178
                        )
179
                except TypeError:
1✔
180
                    pass
181
                annotations[attr_name] = fields[attr_name].type
1✔
182

183
    for field_name, field_type in annotations.items():
1✔
184
        if field_name in fields and fields[field_name].final:
1✔
185
            raise FinalFieldError(
1✔
186
                f"Final field {field_name} defined in {fields[field_name].origin_cls} overridden in {cls}"
187
            )
188

189
        args = {
1✔
190
            "name": field_name,
191
            "type": field_type,
192
            "origin_cls": attrs["__qualname__"],
193
        }
194

195
        if field_name in attrs:
1✔
196
            field_value = attrs[field_name]
1✔
197
            if isinstance(field_value, Field):
1✔
198
                args = {**dataclasses.asdict(field_value), **args}
1✔
199
            elif isinstance(field_value, dataclasses.Field):
1✔
200
                args = {
1✔
201
                    "default": field_value.default,
202
                    "name": field_value.name,
203
                    "type": field_value.type,
204
                    "init": field_value.init,
205
                    "default_factory": field_value.default_factory,
206
                    **args,
207
                }
208
            else:
209
                args["default"] = field_value
1✔
210
                args["default_factory"] = None
1✔
211
        else:
212
            args["default"] = dataclasses.MISSING
1✔
213
            args["default_factory"] = None
1✔
214
            args["required"] = True
1✔
215

216
        field_instance = Field(**args)
1✔
217
        fields[field_name] = field_instance
1✔
218

219
        if cls.__allow_unexpected_arguments__:
1✔
220
            fields["_argv"] = InternalField(name="_argv", type=tuple, default=())
1✔
221
            fields["_kwargs"] = InternalField(name="_kwargs", type=dict, default={})
1✔
222

223
    return fields
1✔
224

225

226
def is_dataclass(obj):
1✔
227
    """Returns True if obj is a dataclass or an instance of a dataclass."""
228
    cls = obj if isinstance(obj, type) else type(obj)
1✔
229
    return hasattr(cls, _FIELDS)
1✔
230

231

232
def class_fields(obj):
1✔
233
    all_fields = fields(obj)
1✔
234
    return [
1✔
235
        field for field in all_fields if field.origin_cls == obj.__class__.__qualname__
236
    ]
237

238

239
def fields(cls):
1✔
240
    return list(getattr(cls, _FIELDS).values())
1✔
241

242

243
def fields_names(cls):
1✔
244
    return list(getattr(cls, _FIELDS).keys())
1✔
245

246

247
def external_fields_names(cls):
1✔
248
    return [field.name for field in fields(cls) if not field.internal]
1✔
249

250

251
def final_fields(cls):
1✔
252
    return [field for field in fields(cls) if field.final]
×
253

254

255
def required_fields(cls):
1✔
256
    return [field for field in fields(cls) if field.required]
1✔
257

258

259
def abstract_fields(cls):
1✔
260
    return [field for field in fields(cls) if field.abstract]
1✔
261

262

263
def is_abstract_field(field):
1✔
264
    return field.abstract
1✔
265

266

267
def is_final_field(field):
1✔
268
    return field.final
1✔
269

270

271
def get_field_default(field):
1✔
272
    if field.default_factory is not None:
1✔
273
        return field.default_factory()
1✔
274

275
    return field.default
1✔
276

277

278
def asdict(obj):
1✔
279
    assert is_dataclass(
×
280
        obj
281
    ), f"{obj} must be a dataclass, got {type(obj)} with bases {obj.__class__.__bases__}"
282
    return _asdict_inner(obj)
×
283

284

285
def _asdict_inner(obj):
1✔
286
    if is_dataclass(obj):
1✔
287
        return obj.to_dict()
1✔
288

289
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
1✔
290
        return type(obj)(*[_asdict_inner(v) for v in obj])
×
291

292
    if isinstance(obj, (list, tuple)):
1✔
293
        return type(obj)([_asdict_inner(v) for v in obj])
1✔
294

295
    if isinstance(obj, dict):
1✔
296
        return type(obj)({_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()})
1✔
297

298
    return copy.deepcopy(obj)
1✔
299

300

301
def to_dict(obj, func=copy.deepcopy, _visited=None):
1✔
302
    """Recursively converts an object into a dictionary representation while avoiding infinite recursion due to circular references.
303

304
    Args:
305
        obj: Any Python object to be converted into a dictionary-like structure.
306
        func (Callable, optional): A function applied to non-iterable objects. Defaults to `copy.deepcopy`.
307
        _visited (set, optional): A set of object IDs used to track visited objects and prevent infinite recursion.
308

309
    Returns:
310
        dict: A dictionary representation of the input object, with supported collections and dataclasses
311
        recursively processed.
312

313
    Notes:
314
        - Supports dataclasses, named tuples, lists, tuples, and dictionaries.
315
        - Circular references are detected using object IDs and replaced by `func(obj)`.
316
        - Named tuples retain their original type instead of being converted to dictionaries.
317
    """
318
    # Initialize visited set on first call
319
    if _visited is None:
1✔
320
        _visited = set()
1✔
321

322
    # Get object ID to track visited objects
323
    obj_id = id(obj)
1✔
324

325
    if isinstance(obj, (int, float, bool)):
1✔
326
        # normalize constants like re.DOTALL
327
        obj = json.loads(json.dumps(obj))
1✔
328

329
    # If we've seen this object before, return a placeholder to avoid infinite recursion
330
    if obj_id in _visited:
1✔
331
        return func(obj)
1✔
332

333
    # For mutable objects, add to visited set before recursing
334
    if (
1✔
335
        isinstance(obj, (dict, list))
336
        or is_dataclass(obj)
337
        or (isinstance(obj, tuple) and hasattr(obj, "_fields"))
338
    ):
339
        _visited.add(obj_id)
1✔
340

341
    if is_dataclass(obj):
1✔
342
        return {
1✔
343
            field.name: to_dict(getattr(obj, field.name), func, _visited)
344
            for field in fields(obj)
345
        }
346

347
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
1✔
348
        return type(obj)(*[to_dict(v, func, _visited) for v in obj])
×
349

350
    if isinstance(obj, (list, tuple)):
1✔
351
        return type(obj)([to_dict(v, func, _visited) for v in obj])
1✔
352

353
    if isinstance(obj, dict):
1✔
354
        return type(obj)(
1✔
355
            {
356
                to_dict(k, func, _visited): to_dict(v, func, _visited)
357
                for k, v in obj.items()
358
            }
359
        )
360

361
    return func(obj)
1✔
362

363

364
class DataclassMeta(ABCMeta):
1✔
365
    """Metaclass for Dataclass.
366

367
    Checks for final fields when a subclass is created.
368
    """
369

370
    @final
1✔
371
    def __init__(cls, name, bases, attrs):
1✔
372
        super().__init__(name, bases, attrs)
1✔
373
        fields = get_fields(cls, attrs)
1✔
374
        setattr(cls, _FIELDS, fields)
1✔
375
        cls.update_init_signature()
1✔
376

377
    def update_init_signature(cls):
1✔
378
        parameters = []
1✔
379

380
        for name, field in getattr(cls, _FIELDS).items():
1✔
381
            if field.init and not field.internal:
1✔
382
                if field.default is not Undefined:
1✔
383
                    default_value = field.default
1✔
384
                elif field.default_factory is not None:
1✔
385
                    default_value = field.default_factory()
1✔
386
                else:
387
                    default_value = Parameter.empty
1✔
388

389
                if isinstance(default_value, dataclasses._MISSING_TYPE):
1✔
390
                    default_value = Parameter.empty
1✔
391
                param = Parameter(
1✔
392
                    name,
393
                    Parameter.POSITIONAL_OR_KEYWORD,
394
                    default=default_value,
395
                    annotation=field.type,
396
                )
397
                parameters.append(param)
1✔
398

399
        if getattr(cls, "__allow_unexpected_arguments__", False):
1✔
400
            parameters.append(Parameter("_argv", Parameter.VAR_POSITIONAL))
1✔
401
            parameters.append(Parameter("_kwargs", Parameter.VAR_KEYWORD))
1✔
402

403
        signature = Signature(parameters, __validate_parameters__=False)
1✔
404

405
        original_init = cls.__init__
1✔
406

407
        @functools.wraps(original_init)
1✔
408
        def custom_cls_init(self, *args, **kwargs):
1✔
409
            original_init(self, *args, **kwargs)
1✔
410

411
        custom_cls_init.__signature__ = signature
1✔
412
        cls.__init__ = custom_cls_init
1✔
413

414

415
class Dataclass(metaclass=DataclassMeta):
1✔
416
    """Base class for data-like classes that provides additional functionality and control.
417

418
    Base class for data-like classes that provides additional functionality and control
419
    over Python's built-in @dataclasses.dataclass decorator. Other classes can inherit from
420
    this class to get the benefits of this implementation. As a base class, it ensures that
421
    all subclasses will automatically be data classes.
422

423
    The usage and field definitions are similar to Python's built-in @dataclasses.dataclass decorator.
424
    However, this implementation provides additional classes for defining "final", "required",
425
    and "abstract" fields.
426

427
    Key enhancements of this custom implementation:
428

429
    1. Automatic Data Class Creation: All subclasses automatically become data classes,
430
       without needing to use the @dataclasses.dataclass decorator.
431

432
    2. Field Immutability: Supports creation of "final" fields (using FinalField class) that
433
       cannot be overridden by subclasses. This functionality is not natively supported in
434
       Python or in the built-in dataclasses module.
435

436
    3. Required Fields: Supports creation of "required" fields (using RequiredField class) that
437
       must be provided when creating an instance of the class, adding a level of validation
438
       not present in the built-in dataclasses module.
439

440
    4. Abstract Fields: Supports creation of "abstract" fields (using AbstractField class) that
441
       must be overridden by any non-abstract subclass. This is similar to abstract methods in
442
       an abc.ABC class, but applied to fields.
443

444
    5. Type Checking: Performs type checking to ensure that if a field is redefined in a subclass,
445
       the type of the field remains consistent, adding static type checking not natively supported
446
       in Python.
447

448
    6. Error Definitions: Defines specific error types (FinalFieldError, RequiredFieldError,
449
       AbstractFieldError, TypeMismatchError) for providing detailed error information during debugging.
450

451
    7. MetaClass Usage: Uses a metaclass (DataclassMeta) for customization of class creation,
452
       allowing checks and alterations to be made at the time of class creation, providing more control.
453

454
    :Example:
455

456
    .. code-block:: python
457

458
        class Parent(Dataclass):
459
            final_field: int = FinalField(1)  # this field cannot be overridden
460
            required_field: str = RequiredField()
461
            also_required_field: float
462
            abstract_field: int = AbstractField()
463

464
        class Child(Parent):
465
            abstract_field = 3  # now once overridden, this is no longer abstract
466
            required_field = Field(name="required_field", default="provided", type=str)
467

468
        class Mixin(Dataclass):
469
            mixin_field = Field(name="mixin_field", default="mixin", type=str)
470

471
        class GrandChild(Child, Mixin):
472
            pass
473

474
        grand_child = GrandChild()
475
        logger.info(grand_child.to_dict())
476

477
        ...
478
    """
479

480
    __allow_unexpected_arguments__ = False
1✔
481

482
    @final
1✔
483
    def __init__(self, *argv, **kwargs):
1✔
484
        """Initialize fields based on kwargs.
485

486
        Checks for abstract fields when an instance is created.
487
        """
488
        super().__init__()
1✔
489
        _init_fields = [field for field in fields(self) if field.init]
1✔
490
        _init_fields_names = [field.name for field in _init_fields]
1✔
491
        _init_positional_fields_names = [
1✔
492
            field.name for field in _init_fields if field.also_positional
493
        ]
494

495
        for name in _init_positional_fields_names[: len(argv)]:
1✔
496
            if name in kwargs:
1✔
497
                raise TypeError(
1✔
498
                    f"{self.__class__.__name__} got multiple values for argument '{name}'"
499
                )
500

501
        expected_unexpected_argv = kwargs.pop("_argv", None)
1✔
502

503
        if len(argv) <= len(_init_positional_fields_names):
1✔
504
            unexpected_argv = []
1✔
505
        else:
506
            unexpected_argv = argv[len(_init_positional_fields_names) :]
1✔
507

508
        if expected_unexpected_argv is not None:
1✔
509
            assert (
1✔
510
                len(unexpected_argv) == 0
511
            ), f"Cannot specify both _argv and unexpected positional arguments. Got {unexpected_argv}"
512
            unexpected_argv = tuple(expected_unexpected_argv)
1✔
513

514
        expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
1✔
515
        unexpected_kwargs = {
1✔
516
            k: v
517
            for k, v in kwargs.items()
518
            if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
519
        }
520

521
        if expected_unexpected_kwargs is not None:
1✔
522
            intersection = set(unexpected_kwargs.keys()) & set(
×
523
                expected_unexpected_kwargs.keys()
524
            )
525
            assert (
×
526
                len(intersection) == 0
527
            ), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
528
            unexpected_kwargs = {**unexpected_kwargs, **expected_unexpected_kwargs}
×
529

530
        if self.__allow_unexpected_arguments__:
1✔
531
            if len(unexpected_argv) > 0:
1✔
532
                kwargs["_argv"] = unexpected_argv
1✔
533
            if len(unexpected_kwargs) > 0:
1✔
534
                kwargs["_kwargs"] = unexpected_kwargs
1✔
535

536
        else:
537
            if len(unexpected_argv) > 0:
1✔
538
                raise UnexpectedArgumentError(
1✔
539
                    f"Too many positional arguments {unexpected_argv} for class {self.__class__.__name__}.\nShould be only {len(_init_positional_fields_names)} positional arguments: {', '.join(_init_positional_fields_names)}"
540
                )
541

542
            if len(unexpected_kwargs) > 0:
1✔
543
                raise UnexpectedArgumentError(
1✔
544
                    f"Unexpected keyword argument(s) {unexpected_kwargs} for class {self.__class__.__name__}.\nShould be one of: {external_fields_names(self)}"
545
                )
546

547
        for name, arg in zip(_init_positional_fields_names, argv):
1✔
548
            kwargs[name] = arg
1✔
549

550
        for field in abstract_fields(self):
1✔
551
            raise AbstractFieldError(
1✔
552
                f"Abstract field '{field.name}' of class {field.origin_cls} not implemented in {self.__class__.__name__}"
553
            )
554

555
        for field in required_fields(self):
1✔
556
            if field.name not in kwargs:
1✔
557
                raise RequiredFieldError(
1✔
558
                    f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
559
                )
560

561
        self.__pre_init__(**kwargs)
1✔
562

563
        for field in fields(self):
1✔
564
            if field.name in kwargs:
1✔
565
                setattr(self, field.name, kwargs[field.name])
1✔
566
            else:
567
                setattr(self, field.name, get_field_default(field))
1✔
568

569
        self.__post_init__()
1✔
570

571
    @property
1✔
572
    def __is_dataclass__(self) -> bool:
1✔
573
        return True
×
574

575
    def __pre_init__(self, **kwargs):
1✔
576
        """Pre initialization hook."""
577
        pass
578

579
    def __post_init__(self):
1✔
580
        """Post initialization hook."""
581
        pass
582

583
    def _to_raw_dict(self):
1✔
584
        """Convert to raw dict."""
585
        return {field.name: getattr(self, field.name) for field in fields(self)}
1✔
586

587
    def to_dict(self, classes: Optional[List] = None, keep_empty: bool = True):
1✔
588
        """Convert to dict.
589

590
        Args:
591
            classes (List, optional): List of parent classes which attributes should
592
                be returned. If set to None, then all class' attributes are returned.
593
            keep_empty (bool): If True, then  parameters are returned regardless if
594
                their values are None or not.
595
        """
596
        if not classes:
1✔
597
            attributes_dict = _asdict_inner(self._to_raw_dict())
1✔
598
        else:
599
            attributes = []
1✔
600
            for cls in classes:
1✔
601
                attributes += list(cls.__annotations__.keys())
1✔
602
            attributes_dict = {
1✔
603
                attribute: getattr(self, attribute) for attribute in attributes
604
            }
605

606
        return {
1✔
607
            attribute: value
608
            for attribute, value in attributes_dict.items()
609
            if keep_empty or value is not None
610
        }
611

612
    def get_repr_dict(self):
1✔
613
        result = {}
1✔
614
        for field in fields(self):
1✔
615
            if not field.internal:
1✔
616
                result[field.name] = getattr(self, field.name)
1✔
617
        return result
1✔
618

619
    def __repr__(self) -> str:
1✔
620
        """String representation."""
621
        return f"{self.__class__.__name__}({', '.join([f'{key}={val!r}' for key, val in self.get_repr_dict().items()])})"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc