• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 15530158291

09 Jun 2025 08:23AM UTC coverage: 80.267% (+0.03%) from 80.242%
15530158291

Pull #1644

github

web-flow
Merge d45934119 into b3a894d7c
Pull Request #1644: Use elaborated cache key and use it for filelock semaphore

1696 of 2089 branches covered (81.19%)

Branch coverage included in aggregate %.

10511 of 13119 relevant lines covered (80.12%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.08
src/unitxt/dataclass.py
1
import copy
1✔
2
import dataclasses
1✔
3
import functools
1✔
4
import inspect
1✔
5
from abc import ABCMeta
1✔
6
from inspect import Parameter, Signature
1✔
7
from typing import Any, Dict, List, Optional, final
1✔
8

9
_FIELDS = "__fields__"
1✔
10

11

12
class Undefined:
1✔
13
    pass
1✔
14

15

16
@dataclasses.dataclass
1✔
17
class Field:
1✔
18
    """An alternative to dataclasses.dataclass decorator for a more flexible field definition.
19

20
    Args:
21
        default (Any, optional):
22
            Default value for the field. Defaults to None.
23
        name (str, optional):
24
            Name of the field. Defaults to None.
25
        type (type, optional):
26
            Type of the field. Defaults to None.
27
        default_factory (Any, optional):
28
            A function that returns the default value. Defaults to None.
29
        final (bool, optional):
30
            A boolean indicating if the field is final (cannot be overridden). Defaults to False.
31
        abstract (bool, optional):
32
            A boolean indicating if the field is abstract (must be implemented by subclasses). Defaults to False.
33
        required (bool, optional):
34
            A boolean indicating if the field is required. Defaults to False.
35
        origin_cls (type, optional):
36
            The original class that defined the field. Defaults to None.
37
    """
38

39
    default: Any = Undefined
1✔
40
    name: str = None
1✔
41
    type: type = None
1✔
42
    init: bool = True
1✔
43
    also_positional: bool = True
1✔
44
    default_factory: Any = None
1✔
45
    final: bool = False
1✔
46
    abstract: bool = False
1✔
47
    required: bool = False
1✔
48
    internal: bool = False
1✔
49
    origin_cls: type = None
1✔
50
    metadata: Dict[str, str] = dataclasses.field(default_factory=dict)
1✔
51

52
    def get_default(self):
1✔
53
        if self.default_factory is not None:
×
54
            return self.default_factory()
×
55
        return self.default
×
56

57

58
@dataclasses.dataclass
1✔
59
class FinalField(Field):
1✔
60
    def __post_init__(self):
1✔
61
        self.final = True
1✔
62

63

64
@dataclasses.dataclass
1✔
65
class RequiredField(Field):
1✔
66
    def __post_init__(self):
1✔
67
        self.required = True
1✔
68

69

70
class MissingDefaultError(TypeError):
1✔
71
    pass
1✔
72

73

74
@dataclasses.dataclass
1✔
75
class OptionalField(Field):
1✔
76
    def __post_init__(self):
1✔
77
        self.required = False
1✔
78
        if self.default is Undefined and self.default_factory is None:
1✔
79
            raise MissingDefaultError(
1✔
80
                "OptionalField must have default or default_factory"
81
            )
82

83

84
@dataclasses.dataclass
1✔
85
class AbstractField(Field):
1✔
86
    def __post_init__(self):
1✔
87
        self.abstract = True
1✔
88

89

90
@dataclasses.dataclass
1✔
91
class NonPositionalField(Field):
1✔
92
    def __post_init__(self):
1✔
93
        self.also_positional = False
1✔
94

95

96
@dataclasses.dataclass
1✔
97
class InternalField(Field):
1✔
98
    def __post_init__(self):
1✔
99
        self.internal = True
1✔
100
        self.init = False
1✔
101
        self.also_positional = False
1✔
102

103

104
class FinalFieldError(TypeError):
1✔
105
    pass
1✔
106

107

108
class RequiredFieldError(TypeError):
1✔
109
    pass
1✔
110

111

112
class AbstractFieldError(TypeError):
1✔
113
    pass
1✔
114

115

116
class TypeMismatchError(TypeError):
1✔
117
    pass
1✔
118

119

120
class UnexpectedArgumentError(TypeError):
1✔
121
    pass
1✔
122

123

124
standard_variables = dir(object)
1✔
125

126

127
def is_class_method(func):
1✔
128
    if inspect.ismethod(func):
1✔
129
        return True
1✔
130
    if inspect.isfunction(func):
1✔
131
        sig = inspect.signature(func)
1✔
132
        params = list(sig.parameters.values())
1✔
133
        if len(params) > 0 and params[0].name in ["self", "cls"]:
1✔
134
            return True
1✔
135
    return False
1✔
136

137

138
def is_possible_field(field_name, field_value):
1✔
139
    """Check if a name-value pair can potentially represent a field.
140

141
    Args:
142
        field_name (str): The name of the field.
143
        field_value: The value of the field.
144

145
    Returns:
146
        bool: True if the name-value pair can represent a field, False otherwise.
147
    """
148
    if field_name in standard_variables:
1✔
149
        return False
1✔
150
    if is_class_method(field_value):
1✔
151
        return False
1✔
152
    return True
1✔
153

154

155
def get_fields(cls, attrs):
1✔
156
    """Get the fields for a class based on its attributes.
157

158
    Args:
159
        cls (type): The class to get the fields for.
160
        attrs (dict): The attributes of the class.
161

162
    Returns:
163
        dict: A dictionary mapping field names to Field instances.
164
    """
165
    fields = {}
1✔
166
    for base in cls.__bases__:
1✔
167
        fields = {**getattr(base, _FIELDS, {}), **fields}
1✔
168
    annotations = {**attrs.get("__annotations__", {})}
1✔
169

170
    for attr_name, attr_value in attrs.items():
1✔
171
        if attr_name not in annotations and is_possible_field(attr_name, attr_value):
1✔
172
            if attr_name in fields:
1✔
173
                try:
1✔
174
                    if not isinstance(attr_value, fields[attr_name].type):
1✔
175
                        raise TypeMismatchError(
1✔
176
                            f"Type mismatch for field '{attr_name}' of class '{fields[attr_name].origin_cls}'. Expected {fields[attr_name].type}, got {type(attr_value)}"
177
                        )
178
                except TypeError:
1✔
179
                    pass
1✔
180
                annotations[attr_name] = fields[attr_name].type
1✔
181

182
    for field_name, field_type in annotations.items():
1✔
183
        if field_name in fields and fields[field_name].final:
1✔
184
            raise FinalFieldError(
1✔
185
                f"Final field {field_name} defined in {fields[field_name].origin_cls} overridden in {cls}"
186
            )
187

188
        args = {
1✔
189
            "name": field_name,
190
            "type": field_type,
191
            "origin_cls": attrs["__qualname__"],
192
        }
193

194
        if field_name in attrs:
1✔
195
            field_value = attrs[field_name]
1✔
196
            if isinstance(field_value, Field):
1✔
197
                args = {**dataclasses.asdict(field_value), **args}
1✔
198
            elif isinstance(field_value, dataclasses.Field):
1✔
199
                args = {
1✔
200
                    "default": field_value.default,
201
                    "name": field_value.name,
202
                    "type": field_value.type,
203
                    "init": field_value.init,
204
                    "default_factory": field_value.default_factory,
205
                    **args,
206
                }
207
            else:
208
                args["default"] = field_value
1✔
209
                args["default_factory"] = None
1✔
210
        else:
211
            args["default"] = dataclasses.MISSING
1✔
212
            args["default_factory"] = None
1✔
213
            args["required"] = True
1✔
214

215
        field_instance = Field(**args)
1✔
216
        fields[field_name] = field_instance
1✔
217

218
        if cls.__allow_unexpected_arguments__:
1✔
219
            fields["_argv"] = InternalField(name="_argv", type=tuple, default=())
1✔
220
            fields["_kwargs"] = InternalField(name="_kwargs", type=dict, default={})
1✔
221

222
    return fields
1✔
223

224

225
def is_dataclass(obj):
1✔
226
    """Returns True if obj is a dataclass or an instance of a dataclass."""
227
    cls = obj if isinstance(obj, type) else type(obj)
1✔
228
    return hasattr(cls, _FIELDS)
1✔
229

230

231
def class_fields(obj):
1✔
232
    all_fields = fields(obj)
1✔
233
    return [
1✔
234
        field for field in all_fields if field.origin_cls == obj.__class__.__qualname__
235
    ]
236

237

238
def fields(cls):
1✔
239
    return list(getattr(cls, _FIELDS).values())
1✔
240

241

242
def fields_names(cls):
1✔
243
    return list(getattr(cls, _FIELDS).keys())
1✔
244

245

246
def external_fields_names(cls):
1✔
247
    return [field.name for field in fields(cls) if not field.internal]
1✔
248

249

250
def final_fields(cls):
1✔
251
    return [field for field in fields(cls) if field.final]
×
252

253

254
def required_fields(cls):
1✔
255
    return [field for field in fields(cls) if field.required]
1✔
256

257

258
def abstract_fields(cls):
1✔
259
    return [field for field in fields(cls) if field.abstract]
1✔
260

261

262
def is_abstract_field(field):
1✔
263
    return field.abstract
1✔
264

265

266
def is_final_field(field):
1✔
267
    return field.final
1✔
268

269

270
def get_field_default(field):
1✔
271
    if field.default_factory is not None:
1✔
272
        return field.default_factory()
1✔
273

274
    return field.default
1✔
275

276

277
def asdict(obj):
1✔
278
    assert is_dataclass(
×
279
        obj
280
    ), f"{obj} must be a dataclass, got {type(obj)} with bases {obj.__class__.__bases__}"
281
    return _asdict_inner(obj)
×
282

283

284
def _asdict_inner(obj):
1✔
285
    if is_dataclass(obj):
1✔
286
        return obj.to_dict()
1✔
287

288
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
1✔
289
        return type(obj)(*[_asdict_inner(v) for v in obj])
×
290

291
    if isinstance(obj, (list, tuple)):
1✔
292
        return type(obj)([_asdict_inner(v) for v in obj])
1✔
293

294
    if isinstance(obj, dict):
1✔
295
        return type(obj)({_asdict_inner(k): _asdict_inner(v) for k, v in obj.items()})
1✔
296

297
    return copy.deepcopy(obj)
1✔
298

299
def to_dict(obj, func=copy.deepcopy, _visited=None):
1✔
300
    """Recursively converts an object into a dictionary representation while avoiding infinite recursion due to circular references.
301

302
    Args:
303
        obj: Any Python object to be converted into a dictionary-like structure.
304
        func (Callable, optional): A function applied to non-iterable objects. Defaults to `copy.deepcopy`.
305
        _visited (set, optional): A set of object IDs used to track visited objects and prevent infinite recursion.
306

307
    Returns:
308
        dict: A dictionary representation of the input object, with supported collections and dataclasses
309
        recursively processed.
310

311
    Notes:
312
        - Supports dataclasses, named tuples, lists, tuples, and dictionaries.
313
        - Circular references are detected using object IDs and replaced by `func(obj)`.
314
        - Named tuples retain their original type instead of being converted to dictionaries.
315
    """
316
    # Initialize visited set on first call
317
    if _visited is None:
1✔
318
        _visited = set()
1✔
319

320
    # Get object ID to track visited objects
321
    obj_id = id(obj)
1✔
322

323
    # If we've seen this object before, return a placeholder to avoid infinite recursion
324
    if obj_id in _visited:
1✔
325
        return func(obj)
1✔
326

327
    # For mutable objects, add to visited set before recursing
328
    if isinstance(obj, (dict, list)) or is_dataclass(obj) or (isinstance(obj, tuple) and hasattr(obj, "_fields")):
1✔
329
        _visited.add(obj_id)
1✔
330

331
    if is_dataclass(obj):
1✔
332
        return {field.name: to_dict(getattr(obj, field.name), func, _visited) for field in fields(obj)}
1✔
333

334
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
1✔
335
        return type(obj)(*[to_dict(v, func, _visited) for v in obj])
×
336

337
    if isinstance(obj, (list, tuple)):
1✔
338
        return type(obj)([to_dict(v, func, _visited) for v in obj])
1✔
339

340
    if isinstance(obj, dict):
1✔
341
        return type(obj)({to_dict(k, func, _visited): to_dict(v, func, _visited) for k, v in obj.items()})
1✔
342

343
    return func(obj)
1✔
344

345
class DataclassMeta(ABCMeta):
1✔
346
    """Metaclass for Dataclass.
347

348
    Checks for final fields when a subclass is created.
349
    """
350

351
    @final
1✔
352
    def __init__(cls, name, bases, attrs):
1✔
353
        super().__init__(name, bases, attrs)
1✔
354
        fields = get_fields(cls, attrs)
1✔
355
        setattr(cls, _FIELDS, fields)
1✔
356
        cls.update_init_signature()
1✔
357

358
    def update_init_signature(cls):
1✔
359
        parameters = []
1✔
360

361
        for name, field in getattr(cls, _FIELDS).items():
1✔
362
            if field.init and not field.internal:
1✔
363
                if field.default is not Undefined:
1✔
364
                    default_value = field.default
1✔
365
                elif field.default_factory is not None:
1✔
366
                    default_value = field.default_factory()
1✔
367
                else:
368
                    default_value = Parameter.empty
1✔
369

370
                if isinstance(default_value, dataclasses._MISSING_TYPE):
1✔
371
                    default_value = Parameter.empty
1✔
372
                param = Parameter(
1✔
373
                    name,
374
                    Parameter.POSITIONAL_OR_KEYWORD,
375
                    default=default_value,
376
                    annotation=field.type,
377
                )
378
                parameters.append(param)
1✔
379

380
        if getattr(cls, "__allow_unexpected_arguments__", False):
1✔
381
            parameters.append(Parameter("_argv", Parameter.VAR_POSITIONAL))
1✔
382
            parameters.append(Parameter("_kwargs", Parameter.VAR_KEYWORD))
1✔
383

384
        signature = Signature(parameters, __validate_parameters__=False)
1✔
385

386
        original_init = cls.__init__
1✔
387

388
        @functools.wraps(original_init)
1✔
389
        def custom_cls_init(self, *args, **kwargs):
1✔
390
            original_init(self, *args, **kwargs)
1✔
391

392
        custom_cls_init.__signature__ = signature
1✔
393
        cls.__init__ = custom_cls_init
1✔
394

395

396
class Dataclass(metaclass=DataclassMeta):
1✔
397
    """Base class for data-like classes that provides additional functionality and control.
398

399
    Base class for data-like classes that provides additional functionality and control
400
    over Python's built-in @dataclasses.dataclass decorator. Other classes can inherit from
401
    this class to get the benefits of this implementation. As a base class, it ensures that
402
    all subclasses will automatically be data classes.
403

404
    The usage and field definitions are similar to Python's built-in @dataclasses.dataclass decorator.
405
    However, this implementation provides additional classes for defining "final", "required",
406
    and "abstract" fields.
407

408
    Key enhancements of this custom implementation:
409

410
    1. Automatic Data Class Creation: All subclasses automatically become data classes,
411
       without needing to use the @dataclasses.dataclass decorator.
412

413
    2. Field Immutability: Supports creation of "final" fields (using FinalField class) that
414
       cannot be overridden by subclasses. This functionality is not natively supported in
415
       Python or in the built-in dataclasses module.
416

417
    3. Required Fields: Supports creation of "required" fields (using RequiredField class) that
418
       must be provided when creating an instance of the class, adding a level of validation
419
       not present in the built-in dataclasses module.
420

421
    4. Abstract Fields: Supports creation of "abstract" fields (using AbstractField class) that
422
       must be overridden by any non-abstract subclass. This is similar to abstract methods in
423
       an abc.ABC class, but applied to fields.
424

425
    5. Type Checking: Performs type checking to ensure that if a field is redefined in a subclass,
426
       the type of the field remains consistent, adding static type checking not natively supported
427
       in Python.
428

429
    6. Error Definitions: Defines specific error types (FinalFieldError, RequiredFieldError,
430
       AbstractFieldError, TypeMismatchError) for providing detailed error information during debugging.
431

432
    7. MetaClass Usage: Uses a metaclass (DataclassMeta) for customization of class creation,
433
       allowing checks and alterations to be made at the time of class creation, providing more control.
434

435
    :Example:
436

437
    .. code-block:: python
438

439
        class Parent(Dataclass):
440
            final_field: int = FinalField(1)  # this field cannot be overridden
441
            required_field: str = RequiredField()
442
            also_required_field: float
443
            abstract_field: int = AbstractField()
444

445
        class Child(Parent):
446
            abstract_field = 3  # now once overridden, this is no longer abstract
447
            required_field = Field(name="required_field", default="provided", type=str)
448

449
        class Mixin(Dataclass):
450
            mixin_field = Field(name="mixin_field", default="mixin", type=str)
451

452
        class GrandChild(Child, Mixin):
453
            pass
454

455
        grand_child = GrandChild()
456
        logger.info(grand_child.to_dict())
457

458
        ...
459
    """
460

461
    __allow_unexpected_arguments__ = False
1✔
462

463
    @final
1✔
464
    def __init__(self, *argv, **kwargs):
1✔
465
        """Initialize fields based on kwargs.
466

467
        Checks for abstract fields when an instance is created.
468
        """
469
        super().__init__()
1✔
470
        _init_fields = [field for field in fields(self) if field.init]
1✔
471
        _init_fields_names = [field.name for field in _init_fields]
1✔
472
        _init_positional_fields_names = [
1✔
473
            field.name for field in _init_fields if field.also_positional
474
        ]
475

476
        for name in _init_positional_fields_names[: len(argv)]:
1✔
477
            if name in kwargs:
1✔
478
                raise TypeError(
1✔
479
                    f"{self.__class__.__name__} got multiple values for argument '{name}'"
480
                )
481

482
        expected_unexpected_argv = kwargs.pop("_argv", None)
1✔
483

484
        if len(argv) <= len(_init_positional_fields_names):
1✔
485
            unexpected_argv = []
1✔
486
        else:
487
            unexpected_argv = argv[len(_init_positional_fields_names) :]
1✔
488

489
        if expected_unexpected_argv is not None:
1✔
490
            assert (
1✔
491
                len(unexpected_argv) == 0
492
            ), f"Cannot specify both _argv and unexpected positional arguments. Got {unexpected_argv}"
493
            unexpected_argv = tuple(expected_unexpected_argv)
1✔
494

495
        expected_unexpected_kwargs = kwargs.pop("_kwargs", None)
1✔
496
        unexpected_kwargs = {
1✔
497
            k: v
498
            for k, v in kwargs.items()
499
            if k not in _init_fields_names and k not in ["_argv", "_kwargs"]
500
        }
501

502
        if expected_unexpected_kwargs is not None:
1✔
503
            intersection = set(unexpected_kwargs.keys()) & set(
×
504
                expected_unexpected_kwargs.keys()
505
            )
506
            assert (
×
507
                len(intersection) == 0
508
            ), f"Cannot specify the same arguments in both _kwargs and in unexpected keyword arguments. Got {intersection} in both."
509
            unexpected_kwargs = {**unexpected_kwargs, **expected_unexpected_kwargs}
×
510

511
        if self.__allow_unexpected_arguments__:
1✔
512
            if len(unexpected_argv) > 0:
1✔
513
                kwargs["_argv"] = unexpected_argv
1✔
514
            if len(unexpected_kwargs) > 0:
1✔
515
                kwargs["_kwargs"] = unexpected_kwargs
1✔
516

517
        else:
518
            if len(unexpected_argv) > 0:
1✔
519
                raise UnexpectedArgumentError(
1✔
520
                    f"Too many positional arguments {unexpected_argv} for class {self.__class__.__name__}.\nShould be only {len(_init_positional_fields_names)} positional arguments: {', '.join(_init_positional_fields_names)}"
521
                )
522

523
            if len(unexpected_kwargs) > 0:
1✔
524
                raise UnexpectedArgumentError(
1✔
525
                    f"Unexpected keyword argument(s) {unexpected_kwargs} for class {self.__class__.__name__}.\nShould be one of: {external_fields_names(self)}"
526
                )
527

528
        for name, arg in zip(_init_positional_fields_names, argv):
1✔
529
            kwargs[name] = arg
1✔
530

531
        for field in abstract_fields(self):
1✔
532
            raise AbstractFieldError(
1✔
533
                f"Abstract field '{field.name}' of class {field.origin_cls} not implemented in {self.__class__.__name__}"
534
            )
535

536
        for field in required_fields(self):
1✔
537
            if field.name not in kwargs:
1✔
538
                raise RequiredFieldError(
1✔
539
                    f"Required field '{field.name}' of class {field.origin_cls} not set in {self.__class__.__name__}"
540
                )
541

542
        self.__pre_init__(**kwargs)
1✔
543

544
        for field in fields(self):
1✔
545
            if field.name in kwargs:
1✔
546
                setattr(self, field.name, kwargs[field.name])
1✔
547
            else:
548
                setattr(self, field.name, get_field_default(field))
1✔
549

550
        self.__post_init__()
1✔
551

552
    @property
1✔
553
    def __is_dataclass__(self) -> bool:
1✔
554
        return True
×
555

556
    def __pre_init__(self, **kwargs):
1✔
557
        """Pre initialization hook."""
558
        pass
1✔
559

560
    def __post_init__(self):
1✔
561
        """Post initialization hook."""
562
        pass
1✔
563

564
    def _to_raw_dict(self):
1✔
565
        """Convert to raw dict."""
566
        return {field.name: getattr(self, field.name) for field in fields(self)}
1✔
567

568
    def to_dict(self, classes: Optional[List] = None, keep_empty: bool = True):
1✔
569
        """Convert to dict.
570

571
        Args:
572
            classes (List, optional): List of parent classes which attributes should
573
                be returned. If set to None, then all class' attributes are returned.
574
            keep_empty (bool): If True, then  parameters are returned regardless if
575
                their values are None or not.
576
        """
577
        if not classes:
1✔
578
            attributes_dict = _asdict_inner(self._to_raw_dict())
1✔
579
        else:
580
            attributes = []
1✔
581
            for cls in classes:
1✔
582
                attributes += list(cls.__annotations__.keys())
1✔
583
            attributes_dict = {
1✔
584
                attribute: getattr(self, attribute) for attribute in attributes
585
            }
586

587
        return {
1✔
588
            attribute: value
589
            for attribute, value in attributes_dict.items()
590
            if keep_empty or value is not None
591
        }
592

593
    def get_repr_dict(self):
1✔
594
        result = {}
1✔
595
        for field in fields(self):
1✔
596
            if not field.internal:
1✔
597
                result[field.name] = getattr(self, field.name)
1✔
598
        return result
1✔
599

600
    def __repr__(self) -> str:
1✔
601
        """String representation."""
602
        return f"{self.__class__.__name__}({', '.join([f'{key}={val!r}' for key, val in self.get_repr_dict().items()])})"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc