• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pantsbuild / pants / 19381742489

15 Nov 2025 12:52AM UTC coverage: 49.706% (-30.6%) from 80.29%
19381742489

Pull #22890

github

web-flow
Merge d961abf79 into 42e1ebd41
Pull Request #22890: Updated all python subsystem constraints to 3.14

4 of 5 new or added lines in 5 files covered. (80.0%)

14659 existing lines in 485 files now uncovered.

31583 of 63540 relevant lines covered (49.71%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.41
/src/python/pants/goal/migrate_call_by_name.py
1
# Copyright 2024 Pants project contributors (see CONTRIBUTORS.md).
2
# Licensed under the Apache License, Version 2.0 (see LICENSE).
3

4
from __future__ import annotations
2✔
5

6
import importlib.util
2✔
7
import json
2✔
8
import logging
2✔
9
from collections.abc import Callable, Iterable
2✔
10
from dataclasses import dataclass
2✔
11
from functools import partial
2✔
12
from pathlib import Path, PurePath
2✔
13
from typing import TypedDict
2✔
14

15
import libcst as cst
2✔
16
import libcst.helpers as h
2✔
17
import libcst.matchers as m
2✔
18
import libcst.metadata
2✔
19
from libcst.display import dump
2✔
20

21
from pants.base.build_environment import get_buildroot
2✔
22
from pants.base.exiter import PANTS_SUCCEEDED_EXIT_CODE, ExitCode
2✔
23
from pants.base.specs import Specs
2✔
24
from pants.build_graph.build_configuration import BuildConfiguration
2✔
25
from pants.engine.fs import Paths
2✔
26
from pants.engine.unions import UnionMembership
2✔
27
from pants.goal.builtin_goal import BuiltinGoal
2✔
28
from pants.init.engine_initializer import GraphSession
2✔
29
from pants.option.option_types import BoolOption
2✔
30
from pants.option.options import Options
2✔
31
from pants.util import cstutil
2✔
32
from pants.util.strutil import softwrap
2✔
33

34
logger = logging.getLogger(__name__)
2✔
35

36

37
class MigrateCallByNameBuiltinGoal(BuiltinGoal):
2✔
38
    name = "migrate-call-by-name"
2✔
39
    help = softwrap(
2✔
40
        """
41
        Migrate from `Get` syntax to call-by-name syntax (#19730). This is a **destructive** operation,
42
        so only run this on source controlled files that you are prepared to revert if necessary.
43

44
        This goal will attempt to migrate the set of paths/targets specified at the command line
45
        if they are part of the "migration plan". This migration does not add any new files, but
46
        instead modifies existing files in-place without any formatting. The resulting changes should
47
        be reviewed, tested, and formatted/linted before committing.
48

49
        The migration plan is a JSON representation of the rule graph, which is generated by the
50
        engine based on the active backends/rules in the project.
51

52
        Each item in the migration plan is a rule that contains the old `Get` syntax, the associated
53
        input/output types, and the new function to directly call. The migration plan can be dumped as
54
        JSON using the `--json` flag, which can be useful for debugging. For example:
55

56
        {
57
            "filepath": "src/python/pants/source/source_root.py",
58
            "function": "get_source_roots",
59
            "gets": [{
60
                "input_types": [{ "module": "pants.source.source_root", "name": "SourceRootsRequest" }],
61
                "output_type": { "module": "pants.source.source_root", "name": "OptionalSourceRootsResult" },
62
                "rule_dep": { "function": "get_optional_source_roots", "module": "pants.source.source_root" }
63
            }],
64
            "module": "pants.source.source_root"
65
        }
66
        """
67
    )
68

69
    should_dump_json = BoolOption(
2✔
70
        flag_name="--json", help=softwrap("Dump the migration plan as JSON"), default=False
71
    )
72

73
    def run(
2✔
74
        self,
75
        *,
76
        build_config: BuildConfiguration,
77
        graph_session: GraphSession,
78
        options: Options,
79
        specs: Specs,
80
        union_membership: UnionMembership,
81
    ) -> ExitCode:
82
        migration_plan = self._create_migration_plan(graph_session, PurePath(get_buildroot()))
×
83
        if self.should_dump_json:
×
84
            print(json.dumps(migration_plan, indent=2, sort_keys=True))
×
85

86
        path_globs = specs.includes.to_specs_paths_path_globs()
×
87
        if not path_globs.globs:
×
88
            return PANTS_SUCCEEDED_EXIT_CODE
×
89

90
        plan_files = {item["filepath"] for item in migration_plan}
×
91

92
        paths: list[Paths] = graph_session.scheduler_session.product_request(Paths, path_globs)
×
93
        requested_files = set(paths[0].files)
×
94

95
        files_to_migrate = requested_files.intersection(plan_files)
×
96
        if not files_to_migrate:
×
97
            logger.info(
×
98
                f"None of the {len(requested_files)} requested files are part of the {len(plan_files)} files in the migration plan. Please ensure the backend containing these files is activated in pants.toml."
99
            )
100
            return PANTS_SUCCEEDED_EXIT_CODE
×
101

102
        syntax_mapper = CallByNameSyntaxMapper(migration_plan)
×
103
        for f in sorted(files_to_migrate):
×
104
            file = Path(f)
×
105
            logger.info(f"Processing {file}")
×
106

107
            transformer = CallByNameTransformer(file, syntax_mapper)
×
108
            source_code = Path.read_text(file)
×
109
            source_tree = cst.parse_module(source_code)
×
110
            new_tree = source_tree.visit(transformer)
×
111

112
            if not new_tree.deep_equals(source_tree):
×
113
                new_source = new_tree.code
×
114
                Path.write_text(file, new_source)
×
115

116
        return PANTS_SUCCEEDED_EXIT_CODE
×
117

118
    def _create_migration_plan(
2✔
119
        self, session: GraphSession, build_root: PurePath
120
    ) -> list[RuleGraphGet]:
121
        """Use the rule graph to create a migration plan for each "active" file that uses the old
122
        Get() syntax.
123

124
        This function is mostly about creating a stable-sorted collection of items with metadata for
125
        downstream
126
        """
127
        items: list[RuleGraphGet] = []
×
128
        for rule, deps in session.scheduler_session.rule_graph_rule_gets().items():
×
129
            if isinstance(rule, partial):
×
130
                # Ignoring partials, see https://github.com/pantsbuild/pants/issues/20744
131
                continue
×
132

133
            assert (spec := importlib.util.find_spec(rule.__module__)) is not None
×
134
            assert spec.origin is not None
×
135

136
            try:
×
137
                spec_origin = str(PurePath(spec.origin).relative_to(build_root))
×
138
            except ValueError:
×
139
                logger.debug(
×
140
                    f"Ignoring migration plan item located outside of build_root ({build_root}) - file was located at {spec.origin}"
141
                )
142
                continue
×
143

144
            item: RuleGraphGet = {
×
145
                "filepath": spec_origin,
146
                "module": rule.__module__,
147
                "function": rule.__name__,
148
                "gets": [],
149
            }
150
            unsorted_deps: list[RuleGraphGetDep] = []
×
151

152
            for output_type, input_types, rule_dep in deps:
×
153
                if isinstance(rule_dep, partial):
×
154
                    # Ignoring partials, see https://github.com/pantsbuild/pants/issues/20744
155
                    continue
×
156

157
                unsorted_deps.append(
×
158
                    {
159
                        "input_types": sorted(
160
                            [
161
                                {
162
                                    "module": input_type.__module__,
163
                                    "name": input_type.__name__,
164
                                }
165
                                for input_type in input_types
166
                            ],
167
                            key=lambda x: (x["module"], x["name"]),
168
                        ),
169
                        "output_type": {
170
                            "module": output_type.__module__,
171
                            "name": output_type.__name__,
172
                        },
173
                        "rule_dep": {
174
                            "function": rule_dep.__name__,
175
                            "module": rule_dep.__module__,
176
                        },
177
                    }
178
                )
179

180
            sorted_deps = sorted(
×
181
                unsorted_deps, key=lambda x: (x["rule_dep"]["module"], x["rule_dep"]["function"])
182
            )
183
            item["gets"] = sorted_deps
×
184
            items.append(item)
×
185

186
        return sorted(items, key=lambda x: (x["filepath"], x["function"]))
×
187

188

189
# ------------------------------------------------------------------------------------------
190
# Migration Plan Typed Dicts
191
# ------------------------------------------------------------------------------------------
192

193

194
class RuleGraphGet(TypedDict):
2✔
195
    filepath: str
2✔
196
    function: str
2✔
197
    module: str
2✔
198
    gets: list[RuleGraphGetDep]
2✔
199

200

201
class RuleGraphGetDep(TypedDict):
2✔
202
    input_types: list[RuleType]
2✔
203
    output_type: RuleType
2✔
204
    rule_dep: RuleFunction
2✔
205

206

207
class RuleType(TypedDict):
2✔
208
    module: str
2✔
209
    name: str
2✔
210

211

212
class RuleFunction(TypedDict):
2✔
213
    function: str
2✔
214
    module: str
2✔
215

216

217
# ------------------------------------------------------------------------------------------
218
# Replacement container
219
# ------------------------------------------------------------------------------------------
220

221

222
@dataclass
2✔
223
class Replacement:
2✔
224
    filename: PurePath
2✔
225
    module: str
2✔
226
    current_source: cst.Call
2✔
227
    new_source: cst.Call
2✔
228
    additional_imports: list[cst.ImportFrom]
2✔
229

230
    def sanitized_imports(self) -> list[cst.ImportFrom]:
2✔
231
        """Remove any circular or self-imports."""
UNCOV
232
        cst_module = cstutil.make_importfrom_attr(self.module)
×
UNCOV
233
        return [
×
234
            i for i in self.additional_imports if i.module and not cst_module.deep_equals(i.module)
235
        ]
236

237
    def sanitize(self, unavailable_names: set[str]):
2✔
238
        """Remove any shadowing of names, except if the new_func is in the current file."""
239

UNCOV
240
        func_name = cst.ensure_type(self.new_source.func, cst.Name)
×
UNCOV
241
        if func_name.value not in unavailable_names:
×
UNCOV
242
            return
×
243

244
        # If the new func_name is not in the sanitized imports, it must already be in the current file
UNCOV
245
        imported_names: set[str] = set()
×
UNCOV
246
        for imp in self.sanitized_imports():
×
UNCOV
247
            assert isinstance(imp.names, Iterable)
×
UNCOV
248
            for import_alias in imp.names:
×
UNCOV
249
                alias_name = cst.ensure_type(import_alias.name, cst.Name)
×
UNCOV
250
                imported_names.add(alias_name.value)
×
251

UNCOV
252
        if func_name.value not in imported_names:
×
UNCOV
253
            return
×
254

255
        # In-place update this replacement and additional imports
UNCOV
256
        bound_name = f"{func_name.value}_get"
×
UNCOV
257
        self.new_source = self.new_source.with_deep_changes(self.new_source.func, value=bound_name)
×
258

UNCOV
259
        for i, imp in enumerate(self.additional_imports):
×
UNCOV
260
            assert isinstance(imp.names, Iterable)
×
UNCOV
261
            for import_alias in imp.names:
×
UNCOV
262
                alias_name = cst.ensure_type(import_alias.name, cst.Name)
×
UNCOV
263
                if alias_name.value == func_name.value:
×
UNCOV
264
                    self.additional_imports[i] = imp.with_changes(
×
265
                        names=[cst.ImportAlias(func_name, asname=cst.AsName(cst.Name(bound_name)))]
266
                    )
267

UNCOV
268
        logging.warning(f"Renamed {func_name} to {bound_name} to avoid shadowing")
×
269

270
    def __str__(self) -> str:
2✔
UNCOV
271
        return f"""
×
272
        Replacement(
273
            filename={self.filename},
274
            module={self.module},
275
            current_source={dump(self.current_source)},
276
            new_source={dump(self.new_source)},
277
            additional_imports={[dump(i) for i in self.additional_imports]},
278
        )
279
        """
280

281

282
# ------------------------------------------------------------------------------------------
283
# Call-by-name transformer
284
# ------------------------------------------------------------------------------------------
285

286

287
class CallByNameTransformer(m.MatcherDecoratableTransformer):
2✔
288
    def __init__(self, filename: PurePath, syntax_mapper: CallByNameSyntaxMapper) -> None:
2✔
289
        super().__init__()
×
290

291
        self.filename = filename
×
292
        self.syntax_mapper = syntax_mapper
×
293
        self.calling_function: str = ""
×
294
        self.additional_imports: list[cst.ImportFrom] = []
×
295
        self.unavailable_names: set[str] = set()
×
296

297
    def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
2✔
298
        self.calling_function = node.name.value
×
299

300
    def leave_FunctionDef(
2✔
301
        self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
302
    ) -> cst.FunctionDef:
303
        self.calling_function = ""
×
304
        return updated_node
×
305

306
    @m.leave(m.Call(func=m.Name("Get")))
2✔
307
    def handle_get(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
2✔
308
        replacement = self.syntax_mapper.map_get_to_new_syntax(
×
309
            original_node, self.filename, self.calling_function
310
        )
311
        if not replacement:
×
312
            return updated_node
×
313

314
        replacement.sanitize(self.unavailable_names)
×
315
        self.additional_imports.extend(replacement.sanitized_imports())
×
316
        return replacement.new_source
×
317

318
    @m.leave(m.Name("MultiGet"))
2✔
319
    def handle_multiget(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
2✔
320
        return updated_node.with_changes(value="concurrently")
×
321

322
    def visit_Module(self, node: cst.Module):
2✔
323
        """Collects all names we risk shadowing."""
324
        wrapper = libcst.metadata.MetadataWrapper(module=node)
×
325
        scopes = set(wrapper.resolve(libcst.metadata.ScopeProvider).values())
×
326
        self.unavailable_names = {
×
327
            assignment.name for scope in scopes if scope for assignment in scope.assignments
328
        }
329

330
    def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
2✔
331
        """Performs final updates on imports, and sanitization."""
332
        if not self.additional_imports:
×
333
            return updated_node
×
334

335
        rules_import_index = 1
×
336
        for i, statement in enumerate(updated_node.body):
×
337
            if m.matches(
×
338
                statement,
339
                matcher=m.SimpleStatementLine(
340
                    body=[
341
                        m.ImportFrom(
342
                            module=cstutil.make_importfrom_attr_matcher("pants.engine.rules")
343
                        )
344
                    ]
345
                ),
346
            ):
347
                rules_import_index = i + 1
×
348
                break
×
349

350
        self.additional_imports.append(cstutil.make_importfrom("pants.engine.rules", "implicitly"))
×
351
        additional_import_statements = [
×
352
            cst.SimpleStatementLine(body=[i]) for i in self.additional_imports
353
        ]
354
        additional_import_statements = [
×
355
            v1
356
            for i, v1 in enumerate(additional_import_statements)
357
            if not any(v1.deep_equals(v2) for v2 in additional_import_statements[:i])
358
        ]
359

360
        return updated_node.with_changes(
×
361
            body=[
362
                *updated_node.body[:rules_import_index],
363
                *additional_import_statements,
364
                *updated_node.body[rules_import_index:],
365
            ]
366
        )
367

368

369
# ------------------------------------------------------------------------------------------
370
# Call-by-name syntax mapping
371
# ------------------------------------------------------------------------------------------
372

373

374
class CallByNameSyntaxMapper:
2✔
375
    def __init__(self, graphs: list[RuleGraphGet]) -> None:
2✔
376
        self.graphs = graphs
×
377

378
        self.mapping: dict[
×
379
            tuple[int, type[cst.Call] | type[cst.Dict] | None],
380
            Callable[[cst.Call, list[RuleGraphGetDep]], tuple[cst.Call, list[cst.ImportFrom]]],
381
        ] = {
382
            (1, None): self.map_no_args_get_to_new_syntax,
383
            (2, cst.Call): self.map_short_form_get_to_new_syntax,
384
            (2, cst.Dict): self.map_dict_form_get_to_new_syntax,
385
            (3, None): self.map_long_form_get_to_new_syntax,
386
        }
387

388
    def _get_graph_item(self, filename: PurePath, calling_func: str) -> RuleGraphGet | None:
2✔
389
        return next(
×
390
            (
391
                item
392
                for item in self.graphs
393
                if item["filepath"] == str(filename) and item["function"] == calling_func
394
            ),
395
            None,
396
        )
397

398
    def map_get_to_new_syntax(
2✔
399
        self, get: cst.Call, filename: PurePath, calling_func: str
400
    ) -> Replacement | None:
401
        new_source: cst.Call | None = None
×
402
        imports: list[cst.ImportFrom] = []
×
403

404
        if not (graph_item := self._get_graph_item(filename, calling_func)):
×
405
            logger.warning(f"Failed to find dependencies for {calling_func} in {filename}")
×
406
            return None
×
407

408
        get_deps = graph_item["gets"]
×
409
        num_args = len(get.args)
×
410

411
        arg_type = None
×
412
        if num_args == 2 and (arg_type := type(get.args[1].value)) not in [cst.Call, cst.Dict]:
×
413
            logger.warning(f"Failed to migrate: Unknown arg type {get.args[1]}")
×
414
            return None
×
415

416
        try:
×
417
            new_source, imports = self.mapping[(num_args, arg_type)](get, get_deps)  # type: ignore
×
418
        except Exception as e:
×
419
            logger.warning(
×
420
                f"Failed to migrate Get ({num_args}, {arg_type}) in {filename}:{calling_func} due to: {e}\n"
421
            )
422
            return None
×
423

424
        return Replacement(
×
425
            filename=filename,
426
            module=graph_item["module"],
427
            current_source=get,
428
            new_source=new_source,
429
            additional_imports=imports,
430
        )
431

432
    def map_no_args_get_to_new_syntax(
2✔
433
        self, get: cst.Call, deps: list[RuleGraphGetDep]
434
    ) -> tuple[cst.Call, list[cst.ImportFrom]]:
435
        """Get(<OutputType>) => the_rule_to_call(**implicitly())"""
436

437
        output_type = cst.ensure_type(get.args[0].value, cst.Name).value
×
438

439
        dep = next(
×
440
            dep
441
            for dep in deps
442
            if dep["output_type"]["name"] == output_type and not dep["input_types"]
443
        )
444
        module, new_function = dep["rule_dep"]["module"], dep["rule_dep"]["function"]
×
445

446
        new_call = cst.Call(
×
447
            func=cst.Name(new_function),
448
            args=[cst.Arg(value=cst.Call(cst.Name("implicitly")), star="**")],
449
        )
450
        if called_funcdef := cstutil.extract_functiondef_from_module(module, new_function):
×
451
            new_call = fix_implicitly_usage(new_call, called_funcdef)
×
452

453
        imports = [cstutil.make_importfrom(module, new_function)]
×
454
        return new_call, imports
×
455

456
    def map_long_form_get_to_new_syntax(
2✔
457
        self, get: cst.Call, deps: list[RuleGraphGetDep]
458
    ) -> tuple[cst.Call, list[cst.ImportFrom]]:
459
        """Get(<OutputType>, <InputType>, input) => the_rule_to_call(**implicitly(input))"""
460

461
        output_type = cst.ensure_type(get.args[0].value, cst.Name).value
×
462
        input_type = cst.ensure_type(get.args[1].value, cst.Name).value
×
463

464
        dep = next(
×
465
            dep
466
            for dep in deps
467
            if dep["output_type"]["name"] == output_type
468
            and len(dep["input_types"]) == 1
469
            and dep["input_types"][0]["name"] == input_type
470
        )
471
        module, new_function = dep["rule_dep"]["module"], dep["rule_dep"]["function"]
×
472

473
        new_call = cst.Call(
×
474
            func=cst.Name(new_function),
475
            args=[
476
                cst.Arg(
477
                    value=cst.Call(
478
                        func=cst.Name("implicitly"),
479
                        args=[
480
                            cst.Arg(
481
                                value=cst.Dict(
482
                                    [
483
                                        cst.DictElement(
484
                                            key=get.args[2].value, value=cst.Name(input_type)
485
                                        )
486
                                    ]
487
                                )
488
                            )
489
                        ],
490
                    ),
491
                    star="**",
492
                )
493
            ],
494
        )
495

496
        if called_funcdef := cstutil.extract_functiondef_from_module(module, new_function):
×
497
            new_call = fix_implicitly_usage(new_call, called_funcdef)
×
498

499
        imports = [cstutil.make_importfrom(module, new_function)]
×
500
        return new_call, imports
×
501

502
    def map_short_form_get_to_new_syntax(
2✔
503
        self, get: cst.Call, deps: list[RuleGraphGetDep]
504
    ) -> tuple[cst.Call, list[cst.ImportFrom]]:
505
        """Get(<OutType>, <InputType>(<input args>)) => the_rule_to_call(input, **implicitly())"""
506

507
        output_type = cst.ensure_type(get.args[0].value, cst.Name).value
×
508
        input_call = cst.ensure_type(get.args[1].value, cst.Call)
×
509
        input_type = cst.ensure_type(input_call.func, cst.Name).value
×
510

511
        dep = next(
×
512
            dep
513
            for dep in deps
514
            if dep["output_type"]["name"] == output_type
515
            and len(dep["input_types"]) == 1
516
            and dep["input_types"][0]["name"] == input_type
517
        )
518
        module, new_function = dep["rule_dep"]["module"], dep["rule_dep"]["function"]
×
519

520
        new_call = cst.Call(
×
521
            func=cst.Name(new_function),
522
            args=[
523
                cst.Arg(value=input_call),
524
                cst.Arg(value=cst.Call(cst.Name("implicitly")), star="**"),
525
            ],
526
        )
527

528
        if called_funcdef := cstutil.extract_functiondef_from_module(module, new_function):
×
529
            new_call = fix_implicitly_usage(new_call, called_funcdef)
×
530

531
        imports = [cstutil.make_importfrom(module, new_function)]
×
532
        return new_call, imports
×
533

534
    def map_dict_form_get_to_new_syntax(
2✔
535
        self, get: cst.Call, deps: list[RuleGraphGetDep]
536
    ) -> tuple[cst.Call, list[cst.ImportFrom]]:
537
        """Get(<OutputType>, {input1: <Input1Type>, ..inputN: <InputNType>}) =>
538
        the_rule_to_call(**implicitly(input))"""
539

540
        output_type = cst.ensure_type(get.args[0].value, cst.Name).value
×
541
        input_dict = cst.ensure_type(get.args[1].value, cst.Dict)
×
542
        input_types = {
×
543
            element.value.value
544
            for element in input_dict.elements
545
            if isinstance(element.value, cst.Name)
546
        }
547

548
        dep = next(
×
549
            dep
550
            for dep in deps
551
            if dep["output_type"]["name"] == output_type
552
            and {i["name"] for i in dep["input_types"]} == input_types
553
        )
554
        module, new_function = dep["rule_dep"]["module"], dep["rule_dep"]["function"]
×
555

556
        new_call = cst.Call(
×
557
            func=cst.Name(new_function),
558
            args=[
559
                cst.Arg(
560
                    value=cst.Call(func=cst.Name("implicitly"), args=[cst.Arg(input_dict)]),
561
                    star="**",
562
                )
563
            ],
564
        )
565

566
        if called_funcdef := cstutil.extract_functiondef_from_module(module, new_function):
×
567
            new_call = fix_implicitly_usage(new_call, called_funcdef)
×
568

569
        imports = [cstutil.make_importfrom(module, new_function)]
×
570
        return new_call, imports
×
571

572

573
# ------------------------------------------------------------------------------------------
574
# Implicity helpers
575
# ------------------------------------------------------------------------------------------
576

577

578
def fix_implicitly_usage(call: cst.Call, target_func: cst.FunctionDef) -> cst.Call:
2✔
579
    """The CallByNameSyntaxMapper aggressively adds `implicitly` for safety. This function removes
580
    unnecessary ones, and attempts to cleanup usage.
581

582
    Examples:
583
        find_all_targets(**implicitly()) -> find_all_targets()
584
        create_pex(**implicitly({req: PexRequest})) -> create_pex(req)
585
        create_venv_pex(**implicitly({req: PexRequest})) -> create_venv_pex(**implicitly(req))
586

587
    Refer to `migrate_call_by_name_test.py` for more examples.
588

589
    Parameters:
590
        call: The replaced `Get` which now uses the migrated call-by-name syntax
591
        target_func: The target called-by-name function
592
    """
UNCOV
593
    call_func_name = cst.ensure_type(call.func, cst.Name).value
×
UNCOV
594
    if call_func_name != target_func.name.value:
×
UNCOV
595
        return call
×
596

597
    # If there are no `implicitly`s, there is nothing to do
UNCOV
598
    implicit_calls = m.findall(call, m.Call(func=m.Name("implicitly")))
×
UNCOV
599
    if not implicit_calls:
×
UNCOV
600
        return call
×
601

602
    # Only handling the 1-arg case (plus implicitly) for now, which is the overwhelming majority of usage
UNCOV
603
    number_of_call_args = len(call.args)
×
UNCOV
604
    if number_of_call_args > 2:
×
UNCOV
605
        return call
×
606

607
    # If the target function takes no arguments, then there is nothing to `implicit`
UNCOV
608
    number_of_target_args = len(target_func.params.params)
×
UNCOV
609
    if number_of_target_args == 0:
×
UNCOV
610
        return call.with_changes(args=[])
×
611

UNCOV
612
    target_annotations = [
×
613
        cst.ensure_type(a, cst.Annotation).annotation
614
        for a in m.findall(target_func.params, m.Annotation())
615
    ]
UNCOV
616
    target_types = [
×
617
        target_type for a in target_annotations if (target_type := h.get_full_name_for_node(a))
618
    ]
619

620
    # Positionally compare the target function's annotations with the call's arguments
621
    # If they match, then there is no need for `implicitly`
622

623
    # Check if `implicitly` contains dict - as that needs special handling
UNCOV
624
    implicit_call = cst.ensure_type(implicit_calls[0], cst.Call)
×
UNCOV
625
    if implicit_call.args and isinstance(d := implicit_call.args[0].value, cst.Dict):
×
UNCOV
626
        if len(d.elements) > 1:
×
627
            # Not handling cases with larger than 1 element
628
            return call
×
629

UNCOV
630
        element = cst.ensure_type(d.elements[0], cst.DictElement)
×
UNCOV
631
        if h.get_full_name_for_node(element.value) == target_types[0]:
×
632
            # If arg and target match, we can strip `implicitly` call
UNCOV
633
            return call.with_changes(args=[cst.Arg(element.key)])
×
634
        else:
635
            # If arg and target don't match, keep `implicitly` call, but remove dict for normal call
UNCOV
636
            new_arg = cst.Arg(
×
637
                cst.Call(cst.Name("implicitly"), args=[cst.Arg(element.key)]), star="**"
638
            )
UNCOV
639
            return call.with_changes(args=[new_arg])
×
640

641
    # If the target function takes in the same number of arguments as we've already passed in,
642
    # and they are of the same type, then we don't need to pass in `implicitly`
UNCOV
643
    if number_of_call_args - 1 == len(target_types):
×
UNCOV
644
        arg = cst.ensure_type(call.args[0].value, cst.Call)
×
UNCOV
645
        new_arg = call.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT)
×
UNCOV
646
        if h.get_full_name_for_node(arg.func) == target_types[0]:
×
647
            # If arg and target match, we can strip `implicitly` call
UNCOV
648
            return call.with_changes(args=[new_arg])
×
649
        else:
650
            # If arg and target don't match, keep `implicitly` call
UNCOV
651
            new_arg = cst.Arg(cst.Call(cst.Name("implicitly"), args=[new_arg]), star="**")
×
UNCOV
652
            return call.with_changes(args=[new_arg])
×
653

654
    # If the target function takes in more arguments than we've passed in, then we need to pass in `implicitly`
655
    # This checks if it should be a trailing implicitly, or if we should wrap the first arg
UNCOV
656
    if number_of_call_args - 1 < len(target_types):
×
UNCOV
657
        arg = cst.ensure_type(call.args[0].value, cst.Call)
×
UNCOV
658
        if h.get_full_name_for_node(arg.func) == target_types[0]:
×
UNCOV
659
            return call
×
660
        else:
UNCOV
661
            new_arg = call.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT)
×
UNCOV
662
            new_arg = cst.Arg(cst.Call(cst.Name("implicitly"), args=[new_arg]), star="**")
×
UNCOV
663
            return call.with_changes(args=[new_arg])
×
664

665
    return call
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc