• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pantsbuild / pants / 21803785359

08 Feb 2026 07:13PM UTC coverage: 43.3% (-37.0%) from 80.277%
21803785359

Pull #23085

github

web-flow
Merge 7c1cd926d into 40389cc58
Pull Request #23085: A helper method for indexing paths by source root

2 of 6 new or added lines in 1 file covered. (33.33%)

17114 existing lines in 539 files now uncovered.

26075 of 60219 relevant lines covered (43.3%)

0.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.67
/src/python/pants/engine/internals/rule_visitor.py
1
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
2
# Licensed under the Apache License, Version 2.0 (see LICENSE).
3
from __future__ import annotations
1✔
4

5
import ast
1✔
6
import inspect
1✔
7
import itertools
1✔
8
import logging
1✔
9
import sys
1✔
10
from collections.abc import Callable, Iterator, Sequence
1✔
11
from contextlib import contextmanager
1✔
12
from dataclasses import dataclass
1✔
13
from types import ModuleType
1✔
14
from typing import Any, get_type_hints
1✔
15

16
import typing_extensions
1✔
17

18
from pants.base.exceptions import RuleTypeError
1✔
19
from pants.engine.internals.selectors import (
1✔
20
    AwaitableConstraints,
21
    concurrently,
22
)
23
from pants.util.memo import memoized
1✔
24
from pants.util.strutil import softwrap
1✔
25
from pants.util.typing import patch_forward_ref
1✔
26

27
logger = logging.getLogger(__name__)
1✔
28
patch_forward_ref()
1✔
29

30

31
def _get_starting_indent(source: str) -> int:
1✔
32
    """Used to remove leading indentation from `source` so ast.parse() doesn't raise an
33
    exception."""
34
    if source.startswith(" "):
1✔
35
        return sum(1 for _ in itertools.takewhile(lambda c: c in {" ", b" "}, source))
1✔
36
    return 0
1✔
37

38

39
def _node_str(node: Any) -> str:
1✔
40
    if isinstance(node, ast.Name):
×
41
        return node.id
×
42
    if isinstance(node, ast.Attribute):
×
43
        return ".".join([_node_str(node.value), node.attr])
×
44
    if isinstance(node, ast.Call):
×
45
        return _node_str(node.func)
×
46
    if sys.version_info[0:2] < (3, 8):
×
47
        if isinstance(node, ast.Str):
×
48
            return node.s
×
49
    else:
50
        if isinstance(node, ast.Constant):
×
51
            return str(node.value)
×
52
    return str(node)
×
53

54

55
PANTS_RULE_DESCRIPTORS_MODULE_KEY = "__pants_rule_descriptors__"
1✔
56

57

58
@dataclass(frozen=True)
1✔
59
class RuleDescriptor:
1✔
60
    """The data we glean about a rule by examining its AST.
61

62
    This will be lazily invoked in the first `@rule` decorator in a module. Therefore it will parse
63
    the AST *before* the module code is fully evaluated, and so the return type may not yet exist as
64
    a parsed type. So we store it here as a str and look it up later.
65
    """
66

67
    module_name: str
1✔
68
    rule_name: str
1✔
69
    return_type: str
1✔
70

71
    @property
1✔
72
    def rule_id(self) -> str:
1✔
73
        # TODO: Handle canonical_name/canonical_name_suffix?
74
        return f"{self.module_name}.{self.rule_name}"
1✔
75

76

77
def get_module_scope_rules(module: ModuleType) -> tuple[RuleDescriptor, ...]:
1✔
78
    """Get descriptors for @rules defined at the top level of the given module.
79

80
    We discover these top-level rules and rule helpers in the module by examining the AST.
81
    This means that while executing the `@rule` decorator of a rule1(), the descriptor of a rule2()
82
    defined later in the module is already known.  This allows rule1() and rule2() to be
83
    mutually recursive.
84

85
    Note that we don't support recursive rules defined dynamically in inner scopes.
86
    """
87
    descriptors = getattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, None)
1✔
88
    if descriptors is None:
1✔
89
        descriptors = []
1✔
90
        for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module))):
1✔
91
            if isinstance(node, ast.AsyncFunctionDef) and isinstance(node.returns, ast.Name):
1✔
92
                descriptors.append(RuleDescriptor(module.__name__, node.name, node.returns.id))
1✔
93
        descriptors = tuple(descriptors)
1✔
94
        setattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, descriptors)
1✔
95

96
    return descriptors
1✔
97

98

99
class _TypeStack:
1✔
100
    """The types and rules that a @rule can refer to in its input/outputs, or its awaitables.
101

102
    We construct this data through a mix of inspection of types already parsed by Python,
103
    and descriptors we infer from the AST. This allows us to support mutual recursion between
104
    rules defined in the same module (the @rule descriptor of the earlier rule can know enough
105
    about the later rule it calls to set up its own awaitables correctly).
106

107
    This logic is necessarily heuristic. It works for well-behaved code, but may be defeated
108
    by metaprogramming, aliasing, shadowing and so on.
109
    """
110

111
    def __init__(self, func: Callable) -> None:
1✔
112
        self._stack: list[dict[str, Any]] = []
1✔
113
        self.root = sys.modules[func.__module__]
1✔
114

115
        # We fall back to descriptors last, so that we get parsed objects whenever possible,
116
        # as those are less susceptible to limitations of the heuristics.
117
        self.push({descr.rule_name: descr for descr in get_module_scope_rules(self.root)})
1✔
118
        self.push(self.root)
1✔
119
        self._push_function_closures(func)
1✔
120
        # Rule args will be pushed later, as we handle them.
121

122
    def __getitem__(self, name: str) -> Any:
1✔
123
        for ns in reversed(self._stack):
1✔
124
            if name in ns:
1✔
125
                return ns[name]
1✔
126
        return self.root.__builtins__.get(name, None)
1✔
127

128
    def __setitem__(self, name: str, value: Any) -> None:
1✔
129
        self._stack[-1][name] = value
1✔
130

131
    def _push_function_closures(self, func: Callable) -> None:
1✔
132
        try:
1✔
133
            closurevars = [c for c in inspect.getclosurevars(func) if isinstance(c, dict)]
1✔
134
        except ValueError:
1✔
135
            return
1✔
136

137
        for closures in closurevars:
1✔
138
            self.push(closures)
1✔
139

140
    def push(self, frame: object) -> None:
1✔
141
        ns = dict(frame if isinstance(frame, dict) else frame.__dict__)
1✔
142
        self._stack.append(ns)
1✔
143

144
    def pop(self) -> None:
1✔
145
        assert len(self._stack) > 1
1✔
146
        self._stack.pop()
1✔
147

148

149
def _lookup_annotation(obj: Any, attr: str) -> Any:
1✔
150
    """Get type associated with a particular attribute on object. This can get hairy, especially on
151
    Python <3.10.
152

153
    https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
154
    """
155
    if hasattr(obj, attr):
1✔
156
        return getattr(obj, attr)
1✔
157
    else:
158
        try:
1✔
159
            return get_type_hints(obj).get(attr)
1✔
160
        except (NameError, TypeError):
1✔
161
            return None
1✔
162

163

164
def _lookup_return_type(func: Callable, check: bool = False) -> Any:
1✔
165
    ret = _lookup_annotation(func, "return")
1✔
166
    typ = typing_extensions.get_origin(ret)
1✔
167
    if isinstance(typ, type):
1✔
168
        args = typing_extensions.get_args(ret)
1✔
169
        if issubclass(typ, (list, set, tuple)):
1✔
170
            return tuple(args)
1✔
171
    if check and ret is None:
1✔
172
        func_file = inspect.getsourcefile(func)
×
173
        func_line = func.__code__.co_firstlineno
×
174
        raise TypeError(
×
175
            f"Failed to look up return type hint for `{func.__name__}` in {func_file}:{func_line}"
176
        )
177
    return ret
1✔
178

179

180
class _AwaitableCollector(ast.NodeVisitor):
1✔
181
    def __init__(self, func: Callable):
1✔
182
        self.func = func
1✔
183
        source = inspect.getsource(func) or "<string>"
1✔
184
        beginning_indent = _get_starting_indent(source)
1✔
185
        if beginning_indent:
1✔
186
            source = "\n".join(line[beginning_indent:] for line in source.split("\n"))
1✔
187

188
        self.source_file = inspect.getsourcefile(func) or "<unknown>"
1✔
189

190
        self.types = _TypeStack(func)
1✔
191
        self.awaitables: list[AwaitableConstraints] = []
1✔
192
        self.visit(ast.parse(source))
1✔
193

194
    def _format(self, node: ast.AST, msg: str) -> str:
1✔
UNCOV
195
        lineno: str = "<unknown>"
×
UNCOV
196
        if isinstance(node, (ast.expr, ast.stmt)):
×
UNCOV
197
            lineno = str(node.lineno + self.func.__code__.co_firstlineno - 1)
×
UNCOV
198
        return f"{self.source_file}:{lineno}: {msg}"
×
199

200
    def _lookup(self, attr: ast.expr) -> Any:
1✔
201
        names = []
1✔
202
        while isinstance(attr, ast.Attribute):
1✔
203
            names.append(attr.attr)
1✔
204
            attr = attr.value
1✔
205
        # NB: attr could be a constant, like `",".join()`
206
        id = getattr(attr, "id", None)
1✔
207
        if id is not None:
1✔
208
            names.append(id)
1✔
209

210
        if not names:
1✔
211
            return attr
1✔
212

213
        name = names.pop()
1✔
214
        result = self.types[name]
1✔
215
        while result is not None and names:
1✔
216
            result = _lookup_annotation(result, names.pop())
1✔
217
        return result
1✔
218

219
    def _missing_type_error(self, node: ast.AST, context: str) -> str:
1✔
220
        mod = self.types.root.__name__
×
221
        return self._format(
×
222
            node,
223
            softwrap(
224
                f"""
225
                Could not resolve type for `{_node_str(node)}` in module {mod}.
226

227
                {context}
228
                """
229
            ),
230
        )
231

232
    def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type:
1✔
233
        if resolved is None:
1✔
234
            raise RuleTypeError(
×
235
                self._missing_type_error(
236
                    node, context="This may be a limitation of the Pants rule type inference."
237
                )
238
            )
239
        elif not isinstance(resolved, type):
1✔
240
            raise RuleTypeError(
×
241
                self._format(
242
                    node,
243
                    f"Expected a type, but got: {type(resolved).__name__} {_node_str(resolved)!r}",
244
                )
245
            )
246
        return resolved
1✔
247

248
    def _get_inputs(self, input_nodes: Sequence[Any]) -> tuple[Sequence[Any], list[Any]]:
1✔
249
        if not input_nodes:
1✔
250
            return input_nodes, []
1✔
251
        if len(input_nodes) != 1:
1✔
252
            return input_nodes, [self._lookup(input_nodes[0])]
×
253

254
        input_constructor = input_nodes[0]
1✔
255
        if isinstance(input_constructor, ast.Call):
1✔
256
            cls_or_func = self._lookup(input_constructor.func)
1✔
257
            try:
1✔
258
                type_ = (
1✔
259
                    _lookup_return_type(cls_or_func, check=True)
260
                    if not isinstance(cls_or_func, type)
261
                    else cls_or_func
262
                )
263
            except TypeError as e:
×
264
                raise RuleTypeError(self._missing_type_error(input_constructor, str(e))) from e
×
265
            return [input_constructor.func], [type_]
1✔
266
        elif isinstance(input_constructor, ast.Dict):
1✔
267
            return input_constructor.values, [self._lookup(v) for v in input_constructor.values]
1✔
268
        else:
269
            return input_nodes, [self._lookup(n) for n in input_nodes]
1✔
270

271
    def _get_byname_awaitable(
1✔
272
        self, rule_id: str, rule_func: Callable | RuleDescriptor, call_node: ast.Call
273
    ) -> AwaitableConstraints:
274
        if isinstance(rule_func, RuleDescriptor):
1✔
275
            # At this point we expect the return type to be defined, so its source code
276
            # must precede that of the rule invoking the awaitable that returns it.
277
            output_type = self.types[rule_func.return_type]
1✔
278
        else:
279
            output_type = _lookup_return_type(rule_func, check=True)
1✔
280

281
        # To support explicit positional arguments, we record the number passed positionally.
282
        # TODO: To support keyword arguments, we would additionally need to begin recording the
283
        # argument names of kwargs. But positional-only callsites can avoid those allocations.
284
        explicit_args_arity = len(call_node.args)
1✔
285

286
        input_types: tuple[type, ...]
287
        if not call_node.keywords:
1✔
288
            input_types = ()
1✔
289
        elif (
1✔
290
            len(call_node.keywords) == 1
291
            and not call_node.keywords[0].arg
292
            and isinstance(implicitly_call := call_node.keywords[0].value, ast.Call)
293
            and self._lookup(implicitly_call.func).__name__ == "implicitly"
294
        ):
295
            input_nodes, input_type_nodes = self._get_inputs(implicitly_call.args)
1✔
296
            input_types = tuple(
1✔
297
                self._check_constraint_arg_type(input_type, input_node)
298
                for input_type, input_node in zip(input_type_nodes, input_nodes)
299
            )
300
        else:
301
            explanation = self._format(
×
302
                call_node,
303
                "Expected an `**implicitly(..)` application as the only keyword input.",
304
            )
305
            raise ValueError(
×
306
                f"Invalid call. {explanation} failed in a call to {rule_id} in {self.source_file}."
307
            )
308

309
        return AwaitableConstraints(
1✔
310
            rule_id,
311
            output_type,
312
            explicit_args_arity,
313
            input_types,
314
        )
315

316
    def visit_Call(self, call_node: ast.Call) -> None:
1✔
317
        func = self._lookup(call_node.func)
1✔
318
        if func is not None:
1✔
319
            if (inspect.isfunction(func) or isinstance(func, RuleDescriptor)) and (
1✔
320
                rule_id := getattr(func, "rule_id", None)
321
            ) is not None:
322
                # Is a direct `@rule` call.
323
                self.awaitables.append(self._get_byname_awaitable(rule_id, func, call_node))
1✔
324
            elif inspect.iscoroutinefunction(func):
1✔
325
                # Is a call to a "rule helper".
326
                self.awaitables.extend(collect_awaitables(func))
1✔
327

328
        self.generic_visit(call_node)
1✔
329

330
    def visit_AsyncFunctionDef(self, rule: ast.AsyncFunctionDef) -> None:
1✔
331
        with self._visit_rule_args(rule.args):
1✔
332
            self.generic_visit(rule)
1✔
333

334
    def visit_FunctionDef(self, rule: ast.FunctionDef) -> None:
1✔
335
        with self._visit_rule_args(rule.args):
1✔
336
            self.generic_visit(rule)
1✔
337

338
    @contextmanager
1✔
339
    def _visit_rule_args(self, node: ast.arguments) -> Iterator[None]:
1✔
340
        self.types.push(
1✔
341
            {
342
                a.arg: self.types[a.annotation.id]
343
                for a in node.args
344
                if isinstance(a.annotation, ast.Name)
345
            }
346
        )
347
        try:
1✔
348
            yield
1✔
349
        finally:
350
            self.types.pop()
1✔
351

352
    def visit_Assign(self, assign_node: ast.Assign) -> None:
1✔
353
        awaitables_idx = len(self.awaitables)
1✔
354
        self.generic_visit(assign_node)
1✔
355
        collected_awaitables = self.awaitables[awaitables_idx:]
1✔
356
        value = None
1✔
357
        node: ast.AST = assign_node
1✔
358
        while True:
1✔
359
            if isinstance(node, (ast.Assign, ast.Await)):
1✔
360
                node = node.value
1✔
361
                continue
1✔
362
            if isinstance(node, ast.Call):
1✔
363
                f = self._lookup(node.func)
1✔
364
                if f is concurrently:
1✔
365
                    value = tuple(get.output_type for get in collected_awaitables)
1✔
366
                elif f is not None:
1✔
367
                    value = _lookup_return_type(f)
1✔
368
            elif isinstance(node, (ast.Name, ast.Attribute)):
1✔
369
                value = self._lookup(node)
1✔
370
            break
1✔
371

372
        for tgt in assign_node.targets:
1✔
373
            if isinstance(tgt, ast.Name):
1✔
374
                names = [tgt.id]
1✔
375
                values = [value]
1✔
376
            elif isinstance(tgt, ast.Tuple):
1✔
377
                names = [el.id for el in tgt.elts if isinstance(el, ast.Name)]
1✔
378
                values = value or itertools.cycle([None])  # type: ignore[assignment]
1✔
379
            else:
380
                # subscript, etc..
381
                continue
1✔
382
            try:
1✔
383
                for name, value in zip(names, values):
1✔
384
                    self.types[name] = value
1✔
UNCOV
385
            except TypeError as e:
×
UNCOV
386
                logger.debug(
×
387
                    self._format(
388
                        node,
389
                        softwrap(
390
                            f"""
391
                            Rule visitor failed to inspect assignment expression for
392
                            {names} - {values}:
393

394
                            {e}
395
                            """
396
                        ),
397
                    )
398
                )
399

400

401
@memoized
1✔
402
def collect_awaitables(func: Callable) -> list[AwaitableConstraints]:
1✔
403
    return _AwaitableCollector(func).awaitables
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc