• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pomponchik / transfunctions / 16879851942

11 Aug 2025 12:20PM UTC coverage: 98.366% (+0.03%) from 98.339%
16879851942

Pull #10

github

esblinov
another typing shit
Pull Request #10: 0.0.9

33 of 36 new or added lines in 6 files covered. (91.67%)

4 existing lines in 1 file now uncovered.

301 of 306 relevant lines covered (98.37%)

5.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.75
/transfunctions/transformer.py
1
import ast
6✔
2
from ast import (
6✔
3
    AST,
4
    Assign,
5
    AsyncFunctionDef,
6
    Await,
7
    Call,
8
    Constant,
9
    FunctionDef,
10
    Module,
11
    Load,
12
    Name,
13
    NodeTransformer,
14
    Pass,
15
    Return,
16
    Store,
17
    With,
18
    arguments,
19
    increment_lineno,
20
    parse,
21
    YieldFrom,
22
)
23
from functools import update_wrapper, wraps
6✔
24
from inspect import getfile, getsource, iscoroutinefunction, isfunction
6✔
25
from sys import version_info
6✔
26
from types import FunctionType, MethodType, FrameType
6✔
27
from typing import Any, Dict, Generic, List, Optional, Union, Type, cast
6✔
28

29
from dill.source import getsource as dill_getsource  # type: ignore[import-untyped]
6✔
30

31
from transfunctions.errors import (
6✔
32
    AliasedDecoratorSyntaxError,
33
    CallTransfunctionDirectlyError,
34
    DualUseOfDecoratorError,
35
    WrongDecoratorSyntaxError,
36
    WrongMarkerSyntaxError,
37
)
38
from transfunctions.typing import Coroutine, Callable, Generator, FunctionParams, ReturnType, SomeClassInstance
6✔
39
from transfunctions.universal_namespace import UniversalNamespaceAroundFunction
6✔
40

41

42
class FunctionTransformer(Generic[FunctionParams, ReturnType]):
6✔
43
    def __init__(
6✔
44
        self, function: Callable[FunctionParams, ReturnType], decorator_lineno: int, decorator_name: str, frame: FrameType,
45
    ) -> None:
46
        if isinstance(function, type(self)):
6✔
47
            raise DualUseOfDecoratorError(f"You cannot use the '{decorator_name}' decorator twice for the same function.")
6✔
48
        if not isfunction(function):
6✔
49
            raise ValueError(f"Only regular or generator functions can be used as a template for @{decorator_name}.")
6✔
50
        if iscoroutinefunction(function):
6✔
51
            raise ValueError(f"Only regular or generator functions can be used as a template for @{decorator_name}. You can't use async functions.")
6✔
52
        if self.is_lambda(function):
6✔
53
            raise ValueError(f"Only regular or generator functions can be used as a template for @{decorator_name}. Don't use lambdas here.")
6✔
54

55
        self.function = function
6✔
56
        self.decorator_lineno = decorator_lineno
6✔
57
        self.decorator_name = decorator_name
6✔
58
        self.frame = frame
6✔
59
        self.base_object: Optional[SomeClassInstance] = None  # type: ignore[valid-type]
6✔
60
        self.cache: Dict[str, Callable] = {}
6✔
61

62
    def __call__(self, *args: Any, **kwargs: Any) -> None:
6✔
63
        raise CallTransfunctionDirectlyError("You can't call a transfunction object directly, create a function, a generator function or a coroutine function from it.")
6✔
64

65
    def __get__(
6✔
66
        self,
67
        base_object: SomeClassInstance,
68
        owner: Type[SomeClassInstance],
69
    ) -> 'FunctionTransformer[FunctionParams, ReturnType]':
70
        self.base_object = base_object
6✔
71
        return self
6✔
72

73
    @staticmethod
6✔
74
    def is_lambda(function: Callable) -> bool:
6✔
75
        # https://stackoverflow.com/a/3655857/14522393
76
        lambda_example = lambda: 0  # noqa: E731
6✔
77
        return isinstance(function, type(lambda_example)) and function.__name__ == lambda_example.__name__
6✔
78

79
    def get_usual_function(self, addictional_transformers: Optional[List[NodeTransformer]] = None) -> Callable[FunctionParams, ReturnType]:
6✔
80
        return self.extract_context('sync_context', addictional_transformers=addictional_transformers)
6✔
81

82
    def get_async_function(self) -> Callable[FunctionParams, Coroutine[Any, Any, ReturnType]]:
6✔
83
        original_function = self.function
6✔
84

85
        class ConvertSyncFunctionToAsync(NodeTransformer):
6✔
86
            def visit_FunctionDef(self, node: FunctionDef) -> Optional[Union[AST, List[AST]]]:
6✔
87
                if node.name == original_function.__name__:
6✔
88
                    return AsyncFunctionDef(
6✔
89
                        name=original_function.__name__,
90
                        args=node.args,
91
                        body=node.body,
92
                        decorator_list=node.decorator_list,
93
                        lineno=node.lineno,
94
                        end_lineno=node.end_lineno,
95
                        col_offset=node.col_offset,
96
                        end_col_offset=node.end_col_offset,
97
                    )
UNCOV
98
                return node
×
99

100
        class ExtractAwaitExpressions(NodeTransformer):
6✔
101
            def visit_Call(self, node: Call) -> Optional[Union[AST, List[AST]]]:
6✔
102
                if isinstance(node.func, Name) and node.func.id == 'await_it':
6✔
103
                    if len(node.args) != 1 or node.keywords:
6✔
104
                        raise WrongMarkerSyntaxError('The "await_it" marker can be used with only one positional argument.')
6✔
105

106
                    return Await(
6✔
107
                        value=node.args[0],
108
                        lineno=node.lineno,
109
                        end_lineno=node.end_lineno,
110
                        col_offset=node.col_offset,
111
                        end_col_offset=node.end_col_offset,
112
                    )
113
                return node
6✔
114

115
        return self.extract_context(
6✔
116
            'async_context',
117
            addictional_transformers=[
118
                ConvertSyncFunctionToAsync(),
119
                ExtractAwaitExpressions(),
120
            ],
121
        )
122

123
    def get_generator_function(self) -> Callable[FunctionParams, Generator[ReturnType, None, None]]:
6✔
124
        class ConvertYieldFroms(NodeTransformer):
6✔
125
            def visit_Call(self, node: Call) -> Optional[Union[AST, List[AST]]]:
6✔
126
                if isinstance(node.func, Name) and node.func.id == 'yield_from_it':
6✔
127
                    if len(node.args) != 1 or node.keywords:
6✔
128
                        raise WrongMarkerSyntaxError('The "yield_from_it" marker can be used with only one positional argument.')
6✔
129

130
                    return YieldFrom(
6✔
131
                        value=node.args[0],
132
                        lineno=node.lineno,
133
                        end_lineno=node.end_lineno,
134
                        col_offset=node.col_offset,
135
                        end_col_offset=node.end_col_offset,
136
                    )
137
                return node
6✔
138

139
        return self.extract_context(
6✔
140
            'generator_context',
141
            addictional_transformers=[
142
                ConvertYieldFroms(),
143
            ],
144
        )
145

146
    @staticmethod
6✔
147
    def clear_spaces_from_source_code(source_code: str) -> str:
6✔
148
        splitted_source_code = source_code.split('\n')
6✔
149

150
        indent = 0
6✔
151
        for letter in splitted_source_code[0]:
6✔
152
            if letter.isspace():
6✔
153
                indent += 1
6✔
154
            else:
155
                break
6✔
156

157
        new_splitted_source_code = [x[indent:] for x in splitted_source_code]
6✔
158

159
        return '\n'.join(new_splitted_source_code)
6✔
160

161

162
    def extract_context(self, context_name: str, addictional_transformers: Optional[List[NodeTransformer]] = None):
6✔
163
        if context_name in self.cache:
6✔
164
            return self.cache[context_name]
6✔
165
        try:
6✔
166
            source_code: str = getsource(self.function)
6✔
NEW
167
        except OSError:
×
UNCOV
168
            source_code = dill_getsource(self.function)
×
169

170
        converted_source_code = self.clear_spaces_from_source_code(source_code)
6✔
171
        tree = parse(converted_source_code)
6✔
172
        original_function = self.function
6✔
173
        transfunction_decorator = None
6✔
174
        decorator_name = self.decorator_name
6✔
175

176
        class RewriteContexts(NodeTransformer):
6✔
177
            def visit_With(self, node: With) -> Optional[Union[AST, List[AST]]]:
6✔
178
                if len(node.items) == 1:
6✔
179
                    if isinstance(node.items[0].context_expr, Name):
6✔
180
                        context_expr = node.items[0].context_expr
6✔
181
                    elif isinstance(node.items[0].context_expr, Call) and isinstance(node.items[0].context_expr.func, ast.Name):
6✔
182
                        context_expr = node.items[0].context_expr.func
6✔
183

184
                    if context_expr.id == context_name:
6✔
185
                        return cast(List[AST], node.body)
6✔
186
                    if context_expr.id != context_name and context_expr.id in ('async_context', 'sync_context', 'generator_context'):
6✔
187
                        return None
6✔
188
                return node
6✔
189

190
        class DeleteDecorator(NodeTransformer):
6✔
191
            def visit_FunctionDef(self, node: FunctionDef) -> Optional[Union[AST, List[AST]]]:
6✔
192
                if node.name == original_function.__name__:
6✔
193
                    nonlocal transfunction_decorator
194
                    transfunction_decorator = None
6✔
195

196
                    if not node.decorator_list:
6✔
197
                        raise WrongDecoratorSyntaxError(f"The @{decorator_name} decorator can only be used with the '@' symbol. Don't use it as a regular function. Also, don't rename it.")
6✔
198

199
                    for decorator in node.decorator_list:
6✔
200
                        if isinstance(decorator, Call):
6✔
201
                            decorator = decorator.func
6✔
202

203
                        if (
6✔
204
                            isinstance(decorator, Name)
205
                            and decorator.id != decorator_name
206
                        ):
207
                            raise WrongDecoratorSyntaxError(f'The @{decorator_name} decorator cannot be used in conjunction with other decorators.')
6✔
208
                        else:
209
                            if transfunction_decorator is not None:
6✔
UNCOV
210
                                raise DualUseOfDecoratorError(f"You cannot use the '{decorator_name}' decorator twice for the same function.")
×
211
                            transfunction_decorator = decorator
6✔
212

213
                    node.decorator_list = []
6✔
214
                return node
6✔
215

216
        RewriteContexts().visit(tree)
6✔
217
        DeleteDecorator().visit(tree)
6✔
218

219
        if transfunction_decorator is None:
6✔
UNCOV
220
            raise AliasedDecoratorSyntaxError(
×
221
                "The transfunction decorator must have been renamed."
222
            )
223

224
        function_def = cast(FunctionDef, tree.body[0])
6✔
225
        if not function_def.body:
6✔
226
            function_def.body.append(
6✔
227
                Pass(
228
                    col_offset=tree.body[0].col_offset,
229
                ),
230
            )
231

232
        if addictional_transformers is not None:
6✔
233
            for addictional_transformer in addictional_transformers:
6✔
234
                addictional_transformer.visit(tree)
6✔
235

236
        tree = self.wrap_ast_by_closures(tree)
6✔
237

238
        if version_info.minor > 10:
6✔
239
            increment_lineno(tree, n=(self.decorator_lineno - transfunction_decorator.lineno))
3✔
240
        else:
NEW
241
            increment_lineno(tree, n=(self.decorator_lineno - transfunction_decorator.lineno - 1))
3✔
242

243
        code = compile(tree, filename=getfile(self.function), mode='exec')
6✔
244
        namespace = UniversalNamespaceAroundFunction(self.function, self.frame)
6✔
245
        exec(code, namespace)
6✔
246
        function_factory = namespace['wrapper']
6✔
247
        result = function_factory()
6✔
248
        result = self.rewrite_globals_and_closure(result)
6✔
249
        result = wraps(self.function)(result)
6✔
250

251
        if self.base_object is not None:
6✔
252
            result = MethodType(
6✔
253
                result,
254
                self.base_object,
255
            )
256

257
        self.cache[context_name] = result
6✔
258

259
        return result
6✔
260

261
    def wrap_ast_by_closures(self, tree: Module) -> Module:
6✔
262
        old_functiondef = tree.body[0]
6✔
263

264
        tree.body[0] = FunctionDef(
6✔
265
            name='wrapper',
266
            body=[Assign(targets=[Name(id=name, ctx=Store(), col_offset=0)], value=Constant(value=None, col_offset=0), col_offset=0) for name in self.function.__code__.co_freevars] + [
267
                old_functiondef,
268
                Return(value=Name(id=self.function.__name__, ctx=Load(), col_offset=0), col_offset=0),
269
            ],
270
            col_offset=0,
271
            args=arguments(
272
                posonlyargs=[],
273
                args=[],
274
                kwonlyargs=[],
275
                kw_defaults=[],
276
                defaults=[],
277
            ),
278
            decorator_list=[],
279
        )
280

281
        return tree
6✔
282

283

284
    def rewrite_globals_and_closure(self, function: FunctionType) -> FunctionType:
6✔
285
        # https://stackoverflow.com/a/13503277/14522393
286
        all_new_closure_names = set(self.function.__code__.co_freevars)
6✔
287

288
        if self.function.__closure__ is not None:
6✔
289
            old_function_closure_variables = {name: cell for name, cell in zip(self.function.__code__.co_freevars, self.function.__closure__)}
6✔
290
            filtered_closure = tuple([cell for name, cell in old_function_closure_variables.items() if name in all_new_closure_names])
6✔
291
            names = tuple([name for name, cell in old_function_closure_variables.items() if name in all_new_closure_names])
6✔
292
            new_code = function.__code__.replace(co_freevars=names)
6✔
293
        else:
294
            filtered_closure = None
6✔
295
            new_code = function.__code__
6✔
296

297
        new_function = FunctionType(
6✔
298
            new_code,
299
            self.function.__globals__,
300
            name=self.function.__name__,
301
            argdefs=self.function.__defaults__,
302
            closure=filtered_closure,
303
        )
304

305
        new_function = cast(FunctionType, update_wrapper(new_function, function))
6✔
306
        new_function.__kwdefaults__ = function.__kwdefaults__
6✔
307
        return new_function
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc