• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alexmojaki / executing / 10623179361

29 Aug 2024 09:57PM UTC coverage: 97.297% (-0.004%) from 97.301%
10623179361

push

github

239 of 250 branches covered (95.6%)

Branch coverage included in aggregate %.

517 of 527 relevant lines covered (98.1%)

4.35 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.3
/executing/executing.py
1
"""
2
MIT License
3

4
Copyright (c) 2021 Alex Hall
5

6
Permission is hereby granted, free of charge, to any person obtaining a copy
7
of this software and associated documentation files (the "Software"), to deal
8
in the Software without restriction, including without limitation the rights
9
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
copies of the Software, and to permit persons to whom the Software is
11
furnished to do so, subject to the following conditions:
12

13
The above copyright notice and this permission notice shall be included in all
14
copies or substantial portions of the Software.
15

16
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
SOFTWARE.
23
"""
24

25
import __future__
6✔
26
import ast
6✔
27
import dis
6✔
28
import inspect
6✔
29
import io
6✔
30
import linecache
6✔
31
import re
6✔
32
import sys
6✔
33
import types
6✔
34
from collections import defaultdict
6✔
35
from copy import deepcopy
6✔
36
from functools import lru_cache
6✔
37
from itertools import islice
6✔
38
from itertools import zip_longest
6✔
39
from operator import attrgetter
6✔
40
from pathlib import Path
6✔
41
from threading import RLock
6✔
42
from tokenize import detect_encoding
6✔
43
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Sized, Tuple, Type, TypeVar, Union, cast
6✔
44
from ._utils import mangled_name,assert_, EnhancedAST,EnhancedInstruction,Instruction,get_instructions
45

46
if TYPE_CHECKING:  # pragma: no cover
47
    from asttokens import ASTTokens, ASTText
48
    from asttokens.asttokens import ASTTextBase
49

50

51
function_node_types = (ast.FunctionDef, ast.AsyncFunctionDef) # type: Tuple[Type, ...]
6✔
52

53
cache = lru_cache(maxsize=None)
6✔
54

55
TESTING = 0
56

57
class NotOneValueFound(Exception):
6✔
58
    def __init__(self,msg,values=[]):
6✔
59
        # type: (str, Sequence) -> None
60
        self.values=values
61
        super(NotOneValueFound,self).__init__(msg)
6✔
62

6✔
63
T = TypeVar('T')
64

65

66
def only(it):
67
    # type: (Iterable[T]) -> T
6✔
68
    if isinstance(it, Sized):
6✔
69
        if len(it) != 1:
70
            raise NotOneValueFound('Expected one value, found %s' % len(it))
71
        # noinspection PyTypeChecker
72
        return list(it)[0]
6✔
73

74
    lst = tuple(islice(it, 2))
75
    if len(lst) == 0:
76
        raise NotOneValueFound('Expected one value, found 0')
77
    if len(lst) > 1:
78
        raise NotOneValueFound('Expected one value, found several',lst)
79
    return lst[0]
6✔
80

3✔
81

82
class Source(object):
83
    """
6✔
84
    The source code of a single file and associated metadata.
85

6✔
86
    The main method of interest is the classmethod `executing(frame)`.
6✔
87

6✔
88
    If you want an instance of this class, don't construct it.
6✔
89
    Ideally use the classmethod `for_frame(frame)`.
6✔
90
    If you don't have a frame, use `for_filename(filename [, module_globals])`.
6✔
91
    These methods cache instances by filename, so at most one instance exists per filename.
6✔
92

93
    Attributes:
94
        - filename
6✔
95
        - text
96
        - lines
97
        - tree: AST parsed from text, or None if text is not valid Python
6✔
98
            All nodes in the tree have an extra `parent` attribute
6✔
99

100
    Other methods of interest:
6✔
101
        - statements_at_line
6✔
102
        - asttokens
103
        - code_qualname
6✔
104
    """
105

106
    def __init__(self, filename, lines):
6✔
107
        # type: (str, Sequence[str]) -> None
108
        """
6✔
109
        Don't call this constructor, see the class docstring.
3✔
110
        """
3✔
111

112
        self.filename = filename
3✔
113
        self.text = ''.join(lines)
114
        self.lines = [line.rstrip('\r\n') for line in lines]
6✔
115

6✔
116
        self._nodes_by_line = defaultdict(list)
6✔
117
        self.tree = None
6✔
118
        self._qualnames = {}
6✔
119
        self._asttokens = None  # type: Optional[ASTTokens]
6✔
120
        self._asttext = None  # type: Optional[ASTText]
121

122
        try:
6✔
123
            self.tree = ast.parse(self.text, filename=filename)
124
        except (SyntaxError, ValueError):
125
            pass
126
        else:
127
            for node in ast.walk(self.tree):
128
                for child in ast.iter_child_nodes(node):
129
                    cast(EnhancedAST, child).parent = cast(EnhancedAST, node)
130
                for lineno in node_linenos(node):
131
                    self._nodes_by_line[lineno].append(node)
132

133
            visitor = QualnameVisitor()
134
            visitor.visit(self.tree)
135
            self._qualnames = visitor.qualnames
136

137
    @classmethod
138
    def for_frame(cls, frame, use_cache=True):
139
        # type: (types.FrameType, bool) -> "Source"
140
        """
141
        Returns the `Source` object corresponding to the file the frame is executing in.
142
        """
143
        return cls.for_filename(frame.f_code.co_filename, frame.f_globals or {}, use_cache)
144

145
    @classmethod
146
    def for_filename(
6✔
147
        cls,
148
        filename,
149
        module_globals=None,
150
        use_cache=True,  # noqa no longer used
151
    ):
152
        # type: (Union[str, Path], Optional[Dict[str, Any]], bool) -> "Source"
6✔
153
        if isinstance(filename, Path):
6✔
154
            filename = str(filename)
6✔
155

156
        def get_lines():
6✔
157
            # type: () -> List[str]
6✔
158
            return linecache.getlines(cast(str, filename), module_globals)
6✔
159

6✔
160
        # Save the current linecache entry, then ensure the cache is up to date.
6✔
161
        entry = linecache.cache.get(filename) # type: ignore[attr-defined]
162
        linecache.checkcache(filename)
6✔
163
        lines = get_lines()
6✔
164
        if entry is not None and not lines:
6✔
165
            # There was an entry, checkcache removed it, and nothing replaced it.
6✔
166
            # This means the file wasn't simply changed (because the `lines` wouldn't be empty)
167
            # but rather the file was found not to exist, probably because `filename` was fake.
6✔
168
            # Restore the original entry so that we still have something.
6✔
169
            linecache.cache[filename] = entry # type: ignore[attr-defined]
6✔
170
            lines = get_lines()
6✔
171

6✔
172
        return cls._for_filename_and_lines(filename, tuple(lines))
173

6✔
174
    @classmethod
6✔
175
    def _for_filename_and_lines(cls, filename, lines):
6✔
176
        # type: (str, Sequence[str]) -> "Source"
177
        source_cache = cls._class_local('__source_cache_with_lines', {}) # type: Dict[Tuple[str, Sequence[str]], Source]
6✔
178
        try:
6✔
179
            return source_cache[(filename, lines)]
180
        except KeyError:
181
            pass
182

183
        result = source_cache[(filename, lines)] = cls(filename, lines)
6✔
184
        return result
185

6✔
186
    @classmethod
6✔
187
    def lazycache(cls, frame):
188
        # type: (types.FrameType) -> None
189
        linecache.lazycache(frame.f_code.co_filename, frame.f_globals)
190

191
    @classmethod
192
    def executing(cls, frame_or_tb):
193
        # type: (Union[types.TracebackType, types.FrameType]) -> "Executing"
6!
194
        """
×
195
        Returns an `Executing` object representing the operation
196
        currently executing in the given frame or traceback object.
6✔
197
        """
198
        if isinstance(frame_or_tb, types.TracebackType):
6✔
199
            # https://docs.python.org/3/reference/datamodel.html#traceback-objects
200
            # "tb_lineno gives the line number where the exception occurred;
201
            #  tb_lasti indicates the precise instruction.
6✔
202
            #  The line number and last instruction in the traceback may differ
6✔
203
            #  from the line number of its frame object
6✔
204
            #  if the exception occurred in a try statement with no matching except clause
6✔
205
            #  or with a finally clause."
206
            tb = frame_or_tb
207
            frame = tb.tb_frame
208
            lineno = tb.tb_lineno
209
            lasti = tb.tb_lasti
6✔
210
        else:
6✔
211
            frame = frame_or_tb
212
            lineno = frame.f_lineno
6✔
213
            lasti = frame.f_lasti
214

6✔
215

6✔
216

217
        code = frame.f_code
6✔
218
        key = (code, id(code), lasti)
6✔
219
        executing_cache = cls._class_local('__executing_cache', {}) # type: Dict[Tuple[types.CodeType, int, int], Any]
6✔
220

6✔
221
        args = executing_cache.get(key)
6✔
222
        if not args:
223
            node = stmts = decorator = None
6✔
224
            source = cls.for_frame(frame)
6✔
225
            tree = source.tree
226
            if tree:
6✔
227
                try:
6✔
228
                    stmts = source.statements_at_line(lineno)
229
                    if stmts:
6✔
230
                        if is_ipython_cell_code(code):
231
                            decorator, node = find_node_ipython(frame, lasti, stmts, source)
6✔
232
                        else:
6✔
233
                            node_finder = NodeFinder(frame, stmts, tree, lasti, source)
234
                            node = node_finder.result
235
                            decorator = node_finder.decorator
236

237
                    if node:
238
                        new_stmts = {statement_containing_node(node)}
6✔
239
                        assert_(new_stmts <= stmts)
240
                        stmts = new_stmts
241
                except Exception:
242
                    if TESTING:
243
                        raise
244

245
            executing_cache[key] = args = source, node, stmts, decorator
246

6✔
247
        return Executing(frame, *args)
6✔
248

6✔
249
    @classmethod
6✔
250
    def _class_local(cls, name, default):
251
        # type: (str, T) -> T
6✔
252
        """
6✔
253
        Returns an attribute directly associated with this class
6✔
254
        (as opposed to subclasses), setting default if necessary
255
        """
256
        # classes have a mappingproxy preventing us from using setdefault
257
        result = cls.__dict__.get(name, default)
6✔
258
        setattr(cls, name, result)
6✔
259
        return result
6✔
260

261
    @cache
6✔
262
    def statements_at_line(self, lineno):
6✔
263
        # type: (int) -> Set[EnhancedAST]
6✔
264
        """
6✔
265
        Returns the statement nodes overlapping the given line.
6✔
266

6✔
267
        Returns at most one statement unless semicolons are present.
6✔
268

6✔
269
        If the `text` attribute is not valid python, meaning
6!
270
        `tree` is None, returns an empty set.
6✔
271

6✔
272
        Otherwise, `Source.for_frame(frame).statements_at_line(frame.f_lineno)`
273
        should return at least one statement.
6✔
274
        """
6✔
275

6✔
276
        return {
277
            statement_containing_node(node)
6✔
278
            for node in
6✔
279
            self._nodes_by_line[lineno]
6✔
280
        }
6✔
281

6✔
282
    def asttext(self):
6✔
283
        # type: () -> ASTText
6✔
284
        """
285
        Returns an ASTText object for getting the source of specific AST nodes.
6✔
286

287
        See http://asttokens.readthedocs.io/en/latest/api-index.html
6✔
288
        """
289
        from asttokens import ASTText  # must be installed separately
6✔
290

6✔
291
        if self._asttext is None:
292
            self._asttext = ASTText(self.text, tree=self.tree, filename=self.filename)
293

294
        return self._asttext
295

296
    def asttokens(self):
297
        # type: () -> ASTTokens
6✔
298
        """
6✔
299
        Returns an ASTTokens object for getting the source of specific AST nodes.
6✔
300

301
        See http://asttokens.readthedocs.io/en/latest/api-index.html
6✔
302
        """
6✔
303
        import asttokens  # must be installed separately
304

305
        if self._asttokens is None:
306
            if hasattr(asttokens, 'ASTText'):
307
                self._asttokens = self.asttext().asttokens
308
            else:  # pragma: no cover
309
                self._asttokens = asttokens.ASTTokens(self.text, tree=self.tree, filename=self.filename)
310
        return self._asttokens
311

312
    def _asttext_base(self):
313
        # type: () -> ASTTextBase
314
        import asttokens  # must be installed separately
315

316
        if hasattr(asttokens, 'ASTText'):
6✔
317
            return self.asttext()
318
        else:  # pragma: no cover
319
            return self.asttokens()
320

321
    @staticmethod
322
    def decode_source(source):
6✔
323
        # type: (Union[str, bytes]) -> str
324
        if isinstance(source, bytes):
325
            encoding = Source.detect_encoding(source)
326
            return source.decode(encoding)
327
        else:
328
            return source
329

6✔
330
    @staticmethod
331
    def detect_encoding(source):
6✔
332
        # type: (bytes) -> str
6✔
333
        return detect_encoding(io.BytesIO(source).readline)[0]
334

6✔
335
    def code_qualname(self, code):
336
        # type: (types.CodeType) -> str
6✔
337
        """
338
        Imitates the __qualname__ attribute of functions for code objects.
339
        Given:
340

341
            - A function `func`
342
            - A frame `frame` for an execution of `func`, meaning:
343
                `frame.f_code is func.__code__`
6✔
344

345
        `Source.for_frame(frame).code_qualname(frame.f_code)`
6✔
346
        will be equal to `func.__qualname__`*. Works for Python 2 as well,
6✔
347
        where of course no `__qualname__` attribute exists.
6✔
348

349
        Falls back to `code.co_name` if there is no appropriate qualname.
350

6✔
351
        Based on https://github.com/wbolster/qualname
352

6✔
353
        (* unless `func` is a lambda
354
        nested inside another lambda on the same line, in which case
6✔
355
        the outer lambda's qualname will be returned for the codes
356
        of both lambdas)
6✔
357
        """
6✔
358
        assert_(code.co_filename == self.filename)
359
        return self._qualnames.get((code.co_name, code.co_firstlineno), code.co_name)
360

361

6✔
362
class Executing(object):
6✔
363
    """
364
    Information about the operation a frame is currently executing.
6!
365

6✔
366
    Generally you will just want `node`, which is the AST node being executed,
6✔
367
    or None if it's unknown.
368

×
369
    If a decorator is currently being called, then:
370
        - `node` is a function or class definition
6✔
371
        - `decorator` is the expression in `node.decorator_list` being called
6✔
372
        - `statements == {node}`
373
    """
6✔
374

375
    def __init__(self, frame, source, node, stmts, decorator):
6✔
376
        # type: (types.FrameType, Source, EnhancedAST, Set[ast.stmt], Optional[EnhancedAST]) -> None
377
        self.frame = frame
378
        self.source = source
379
        self.node = node
380
        self.statements = stmts
381
        self.decorator = decorator
382

383
    def code_qualname(self):
384
        # type: () -> str
385
        return self.source.code_qualname(self.frame.f_code)
386

387
    def text(self):
388
        # type: () -> str
389
        return self.source._asttext_base().get_text(self.node)
390

391
    def text_range(self):
392
        # type: () -> Tuple[int, int]
393
        return self.source._asttext_base().get_text_range(self.node)
394

395

396
class QualnameVisitor(ast.NodeVisitor):
397
    def __init__(self):
398
        # type: () -> None
6✔
399
        super(QualnameVisitor, self).__init__()
6✔
400
        self.stack = [] # type: List[str]
401
        self.qualnames = {} # type: Dict[Tuple[str, int], str]
402

6✔
403
    def add_qualname(self, node, name=None):
404
        # type: (ast.AST, Optional[str]) -> None
405
        name = name or node.name # type: ignore[attr-defined]
406
        self.stack.append(name)
407
        if getattr(node, 'decorator_list', ()):
408
            lineno = node.decorator_list[0].lineno # type: ignore[attr-defined]
409
        else:
410
            lineno = node.lineno # type: ignore[attr-defined]
411
        self.qualnames.setdefault((name, lineno), ".".join(self.stack))
412

413
    def visit_FunctionDef(self, node, name=None):
414
        # type: (ast.AST, Optional[str]) -> None
415
        assert isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)), node
6✔
416
        self.add_qualname(node, name)
417
        self.stack.append('<locals>')
6✔
418
        children = [] # type: Sequence[ast.AST]
6✔
419
        if isinstance(node, ast.Lambda):
6✔
420
            children = [node.body]
6✔
421
        else:
6✔
422
            children = node.body
423
        for child in children:
6✔
424
            self.visit(child)
425
        self.stack.pop()
6✔
426
        self.stack.pop()
427

6✔
428
        # Find lambdas in the function definition outside the body,
429
        # e.g. decorators or default arguments
6✔
430
        # Based on iter_child_nodes
431
        for field, child in ast.iter_fields(node):
6✔
432
            if field == 'body':
433
                continue
6✔
434
            if isinstance(child, ast.AST):
435
                self.visit(child)
436
            elif isinstance(child, list):
6✔
437
                for grandchild in child:
6✔
438
                    if isinstance(grandchild, ast.AST):
439
                        self.visit(grandchild)
6✔
440

6✔
441
    visit_AsyncFunctionDef = visit_FunctionDef
6✔
442

443
    def visit_Lambda(self, node):
6✔
444
        # type: (ast.AST) -> None
445
        assert isinstance(node, ast.Lambda)
6✔
446
        self.visit_FunctionDef(node, '<lambda>')
6✔
447

6✔
448
    def visit_ClassDef(self, node):
6✔
449
        # type: (ast.AST) -> None
450
        assert isinstance(node, ast.ClassDef)
6✔
451
        self.add_qualname(node)
6✔
452
        self.generic_visit(node)
453
        self.stack.pop()
6✔
454

455

6✔
456

6✔
457

6✔
458

6✔
459
future_flags = sum(
6✔
460
    getattr(__future__, fname).compiler_flag for fname in __future__.all_feature_names
6✔
461
)
462

6✔
463

6✔
464
def compile_similar_to(source, matching_code):
6✔
465
    # type: (ast.Module, types.CodeType) -> Any
6✔
466
    return compile(
6✔
467
        source,
468
        matching_code.co_filename,
469
        'exec',
470
        flags=future_flags & matching_code.co_flags,
471
        dont_inherit=True,
6✔
472
    )
6✔
473

6✔
474

6✔
475
sentinel = 'io8urthglkjdghvljusketgIYRFYUVGHFRTBGVHKGF78678957647698'
6✔
476

6✔
477
def is_rewritten_by_pytest(code):
6✔
478
    # type: (types.CodeType) -> bool
6!
479
    return any(
6✔
480
        bc.opname != "LOAD_CONST" and isinstance(bc.argval,str) and bc.argval.startswith("@py")
481
        for bc in get_instructions(code)
6✔
482
    )
483

6✔
484

485
class SentinelNodeFinder(object):
6✔
486
    result = None # type: EnhancedAST
6✔
487

488
    def __init__(self, frame, stmts, tree, lasti, source):
6✔
489
        # type: (types.FrameType, Set[EnhancedAST], ast.Module, int, Source) -> None
490
        assert_(stmts)
6✔
491
        self.frame = frame
6✔
492
        self.tree = tree
6✔
493
        self.code = code = frame.f_code
6✔
494
        self.is_pytest = is_rewritten_by_pytest(code)
495

496
        if self.is_pytest:
497
            self.ignore_linenos = frozenset(assert_linenos(tree))
498
        else:
499
            self.ignore_linenos = frozenset()
6✔
500

501
        self.decorator = None
502

503
        self.instruction = instruction = self.get_actual_current_instruction(lasti)
504
        op_name = instruction.opname
6✔
505
        extra_filter = lambda e: True # type: Callable[[Any], bool]
506
        ctx = type(None) # type: Type
3✔
507

508
        typ = type(None) # type: Type
509
        if op_name.startswith('CALL_'):
510
            typ = ast.Call
511
        elif op_name.startswith(('BINARY_SUBSCR', 'SLICE+')):
512
            typ = ast.Subscript
513
            ctx = ast.Load
514
        elif op_name.startswith('BINARY_'):
515
            typ = ast.BinOp
6✔
516
            op_type = dict(
517
                BINARY_POWER=ast.Pow,
6✔
518
                BINARY_MULTIPLY=ast.Mult,
519
                BINARY_MATRIX_MULTIPLY=getattr(ast, "MatMult", ()),
6✔
520
                BINARY_FLOOR_DIVIDE=ast.FloorDiv,
521
                BINARY_TRUE_DIVIDE=ast.Div,
522
                BINARY_MODULO=ast.Mod,
523
                BINARY_ADD=ast.Add,
524
                BINARY_SUBTRACT=ast.Sub,
525
                BINARY_LSHIFT=ast.LShift,
6✔
526
                BINARY_RSHIFT=ast.RShift,
6✔
527
                BINARY_AND=ast.BitAnd,
528
                BINARY_XOR=ast.BitXor,
6✔
529
                BINARY_OR=ast.BitOr,
530
            )[op_name]
3✔
531
            extra_filter = lambda e: isinstance(e.op, op_type)
3✔
532
        elif op_name.startswith('UNARY_'):
3✔
533
            typ = ast.UnaryOp
3✔
534
            op_type = dict(
3✔
535
                UNARY_POSITIVE=ast.UAdd,
536
                UNARY_NEGATIVE=ast.USub,
3✔
537
                UNARY_NOT=ast.Not,
3✔
538
                UNARY_INVERT=ast.Invert,
539
            )[op_name]
3✔
540
            extra_filter = lambda e: isinstance(e.op, op_type)
541
        elif op_name in ('LOAD_ATTR', 'LOAD_METHOD', 'LOOKUP_METHOD'):
3✔
542
            typ = ast.Attribute
543
            ctx = ast.Load
3✔
544
            extra_filter = lambda e:mangled_name(e) == instruction.argval 
3✔
545
        elif op_name in ('LOAD_NAME', 'LOAD_GLOBAL', 'LOAD_FAST', 'LOAD_DEREF', 'LOAD_CLASSDEREF'):
3✔
546
            typ = ast.Name
3✔
547
            ctx = ast.Load
548
            extra_filter = lambda e:mangled_name(e) == instruction.argval 
3✔
549
        elif op_name in ('COMPARE_OP', 'IS_OP', 'CONTAINS_OP'):
3✔
550
            typ = ast.Compare
3✔
551
            extra_filter = lambda e: len(e.ops) == 1
3✔
552
        elif op_name.startswith(('STORE_SLICE', 'STORE_SUBSCR')):
3✔
553
            ctx = ast.Store
3✔
554
            typ = ast.Subscript
3✔
555
        elif op_name.startswith('STORE_ATTR'):
3✔
556
            ctx = ast.Store
3✔
557
            typ = ast.Attribute
558
            extra_filter = lambda e:mangled_name(e) == instruction.argval 
559
        else:
560
            raise RuntimeError(op_name)
561

562

563
        with lock:
564
            exprs = {
565
                cast(EnhancedAST, node)
566
                for stmt in stmts
567
                for node in ast.walk(stmt)
568
                if isinstance(node, typ)
569
                if isinstance(getattr(node, "ctx", None), ctx)
570
                if extra_filter(node)
571
                if statement_containing_node(node) == stmt
3✔
572
            }
3✔
573

3✔
574
            if ctx == ast.Store:
3✔
575
                # No special bytecode tricks here.
576
                # We can handle multiple assigned attributes with different names,
577
                # but only one assigned subscript.
578
                self.result = only(exprs)
579
                return
580

3✔
581
            matching = list(self.matching_nodes(exprs))
3✔
582
            if not matching and typ == ast.Call:
3✔
583
                self.find_decorator(stmts)
3✔
584
            else:
3✔
585
                self.result = only(matching)
3✔
586

3✔
587
    def find_decorator(self, stmts):
3✔
588
        # type: (Union[List[EnhancedAST], Set[EnhancedAST]]) -> None
3✔
589
        stmt = only(stmts)
3✔
590
        assert_(isinstance(stmt, (ast.ClassDef, function_node_types)))
3✔
591
        decorators = stmt.decorator_list # type: ignore[attr-defined]
3✔
592
        assert_(decorators)
3✔
593
        line_instructions = [
3✔
594
            inst
3✔
595
            for inst in self.clean_instructions(self.code)
3✔
596
            if inst.lineno == self.frame.f_lineno
3✔
597
        ]
3✔
598
        last_decorator_instruction_index = [
3✔
599
            i
600
            for i, inst in enumerate(line_instructions)
3✔
601
            if inst.opname == "CALL_FUNCTION"
602
        ][-1]
3✔
603
        assert_(
3✔
604
            line_instructions[last_decorator_instruction_index + 1].opname.startswith(
605
                "STORE_"
606
            )
607
        )
608
        decorator_instructions = line_instructions[
609
            last_decorator_instruction_index
610
            - len(decorators)
611
            + 1 : last_decorator_instruction_index
612
            + 1
613
        ]
3✔
614
        assert_({inst.opname for inst in decorator_instructions} == {"CALL_FUNCTION"})
615
        decorator_index = decorator_instructions.index(self.instruction)
616
        decorator = decorators[::-1][decorator_index]
617
        self.decorator = decorator
3✔
618
        self.result = stmt
3✔
619

620
    def clean_instructions(self, code):
3✔
621
        # type: (types.CodeType) -> List[EnhancedInstruction]
3✔
622
        return [
3✔
623
            inst
624
            for inst in get_instructions(code)
3✔
625
            if inst.opname not in ("EXTENDED_ARG", "NOP")
626
            if inst.lineno not in self.ignore_linenos
6✔
627
        ]
628

3✔
629
    def get_original_clean_instructions(self):
3✔
630
        # type: () -> List[EnhancedInstruction]
3✔
631
        result = self.clean_instructions(self.code)
3✔
632

3✔
633
        # pypy sometimes (when is not clear)
634
        # inserts JUMP_IF_NOT_DEBUG instructions in bytecode
635
        # If they're not present in our compiled instructions,
636
        # ignore them in the original bytecode
637
        if not any(
3✔
638
                inst.opname == "JUMP_IF_NOT_DEBUG"
639
                for inst in self.compile_instructions()
640
        ):
641
            result = [
642
                inst for inst in result
3✔
643
                if inst.opname != "JUMP_IF_NOT_DEBUG"
644
            ]
645

646
        return result
647

3✔
648
    def matching_nodes(self, exprs):
649
        # type: (Set[EnhancedAST]) -> Iterator[EnhancedAST]
650
        original_instructions = self.get_original_clean_instructions()
651
        original_index = only(
652
            i
653
            for i, inst in enumerate(original_instructions)
3✔
654
            if inst == self.instruction
3✔
655
        )
3✔
656
        for expr_index, expr in enumerate(exprs):
3✔
657
            setter = get_setter(expr)
3✔
658
            assert setter is not None
659
            # noinspection PyArgumentList
6✔
660
            replacement = ast.BinOp(
661
                left=expr,
3✔
662
                op=ast.Pow(),
663
                right=ast.Str(s=sentinel),
664
            )
665
            ast.fix_missing_locations(replacement)
666
            setter(replacement)
667
            try:
668
                instructions = self.compile_instructions()
6✔
669
            finally:
670
                setter(expr)
3✔
671

672
            if sys.version_info >= (3, 10):
673
                try:
674
                    handle_jumps(instructions, original_instructions)
675
                except Exception:
676
                    # Give other candidates a chance
3!
677
                    if TESTING or expr_index < len(exprs) - 1:
678
                        continue
679
                    raise
680

3✔
681
            indices = [
682
                i
683
                for i, instruction in enumerate(instructions)
684
                if instruction.argval == sentinel
685
            ]
3✔
686

687
            # There can be several indices when the bytecode is duplicated,
6✔
688
            # as happens in a finally block in 3.9+
689
            # First we remove the opcodes caused by our modifications
3✔
690
            for index_num, sentinel_index in enumerate(indices):
3✔
691
                # Adjustment for removing sentinel instructions below
692
                # in past iterations
693
                sentinel_index -= index_num * 2
694

695
                assert_(instructions.pop(sentinel_index).opname == 'LOAD_CONST')
3✔
696
                assert_(instructions.pop(sentinel_index).opname == 'BINARY_POWER')
3✔
697

3✔
698
            # Then we see if any of the instruction indices match
699
            for index_num, sentinel_index in enumerate(indices):
3✔
700
                sentinel_index -= index_num * 2
701
                new_index = sentinel_index - 1
702

703
                if new_index != original_index:
704
                    continue
3✔
705

3✔
706
                original_inst = original_instructions[original_index]
3✔
707
                new_inst = instructions[new_index]
3✔
708

709
                # In Python 3.9+, changing 'not x in y' to 'not sentinel_transformation(x in y)'
3✔
710
                # changes a CONTAINS_OP(invert=1) to CONTAINS_OP(invert=0),<sentinel stuff>,UNARY_NOT
711
                if (
3✔
712
                        original_inst.opname == new_inst.opname in ('CONTAINS_OP', 'IS_OP')
1✔
713
                        and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
1✔
714
                        and (
1✔
715
                        original_instructions[original_index + 1].opname
716
                        != instructions[new_index + 1].opname == 'UNARY_NOT'
1!
717
                )):
1✔
718
                    # Remove the difference for the upcoming assert
×
719
                    instructions.pop(new_index + 1)
720

3✔
721
                # Check that the modified instructions don't have anything unexpected
722
                # 3.10 is a bit too weird to assert this in all cases but things still work
723
                if sys.version_info < (3, 10):
724
                    for inst1, inst2 in zip_longest(
725
                        original_instructions, instructions
726
                    ):
727
                        assert_(inst1 and inst2 and opnames_match(inst1, inst2))
728

729
                yield expr
3✔
730

731
    def compile_instructions(self):
732
        # type: () -> List[EnhancedInstruction]
3✔
733
        module_code = compile_similar_to(self.tree, self.code)
734
        code = only(self.find_codes(module_code))
3✔
735
        return self.clean_instructions(code)
3✔
736

737
    def find_codes(self, root_code):
738
        # type: (types.CodeType) -> list
3✔
739
        checks = [
3✔
740
            attrgetter('co_firstlineno'),
3✔
741
            attrgetter('co_freevars'),
742
            attrgetter('co_cellvars'),
3✔
743
            lambda c: is_ipython_cell_code_name(c.co_name) or c.co_name,
3✔
744
        ] # type: List[Callable]
745
        if not self.is_pytest:
3✔
746
            checks += [
3✔
747
                attrgetter('co_names'),
748
                attrgetter('co_varnames'),
749
            ]
750

3✔
751
        def matches(c):
752
            # type: (types.CodeType) -> bool
753
            return all(
754
                f(c) == f(self.code)
755
                for f in checks
756
            )
757

758
        code_options = []
2✔
759
        if matches(root_code):
760
            code_options.append(root_code)
761

762
        def finder(code):
3✔
763
            # type: (types.CodeType) -> None
2✔
764
            for const in code.co_consts:
765
                if not inspect.iscode(const):
766
                    continue
2✔
767

768
                if matches(const):
3✔
769
                    code_options.append(const)
770
                finder(const)
6✔
771

772
        finder(root_code)
3✔
773
        return code_options
3✔
774

3✔
775
    def get_actual_current_instruction(self, lasti):
776
        # type: (int) -> EnhancedInstruction
6✔
777
        """
778
        Get the instruction corresponding to the current
3✔
779
        frame offset, skipping EXTENDED_ARG instructions
780
        """
781
        # Don't use get_original_clean_instructions
782
        # because we need the actual instructions including
783
        # EXTENDED_ARG
784
        instructions = list(get_instructions(self.code))
3✔
785
        index = only(
3✔
786
            i
787
            for i, inst in enumerate(instructions)
788
            if inst.offset == lasti
789
        )
790

3✔
791
        while True:
792
            instruction = instructions[index]
3✔
793
            if instruction.opname != "EXTENDED_ARG":
794
                return instruction
795
            index += 1
796

797

3✔
798

3✔
799
def non_sentinel_instructions(instructions, start):
3✔
800
    # type: (List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction]]
801
    """
3✔
802
    Yields (index, instruction) pairs excluding the basic
803
    instructions introduced by the sentinel transformation
3✔
804
    """
3✔
805
    skip_power = False
3✔
806
    for i, inst in islice(enumerate(instructions), start, None):
807
        if inst.argval == sentinel:
3✔
808
            assert_(inst.opname == "LOAD_CONST")
3✔
809
            skip_power = True
3✔
810
            continue
811
        elif skip_power:
3✔
812
            assert_(inst.opname == "BINARY_POWER")
3✔
813
            skip_power = False
814
            continue
6✔
815
        yield i, inst
816

817

818
def walk_both_instructions(original_instructions, original_start, instructions, start):
819
    # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Iterator[Tuple[int, EnhancedInstruction, int, EnhancedInstruction]]
820
    """
821
    Yields matching indices and instructions from the new and original instructions,
822
    leaving out changes made by the sentinel transformation.
823
    """
3✔
824
    original_iter = islice(enumerate(original_instructions), original_start, None)
3✔
825
    new_iter = non_sentinel_instructions(instructions, start)
826
    inverted_comparison = False
827
    while True:
828
        try:
829
            original_i, original_inst = next(original_iter)
830
            new_i, new_inst = next(new_iter)
831
        except StopIteration:
3✔
832
            return
3✔
833
        if (
3✔
834
            inverted_comparison
2✔
835
            and original_inst.opname != new_inst.opname == "UNARY_NOT"
836
        ):
837
            new_i, new_inst = next(new_iter)
838
        inverted_comparison = (
6✔
839
            original_inst.opname == new_inst.opname in ("CONTAINS_OP", "IS_OP")
840
            and original_inst.arg != new_inst.arg # type: ignore[attr-defined]
841
        )
842
        yield original_i, original_inst, new_i, new_inst
843

844

1✔
845
def handle_jumps(instructions, original_instructions):
1✔
846
    # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> None
1✔
847
    """
1✔
848
    Transforms instructions in place until it looks more like original_instructions.
1✔
849
    This is only needed in 3.10+ where optimisations lead to more drastic changes
1✔
850
    after the sentinel transformation.
1✔
851
    Replaces JUMP instructions that aren't also present in original_instructions
1✔
852
    with the sections that they jump to until a raise or return.
1✔
853
    In some other cases duplication found in `original_instructions`
1✔
854
    is replicated in `instructions`.
1✔
855
    """
856
    while True:
857
        for original_i, original_inst, new_i, new_inst in walk_both_instructions(
6✔
858
            original_instructions, 0, instructions, 0
859
        ):
860
            if opnames_match(original_inst, new_inst):
861
                continue
862

863
            if "JUMP" in new_inst.opname and "JUMP" not in original_inst.opname:
1✔
864
                # Find where the new instruction is jumping to, ignoring
1✔
865
                # instructions which have been copied in previous iterations
1✔
866
                start = only(
867
                    i
1✔
868
                    for i, inst in enumerate(instructions)
1✔
869
                    if inst.offset == new_inst.argval
1✔
870
                    and not getattr(inst, "_copied", False)
1✔
871
                )
1✔
872
                # Replace the jump instruction with the jumped to section of instructions
1✔
873
                # That section may also be deleted if it's not similarly duplicated
874
                # in original_instructions
875
                new_instructions = handle_jump(
876
                    original_instructions, original_i, instructions, start
1✔
877
                )
1✔
878
                assert new_instructions is not None
879
                instructions[new_i : new_i + 1] = new_instructions            
880
            else:
881
                # Extract a section of original_instructions from original_i to return/raise
1✔
882
                orig_section = []
883
                for section_inst in original_instructions[original_i:]:
884
                    orig_section.append(section_inst)
6✔
885
                    if section_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
886
                        break
887
                else:
888
                    # No return/raise - this is just a mismatch we can't handle
889
                    raise AssertionError
890

891
                instructions[new_i:new_i] = only(find_new_matching(orig_section, instructions))
892

893
            # instructions has been modified, the for loop can't sensibly continue
894
            # Restart it from the beginning, checking for other issues
895
            break
896

1✔
897
        else:  # No mismatched jumps found, we're done
898
            return
899

1✔
900

1✔
901
def find_new_matching(orig_section, instructions):
902
    # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> Iterator[List[EnhancedInstruction]]
1✔
903
    """
904
    Yields sections of `instructions` which match `orig_section`.
905
    The yielded sections include sentinel instructions, but these
1✔
906
    are ignored when checking for matches.
907
    """
908
    for start in range(len(instructions) - len(orig_section)):
909
        indices, dup_section = zip(
910
            *islice(
911
                non_sentinel_instructions(instructions, start),
912
                len(orig_section),
913
            )
914
        )
1✔
915
        if len(dup_section) < len(orig_section):
916
            return
917
        if sections_match(orig_section, dup_section):
1✔
918
            yield instructions[start:indices[-1] + 1]
1✔
919

920

921
def handle_jump(original_instructions, original_start, instructions, start):
1✔
922
    # type: (List[EnhancedInstruction], int, List[EnhancedInstruction], int) -> Optional[List[EnhancedInstruction]]
1!
923
    """
1✔
924
    Returns the section of instructions starting at `start` and ending
1✔
925
    with a RETURN_VALUE or RAISE_VARARGS instruction.
1✔
926
    There should be a matching section in original_instructions starting at original_start.
927
    If that section doesn't appear elsewhere in original_instructions,
928
    then also delete the returned section of instructions.
×
929
    """
930
    for original_j, original_inst, new_j, new_inst in walk_both_instructions(
1✔
931
        original_instructions, original_start, instructions, start
932
    ):
933
        assert_(opnames_match(original_inst, new_inst))
934
        if original_inst.opname in ("RETURN_VALUE", "RAISE_VARARGS"):
1✔
935
            inlined = deepcopy(instructions[start : new_j + 1])
936
            for inl in inlined:
937
                inl._copied = True
1✔
938
            orig_section = original_instructions[original_start : original_j + 1]
939
            if not check_duplicates(
940
                original_start, orig_section, original_instructions
6✔
941
            ):
942
                instructions[start : new_j + 1] = []
943
            return inlined
944
    
945
    return None
946

947

1✔
948
def check_duplicates(original_i, orig_section, original_instructions):
1✔
949
    # type: (int, List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
950
    """
951
    Returns True if a section of original_instructions starting somewhere other
952
    than original_i and matching orig_section is found, i.e. orig_section is duplicated.
953
    """
954
    for dup_start in range(len(original_instructions)):
1✔
955
        if dup_start == original_i:
1✔
956
            continue
1✔
957
        dup_section = original_instructions[dup_start : dup_start + len(orig_section)]
1✔
958
        if len(dup_section) < len(orig_section):
959
            return False
960
        if sections_match(orig_section, dup_section):
6✔
961
            return True
962
    
963
    return False
964

965
def sections_match(orig_section, dup_section):
966
    # type: (List[EnhancedInstruction], List[EnhancedInstruction]) -> bool
967
    """
968
    Returns True if the given lists of instructions have matching linenos and opnames.
969
    """
1!
970
    return all(
971
        (
972
            orig_inst.lineno == dup_inst.lineno
1✔
973
            # POP_BLOCKs have been found to have differing linenos in innocent cases
1✔
974
            or "POP_BLOCK" == orig_inst.opname == dup_inst.opname
1✔
975
        )
1✔
976
        and opnames_match(orig_inst, dup_inst)
1✔
977
        for orig_inst, dup_inst in zip(orig_section, dup_section)
1✔
978
    )
1✔
979

980

981
def opnames_match(inst1, inst2):
1✔
982
    # type: (Instruction, Instruction) -> bool
1✔
983
    return (
984
        inst1.opname == inst2.opname
×
985
        or "JUMP" in inst1.opname
986
        and "JUMP" in inst2.opname
987
        or (inst1.opname == "PRINT_EXPR" and inst2.opname == "POP_TOP")
6✔
988
        or (
989
            inst1.opname in ("LOAD_METHOD", "LOOKUP_METHOD")
990
            and inst2.opname == "LOAD_ATTR"
991
        )
992
        or (inst1.opname == "CALL_METHOD" and inst2.opname == "CALL_FUNCTION")
993
    )
1!
994

1✔
995

1✔
996
def get_setter(node):
1✔
997
    # type: (EnhancedAST) -> Optional[Callable[[ast.AST], None]]
1✔
998
    parent = node.parent
1✔
999
    for name, field in ast.iter_fields(parent):
1✔
1000
        if field is node:
1✔
1001
            def setter(new_node):
1002
                # type: (ast.AST) -> None
×
1003
                return setattr(parent, name, new_node)
1004
            return setter
6✔
1005
        elif isinstance(field, list):
1006
            for i, item in enumerate(field):
1007
                if item is node:
1008
                    def setter(new_node):
1009
                        # type: (ast.AST) -> None
1✔
1010
                        field[i] = new_node
1011

1012
                    return setter
1013
    return None
1014

1015
lock = RLock()
1016

1017

1018
@cache
1019
def statement_containing_node(node):
1020
    # type: (ast.AST) -> EnhancedAST
6✔
1021
    while not isinstance(node, ast.stmt):
1022
        node = cast(EnhancedAST, node).parent
3✔
1023
    return cast(EnhancedAST, node)
1024

1025

1026
def assert_linenos(tree):
1027
    # type: (ast.AST) -> Iterator[int]
1028
    for node in ast.walk(tree):
1029
        if (
1030
                hasattr(node, 'parent') and
1031
                isinstance(statement_containing_node(node), ast.Assert)
1032
        ):
1033
            for lineno in node_linenos(node):
1034
                yield lineno
1035

6✔
1036

1037
def _extract_ipython_statement(stmt):
3✔
1038
    # type: (EnhancedAST) -> ast.Module
3!
1039
    # IPython separates each statement in a cell to be executed separately
3✔
1040
    # So NodeFinder should only compile one statement at a time or it
3✔
1041
    # will find a code mismatch.
1042
    while not isinstance(stmt.parent, ast.Module):
3✔
1043
        stmt = stmt.parent
3✔
1044
    # use `ast.parse` instead of `ast.Module` for better portability
3✔
1045
    # python3.8 changes the signature of `ast.Module`
3✔
1046
    # Inspired by https://github.com/pallets/werkzeug/pull/1552/files
3✔
1047
    tree = ast.parse("")
3✔
1048
    tree.body = [cast(ast.stmt, stmt)]
1049
    ast.copy_location(tree, stmt)
3✔
1050
    return tree
1051

3✔
1052

×
1053
def is_ipython_cell_code_name(code_name):
1054
    # type: (str) -> bool
6✔
1055
    return bool(re.match(r"(<module>|<cell line: \d+>)$", code_name))
1056

1057

6✔
1058
def is_ipython_cell_filename(filename):
6✔
1059
    # type: (str) -> bool
1060
    return bool(re.search(r"<ipython-input-|[/\\]ipykernel_\d+[/\\]", filename))
6✔
1061

6✔
1062

6✔
1063
def is_ipython_cell_code(code_obj):
1064
    # type: (types.CodeType) -> bool
1065
    return (
6✔
1066
        is_ipython_cell_filename(code_obj.co_filename) and
1067
        is_ipython_cell_code_name(code_obj.co_name)
3✔
1068
    )
3✔
1069

1070

1071
def find_node_ipython(frame, lasti, stmts, source):
1072
    # type: (types.FrameType, int, Set[EnhancedAST], Source) -> Tuple[Optional[Any], Optional[Any]]
3✔
1073
    node = decorator = None
3✔
1074
    for stmt in stmts:
1075
        tree = _extract_ipython_statement(stmt)
1076
        try:
6✔
1077
            node_finder = NodeFinder(frame, stmts, tree, lasti, source)
1078
            if (node or decorator) and (node_finder.result or node_finder.decorator):
1079
                # Found potential nodes in separate statements,
1080
                # cannot resolve ambiguity, give up here
1081
                return None, None
6!
1082

×
1083
            node = node_finder.result
1084
            decorator = node_finder.decorator
1085
        except Exception:
1086
            pass
6✔
1087
    return decorator, node
6✔
1088

6✔
1089

6✔
1090

1091
def node_linenos(node):
1092
    # type: (ast.AST) -> Iterator[int]
6✔
1093
    if hasattr(node, "lineno"):
1094
        linenos = [] # type: Sequence[int]
6✔
1095
        if hasattr(node, "end_lineno") and isinstance(node, ast.expr):
1096
            assert node.end_lineno is not None # type: ignore[attr-defined]
1097
            linenos = range(node.lineno, node.end_lineno + 1) # type: ignore[attr-defined]
6✔
1098
        else:
1099
            linenos = [node.lineno] # type: ignore[attr-defined]
6✔
1100
        for lineno in linenos:
1101
            yield lineno
1102

6✔
1103

1104
if sys.version_info >= (3, 11):
6✔
1105
    from ._position_node_finder import PositionNodeFinder as NodeFinder
1106
else:
1107
    NodeFinder = SentinelNodeFinder
1108

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc