• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 12886938037

21 Jan 2025 12:28PM UTC coverage: 79.576% (-0.008%) from 79.584%
12886938037

Pull #1535

github

web-flow
Merge 600b1ab85 into bcc7b6afe
Pull Request #1535: Update version to 1.17.0

1395 of 1742 branches covered (80.08%)

Branch coverage included in aggregate %.

8825 of 11101 relevant lines covered (79.5%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.17
src/unitxt/type_utils.py
1
import ast
1✔
2
import collections.abc
1✔
3
import io
1✔
4
import itertools
1✔
5
import re
1✔
6
import typing
1✔
7
from functools import lru_cache
1✔
8
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
1✔
9

10
from .utils import safe_eval
1✔
11

12
_registered_types = {
1✔
13
    "Any": typing.Any,
14
    "List": typing.List,
15
    "Dict": typing.Dict,
16
    "Tuple": typing.Tuple,
17
    "Union": typing.Union,
18
    "Optional": typing.Optional,
19
    "Literal": typing.Literal,
20
    "int": int,
21
    "str": str,
22
    "float": float,
23
    "bool": bool,
24
}
25

26

27
def register_type(new_type):
1✔
28
    assert is_new_type(new_type) or is_typed_dict(
1✔
29
        new_type
30
    ), "Can register only typing.NewType or typing.TypedDict"
31
    _registered_types[new_type.__name__] = new_type
1✔
32

33

34
Type = typing.Any
1✔
35

36

37
class UnsupportedTypeError(ValueError):
1✔
38
    def __init__(self, type_object):
1✔
39
        supported_types = ", ".join(_registered_types.keys())
1✔
40
        super().__init__(
1✔
41
            f"Type: '{type_object!s}' is not supported type. Use one of {supported_types}"
42
        )
43

44

45
class GenericTypedDict(TypedDict):
1✔
46
    pass
1✔
47

48

49
_generics = [
1✔
50
    List[Any],
51
    Dict[Any, Any],
52
    Tuple[Any],
53
    Union[Any, Any],
54
    Optional[Any],
55
    Any,
56
    Literal,
57
]
58

59
_generics_types = [type(t) for t in _generics]
1✔
60

61

62
def is_new_type(object):
1✔
63
    return callable(object) and hasattr(object, "__supertype__")
1✔
64

65

66
def is_typed_dict(object):
1✔
67
    return isinstance(object, type(GenericTypedDict))
1✔
68

69

70
def is_type(object):
1✔
71
    """Checks if the provided object is a type, including generics, Literal, TypedDict, and NewType."""
72
    return (
1✔
73
        isinstance(object, (type, *_generics_types))
74
        or is_new_type(object)
75
        or is_typed_dict(object)
76
    )
77

78

79
def is_type_dict(object):
1✔
80
    if not isinstance(object, dict):
1✔
81
        raise ValueError("Should be dict.")
×
82
    for value in object.values():
1✔
83
        if isinstance(value, dict):
1✔
84
            if not is_type_dict(value):
×
85
                return False
×
86
        elif not is_type(value):
1✔
87
            return False
×
88
    return True
1✔
89

90

91
def convert_union_type(type_string: str) -> str:
1✔
92
    """Converts Python 3.10 union type hints into form compatible with Python 3.9 version.
93

94
    Args:
95
        type_string (str): A string representation of a Python type hint. It can be any
96
            valid Python type, which does not contain strings (e.g. 'Literal').
97
            Examples include 'List[int|float]', 'str|float|bool' etc.
98

99
            Formally, the function depends on the input string adhering to the following rules.
100
            Assuming that the input is a valid type hint the function does not check that 'word' is
101
            'str', 'bool', 'List' etc. It just depends on the following general structure (spaces ignored):
102
            type -> word OR type( | type)* OR word[type( , type)*]
103
            word is a sequence of (0 or more) chars, each being any char but: [ ] , |
104
            This implies that if any of these 4 chars shows not as a meta char of the input
105
            type_string, but inside some constant string (of Literal, for example), the scheme
106
            will not work.
107

108
            Cases like Literal, that might contain occurrences of the four chars above not as meta chars
109
            in the type string, must be handled as special cases by this function, as shown for Literal,
110
            as an example. Because 'format_type_string' serves as preprocessing for 'parse_type_string',
111
            which has a list of allowed types, of which Literal is not a member, Literal and such are not
112
            relevant at all now; and the case is brought here just for an example for future use.
113

114

115
    Returns:
116
        str: A type string with converted union types, which is compatible with typing module.
117

118
    Examples:
119
        convert_union_type('List[int|float]') -> 'List[Union[int,float]]'
120
        convert_union_type('Optional[int|float|bool]') -> 'Optional[Union[int,float,bool]]'
121

122
    """
123

124
    def consume_literal(string: str) -> str:
1✔
125
        # identifies the prefix of string that matches a full Literal typing, with all its constants, including
126
        # constants that contain [ ] , etc. on which construct_union_part depends.
127
        # string starts with the [ that follows 'Literal'
128
        candidate_end = string.find("]")
1✔
129
        while candidate_end != -1:
1✔
130
            try:
1✔
131
                ast.literal_eval(string[: candidate_end + 1])
1✔
132
                break
1✔
133
            except Exception:
1✔
134
                candidate_end = string.find("]", candidate_end + 1)
1✔
135

136
        if candidate_end == -1:
1✔
137
            raise ValueError("invalid Literal in input type_string")
×
138
        return string[: candidate_end + 1]
1✔
139

140
    stack = [""]  # the start of a type
1✔
141
    input = type_string.strip()
1✔
142
    next_word = re.compile(r"([^\[\],|]*)([\[\],|]|$)")
1✔
143
    while len(input) > 0:
1✔
144
        word = next_word.match(input)
1✔
145
        input = input[len(word.group(0)) :].strip()
1✔
146
        stack[-1] += word.group(1)
1✔
147
        if word.group(2) in ["]", ",", ""]:  # "" for eol:$
1✔
148
            # top of stack is now complete to a whole type
149
            lwt = stack.pop()
1✔
150
            if (
1✔
151
                "|" in lwt
152
            ):  # the | -s are only at the top level of lwt, not inside any subtype
153
                lwt = "Union[" + lwt.replace("|", ",") + "]"
1✔
154
            lwt += word.group(2)
1✔
155
            if len(stack) > 0:
1✔
156
                stack[-1] += lwt
1✔
157
            else:
158
                stack = [lwt]
1✔
159
            if word.group(2) == ",":
1✔
160
                stack.append("")  # to start the expected next type
1✔
161

162
        elif word.group(2) in ["|"]:
1✔
163
            # top of stack is the last whole element(s) to be union-ed,
164
            # and more are expected
165
            stack[-1] += "|"
1✔
166

167
        else:  # "["
168
            if word.group(1) == "Literal":
1✔
169
                literal_ops = consume_literal("[" + input)
1✔
170
                stack[-1] += literal_ops
1✔
171
                input = input[len(literal_ops) - 1 :]
1✔
172
            else:
173
                stack[-1] += "["
1✔
174
                stack.append("")
1✔
175
                # start type (,type)*  inside the []
176

177
    assert len(stack) == 1
1✔
178
    if "|" in stack[0]:  # these belong to the top level only
1✔
179
        stack[0] = "Union[" + stack[0].replace("|", ",") + "]"
1✔
180
    return stack[0]
1✔
181

182

183
def format_type_string(type_string: str) -> str:
1✔
184
    """Formats a string representing a valid Python type hint so that it is compatible with Python 3.9 notation.
185

186
    Args:
187
        type_string (str): A string representation of a Python type hint. This can be any
188
                           valid type, which does not contain strings (e.g. 'Literal').
189
                           Examples include 'List[int]', 'Dict[str, Any]', 'Optional[List[str]]', etc.
190

191
    Returns:
192
        str: A formatted type string.
193

194
    Examples:
195
        format_type_string('list[int | float]') -> 'List[Union[int,float]]'
196
        format_type_string('dict[str, Optional[str]]') -> 'Dict[str,Optional[str]]'
197

198
    The function formats valid type string (either after or before Python 3.10) into a
199
    form compatible with 3.9. This is done by captilizing the first letter of a lower-cased
200
    type name and transferring the 'bitwise or operator' into 'Union' notation. The function
201
    also removes whitespaces and redundant module name in type names imported from 'typing'
202
    module, e.g. 'typing.Tuple' -> 'Tuple'.
203

204
    Currently, the capitalization is applied only to types which unitxt allows, i.e.
205
    'list', 'dict', 'tuple'. Moreover, the function expects the input to not contain types
206
    which contain strings, for example 'Literal'.
207
    """
208
    types_map = {
1✔
209
        "list": "List",
210
        "tuple": "Tuple",
211
        "dict": "Dict",
212
        "typing.": "",
213
        " ": "",
214
    }
215
    for old_type, new_type in types_map.items():
1✔
216
        type_string = type_string.replace(old_type, new_type)
1✔
217
    return convert_union_type(type_string)
1✔
218

219

220
def parse_type_string(type_string: str) -> typing.Any:
1✔
221
    """Parses a string representing a Python type hint and evaluates it to return the corresponding type object.
222

223
    This function uses a safe evaluation context
224
    to mitigate the risks of executing arbitrary code.
225

226
    Args:
227
        type_string (str): A string representation of a Python type hint. Examples include
228
                           'List[int]', 'Dict[str, Any]', 'Optional[List[str]]', etc.
229

230
    Returns:
231
        typing.Any: The Python type object corresponding to the given type string.
232

233
    Raises:
234
        ValueError: If the type string contains elements not allowed in the safe context
235
                    or tokens list.
236

237
    The function formats the string first if it represents a new Python type hint
238
    (i.e. valid since Python 3.10), which uses lowercased names for some types and
239
    'bitwise or operator' instead of 'Union', for example: 'list[int|float]' instead
240
    of 'List[Union[int,float]]' etc.
241

242
    The function uses a predefined safe context with common types from the `typing` module
243
    and basic Python data types. It also defines a list of safe tokens that are allowed
244
    in the type string.
245
    """
246
    type_string = format_type_string(type_string)
1✔
247

248
    return safe_eval(
1✔
249
        type_string, context=_registered_types, allowed_tokens=["[", "]", ",", " "]
250
    )
251

252

253
def replace_class_names(full_string: str) -> str:
1✔
254
    # Regular expression to match any fully qualified class name and extract the class name
255
    pattern = r"(?:\w+\.)*<locals>\.(\w+)|(?:\w+\.)*(\w+)"
1✔
256

257
    # Function to replace the matched pattern with just the class name
258
    def replacement(match):
1✔
259
        # If the match has a group for <locals>
260
        if match.group(1):
1✔
261
            return match.group(1)
1✔
262
        # Otherwise, return the last group (class name)
263
        return match.group(2)
1✔
264

265
    # Use re.sub to replace all occurrences in the string
266
    return re.sub(pattern, replacement, full_string)
1✔
267

268

269
def to_type_string(typing_type):
1✔
270
    type_string = strtype(typing_type)
1✔
271
    assert parse_type_string(type_string), "Is not parsed well"
1✔
272
    return type_string
1✔
273

274

275
def to_type_dict(dict_of_typing_types):
1✔
276
    result = {}
1✔
277
    for key, val in dict_of_typing_types.items():
1✔
278
        if isinstance(val, dict):
1✔
279
            result[key] = to_type_dict(val)
×
280
        else:
281
            result[key] = to_type_string(val)
1✔
282
    return result
1✔
283

284

285
def parse_type_dict(type_dict):
1✔
286
    results = {}
1✔
287
    for k, v in type_dict.items():
1✔
288
        if isinstance(v, str):
1✔
289
            results[k] = parse_type_string(v)
1✔
290
        elif isinstance(v, dict):
×
291
            results[k] = parse_type_dict(v)
×
292
        else:
293
            raise ValueError(
×
294
                f"Can parse only nested dictionary with type strings, got {type(v)}"
295
            )
296
    return results
1✔
297

298

299
def infer_type(obj) -> typing.Any:
1✔
300
    return parse_type_string(infer_type_string(obj))
1✔
301

302

303
def infer_type_string(obj: typing.Any) -> str:
1✔
304
    """Encodes the type of a given object into a string.
305

306
    Args:
307
        obj:Any
308

309
    Returns:
310
      a string representation of the type of the object. e.g. ``"str"``, ``"List[int]"``, ``"Dict[str, Any]"``
311

312
    | formal definition of the returned string:
313
    | Type -> basic | List[Type] | Dict[Type, Type] | Union[Type(, Type)*] | Tuple[Type(, Type)*]
314
    | basic -> ``bool`` | ``str`` | ``int`` | ``float`` | ``Any``
315

316

317
    Examples:
318
        | ``infer_type_string({"how_much": 7})`` returns ``"Dict[str,int]"``
319
        | ``infer_type_string([1, 2])`` returns ``"List[int]"``
320
        | ``infer_type_string([])`` returns ``"List[Any]")``    no contents to list to indicate any type
321
        | ``infer_type_string([[], [7]])`` returns ``"List[List[int]]"``  type of parent list indicated
322
          by the type of the non-empty child list. The empty child list is indeed, by default, also of
323
          that type of the non-empty child.
324
        | ``infer_type_string([[], 7, True])`` returns ``"List[Union[List[Any],int]]"``
325
          because ``bool`` is also an ``int``
326

327
    """
328

329
    def consume_arg(args_list: str) -> typing.Tuple[str, str]:
1✔
330
        first_word = re.search(r"^(List\[|Dict\[|Union\[|Tuple\[)", args_list)
1✔
331
        if not first_word:
1✔
332
            first_word = re.search(r"^(str|bool|int|float|Any)", args_list)
1✔
333
            assert first_word, "parsing error"
1✔
334
            return first_word.group(), args_list[first_word.span()[1] :]
1✔
335
        arg_to_ret = first_word.group()
1✔
336
        args_list = args_list[first_word.span()[1] :]
1✔
337
        arg, args_list = consume_arg(args_list)
1✔
338
        arg_to_ret += arg
1✔
339
        while args_list.startswith(","):
1✔
340
            arg, args_list = consume_arg(args_list[1:])
1✔
341
            arg_to_ret = arg_to_ret + "," + arg
1✔
342
        assert args_list.startswith("]"), "parsing error"
1✔
343
        return arg_to_ret + "]", args_list[1:]
1✔
344

345
    def find_args_in(args: str) -> typing.List[str]:
1✔
346
        to_ret = []
1✔
347
        while len(args) > 0:
1✔
348
            arg, args = consume_arg(args)
1✔
349
            to_ret.append(arg)
1✔
350
            if args.startswith(","):
1✔
351
                args = args[1:]
1✔
352
        return to_ret
1✔
353

354
    def is_covered_by(left: str, right: str) -> bool:
1✔
355
        if left == right:
1✔
356
            return True
1✔
357
        if left.startswith("Union["):
1✔
358
            return all(
1✔
359
                is_covered_by(left_el, right) for left_el in find_args_in(left[6:-1])
360
            )
361
        if right.startswith("Union["):
1✔
362
            return any(
1✔
363
                is_covered_by(left, right_el) for right_el in find_args_in(right[6:-1])
364
            )
365
        if left.startswith("List[") and right.startswith("List["):
1✔
366
            return is_covered_by(
1✔
367
                left[5:-1], right[5:-1]
368
            )  # un-wrap the leading List[  and the trailing ]
369
        if left.startswith("Dict[") and right.startswith("Dict["):
1✔
370
            return is_covered_by(
1✔
371
                left[5 : left.find(",")], right[5 : right.find(",")]
372
            ) and is_covered_by(
373
                left[1 + left.find(",") : -1], right[1 + right.find(",") : -1]
374
            )
375
        if left.startswith("Tuple[") and right.startswith("Tuple["):
1✔
376
            if left.count(",") != right.count(","):
1✔
377
                return False
1✔
378
            return all(
1✔
379
                is_covered_by(left_el, right_el)
380
                for (left_el, right_el) in zip(
381
                    left[6:-1].split(","), right[6:-1].split(",")
382
                )
383
            )
384
        if left == "bool" and right == "int":
1✔
385
            return True
1✔
386
        if left == "Any":
1✔
387
            return True
1✔
388

389
        return False
1✔
390

391
    def merge_into(left: str, right: typing.List[str]):
1✔
392
        # merge the set of types from left into the set of types from right, yielding a set that
393
        # covers both. None of the input sets contain Union as main element. Union may reside inside
394
        # List, or Dict, or Tuple.
395
        # This is needed when building a parent List, e.g. from its elements, and the
396
        # type of that list needs to be the union of the types of its elements.
397
        # if all elements have same type -- this is the type to write in List[type]
398
        # if not -- we write List[Union[type1, type2,...]].
399

400
        for right_el in right:
1✔
401
            if is_covered_by(right_el, left):
1✔
402
                right.remove(right_el)
1✔
403
                right.append(left)
1✔
404
                return
1✔
405
        if not any(is_covered_by(left, right_el) for right_el in right):
1✔
406
            right.append(left)
1✔
407

408
    def encode_a_list_of_type_names(list_of_type_names: typing.List[str]) -> str:
1✔
409
        # The type_names in the input are the set of names of all the elements of one list object,
410
        # or all the keys of one dict object, or all the val thereof, or all the type names of a specific position
411
        # in a tuple object The result should be a name of a type that covers them all.
412
        # So if, for example, the input contains both 'bool' and 'int', then 'int' suffices to cover both.
413
        # 'Any' can not show as a type_name of a basic (sub)object, but 'List[Any]' can show for an element of
414
        # a list object, an element that is an empty list. In such a case, if there are other elements in the input
415
        # that are more specific, e.g. 'List[str]' we should take the latter, and discard 'List[Any]' in order to get
416
        # a meaningful result: as narrow as possible but covers all.
417
        #
418
        to_ret = []
1✔
419
        for type_name in list_of_type_names:
1✔
420
            merge_into(type_name, to_ret)
1✔
421

422
        if len(to_ret) == 1:
1✔
423
            return to_ret[0]
1✔
424
        to_ret.sort()
1✔
425
        ans = "Union["
1✔
426
        for typ in to_ret[:-1]:
1✔
427
            ans += typ + ","
1✔
428
        return ans + to_ret[-1] + "]"
1✔
429

430
    basic_types = [bool, int, str, float]
1✔
431
    names_of_basic_types = ["bool", "int", "str", "float"]
1✔
432
    # bool should show before int, because bool is subtype of int
433

434
    for basic_type, name_of_basic_type in zip(basic_types, names_of_basic_types):
1✔
435
        if isinstance(obj, basic_type):
1✔
436
            return name_of_basic_type
1✔
437
    if isinstance(obj, list):
1✔
438
        included_types = set()
1✔
439
        for list_el in obj:
1✔
440
            included_types.add(infer_type_string(list_el))
1✔
441
        included_types = list(included_types)
1✔
442
        if len(included_types) == 0:
1✔
443
            return "List[Any]"
1✔
444
        return "List[" + encode_a_list_of_type_names(included_types) + "]"
1✔
445
    if isinstance(obj, dict):
1✔
446
        if len(obj) == 0:
1✔
447
            return "Dict[Any,Any]"
1✔
448
        included_key_types = set()
1✔
449
        included_val_types = set()
1✔
450
        for k, v in obj.items():
1✔
451
            included_key_types.add(infer_type_string(k))
1✔
452
            included_val_types.add(infer_type_string(v))
1✔
453
        included_key_types = list(included_key_types)
1✔
454
        included_val_types = list(included_val_types)
1✔
455
        return (
1✔
456
            "Dict["
457
            + encode_a_list_of_type_names(included_key_types)
458
            + ","
459
            + encode_a_list_of_type_names(included_val_types)
460
            + "]"
461
        )
462
    if isinstance(obj, tuple):
1✔
463
        if len(obj) == 0:
1✔
464
            return "Tuple[Any]"
1✔
465
        to_ret = "Tuple["
1✔
466
        for sub_tup in obj[:-1]:
1✔
467
            to_ret += infer_type_string(sub_tup) + ","
1✔
468
        return to_ret + infer_type_string(obj[-1]) + "]"
1✔
469

470
    return "Any"
1✔
471

472

473
def isoftype(object, typing_type):
1✔
474
    """Checks if an object is of a certain typing type, including nested types.
475

476
    This function supports simple types, typing types (List[int], Tuple[str, int]),
477
    nested typing types (List[List[int]], Tuple[List[str], int]), Literal, TypedDict,
478
    and NewType.
479

480
    Args:
481
        object: The object to check.
482
        typing_type: The typing type to check against.
483

484
    Returns:
485
        bool: True if the object is of the specified type, False otherwise.
486
    """
487
    if not is_type(typing_type):
1✔
488
        raise UnsupportedTypeError(typing_type)
1✔
489

490
    if is_new_type(typing_type):
1✔
491
        typing_type = typing_type.__supertype__
1✔
492

493
    if is_typed_dict(typing_type):
1✔
494
        if not isinstance(object, dict):
1✔
495
            return False
1✔
496
        for key, expected_type in typing_type.__annotations__.items():
1✔
497
            if key not in object or not isoftype(object[key], expected_type):
1✔
498
                return False
1✔
499
        return True
1✔
500

501
    if typing_type == typing.Any:
1✔
502
        return True
1✔
503

504
    if hasattr(typing_type, "__origin__"):
1✔
505
        origin = typing_type.__origin__
1✔
506
        type_args = typing.get_args(typing_type)
1✔
507

508
        if origin is Literal:
1✔
509
            return object in type_args
1✔
510

511
        if origin is typing.Union:
1✔
512
            return any(isoftype(object, sub_type) for sub_type in type_args)
1✔
513

514
        if not isinstance(object, origin):
1✔
515
            return False
1✔
516
        if origin is list or origin is set:
1✔
517
            return all(isoftype(element, type_args[0]) for element in object)
1✔
518
        if origin is dict:
1✔
519
            return all(
1✔
520
                isoftype(key, type_args[0]) and isoftype(value, type_args[1])
521
                for key, value in object.items()
522
            )
523
        if origin is tuple:
1✔
524
            return all(
1✔
525
                isoftype(element, type_arg)
526
                for element, type_arg in zip(object, type_args)
527
            )
528

529
    return isinstance(object, typing_type)
1✔
530

531

532
def strtype(typing_type) -> str:
1✔
533
    """Converts a typing type to its string representation.
534

535
    Args:
536
        typing_type (Any): The typing type to be converted. This can include standard types,
537
            custom types, or types from the `typing` module, such as `Literal`, `Union`,
538
            `List`, `Dict`, `Tuple`, `TypedDict`, and `NewType`.
539

540
    Returns:
541
        str: The string representation of the provided typing type.
542

543
    Raises:
544
        UnsupportedTypeError: If the provided `typing_type` is not a recognized type.
545

546
    Notes:
547
        - If `typing_type` is `Literal`, `NewType`, or `TypedDict`, the function returns
548
          the name of the type.
549
        - If `typing_type` is `Any`, it returns the string `"Any"`.
550
        - For other typing constructs like `Union`, `List`, `Dict`, and `Tuple`, the function
551
          recursively converts each part of the type to its string representation.
552
        - The function checks the `__origin__` attribute to determine the base type and formats
553
          the type arguments accordingly.
554
    """
555
    if isinstance(typing_type, str):
1✔
556
        return typing_type
×
557

558
    if not is_type(typing_type):
1✔
559
        raise UnsupportedTypeError(typing_type)
×
560

561
    if is_new_type(typing_type) or is_typed_dict(typing_type):
1✔
562
        return typing_type.__name__
1✔
563

564
    if typing_type == typing.Any:
1✔
565
        return "Any"
1✔
566

567
    if hasattr(typing_type, "__origin__"):
1✔
568
        origin = typing_type.__origin__
1✔
569
        type_args = typing.get_args(typing_type)
1✔
570

571
        if type_args[-1] is type(None):
1✔
572
            return (
1✔
573
                "Optional["
574
                + ", ".join([strtype(sub_type) for sub_type in type_args[:-1]])
575
                + "]"
576
            )
577

578
        if origin is Literal:
1✔
579
            return str(typing_type).replace("typing.", "")
1✔
580
        if origin is typing.Union:
1✔
581
            return (
1✔
582
                "Union["
583
                + ", ".join([strtype(sub_type) for sub_type in type_args])
584
                + "]"
585
            )
586
        if origin is list or origin is set:
1✔
587
            return "List[" + strtype(type_args[0]) + "]"
1✔
588
        if origin is set:
1✔
589
            return "Set[" + strtype(type_args[0]) + "]"
×
590
        if origin is dict:
1✔
591
            return "Dict[" + strtype(type_args[0]) + ", " + strtype(type_args[1]) + "]"
1✔
592
        if origin is tuple:
1✔
593
            return (
1✔
594
                "Tuple["
595
                + ", ".join([strtype(sub_type) for sub_type in type_args])
596
                + "]"
597
            )
598

599
    return typing_type.__name__
1✔
600

601

602
# copied from: https://github.com/bojiang/typing_utils/blob/main/typing_utils/__init__.py
603
# liscened under Apache License 2.0
604

605

606
if hasattr(typing, "ForwardRef"):  # python3.8
1✔
607
    ForwardRef = typing.ForwardRef
1✔
608
elif hasattr(typing, "_ForwardRef"):  # python3.6
×
609
    ForwardRef = typing._ForwardRef
×
610
else:
611
    raise NotImplementedError()
×
612

613

614
unknown = None
1✔
615

616

617
BUILTINS_MAPPING = {
1✔
618
    typing.List: list,
619
    typing.Set: set,
620
    typing.Dict: dict,
621
    typing.Tuple: tuple,
622
    typing.ByteString: bytes,  # https://docs.python.org/3/library/typing.html#typing.ByteString
623
    typing.Callable: collections.abc.Callable,
624
    typing.Sequence: collections.abc.Sequence,
625
    type(None): None,
626
}
627

628

629
STATIC_SUBTYPE_MAPPING: typing.Dict[type, typing.Type] = {
1✔
630
    io.TextIOWrapper: typing.TextIO,
631
    io.TextIOBase: typing.TextIO,
632
    io.StringIO: typing.TextIO,
633
    io.BufferedReader: typing.BinaryIO,
634
    io.BufferedWriter: typing.BinaryIO,
635
    io.BytesIO: typing.BinaryIO,
636
}
637

638

639
def optional_all(elements) -> typing.Optional[bool]:
1✔
640
    if all(elements):
1✔
641
        return True
1✔
642
    if all(e is False for e in elements):
1✔
643
        return False
1✔
644
    return unknown
×
645

646

647
def optional_any(elements) -> typing.Optional[bool]:
1✔
648
    if any(elements):
1✔
649
        return True
1✔
650
    if any(e is None for e in elements):
1✔
651
        return unknown
×
652
    return False
1✔
653

654

655
def _hashable(value):
1✔
656
    """Determine whether `value` can be hashed."""
657
    try:
1✔
658
        hash(value)
1✔
659
    except TypeError:
×
660
        return False
×
661
    return True
1✔
662

663

664
get_type_hints = typing.get_type_hints
1✔
665

666
GenericClass = type(typing.List)
1✔
667
UnionClass = type(typing.Union)
1✔
668

669
_Type = typing.Union[None, type, "typing.TypeVar"]
1✔
670
OriginType = typing.Union[None, type]
1✔
671
TypeArgs = typing.Union[type, typing.AbstractSet[type], typing.Sequence[type]]
1✔
672

673

674
def _normalize_aliases(type_: _Type) -> _Type:
1✔
675
    if isinstance(type_, typing.TypeVar):
1✔
676
        return type_
1✔
677

678
    assert _hashable(type_), "_normalize_aliases should only be called on element types"
1✔
679

680
    if type_ in BUILTINS_MAPPING:
1✔
681
        return BUILTINS_MAPPING[type_]
1✔
682
    return type_
1✔
683

684

685
def get_origin(type_):
1✔
686
    """Get the unsubscripted version of a type.
687

688
    This supports generic types, Callable, Tuple, Union, Literal, Final and ClassVar.
689
    Return None for unsupported types.
690

691
    Examples:
692
        Here are some code examples using `get_origin` from the `typing_utils` module:
693

694
        .. code-block:: python
695

696
            from typing_utils import get_origin
697

698
            # Examples of get_origin usage
699
            get_origin(Literal[42]) is Literal  # True
700
            get_origin(int) is None  # True
701
            get_origin(ClassVar[int]) is ClassVar  # True
702
            get_origin(Generic) is Generic  # True
703
            get_origin(Generic[T]) is Generic  # True
704
            get_origin(Union[T, int]) is Union  # True
705
            get_origin(List[Tuple[T, T]][int]) == list  # True
706

707
    """
708
    if hasattr(typing, "get_origin"):  # python 3.8+
1✔
709
        _getter = typing.get_origin
1✔
710
        ori = _getter(type_)
1✔
711
    elif hasattr(typing.List, "_special"):  # python 3.7
×
712
        if isinstance(type_, GenericClass) and not type_._special:
×
713
            ori = type_.__origin__
×
714
        elif hasattr(type_, "_special") and type_._special:
×
715
            ori = type_
×
716
        elif type_ is typing.Generic:
×
717
            ori = typing.Generic
×
718
        else:
719
            ori = None
×
720
    else:  # python 3.6
721
        if isinstance(type_, GenericClass):
×
722
            ori = type_.__origin__
×
723
            if ori is None:
×
724
                ori = type_
×
725
        elif isinstance(type_, UnionClass):
×
726
            ori = type_.__origin__
×
727
        elif type_ is typing.Generic:
×
728
            ori = typing.Generic
×
729
        else:
730
            ori = None
×
731
    return _normalize_aliases(ori)
1✔
732

733

734
def get_args(type_) -> typing.Tuple:
1✔
735
    """Get type arguments with all substitutions performed.
736

737
    For unions, basic simplifications used by Union constructor are performed.
738

739
    Examples:
740
        Here are some code examples using `get_args` from the `typing_utils` module:
741

742
        .. code-block:: python
743

744
            from typing_utils import get_args
745

746
            # Examples of get_args usage
747
            get_args(Dict[str, int]) == (str, int)  # True
748
            get_args(int) == ()  # True
749
            get_args(Union[int, Union[T, int], str][int]) == (int, str)  # True
750
            get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])  # True
751
            get_args(Callable[[], T][int]) == ([], int)  # True
752
    """
753
    if hasattr(typing, "get_args"):  # python 3.8+
1✔
754
        _getter = typing.get_args
1✔
755
        res = _getter(type_)
1✔
756
    elif hasattr(typing.List, "_special"):  # python 3.7
×
757
        if (
×
758
            isinstance(type_, GenericClass) and not type_._special
759
        ):  # backport for python 3.8
760
            res = type_.__args__
×
761
            if get_origin(type_) is collections.abc.Callable and res[0] is not Ellipsis:
×
762
                res = (list(res[:-1]), res[-1])
×
763
        else:
764
            res = ()
×
765
    else:  # python 3.6
766
        if isinstance(type_, (GenericClass, UnionClass)):  # backport for python 3.8
×
767
            res = type_.__args__
×
768
            if get_origin(type_) is collections.abc.Callable and res[0] is not Ellipsis:
×
769
                res = (list(res[:-1]), res[-1])
×
770
        else:
771
            res = ()
×
772
    return () if res is None else res
1✔
773

774

775
def eval_forward_ref(ref, forward_refs=None):
1✔
776
    """Eval forward_refs in all cPython versions."""
777
    localns = forward_refs or {}
×
778

779
    if hasattr(typing, "_eval_type"):  # python3.8 & python 3.9
×
780
        _eval_type = typing._eval_type
×
781
        return _eval_type(ref, globals(), localns)
×
782

783
    if hasattr(ref, "_eval_type"):  # python3.6
×
784
        _eval_type = ref._eval_type
×
785
        return _eval_type(globals(), localns)
×
786

787
    raise NotImplementedError()
×
788

789

790
class NormalizedType(typing.NamedTuple):
1✔
791
    """Normalized type, made it possible to compare, hash between types."""
792

793
    origin: _Type
1✔
794
    args: typing.Union[tuple, frozenset] = ()
1✔
795

796
    def __eq__(self, other):
1✔
797
        if isinstance(other, NormalizedType):
1✔
798
            if self.origin != other.origin:
1✔
799
                return False
×
800
            if isinstance(self.args, frozenset) and isinstance(other.args, frozenset):
1✔
801
                return self.args <= other.args and other.args <= self.args
1✔
802
            return self.origin == other.origin and self.args == other.args
1✔
803
        if not self.args:
×
804
            return self.origin == other
×
805
        return False
×
806

807
    def __hash__(self) -> int:
1✔
808
        if not self.args:
1✔
809
            return hash(self.origin)
1✔
810
        return hash((self.origin, self.args))
1✔
811

812
    def __repr__(self):
1✔
813
        if not self.args:
×
814
            return f"{self.origin}"
×
815
        return f"{self.origin}[{self.args}])"
×
816

817

818
@lru_cache(maxsize=None)
1✔
819
def _normalize_args(tps: TypeArgs):
1✔
820
    if isinstance(tps, str):
1✔
821
        return tps
1✔
822
    if isinstance(tps, collections.abc.Sequence):
1✔
823
        return tuple(_normalize_args(type_) for type_ in tps)
1✔
824
    if isinstance(tps, collections.abc.Set):
1✔
825
        return frozenset(_normalize_args(type_) for type_ in tps)
1✔
826
    return normalize(tps)
1✔
827

828

829
def normalize(type_: _Type) -> NormalizedType:
1✔
830
    """Convert types to NormalizedType instances."""
831
    args = get_args(type_)
1✔
832
    origin = get_origin(type_)
1✔
833
    if not origin:
1✔
834
        return NormalizedType(_normalize_aliases(type_))
1✔
835
    origin = _normalize_aliases(origin)
1✔
836

837
    if origin is typing.Union:  # sort args when the origin is Union
1✔
838
        args = _normalize_args(frozenset(args))
1✔
839
    else:
840
        args = _normalize_args(args)
1✔
841
    return NormalizedType(origin, args)
1✔
842

843

844
def _is_origin_subtype(left: OriginType, right: OriginType) -> bool:
1✔
845
    if left is right:
1✔
846
        return True
1✔
847

848
    if (
1✔
849
        left is not None
850
        and left in STATIC_SUBTYPE_MAPPING
851
        and right == STATIC_SUBTYPE_MAPPING[left]
852
    ):
853
        return True
×
854

855
    if hasattr(left, "mro"):
1✔
856
        for parent in left.mro():
1✔
857
            if parent == right:
1✔
858
                return True
1✔
859

860
    if isinstance(left, type) and isinstance(right, type):
1✔
861
        return issubclass(left, right)
1✔
862

863
    return left == right
1✔
864

865

866
NormalizedTypeArgs = typing.Union[
1✔
867
    typing.Tuple["NormalizedTypeArgs", ...],
868
    typing.FrozenSet[NormalizedType],
869
    NormalizedType,
870
]
871

872

873
def _is_origin_subtype_args(
1✔
874
    left: NormalizedTypeArgs,
875
    right: NormalizedTypeArgs,
876
    forward_refs: typing.Optional[typing.Mapping[str, type]],
877
) -> typing.Optional[bool]:
878
    if isinstance(left, frozenset):
1✔
879
        if not isinstance(right, frozenset):
1✔
880
            return False
×
881

882
        excluded = left - right
1✔
883
        if not excluded:
1✔
884
            # Union[str, int] <> Union[int, str]
885
            return True
×
886

887
        # Union[list, int] <> Union[typing.Sequence, int]
888
        return all(
1✔
889
            any(_is_normal_subtype(e, r, forward_refs) for r in right) for e in excluded
890
        )
891

892
    if isinstance(left, collections.abc.Sequence) and not isinstance(
1✔
893
        left, NormalizedType
894
    ):
895
        if not isinstance(right, collections.abc.Sequence) or isinstance(
1✔
896
            right, NormalizedType
897
        ):
898
            return False
×
899

900
        if (
1✔
901
            left
902
            and left[-1].origin is not Ellipsis
903
            and right
904
            and right[-1].origin is Ellipsis
905
        ):
906
            # Tuple[type, type] <> Tuple[type, ...]
907
            return all(
×
908
                _is_origin_subtype_args(lft, right[0], forward_refs) for lft in left
909
            )
910

911
        if len(left) != len(right):
1✔
912
            return False
1✔
913

914
        return all(
1✔
915
            lft is not None
916
            and rgt is not None
917
            and _is_origin_subtype_args(lft, rgt, forward_refs)
918
            for lft, rgt in itertools.zip_longest(left, right)
919
        )
920

921
    assert isinstance(left, NormalizedType)
1✔
922
    assert isinstance(right, NormalizedType)
1✔
923

924
    return _is_normal_subtype(left, right, forward_refs)
1✔
925

926

927
@lru_cache(maxsize=None)
1✔
928
def _is_normal_subtype(
1✔
929
    left: NormalizedType,
930
    right: NormalizedType,
931
    forward_refs: typing.Optional[typing.Mapping[str, type]],
932
) -> typing.Optional[bool]:
933
    if isinstance(left.origin, ForwardRef):
1✔
934
        left = normalize(eval_forward_ref(left.origin, forward_refs=forward_refs))
×
935

936
    if isinstance(right.origin, ForwardRef):
1✔
937
        right = normalize(eval_forward_ref(right.origin, forward_refs=forward_refs))
×
938

939
    # Any
940
    if right.origin is typing.Any:
1✔
941
        return True
×
942

943
    # Union
944
    if right.origin is typing.Union and left.origin is typing.Union:
1✔
945
        return _is_origin_subtype_args(left.args, right.args, forward_refs)
1✔
946
    if right.origin is typing.Union:
1✔
947
        return optional_any(
1✔
948
            _is_normal_subtype(left, a, forward_refs) for a in right.args
949
        )
950
    if left.origin is typing.Union:
1✔
951
        return optional_all(
1✔
952
            _is_normal_subtype(a, right, forward_refs) for a in left.args
953
        )
954

955
    # TypeVar
956
    if isinstance(left.origin, typing.TypeVar) and isinstance(
1✔
957
        right.origin, typing.TypeVar
958
    ):
959
        if left.origin is right.origin:
×
960
            return True
×
961

962
        left_bound = getattr(left.origin, "__bound__", None)
×
963
        right_bound = getattr(right.origin, "__bound__", None)
×
964
        if right_bound is None or left_bound is None:
×
965
            return unknown
×
966
        return _is_normal_subtype(
×
967
            normalize(left_bound), normalize(right_bound), forward_refs
968
        )
969
    if isinstance(right.origin, typing.TypeVar):
1✔
970
        return unknown
×
971
    if isinstance(left.origin, typing.TypeVar):
1✔
972
        left_bound = getattr(left.origin, "__bound__", None)
×
973
        if left_bound is None:
×
974
            return unknown
×
975
        return _is_normal_subtype(normalize(left_bound), right, forward_refs)
×
976

977
    if not left.args and not right.args:
1✔
978
        return _is_origin_subtype(left.origin, right.origin)
1✔
979

980
    if not right.args:
1✔
981
        return _is_origin_subtype(left.origin, right.origin)
1✔
982

983
    if _is_origin_subtype(left.origin, right.origin):
1✔
984
        return _is_origin_subtype_args(left.args, right.args, forward_refs)
1✔
985

986
    return False
1✔
987

988

989
def issubtype(
1✔
990
    left: _Type,
991
    right: _Type,
992
    forward_refs: typing.Optional[dict] = None,
993
) -> typing.Optional[bool]:
994
    """Check that the left argument is a subtype of the right.
995

996
    For unions, check if the type arguments of the left is a subset of the right.
997
    Also works for nested types including ForwardRefs.
998

999
    Examples:
1000
        Here are some code examples using `issubtype` from the `typing_utils` module:
1001

1002
        .. code-block:: python
1003

1004
            from typing_utils import issubtype
1005

1006
            # Examples of issubtype checks
1007
            issubtype(typing.List, typing.Any)  # True
1008
            issubtype(list, list)  # True
1009
            issubtype(list, typing.List)  # True
1010
            issubtype(list, typing.Sequence)  # True
1011
            issubtype(typing.List[int], list)  # True
1012
            issubtype(typing.List[typing.List], list)  # True
1013
            issubtype(list, typing.List[int])  # False
1014
            issubtype(list, typing.Union[typing.Tuple, typing.Set])  # False
1015
            issubtype(typing.List[typing.List], typing.List[typing.Sequence])  # True
1016

1017
            # Example with custom JSON type
1018
            JSON = typing.Union[
1019
                int, float, bool, str, None, typing.Sequence["JSON"],
1020
                typing.Mapping[str, "JSON"]
1021
            ]
1022
            issubtype(str, JSON, forward_refs={'JSON': JSON})  # True
1023
            issubtype(typing.Dict[str, str], JSON, forward_refs={'JSON': JSON})  # True
1024
            issubtype(typing.Dict[str, bytes], JSON, forward_refs={'JSON': JSON})  # False
1025
    """
1026
    return _is_normal_subtype(normalize(left), normalize(right), forward_refs)
1✔
1027

1028

1029
def to_float_or_default(v, failure_default=0):
1✔
1030
    try:
1✔
1031
        return float(v)
1✔
1032
    except Exception as e:
1✔
1033
        if failure_default is None:
1✔
1034
            raise e
1✔
1035
        return failure_default
1✔
1036

1037

1038
def verify_required_schema(
1✔
1039
    required_schema_dict: Dict[str, type],
1040
    input_dict: Dict[str, Any],
1041
    class_name: str,
1042
    id: Optional[str] = "",
1043
    description: Optional[str] = "",
1044
) -> None:
1045
    """Verifies if passed input_dict has all required fields, and they are of proper types according to required_schema_dict.
1046

1047
    Parameters:
1048
        required_schema_dict (Dict[str, str]):
1049
            Schema where a key is name of a field and a value is a string
1050
            representing a type of its value.
1051
        input_dict (Dict[str, Any]):
1052
            Dict with input fields and their respective values.
1053
    """
1054
    for field_name, data_type in required_schema_dict.items():
1✔
1055
        try:
1✔
1056
            value = input_dict[field_name]
1✔
1057
        except KeyError as e:
1✔
1058
            raise Exception(
1✔
1059
                f"The {class_name} ('{id}') expected a field '{field_name}' which the input instance did not contain.\n"
1060
                f"The input instance fields are  : {list(input_dict.keys())}.\n"
1061
                f"{class_name} description: {description}"
1062
            ) from e
1063

1064
        if not isoftype(value, data_type):
1✔
1065
            raise ValueError(
1✔
1066
                f"Passed value '{value}' of field '{field_name}' is not "
1067
                f"of required type: ({to_type_string(data_type)}) in {class_name} ('{id}').\n"
1068
                f"{class_name} description: {description}"
1069
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc