• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pantsbuild / pants / 22740642519

05 Mar 2026 11:00PM UTC coverage: 52.677% (-40.3%) from 92.931%
22740642519

Pull #23157

github

web-flow
Merge 2aa18e6d4 into f0030f5e7
Pull Request #23157: [pants ng] Partition source files by config.

31678 of 60136 relevant lines covered (52.68%)

0.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.55
/src/python/pants/engine/internals/rule_visitor.py
1
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
2
# Licensed under the Apache License, Version 2.0 (see LICENSE).
3
from __future__ import annotations
1✔
4

5
import ast
1✔
6
import inspect
1✔
7
import itertools
1✔
8
import logging
1✔
9
import sys
1✔
10
from collections.abc import Callable, Iterator, Sequence
1✔
11
from contextlib import contextmanager
1✔
12
from dataclasses import dataclass
1✔
13
from types import ModuleType
1✔
14
from typing import Any, get_type_hints
1✔
15

16
import typing_extensions
1✔
17

18
from pants.base.exceptions import RuleTypeError
1✔
19
from pants.engine.internals.selectors import (
1✔
20
    AwaitableConstraints,
21
    concurrently,
22
)
23
from pants.util.memo import memoized
1✔
24
from pants.util.strutil import softwrap
1✔
25

26
logger = logging.getLogger(__name__)
1✔
27

28

29
def _get_starting_indent(source: str) -> int:
1✔
30
    """Used to remove leading indentation from `source` so ast.parse() doesn't raise an
31
    exception."""
32
    if source.startswith(" "):
1✔
33
        return sum(1 for _ in itertools.takewhile(lambda c: c in {" ", b" "}, source))
1✔
34
    return 0
1✔
35

36

37
def _node_str(node: Any) -> str:
1✔
38
    if isinstance(node, ast.Name):
×
39
        return node.id
×
40
    if isinstance(node, ast.Attribute):
×
41
        return ".".join([_node_str(node.value), node.attr])
×
42
    if isinstance(node, ast.Call):
×
43
        return _node_str(node.func)
×
44
    if sys.version_info[0:2] < (3, 8):
×
45
        if isinstance(node, ast.Str):
×
46
            return node.s
×
47
    else:
48
        if isinstance(node, ast.Constant):
×
49
            return str(node.value)
×
50
    return str(node)
×
51

52

53
PANTS_RULE_DESCRIPTORS_MODULE_KEY = "__pants_rule_descriptors__"
1✔
54

55

56
@dataclass(frozen=True)
1✔
57
class RuleDescriptor:
1✔
58
    """The data we glean about a rule by examining its AST.
59

60
    This will be lazily invoked in the first `@rule` decorator in a module. Therefore it will parse
61
    the AST *before* the module code is fully evaluated, and so the return type may not yet exist as
62
    a parsed type. So we store it here as a str and look it up later.
63
    """
64

65
    module_name: str
1✔
66
    rule_name: str
1✔
67
    return_type: str
1✔
68

69
    @property
1✔
70
    def rule_id(self) -> str:
1✔
71
        # TODO: Handle canonical_name/canonical_name_suffix?
72
        return f"{self.module_name}.{self.rule_name}"
1✔
73

74

75
def get_module_scope_rules(module: ModuleType) -> tuple[RuleDescriptor, ...]:
1✔
76
    """Get descriptors for @rules defined at the top level of the given module.
77

78
    We discover these top-level rules and rule helpers in the module by examining the AST.
79
    This means that while executing the `@rule` decorator of a rule1(), the descriptor of a rule2()
80
    defined later in the module is already known.  This allows rule1() and rule2() to be
81
    mutually recursive.
82

83
    Note that we don't support recursive rules defined dynamically in inner scopes.
84
    """
85
    descriptors = getattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, None)
1✔
86
    if descriptors is None:
1✔
87
        descriptors = []
1✔
88
        for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module))):
1✔
89
            if isinstance(node, ast.AsyncFunctionDef) and isinstance(node.returns, ast.Name):
1✔
90
                descriptors.append(RuleDescriptor(module.__name__, node.name, node.returns.id))
1✔
91
        descriptors = tuple(descriptors)
1✔
92
        setattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, descriptors)
1✔
93

94
    return descriptors
1✔
95

96

97
class _TypeStack:
1✔
98
    """The types and rules that a @rule can refer to in its input/outputs, or its awaitables.
99

100
    We construct this data through a mix of inspection of types already parsed by Python,
101
    and descriptors we infer from the AST. This allows us to support mutual recursion between
102
    rules defined in the same module (the @rule descriptor of the earlier rule can know enough
103
    about the later rule it calls to set up its own awaitables correctly).
104

105
    This logic is necessarily heuristic. It works for well-behaved code, but may be defeated
106
    by metaprogramming, aliasing, shadowing and so on.
107
    """
108

109
    def __init__(self, func: Callable) -> None:
1✔
110
        self._stack: list[dict[str, Any]] = []
1✔
111
        self.root = sys.modules[func.__module__]
1✔
112

113
        # We fall back to descriptors last, so that we get parsed objects whenever possible,
114
        # as those are less susceptible to limitations of the heuristics.
115
        self.push({descr.rule_name: descr for descr in get_module_scope_rules(self.root)})
1✔
116
        self.push(self.root)
1✔
117
        self._push_function_closures(func)
1✔
118
        # Rule args will be pushed later, as we handle them.
119

120
    def __getitem__(self, name: str) -> Any:
1✔
121
        for ns in reversed(self._stack):
1✔
122
            if name in ns:
1✔
123
                return ns[name]
1✔
124
        return self.root.__builtins__.get(name, None)
1✔
125

126
    def __setitem__(self, name: str, value: Any) -> None:
1✔
127
        self._stack[-1][name] = value
1✔
128

129
    def _push_function_closures(self, func: Callable) -> None:
1✔
130
        try:
1✔
131
            closurevars = [c for c in inspect.getclosurevars(func) if isinstance(c, dict)]
1✔
132
        except ValueError:
1✔
133
            return
1✔
134

135
        for closures in closurevars:
1✔
136
            self.push(closures)
1✔
137

138
    def push(self, frame: object) -> None:
1✔
139
        ns = dict(frame if isinstance(frame, dict) else frame.__dict__)
1✔
140
        self._stack.append(ns)
1✔
141

142
    def pop(self) -> None:
1✔
143
        assert len(self._stack) > 1
1✔
144
        self._stack.pop()
1✔
145

146

147
def _lookup_annotation(obj: Any, attr: str) -> Any:
1✔
148
    """Get type associated with a particular attribute on object. This can get hairy, especially on
149
    Python <3.10.
150

151
    https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
152
    """
153
    if hasattr(obj, attr):
1✔
154
        return getattr(obj, attr)
1✔
155
    else:
156
        try:
1✔
157
            return get_type_hints(obj).get(attr)
1✔
158
        except (NameError, TypeError):
1✔
159
            return None
1✔
160

161

162
def _lookup_return_type(func: Callable, check: bool = False) -> Any:
1✔
163
    ret = _lookup_annotation(func, "return")
1✔
164
    typ = typing_extensions.get_origin(ret)
1✔
165
    if isinstance(typ, type):
1✔
166
        args = typing_extensions.get_args(ret)
1✔
167
        if issubclass(typ, (list, set, tuple)):
1✔
168
            return tuple(args)
1✔
169
    if check and ret is None:
1✔
170
        func_file = inspect.getsourcefile(func)
×
171
        func_line = func.__code__.co_firstlineno
×
172
        raise TypeError(
×
173
            f"Failed to look up return type hint for `{func.__name__}` in {func_file}:{func_line}"
174
        )
175
    return ret
1✔
176

177

178
class _AwaitableCollector(ast.NodeVisitor):
1✔
179
    def __init__(self, func: Callable):
1✔
180
        self.func = func
1✔
181
        source = inspect.getsource(func) or "<string>"
1✔
182
        beginning_indent = _get_starting_indent(source)
1✔
183
        if beginning_indent:
1✔
184
            source = "\n".join(line[beginning_indent:] for line in source.split("\n"))
1✔
185

186
        self.source_file = inspect.getsourcefile(func) or "<unknown>"
1✔
187

188
        self.types = _TypeStack(func)
1✔
189
        self.awaitables: list[AwaitableConstraints] = []
1✔
190
        self.visit(ast.parse(source))
1✔
191

192
    def _format(self, node: ast.AST, msg: str) -> str:
1✔
193
        lineno: str = "<unknown>"
×
194
        if isinstance(node, (ast.expr, ast.stmt)):
×
195
            lineno = str(node.lineno + self.func.__code__.co_firstlineno - 1)
×
196
        return f"{self.source_file}:{lineno}: {msg}"
×
197

198
    def _lookup(self, attr: ast.expr) -> Any:
1✔
199
        names = []
1✔
200
        while isinstance(attr, ast.Attribute):
1✔
201
            names.append(attr.attr)
1✔
202
            attr = attr.value
1✔
203
        # NB: attr could be a constant, like `",".join()`
204
        id = getattr(attr, "id", None)
1✔
205
        if id is not None:
1✔
206
            names.append(id)
1✔
207

208
        if not names:
1✔
209
            return attr
1✔
210

211
        name = names.pop()
1✔
212
        result = self.types[name]
1✔
213
        while result is not None and names:
1✔
214
            result = _lookup_annotation(result, names.pop())
1✔
215
        return result
1✔
216

217
    def _missing_type_error(self, node: ast.AST, context: str) -> str:
1✔
218
        mod = self.types.root.__name__
×
219
        return self._format(
×
220
            node,
221
            softwrap(
222
                f"""
223
                Could not resolve type for `{_node_str(node)}` in module {mod}.
224

225
                {context}
226
                """
227
            ),
228
        )
229

230
    def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type:
1✔
231
        if resolved is None:
1✔
232
            raise RuleTypeError(
×
233
                self._missing_type_error(
234
                    node, context="This may be a limitation of the Pants rule type inference."
235
                )
236
            )
237
        elif not isinstance(resolved, type):
1✔
238
            raise RuleTypeError(
×
239
                self._format(
240
                    node,
241
                    f"Expected a type, but got: {type(resolved).__name__} {_node_str(resolved)!r}",
242
                )
243
            )
244
        return resolved
1✔
245

246
    def _get_inputs(self, input_nodes: Sequence[Any]) -> tuple[Sequence[Any], list[Any]]:
1✔
247
        if not input_nodes:
1✔
248
            return input_nodes, []
1✔
249
        if len(input_nodes) != 1:
1✔
250
            return input_nodes, [self._lookup(input_nodes[0])]
×
251

252
        input_constructor = input_nodes[0]
1✔
253
        if isinstance(input_constructor, ast.Call):
1✔
254
            cls_or_func = self._lookup(input_constructor.func)
1✔
255
            try:
1✔
256
                type_ = (
1✔
257
                    _lookup_return_type(cls_or_func, check=True)
258
                    if not isinstance(cls_or_func, type)
259
                    else cls_or_func
260
                )
261
            except TypeError as e:
×
262
                raise RuleTypeError(self._missing_type_error(input_constructor, str(e))) from e
×
263
            return [input_constructor.func], [type_]
1✔
264
        elif isinstance(input_constructor, ast.Dict):
1✔
265
            return input_constructor.values, [self._lookup(v) for v in input_constructor.values]
1✔
266
        else:
267
            return input_nodes, [self._lookup(n) for n in input_nodes]
1✔
268

269
    def _get_byname_awaitable(
1✔
270
        self, rule_id: str, rule_func: Callable | RuleDescriptor, call_node: ast.Call
271
    ) -> AwaitableConstraints:
272
        if isinstance(rule_func, RuleDescriptor):
1✔
273
            # At this point we expect the return type to be defined, so its source code
274
            # must precede that of the rule invoking the awaitable that returns it.
275
            output_type = self.types[rule_func.return_type]
1✔
276
        else:
277
            output_type = _lookup_return_type(rule_func, check=True)
1✔
278

279
        # To support explicit positional arguments, we record the number passed positionally.
280
        # TODO: To support keyword arguments, we would additionally need to begin recording the
281
        # argument names of kwargs. But positional-only callsites can avoid those allocations.
282
        explicit_args_arity = len(call_node.args)
1✔
283

284
        input_types: tuple[type, ...]
285
        if not call_node.keywords:
1✔
286
            input_types = ()
1✔
287
        elif (
1✔
288
            len(call_node.keywords) == 1
289
            and not call_node.keywords[0].arg
290
            and isinstance(implicitly_call := call_node.keywords[0].value, ast.Call)
291
            and self._lookup(implicitly_call.func).__name__ == "implicitly"
292
        ):
293
            input_nodes, input_type_nodes = self._get_inputs(implicitly_call.args)
1✔
294
            input_types = tuple(
1✔
295
                self._check_constraint_arg_type(input_type, input_node)
296
                for input_type, input_node in zip(input_type_nodes, input_nodes)
297
            )
298
        else:
299
            explanation = self._format(
×
300
                call_node,
301
                "Expected an `**implicitly(..)` application as the only keyword input.",
302
            )
303
            raise ValueError(
×
304
                f"Invalid call. {explanation} failed in a call to {rule_id} in {self.source_file}."
305
            )
306

307
        return AwaitableConstraints(
1✔
308
            rule_id,
309
            output_type,
310
            explicit_args_arity,
311
            input_types,
312
        )
313

314
    def visit_Call(self, call_node: ast.Call) -> None:
1✔
315
        func = self._lookup(call_node.func)
1✔
316
        if func is not None:
1✔
317
            if (inspect.isfunction(func) or isinstance(func, RuleDescriptor)) and (
1✔
318
                rule_id := getattr(func, "rule_id", None)
319
            ) is not None:
320
                # Is a direct `@rule` call.
321
                self.awaitables.append(self._get_byname_awaitable(rule_id, func, call_node))
1✔
322
            elif inspect.iscoroutinefunction(func):
1✔
323
                # Is a call to a "rule helper".
324
                self.awaitables.extend(collect_awaitables(func))
1✔
325

326
        self.generic_visit(call_node)
1✔
327

328
    def visit_AsyncFunctionDef(self, rule: ast.AsyncFunctionDef) -> None:
1✔
329
        with self._visit_rule_args(rule.args):
1✔
330
            self.generic_visit(rule)
1✔
331

332
    def visit_FunctionDef(self, rule: ast.FunctionDef) -> None:
1✔
333
        with self._visit_rule_args(rule.args):
1✔
334
            self.generic_visit(rule)
1✔
335

336
    @contextmanager
1✔
337
    def _visit_rule_args(self, node: ast.arguments) -> Iterator[None]:
1✔
338
        self.types.push(
1✔
339
            {
340
                a.arg: self.types[a.annotation.id]
341
                for a in node.args
342
                if isinstance(a.annotation, ast.Name)
343
            }
344
        )
345
        try:
1✔
346
            yield
1✔
347
        finally:
348
            self.types.pop()
1✔
349

350
    def visit_Assign(self, assign_node: ast.Assign) -> None:
1✔
351
        awaitables_idx = len(self.awaitables)
1✔
352
        self.generic_visit(assign_node)
1✔
353
        collected_awaitables = self.awaitables[awaitables_idx:]
1✔
354
        value = None
1✔
355
        node: ast.AST = assign_node
1✔
356
        while True:
1✔
357
            if isinstance(node, (ast.Assign, ast.Await)):
1✔
358
                node = node.value
1✔
359
                continue
1✔
360
            if isinstance(node, ast.Call):
1✔
361
                f = self._lookup(node.func)
1✔
362
                if f is concurrently:
1✔
363
                    value = tuple(get.output_type for get in collected_awaitables)
1✔
364
                elif f is not None:
1✔
365
                    value = _lookup_return_type(f)
1✔
366
            elif isinstance(node, (ast.Name, ast.Attribute)):
1✔
367
                value = self._lookup(node)
1✔
368
            break
1✔
369

370
        for tgt in assign_node.targets:
1✔
371
            if isinstance(tgt, ast.Name):
1✔
372
                names = [tgt.id]
1✔
373
                values = [value]
1✔
374
            elif isinstance(tgt, ast.Tuple):
1✔
375
                names = [el.id for el in tgt.elts if isinstance(el, ast.Name)]
1✔
376
                values = value or itertools.cycle([None])  # type: ignore[assignment]
1✔
377
            else:
378
                # subscript, etc..
379
                continue
1✔
380
            try:
1✔
381
                for name, value in zip(names, values):
1✔
382
                    self.types[name] = value
1✔
383
            except TypeError as e:
×
384
                logger.debug(
×
385
                    self._format(
386
                        node,
387
                        softwrap(
388
                            f"""
389
                            Rule visitor failed to inspect assignment expression for
390
                            {names} - {values}:
391

392
                            {e}
393
                            """
394
                        ),
395
                    )
396
                )
397

398

399
@memoized
1✔
400
def collect_awaitables(func: Callable) -> list[AwaitableConstraints]:
1✔
401
    return _AwaitableCollector(func).awaitables
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc