• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pantsbuild / pants / 25441711719

06 May 2026 02:31PM UTC coverage: 92.915%. Remained the same
25441711719

push

github

web-flow
use sha pin (with comment) format for generated actions (#23312)

Per the GitHub Action best practices we recently enabled at #23249, we
should pin each action to a SHA so that the reference is actually
immutable.

This will -- I hope -- knock out a large chunk of the 421 alerts we
currently get from zizmor. The next followup would then be upgrades and
harmonizing the generated and none-generated pins.

Notice: This idea was suggested by Claude while going over pinact output
and I was surprised to see that post processing the yaml wasn't too
gross.

92206 of 99237 relevant lines covered (92.91%)

4.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.38
/src/python/pants/engine/internals/rule_visitor.py
1
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
2
# Licensed under the Apache License, Version 2.0 (see LICENSE).
3
from __future__ import annotations
12✔
4

5
import ast
12✔
6
import inspect
12✔
7
import itertools
12✔
8
import logging
12✔
9
import sys
12✔
10
from collections.abc import Callable, Iterator, Sequence
12✔
11
from contextlib import contextmanager
12✔
12
from dataclasses import dataclass
12✔
13
from types import ModuleType
12✔
14
from typing import Any, get_type_hints
12✔
15

16
import typing_extensions
12✔
17

18
from pants.base.exceptions import RuleTypeError
12✔
19
from pants.engine.internals.native_engine import RuleCallTrampoline
12✔
20
from pants.engine.internals.selectors import (
12✔
21
    AwaitableConstraints,
22
    concurrently,
23
)
24
from pants.util.memo import memoized
12✔
25
from pants.util.strutil import softwrap
12✔
26

27
logger = logging.getLogger(__name__)
12✔
28

29

30
def _get_starting_indent(source: str) -> int:
12✔
31
    """Used to remove leading indentation from `source` so ast.parse() doesn't raise an
32
    exception."""
33
    if source.startswith(" "):
12✔
34
        return sum(1 for _ in itertools.takewhile(lambda c: c in {" ", b" "}, source))
12✔
35
    return 0
12✔
36

37

38
def _node_str(node: Any) -> str:
12✔
39
    if isinstance(node, ast.Name):
×
40
        return node.id
×
41
    if isinstance(node, ast.Attribute):
×
42
        return ".".join([_node_str(node.value), node.attr])
×
43
    if isinstance(node, ast.Call):
×
44
        return _node_str(node.func)
×
45
    if sys.version_info[0:2] < (3, 8):
×
46
        if isinstance(node, ast.Str):
×
47
            return node.s
×
48
    else:
49
        if isinstance(node, ast.Constant):
×
50
            return str(node.value)
×
51
    return str(node)
×
52

53

54
PANTS_RULE_DESCRIPTORS_MODULE_KEY = "__pants_rule_descriptors__"
12✔
55

56

57
@dataclass(frozen=True)
12✔
58
class RuleDescriptor:
12✔
59
    """The data we glean about a rule by examining its AST.
60

61
    This will be lazily invoked in the first `@rule` decorator in a module. Therefore it will parse
62
    the AST *before* the module code is fully evaluated, and so the return type may not yet exist as
63
    a parsed type. So we store it here as a str and look it up later.
64
    """
65

66
    module_name: str
12✔
67
    rule_name: str
12✔
68
    return_type: str
12✔
69

70
    @property
12✔
71
    def rule_id(self) -> str:
12✔
72
        # TODO: Handle canonical_name/canonical_name_suffix?
73
        return f"{self.module_name}.{self.rule_name}"
12✔
74

75

76
def get_module_scope_rules(module: ModuleType) -> tuple[RuleDescriptor, ...]:
12✔
77
    """Get descriptors for @rules defined at the top level of the given module.
78

79
    We discover these top-level rules and rule helpers in the module by examining the AST.
80
    This means that while executing the `@rule` decorator of a rule1(), the descriptor of a rule2()
81
    defined later in the module is already known.  This allows rule1() and rule2() to be
82
    mutually recursive.
83

84
    Note that we don't support recursive rules defined dynamically in inner scopes.
85
    """
86
    descriptors = getattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, None)
12✔
87
    if descriptors is None:
12✔
88
        descriptors = []
12✔
89
        for node in ast.iter_child_nodes(ast.parse(inspect.getsource(module))):
12✔
90
            if isinstance(node, ast.AsyncFunctionDef) and isinstance(node.returns, ast.Name):
12✔
91
                descriptors.append(RuleDescriptor(module.__name__, node.name, node.returns.id))
12✔
92
        descriptors = tuple(descriptors)
12✔
93
        setattr(module, PANTS_RULE_DESCRIPTORS_MODULE_KEY, descriptors)
12✔
94

95
    return descriptors
12✔
96

97

98
class _TypeStack:
12✔
99
    """The types and rules that a @rule can refer to in its input/outputs, or its awaitables.
100

101
    We construct this data through a mix of inspection of types already parsed by Python,
102
    and descriptors we infer from the AST. This allows us to support mutual recursion between
103
    rules defined in the same module (the @rule descriptor of the earlier rule can know enough
104
    about the later rule it calls to set up its own awaitables correctly).
105

106
    This logic is necessarily heuristic. It works for well-behaved code, but may be defeated
107
    by metaprogramming, aliasing, shadowing and so on.
108
    """
109

110
    def __init__(self, func: Callable) -> None:
12✔
111
        self._stack: list[dict[str, Any]] = []
12✔
112
        self.root = sys.modules[func.__module__]
12✔
113

114
        # We fall back to descriptors last, so that we get parsed objects whenever possible,
115
        # as those are less susceptible to limitations of the heuristics.
116
        self.push({descr.rule_name: descr for descr in get_module_scope_rules(self.root)})
12✔
117
        self.push(self.root)
12✔
118
        self._push_function_closures(func)
12✔
119
        # Rule args will be pushed later, as we handle them.
120

121
    def __getitem__(self, name: str) -> Any:
12✔
122
        for ns in reversed(self._stack):
12✔
123
            if name in ns:
12✔
124
                return ns[name]
12✔
125
        return self.root.__builtins__.get(name, None)
12✔
126

127
    def __setitem__(self, name: str, value: Any) -> None:
12✔
128
        self._stack[-1][name] = value
12✔
129

130
    def _push_function_closures(self, func: Callable) -> None:
12✔
131
        try:
12✔
132
            closurevars = [c for c in inspect.getclosurevars(func) if isinstance(c, dict)]
12✔
133
        except ValueError:
12✔
134
            return
12✔
135

136
        for closures in closurevars:
12✔
137
            self.push(closures)
12✔
138

139
    def push(self, frame: object) -> None:
12✔
140
        ns = dict(frame if isinstance(frame, dict) else frame.__dict__)
12✔
141
        self._stack.append(ns)
12✔
142

143
    def pop(self) -> None:
12✔
144
        assert len(self._stack) > 1
12✔
145
        self._stack.pop()
12✔
146

147

148
def _lookup_annotation(obj: Any, attr: str) -> Any:
12✔
149
    """Get type associated with a particular attribute on object. This can get hairy, especially on
150
    Python <3.10.
151

152
    https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
153
    """
154
    if hasattr(obj, attr):
12✔
155
        return getattr(obj, attr)
12✔
156
    else:
157
        try:
12✔
158
            return get_type_hints(obj).get(attr)
12✔
159
        except (NameError, TypeError):
12✔
160
            return None
12✔
161

162

163
def _lookup_return_type(func: Callable, check: bool = False) -> Any:
12✔
164
    ret = _lookup_annotation(func, "return")
12✔
165
    typ = typing_extensions.get_origin(ret)
12✔
166
    if isinstance(typ, type):
12✔
167
        args = typing_extensions.get_args(ret)
12✔
168
        if issubclass(typ, (list, set, tuple)):
12✔
169
            return tuple(args)
12✔
170
    if check and ret is None:
12✔
171
        func_file = inspect.getsourcefile(func)
×
172
        func_line = func.__code__.co_firstlineno
×
173
        raise TypeError(
×
174
            f"Failed to look up return type hint for `{func.__name__}` in {func_file}:{func_line}"
175
        )
176
    return ret
12✔
177

178

179
class _AwaitableCollector(ast.NodeVisitor):
12✔
180
    def __init__(self, func: Callable):
12✔
181
        # `func` may be a RuleCallTrampoline (the return value of an `@rule`-decorated
182
        # function). `inspect.getsource` and friends only know about real Python functions,
183
        # so follow `__wrapped__` to reach the underlying implementation.
184
        if isinstance(func, RuleCallTrampoline):
12✔
185
            func = func.__wrapped__
1✔
186
        self.func = func
12✔
187
        source = inspect.getsource(func) or "<string>"
12✔
188
        beginning_indent = _get_starting_indent(source)
12✔
189
        if beginning_indent:
12✔
190
            source = "\n".join(line[beginning_indent:] for line in source.split("\n"))
12✔
191

192
        self.source_file = inspect.getsourcefile(func) or "<unknown>"
12✔
193

194
        self.types = _TypeStack(func)
12✔
195
        self.awaitables: list[AwaitableConstraints] = []
12✔
196
        self.visit(ast.parse(source))
12✔
197

198
    def _format(self, node: ast.AST, msg: str) -> str:
12✔
199
        lineno: str = "<unknown>"
5✔
200
        if isinstance(node, (ast.expr, ast.stmt)):
5✔
201
            lineno = str(node.lineno + self.func.__code__.co_firstlineno - 1)
5✔
202
        return f"{self.source_file}:{lineno}: {msg}"
5✔
203

204
    def _lookup(self, attr: ast.expr) -> Any:
12✔
205
        names = []
12✔
206
        while isinstance(attr, ast.Attribute):
12✔
207
            names.append(attr.attr)
12✔
208
            attr = attr.value
12✔
209
        # NB: attr could be a constant, like `",".join()`
210
        id = getattr(attr, "id", None)
12✔
211
        if id is not None:
12✔
212
            names.append(id)
12✔
213

214
        if not names:
12✔
215
            return attr
12✔
216

217
        name = names.pop()
12✔
218
        result = self.types[name]
12✔
219
        while result is not None and names:
12✔
220
            result = _lookup_annotation(result, names.pop())
12✔
221
        return result
12✔
222

223
    def _missing_type_error(self, node: ast.AST, context: str) -> str:
12✔
224
        mod = self.types.root.__name__
×
225
        return self._format(
×
226
            node,
227
            softwrap(
228
                f"""
229
                Could not resolve type for `{_node_str(node)}` in module {mod}.
230

231
                {context}
232
                """
233
            ),
234
        )
235

236
    def _check_constraint_arg_type(self, resolved: Any, node: ast.AST) -> type:
12✔
237
        if resolved is None:
12✔
238
            raise RuleTypeError(
×
239
                self._missing_type_error(
240
                    node, context="This may be a limitation of the Pants rule type inference."
241
                )
242
            )
243
        elif not isinstance(resolved, type):
12✔
244
            raise RuleTypeError(
×
245
                self._format(
246
                    node,
247
                    f"Expected a type, but got: {type(resolved).__name__} {_node_str(resolved)!r}",
248
                )
249
            )
250
        return resolved
12✔
251

252
    def _get_inputs(self, input_nodes: Sequence[Any]) -> tuple[Sequence[Any], list[Any]]:
12✔
253
        if not input_nodes:
12✔
254
            return input_nodes, []
12✔
255
        if len(input_nodes) != 1:
12✔
256
            return input_nodes, [self._lookup(input_nodes[0])]
×
257

258
        input_constructor = input_nodes[0]
12✔
259
        if isinstance(input_constructor, ast.Call):
12✔
260
            cls_or_func = self._lookup(input_constructor.func)
12✔
261
            try:
12✔
262
                type_ = (
12✔
263
                    _lookup_return_type(cls_or_func, check=True)
264
                    if not isinstance(cls_or_func, type)
265
                    else cls_or_func
266
                )
267
            except TypeError as e:
×
268
                raise RuleTypeError(self._missing_type_error(input_constructor, str(e))) from e
×
269
            return [input_constructor.func], [type_]
12✔
270
        elif isinstance(input_constructor, ast.Dict):
12✔
271
            return input_constructor.values, [self._lookup(v) for v in input_constructor.values]
12✔
272
        else:
273
            return input_nodes, [self._lookup(n) for n in input_nodes]
12✔
274

275
    def _get_byname_awaitable(
12✔
276
        self, rule_id: str, rule_func: Callable | RuleDescriptor, call_node: ast.Call
277
    ) -> AwaitableConstraints:
278
        if isinstance(rule_func, RuleDescriptor):
12✔
279
            # At this point we expect the return type to be defined, so its source code
280
            # must precede that of the rule invoking the awaitable that returns it.
281
            output_type = self.types[rule_func.return_type]
12✔
282
        else:
283
            output_type = _lookup_return_type(rule_func, check=True)
12✔
284

285
        # To support explicit positional arguments, we record the number passed positionally.
286
        # TODO: To support keyword arguments, we would additionally need to begin recording the
287
        # argument names of kwargs. But positional-only callsites can avoid those allocations.
288
        explicit_args_arity = len(call_node.args)
12✔
289

290
        input_types: tuple[type, ...]
291
        if not call_node.keywords:
12✔
292
            input_types = ()
12✔
293
        elif (
12✔
294
            len(call_node.keywords) == 1
295
            and not call_node.keywords[0].arg
296
            and isinstance(implicitly_call := call_node.keywords[0].value, ast.Call)
297
            and self._lookup(implicitly_call.func).__name__ == "implicitly"
298
        ):
299
            input_nodes, input_type_nodes = self._get_inputs(implicitly_call.args)
12✔
300
            input_types = tuple(
12✔
301
                self._check_constraint_arg_type(input_type, input_node)
302
                for input_type, input_node in zip(input_type_nodes, input_nodes)
303
            )
304
        else:
305
            explanation = self._format(
×
306
                call_node,
307
                "Expected an `**implicitly(..)` application as the only keyword input.",
308
            )
309
            raise ValueError(
×
310
                f"Invalid call. {explanation} failed in a call to {rule_id} in {self.source_file}."
311
            )
312

313
        return AwaitableConstraints(
12✔
314
            rule_id,
315
            output_type,
316
            explicit_args_arity,
317
            input_types,
318
        )
319

320
    def visit_Call(self, call_node: ast.Call) -> None:
12✔
321
        func = self._lookup(call_node.func)
12✔
322
        if func is not None:
12✔
323
            if (
12✔
324
                inspect.isfunction(func) or isinstance(func, (RuleDescriptor, RuleCallTrampoline))
325
            ) and (rule_id := getattr(func, "rule_id", None)) is not None:
326
                # Is a direct `@rule` call.
327
                self.awaitables.append(self._get_byname_awaitable(rule_id, func, call_node))
12✔
328
            elif inspect.iscoroutinefunction(func):
12✔
329
                # Is a call to a "rule helper".
330
                self.awaitables.extend(collect_awaitables(func))
12✔
331

332
        self.generic_visit(call_node)
12✔
333

334
    def visit_AsyncFunctionDef(self, rule: ast.AsyncFunctionDef) -> None:
12✔
335
        with self._visit_rule_args(rule.args):
12✔
336
            self.generic_visit(rule)
12✔
337

338
    def visit_FunctionDef(self, rule: ast.FunctionDef) -> None:
12✔
339
        with self._visit_rule_args(rule.args):
12✔
340
            self.generic_visit(rule)
12✔
341

342
    @contextmanager
12✔
343
    def _visit_rule_args(self, node: ast.arguments) -> Iterator[None]:
12✔
344
        self.types.push(
12✔
345
            {
346
                a.arg: self.types[a.annotation.id]
347
                for a in node.args
348
                if isinstance(a.annotation, ast.Name)
349
            }
350
        )
351
        try:
12✔
352
            yield
12✔
353
        finally:
354
            self.types.pop()
12✔
355

356
    def visit_Assign(self, assign_node: ast.Assign) -> None:
12✔
357
        awaitables_idx = len(self.awaitables)
12✔
358
        self.generic_visit(assign_node)
12✔
359
        collected_awaitables = self.awaitables[awaitables_idx:]
12✔
360
        value = None
12✔
361
        node: ast.AST = assign_node
12✔
362
        while True:
12✔
363
            if isinstance(node, (ast.Assign, ast.Await)):
12✔
364
                node = node.value
12✔
365
                continue
12✔
366
            if isinstance(node, ast.Call):
12✔
367
                f = self._lookup(node.func)
12✔
368
                if f is concurrently:
12✔
369
                    value = tuple(get.output_type for get in collected_awaitables)
12✔
370
                elif f is not None:
12✔
371
                    value = _lookup_return_type(f)
12✔
372
            elif isinstance(node, (ast.Name, ast.Attribute)):
12✔
373
                value = self._lookup(node)
12✔
374
            break
12✔
375

376
        for tgt in assign_node.targets:
12✔
377
            if isinstance(tgt, ast.Name):
12✔
378
                names = [tgt.id]
12✔
379
                values = [value]
12✔
380
            elif isinstance(tgt, ast.Tuple):
12✔
381
                names = [el.id for el in tgt.elts if isinstance(el, ast.Name)]
12✔
382
                values = value or itertools.cycle([None])  # type: ignore[assignment]
12✔
383
            else:
384
                # subscript, etc..
385
                continue
12✔
386
            try:
12✔
387
                for name, value in zip(names, values):
12✔
388
                    self.types[name] = value
12✔
389
            except TypeError as e:
5✔
390
                logger.debug(
5✔
391
                    self._format(
392
                        node,
393
                        softwrap(
394
                            f"""
395
                            Rule visitor failed to inspect assignment expression for
396
                            {names} - {values}:
397

398
                            {e}
399
                            """
400
                        ),
401
                    )
402
                )
403

404

405
@memoized
12✔
406
def collect_awaitables(func: Callable) -> list[AwaitableConstraints]:
12✔
407
    return _AwaitableCollector(func).awaitables
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc