• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pantsbuild / pants / 26260209689

21 May 2026 11:59PM UTC coverage: 75.453% (-15.7%) from 91.156%
26260209689

Pull #23365

github

web-flow
Merge 5fe873b58 into 7ea655ba0
Pull Request #23365: uv.lock -> pex optimization

5 of 16 new or added lines in 1 file covered. (31.25%)

10118 existing lines in 378 files now uncovered.

54669 of 72454 relevant lines covered (75.45%)

2.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.94
/src/python/pants/engine/internals/rule_visitor.py
1
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
2
# Licensed under the Apache License, Version 2.0 (see LICENSE).
3
from __future__ import annotations
5✔
4

5
import ast
5✔
6
import inspect
5✔
7
import itertools
5✔
8
import logging
5✔
9
import sys
5✔
10
from collections.abc import Callable, Iterator, Sequence
5✔
11
from contextlib import contextmanager
5✔
12
from dataclasses import dataclass
5✔
13
from types import ModuleType
5✔
14
from typing import Any, get_type_hints
5✔
15

16
import typing_extensions
5✔
17

18
from pants.base.exceptions import RuleTypeError
5✔
19
from pants.engine.internals.native_engine import RuleCallTrampoline
5✔
20
from pants.engine.internals.selectors import (
5✔
21
    AwaitableConstraints,
22
    concurrently,
23
)
24
from pants.util.memo import memoized
5✔
25
from pants.util.strutil import softwrap
5✔
26

27
logger = logging.getLogger(__name__)
5✔
28

29

30
def _get_starting_indent(source: str) -> int:
5✔
31
    """Used to remove leading indentation from `source` so ast.parse() doesn't raise an
32
    exception."""
33
    if source.startswith(" "):
5✔
34
        return sum(1 for _ in itertools.takewhile(lambda c: c in {" ", b" "}, source))
5✔
35
    return 0
5✔
36

37

38
def _node_str(node: Any) -> str:
5✔
39
    if isinstance(node, ast.Name):
×
40
        return node.id
×
41
    if isinstance(node, ast.Attribute):
×
42
        return ".".join([_node_str(node.value), node.attr])
×
43
    if isinstance(node, ast.Call):
×
44
        return _node_str(node.func)
×
45
    if sys.version_info[0:2] < (3, 8):
×
46
        if isinstance(node, ast.Str):
×
47
            return node.s
×
48
    else:
49
        if isinstance(node, ast.Constant):
×
50
            return str(node.value)
×
51
    return str(node)
×
52

53

54
PANTS_RULE_DESCRIPTORS_MODULE_KEY = "__pants_rule_descriptors__"
5✔
55

56

57
@dataclass(frozen=True)
5✔
58
class RuleDescriptor:
5✔
59
    """The data we glean about a rule by examining its AST.
60

61
    This will be lazily invoked in the first `@rule` decorator in a module. Therefore it will parse
62
    the AST *before* the module code is fully evaluated, and so the return type may not yet exist as
63
    a parsed type. So we store it here as a str and look it up later.
64
    """
65

66
    module_name: str
5✔
67
    rule_name: str
5✔
68
    return_type: str
5✔
69

70
    @property
5✔
71
    def rule_id(self) -> str:
5✔
72
        # TODO: Handle canonical_name/canonical_name_suffix?
73
        return f"{self.module_name}.{self.rule_name}"
5✔
74

75

76
def get_module_scope_rules(module: ModuleType) -> tuple[RuleDescriptor, ...]:
5✔
77
    """Get descriptors for @rules defined at the top level of the given module.
78

79
    We discover these top-level rules and rule helpers in the module by examining the AST.
80
    This means that while executing the `@rule` decorator of a rule1(), the descriptor of a rule2()
81
    defined later in the module is already known.  This allows rule1() and rule2() to be
82
    mutually recursive.
83

84
    Note that we don't support recursive rules defined dynamically in inner scopes.
85
    """
86
    descriptors = getattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, None)
5✔
87
    if descriptors is None:
5✔
88
        descriptors = []
5✔
89
        for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module))):
5✔
90
            if isinstance(node, ast.AsyncFunctionDef) and isinstance(node.returns, ast.Name):
5✔
91
                descriptors.append(RuleDescriptor(module.__name__, node.name, node.returns.id))
5✔
92
        descriptors = tuple(descriptors)
5✔
93
        setattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, descriptors)
5✔
94

95
    return descriptors
5✔
96

97

98
class _TypeStack:
5✔
99
    """The types and rules that a @rule can refer to in its input/outputs, or its awaitables.
100

101
    We construct this data through a mix of inspection of types already parsed by Python,
102
    and descriptors we infer from the AST. This allows us to support mutual recursion between
103
    rules defined in the same module (the @rule descriptor of the earlier rule can know enough
104
    about the later rule it calls to set up its own awaitables correctly).
105

106
    This logic is necessarily heuristic. It works for well-behaved code, but may be defeated
107
    by metaprogramming, aliasing, shadowing and so on.
108
    """
109

110
    def __init__(self, func: Callable) -> None:
5✔
111
        self._stack: list[dict[str, Any]] = []
5✔
112
        self.root = sys.modules[func.__module__]
5✔
113

114
        # We fall back to descriptors last, so that we get parsed objects whenever possible,
115
        # as those are less susceptible to limitations of the heuristics.
116
        self.push({descr.rule_name: descr for descr in get_module_scope_rules(self.root)})
5✔
117
        self.push(self.root)
5✔
118
        self._push_function_closures(func)
5✔
119
        # Rule args will be pushed later, as we handle them.
120

121
    def __getitem__(self, name: str) -> Any:
5✔
122
        for ns in reversed(self._stack):
5✔
123
            if name in ns:
5✔
124
                return ns[name]
5✔
125
        return self.root.__builtins__.get(name, None)
5✔
126

127
    def __setitem__(self, name: str, value: Any) -> None:
5✔
128
        self._stack[-1][name] = value
5✔
129

130
    def _push_function_closures(self, func: Callable) -> None:
5✔
131
        try:
5✔
132
            closurevars = [c for c in inspect.getclosurevars(func) if isinstance(c, dict)]
5✔
133
        except ValueError:
5✔
134
            return
5✔
135

136
        for closures in closurevars:
5✔
137
            self.push(closures)
5✔
138

139
    def push(self, frame: object) -> None:
5✔
140
        ns = dict(frame if isinstance(frame, dict) else frame.__dict__)
5✔
141
        self._stack.append(ns)
5✔
142

143
    def pop(self) -> None:
5✔
144
        assert len(self._stack) > 1
5✔
145
        self._stack.pop()
5✔
146

147

148
def _lookup_annotation(obj: Any, attr: str) -> Any:
5✔
149
    """Get type associated with a particular attribute on object. This can get hairy, especially on
150
    Python <3.10.
151

152
    https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
153
    """
154
    if hasattr(obj, attr):
5✔
155
        return getattr(obj, attr)
5✔
156
    else:
157
        try:
5✔
158
            return get_type_hints(obj).get(attr)
5✔
159
        except (NameError, TypeError):
5✔
160
            return None
5✔
161

162

163
def _lookup_return_type(func: Callable, check: bool = False) -> Any:
5✔
164
    ret = _lookup_annotation(func, "return")
5✔
165
    typ = typing_extensions.get_origin(ret)
5✔
166
    if isinstance(typ, type):
5✔
167
        args = typing_extensions.get_args(ret)
5✔
168
        if issubclass(typ, (list, set, tuple)):
5✔
169
            return tuple(args)
5✔
170
    if check and ret is None:
5✔
171
        func_file = inspect.getsourcefile(func)
×
172
        func_line = func.__code__.co_firstlineno
×
173
        raise TypeError(
×
174
            f"Failed to look up return type hint for `{func.__name__}` in {func_file}:{func_line}"
175
        )
176
    return ret
5✔
177

178

179
class _AwaitableCollector(ast.NodeVisitor):
5✔
180
    def __init__(self, func: Callable):
5✔
181
        # `func` may be a RuleCallTrampoline (the return value of an `@rule`-decorated
182
        # function). `inspect.getsource` and friends only know about real Python functions,
183
        # so follow `__wrapped__` to reach the underlying implementation.
184
        if isinstance(func, RuleCallTrampoline):
5✔
UNCOV
185
            func = func.__wrapped__
×
186
        self.func = func
5✔
187
        source = inspect.getsource(func) or "<string>"
5✔
188
        beginning_indent = _get_starting_indent(source)
5✔
189
        if beginning_indent:
5✔
190
            source = "\n".join(line[beginning_indent:] for line in source.split("\n"))
5✔
191

192
        self.source_file = inspect.getsourcefile(func) or "<unknown>"
5✔
193

194
        self.types = _TypeStack(func)
5✔
195
        self.awaitables: list[AwaitableConstraints] = []
5✔
196
        self.visit(ast.parse(source))
5✔
197

198
    def _format(self, node: ast.AST, msg: str) -> str:
5✔
199
        lineno: str = "<unknown>"
3✔
200
        if isinstance(node, (ast.expr, ast.stmt)):
3✔
201
            lineno = str(node.lineno + self.func.__code__.co_firstlineno - 1)
3✔
202
        return f"{self.source_file}:{lineno}: {msg}"
3✔
203

204
    def _lookup(self, attr: ast.expr) -> Any:
5✔
205
        names = []
5✔
206
        while isinstance(attr, ast.Attribute):
5✔
207
            names.append(attr.attr)
5✔
208
            attr = attr.value
5✔
209
        # NB: attr could be a constant, like `",".join()`
210
        id = getattr(attr, "id", None)
5✔
211
        if id is not None:
5✔
212
            names.append(id)
5✔
213

214
        if not names:
5✔
215
            return attr
5✔
216

217
        name = names.pop()
5✔
218
        result = self.types[name]
5✔
219
        while result is not None and names:
5✔
220
            result = _lookup_annotation(result, names.pop())
5✔
221
        return result
5✔
222

223
    def _missing_type_error(self, node: ast.AST, context: str) -> str:
5✔
224
        mod = self.types.root.__name__
×
225
        return self._format(
×
226
            node,
227
            softwrap(
228
                f"""
229
                Could not resolve type for `{_node_str(node)}` in module {mod}.
230

231
                {context}
232
                """
233
            ),
234
        )
235

236
    def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type:
5✔
237
        if resolved is None:
5✔
238
            raise RuleTypeError(
×
239
                self._missing_type_error(
240
                    node, context="This may be a limitation of the Pants rule type inference."
241
                )
242
            )
243
        elif not isinstance(resolved, type):
5✔
244
            raise RuleTypeError(
×
245
                self._format(
246
                    node,
247
                    f"Expected a type, but got: {type(resolved).__name__} {_node_str(resolved)!r}",
248
                )
249
            )
250
        return resolved
5✔
251

252
    def _get_inputs(self, input_nodes: Sequence[Any]) -> tuple[Sequence[Any], list[Any]]:
5✔
253
        if not input_nodes:
5✔
254
            return input_nodes, []
5✔
255
        if len(input_nodes) != 1:
5✔
256
            return input_nodes, [self._lookup(input_nodes[0])]
×
257

258
        input_constructor = input_nodes[0]
5✔
259
        if isinstance(input_constructor, ast.Call):
5✔
260
            cls_or_func = self._lookup(input_constructor.func)
5✔
261
            try:
5✔
262
                type_ = (
5✔
263
                    _lookup_return_type(cls_or_func, check=True)
264
                    if not isinstance(cls_or_func, type)
265
                    else cls_or_func
266
                )
267
            except TypeError as e:
×
268
                raise RuleTypeError(self._missing_type_error(input_constructor, str(e))) from e
×
269
            return [input_constructor.func], [type_]
5✔
270
        elif isinstance(input_constructor, ast.Dict):
5✔
271
            return input_constructor.values, [self._lookup(v) for v in input_constructor.values]
5✔
272
        else:
273
            return input_nodes, [self._lookup(n) for n in input_nodes]
5✔
274

275
    def _get_byname_awaitable(
5✔
276
        self, rule_id: str, rule_func: Callable | RuleDescriptor, call_node: ast.Call
277
    ) -> AwaitableConstraints:
278
        if isinstance(rule_func, RuleDescriptor):
5✔
279
            # At this point we expect the return type to be defined, so its source code
280
            # must precede that of the rule invoking the awaitable that returns it.
281
            output_type = self.types[rule_func.return_type]
5✔
282
        else:
283
            output_type = _lookup_return_type(rule_func, check=True)
5✔
284

285
        # To support explicit positional arguments, we record the number passed positionally.
286
        # TODO: To support keyword arguments, we would additionally need to begin recording the
287
        # argument names of kwargs. But positional-only callsites can avoid those allocations.
288
        explicit_args_arity = len(call_node.args)
5✔
289

290
        input_types: tuple[type, ...]
291
        if not call_node.keywords:
5✔
292
            input_types = ()
5✔
293
        elif (
5✔
294
            len(call_node.keywords) == 1
295
            and not call_node.keywords[0].arg
296
            and isinstance(implicitly_call := call_node.keywords[0].value, ast.Call)
297
            and self._lookup(implicitly_call.func).__name__ == "implicitly"
298
        ):
299
            input_nodes, input_type_nodes = self._get_inputs(implicitly_call.args)
5✔
300
            input_types = tuple(
5✔
301
                self._check_constraint_arg_type(input_type, input_node)
302
                for input_type, input_node in zip(input_type_nodes, input_nodes)
303
            )
304
        else:
305
            explanation = self._format(
×
306
                call_node,
307
                "Expected an `**implicitly(..)` application as the only keyword input.",
308
            )
309
            raise ValueError(
×
310
                f"Invalid call. {explanation} failed in a call to {rule_id} in {self.source_file}."
311
            )
312

313
        return AwaitableConstraints(
5✔
314
            rule_id,
315
            output_type,
316
            explicit_args_arity,
317
            input_types,
318
        )
319

320
    def visit_Call(self, call_node: ast.Call) -> None:
5✔
321
        func = self._lookup(call_node.func)
5✔
322
        if func is not None:
5✔
323
            if (
5✔
324
                inspect.isfunction(func) or isinstance(func, (RuleDescriptor, RuleCallTrampoline))
325
            ) and (rule_id := getattr(func, "rule_id", None)) is not None:
326
                # Is a direct `@rule` call.
327
                self.awaitables.append(self._get_byname_awaitable(rule_id, func, call_node))
5✔
328
            elif inspect.iscoroutinefunction(func):
5✔
329
                # Is a call to a "rule helper".
330
                self.awaitables.extend(collect_awaitables(func))
5✔
331

332
        self.generic_visit(call_node)
5✔
333

334
    def visit_AsyncFunctionDef(self, rule: ast.AsyncFunctionDef) -> None:
5✔
335
        with self._visit_rule_args(rule.args):
5✔
336
            self.generic_visit(rule)
5✔
337

338
    def visit_FunctionDef(self, rule: ast.FunctionDef) -> None:
5✔
339
        with self._visit_rule_args(rule.args):
5✔
340
            self.generic_visit(rule)
5✔
341

342
    @contextmanager
5✔
343
    def _visit_rule_args(self, node: ast.arguments) -> Iterator[None]:
5✔
344
        self.types.push(
5✔
345
            {
346
                a.arg: self.types[a.annotation.id]
347
                for a in node.args
348
                if isinstance(a.annotation, ast.Name)
349
            }
350
        )
351
        try:
5✔
352
            yield
5✔
353
        finally:
354
            self.types.pop()
5✔
355

356
    def visit_Assign(self, assign_node: ast.Assign) -> None:
5✔
357
        awaitables_idx = len(self.awaitables)
5✔
358
        self.generic_visit(assign_node)
5✔
359
        collected_awaitables = self.awaitables[awaitables_idx:]
5✔
360
        value = None
5✔
361
        node: ast.AST = assign_node
5✔
362
        while True:
5✔
363
            if isinstance(node, (ast.Assign, ast.Await)):
5✔
364
                node = node.value
5✔
365
                continue
5✔
366
            if isinstance(node, ast.Call):
5✔
367
                f = self._lookup(node.func)
5✔
368
                if f is concurrently:
5✔
369
                    value = tuple(get.output_type for get in collected_awaitables)
5✔
370
                elif f is not None:
5✔
371
                    value = _lookup_return_type(f)
5✔
372
            elif isinstance(node, (ast.Name, ast.Attribute)):
5✔
373
                value = self._lookup(node)
5✔
374
            break
5✔
375

376
        for tgt in assign_node.targets:
5✔
377
            if isinstance(tgt, ast.Name):
5✔
378
                names = [tgt.id]
5✔
379
                values = [value]
5✔
380
            elif isinstance(tgt, ast.Tuple):
5✔
381
                names = [el.id for el in tgt.elts if isinstance(el, ast.Name)]
5✔
382
                values = value or itertools.cycle([None])  # type: ignore[assignment]
5✔
383
            else:
384
                # subscript, etc..
385
                continue
5✔
386
            try:
5✔
387
                for name, value in zip(names, values):
5✔
388
                    self.types[name] = value
5✔
389
            except TypeError as e:
3✔
390
                logger.debug(
3✔
391
                    self._format(
392
                        node,
393
                        softwrap(
394
                            f"""
395
                            Rule visitor failed to inspect assignment expression for
396
                            {names} - {values}:
397

398
                            {e}
399
                            """
400
                        ),
401
                    )
402
                )
403

404

405
@memoized
5✔
406
def collect_awaitables(func: Callable) -> list[AwaitableConstraints]:
5✔
407
    return _AwaitableCollector(func).awaitables
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc