• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pantsbuild / pants / 18252174847

05 Oct 2025 01:36AM UTC coverage: 43.382% (-36.9%) from 80.261%
18252174847

push

github

web-flow
run tests on mac arm (#22717)

Just doing the minimal to pull forward the x86_64 pattern.

ref #20993

25776 of 59416 relevant lines covered (43.38%)

1.3 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.28
/src/python/pants/testutil/rule_runner.py
1
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md).
2
# Licensed under the Apache License, Version 2.0 (see LICENSE).
3

4
from __future__ import annotations
3✔
5

6
import atexit
3✔
7
import dataclasses
3✔
8
import difflib
3✔
9
import functools
3✔
10
import inspect
3✔
11
import os
3✔
12
import re
3✔
13
import sys
3✔
14
import warnings
3✔
15
from collections.abc import Callable, Coroutine, Generator, Iterable, Iterator, Mapping, Sequence
3✔
16
from contextlib import contextmanager
3✔
17
from dataclasses import dataclass
3✔
18
from io import StringIO
3✔
19
from pathlib import Path, PurePath
3✔
20
from tempfile import mkdtemp
3✔
21
from typing import Any, Generic, TypeVar, cast, overload
3✔
22

23
from pants.base.build_environment import get_buildroot
3✔
24
from pants.base.build_root import BuildRoot
3✔
25
from pants.base.specs_parser import SpecsParser
3✔
26
from pants.build_graph import build_configuration
3✔
27
from pants.build_graph.build_configuration import BuildConfiguration
3✔
28
from pants.build_graph.build_file_aliases import BuildFileAliases
3✔
29
from pants.core.goals.run import generate_run_request
3✔
30
from pants.core.util_rules import adhoc_binaries, misc
3✔
31
from pants.engine.addresses import Address
3✔
32
from pants.engine.console import Console
3✔
33
from pants.engine.env_vars import CompleteEnvironmentVars
3✔
34
from pants.engine.environment import EnvironmentName
3✔
35
from pants.engine.fs import CreateDigest, Digest, FileContent, Snapshot, Workspace
3✔
36
from pants.engine.goal import CurrentExecutingGoals, Goal
3✔
37
from pants.engine.internals import native_engine, options_parsing
3✔
38
from pants.engine.internals.native_engine import ProcessExecutionEnvironment, PyExecutor
3✔
39
from pants.engine.internals.scheduler import ExecutionError, Scheduler, SchedulerSession
3✔
40
from pants.engine.internals.selectors import Call, Effect, Get, Params
3✔
41
from pants.engine.internals.session import SessionValues
3✔
42
from pants.engine.platform import Platform
3✔
43
from pants.engine.process import InteractiveProcess, InteractiveProcessResult
3✔
44
from pants.engine.rules import QueryRule as QueryRule
3✔
45
from pants.engine.target import AllTargets, Target, WrappedTarget, WrappedTargetRequest
3✔
46
from pants.engine.unions import UnionMembership, UnionRule
3✔
47
from pants.goal.auxiliary_goal import AuxiliaryGoal
3✔
48
from pants.init.engine_initializer import EngineInitializer
3✔
49
from pants.init.logging import initialize_stdio, initialize_stdio_raw, stdio_destination
3✔
50
from pants.option.bootstrap_options import DynamicRemoteOptions, ExecutionOptions, LocalStoreOptions
3✔
51
from pants.option.global_options import GlobalOptions
3✔
52
from pants.option.options_bootstrapper import OptionsBootstrapper
3✔
53
from pants.source import source_root
3✔
54
from pants.testutil.option_util import create_options_bootstrapper
3✔
55
from pants.util.collections import assert_single_element
3✔
56
from pants.util.contextutil import pushd, temporary_dir, temporary_file
3✔
57
from pants.util.dirutil import recursive_dirname, safe_mkdir, safe_mkdtemp, safe_open
3✔
58
from pants.util.logging import LogLevel
3✔
59
from pants.util.ordered_set import OrderedSet
3✔
60
from pants.util.strutil import softwrap
3✔
61

62

63
def logging(original_function=None, *, level: LogLevel = LogLevel.INFO):
3✔
64
    """A decorator that enables logging (optionally at the given level).
65

66
    May be used without a parameter list:
67

68
        ```
69
        @logging
70
        def test_function():
71
            ...
72
        ```
73

74
    ...or with a level argument:
75

76
        ```
77
        @logging(level=LogLevel.DEBUG)
78
        def test_function():
79
            ...
80
        ```
81
    """
82

83
    def _decorate(func):
3✔
84
        @functools.wraps(func)
3✔
85
        def wrapper(*args, **kwargs):
3✔
86
            stdout_fileno, stderr_fileno = sys.stdout.fileno(), sys.stderr.fileno()
3✔
87
            with (
3✔
88
                temporary_dir() as tempdir,
89
                initialize_stdio_raw(level, False, False, {}, True, [], tempdir),
90
                stdin_context() as stdin,
91
                stdio_destination(stdin.fileno(), stdout_fileno, stderr_fileno),
92
            ):
93
                return func(*args, **kwargs)
3✔
94

95
        return wrapper
3✔
96

97
    if original_function:
3✔
98
        return _decorate(original_function)
3✔
99
    return _decorate
×
100

101

102
@contextmanager
3✔
103
def engine_error(
3✔
104
    expected_underlying_exception: type[Exception] = Exception,
105
    *,
106
    contains: str | None = None,
107
    normalize_tracebacks: bool = False,
108
) -> Iterator[None]:
109
    """A context manager to catch `ExecutionError`s in tests and check that the underlying exception
110
    is expected.
111

112
    Use like this:
113

114
        with engine_error(ValueError, contains="foo"):
115
            rule_runner.request(OutputType, [input])
116

117
    Will raise AssertionError if no ExecutionError occurred.
118

119
    Set `normalize_tracebacks=True` to replace file locations and addresses in the error message
120
    with fixed values for testability, and check `contains` against the `ExecutionError` message
121
    instead of the underlying error only.
122
    """
123
    try:
×
124
        yield
×
125
    except ExecutionError as exec_error:
×
126
        if not len(exec_error.wrapped_exceptions) == 1:
×
127
            formatted_errors = "\n\n".join(repr(e) for e in exec_error.wrapped_exceptions)
×
128
            raise ValueError(
×
129
                softwrap(
130
                    f"""
131
                    Multiple underlying exceptions, but this helper function expected only one.
132
                    Use `with pytest.raises(ExecutionError) as exc` directly and inspect
133
                    `exc.value.wrapped_exceptions`.
134

135
                    Errors: {formatted_errors}
136
                    """
137
                )
138
            )
139
        underlying = exec_error.wrapped_exceptions[0]
×
140
        if not isinstance(underlying, expected_underlying_exception):
×
141
            raise AssertionError(
×
142
                softwrap(
143
                    f"""
144
                    ExecutionError occurred as expected, but the underlying exception had type
145
                    {type(underlying)} rather than the expected type
146
                    {expected_underlying_exception}:
147

148
                    {underlying}
149
                    """
150
                )
151
            )
152
        if contains is not None:
×
153
            if normalize_tracebacks:
×
154
                errmsg = remove_locations_from_traceback(str(exec_error))
×
155
            else:
156
                errmsg = str(underlying)
×
157
            if contains not in errmsg:
×
158
                diff = "\n".join(
×
159
                    difflib.Differ().compare(contains.splitlines(), errmsg.splitlines())
160
                )
161
                raise AssertionError(
×
162
                    softwrap(
163
                        f"""
164
                        Expected value not found in exception.
165

166
                        => Expected: {contains}
167

168
                        => Actual: {errmsg}
169

170
                        => Diff: {diff}
171
                        """
172
                    )
173
                )
174
    else:
175
        raise AssertionError(
×
176
            softwrap(
177
                f"""
178
                DID NOT RAISE ExecutionError with underlying exception type
179
                {expected_underlying_exception}.
180
                """
181
            )
182
        )
183

184

185
def remove_locations_from_traceback(trace: str) -> str:
3✔
186
    location_pattern = re.compile(r'"/.*", line \d+')
×
187
    address_pattern = re.compile(r"0x[0-9a-f]+")
×
188
    new_trace = location_pattern.sub("LOCATION-INFO", trace)
×
189
    new_trace = address_pattern.sub("0xEEEEEEEEE", new_trace)
×
190
    return new_trace
×
191

192

193
# -----------------------------------------------------------------------------------------------
194
# `RuleRunner`
195
# -----------------------------------------------------------------------------------------------
196

197

198
_I = TypeVar("_I")
3✔
199
_O = TypeVar("_O")
3✔
200

201

202
# A global executor for Schedulers created in unit tests, which is shutdown using `atexit`. This
203
# allows for reusing threads, and avoids waiting for straggling tasks during teardown of each test.
204
EXECUTOR = PyExecutor(
3✔
205
    # Use the ~minimum possible parallelism since integration tests using RuleRunner will already
206
    # be run by Pants using an appropriate Parallelism. We must set max_threads > core_threads; so
207
    # 2 is the minimum, but, via trial and error, 3 minimizes test times on average.
208
    core_threads=1,
209
    max_threads=3,
210
)
211
atexit.register(lambda: EXECUTOR.shutdown(5))
3✔
212

213

214
# Environment variable names required for locating Python interpreters, for use with RuleRunner's
215
# env_inherit arguments.
216
# TODO: This is verbose and redundant: see https://github.com/pantsbuild/pants/issues/13350.
217
PYTHON_BOOTSTRAP_ENV = {"PATH", "PYENV_ROOT", "HOME"}
3✔
218

219

220
@dataclass(frozen=True)
3✔
221
class GoalRuleResult:
3✔
222
    exit_code: int
3✔
223
    stdout: str
3✔
224
    stderr: str
3✔
225

226
    @staticmethod
3✔
227
    def noop() -> GoalRuleResult:
3✔
228
        return GoalRuleResult(0, stdout="", stderr="")
×
229

230

231
# This is not frozen because we need to update the `scheduler` when setting options.
232
@dataclass
3✔
233
class RuleRunner:
3✔
234
    build_root: str
3✔
235
    options_bootstrapper: OptionsBootstrapper
3✔
236
    extra_session_values: dict[Any, Any]
3✔
237
    max_workunit_verbosity: LogLevel
3✔
238
    build_config: BuildConfiguration
3✔
239
    scheduler: SchedulerSession
3✔
240
    rules: tuple[Any, ...]
3✔
241

242
    def __init__(
3✔
243
        self,
244
        *,
245
        rules: Iterable | None = None,
246
        target_types: Iterable[type[Target]] | None = None,
247
        objects: dict[str, Any] | None = None,
248
        aliases: Iterable[BuildFileAliases] | None = None,
249
        context_aware_object_factories: dict[str, Any] | None = None,
250
        isolated_local_store: bool = False,
251
        preserve_tmpdirs: bool = False,
252
        ca_certs_path: str | None = None,
253
        bootstrap_args: Iterable[str] = (),
254
        extra_session_values: dict[Any, Any] | None = None,
255
        max_workunit_verbosity: LogLevel = LogLevel.DEBUG,
256
        inherent_environment: EnvironmentName | None = EnvironmentName(None),
257
        is_bootstrap: bool = False,
258
        auxiliary_goals: Iterable[type[AuxiliaryGoal]] | None = None,
259
    ) -> None:
260
        bootstrap_args = [*bootstrap_args]
3✔
261

262
        root_dir: Path | None = None
3✔
263
        if preserve_tmpdirs:
3✔
264
            root_dir = Path(mkdtemp(prefix="RuleRunner."))
×
265
            print(f"Preserving rule runner temporary directories at {root_dir}.", file=sys.stderr)
×
266
            bootstrap_args.extend(
×
267
                ["--keep-sandboxes=always", f"--local-execution-root-dir={root_dir}"]
268
            )
269
            build_root = (root_dir / "BUILD_ROOT").resolve()
×
270
            build_root.mkdir()
×
271
            self.build_root = str(build_root)
×
272
        else:
273
            self.build_root = os.path.realpath(safe_mkdtemp(prefix="_BUILD_ROOT"))
3✔
274

275
        safe_mkdir(self.pants_workdir)
3✔
276
        BuildRoot().path = self.build_root
3✔
277

278
        def rewrite_rule_for_inherent_environment(rule):
3✔
279
            if not inherent_environment or not isinstance(rule, QueryRule):
3✔
280
                return rule
3✔
281
            return QueryRule(rule.output_type, OrderedSet((*rule.input_types, EnvironmentName)))
3✔
282

283
        # TODO: Redesign rule registration for tests to be more ergonomic and to make this less
284
        #  special-cased.
285
        self.rules = tuple(rewrite_rule_for_inherent_environment(rule) for rule in (rules or ()))
3✔
286
        all_rules = (
3✔
287
            *self.rules,
288
            *build_configuration.rules(),
289
            *source_root.rules(),
290
            *options_parsing.rules(),
291
            *misc.rules(),
292
            *adhoc_binaries.rules(),
293
            # Many tests indirectly rely on this rule.
294
            generate_run_request,
295
            QueryRule(WrappedTarget, [WrappedTargetRequest]),
296
            QueryRule(AllTargets, []),
297
            QueryRule(UnionMembership, []),
298
        )
299
        build_config_builder = BuildConfiguration.Builder()
3✔
300
        build_config_builder.register_aliases(
3✔
301
            BuildFileAliases(
302
                objects=objects, context_aware_object_factories=context_aware_object_factories
303
            )
304
        )
305
        aliases = aliases or ()
3✔
306
        for build_file_aliases in aliases:
3✔
307
            build_config_builder.register_aliases(build_file_aliases)
×
308

309
        build_config_builder.register_rules("_dummy_for_test_", all_rules)
3✔
310
        build_config_builder.register_target_types("_dummy_for_test_", target_types or ())
3✔
311
        build_config_builder.register_auxiliary_goals("_dummy_for_test_", auxiliary_goals or ())
3✔
312
        self.build_config = build_config_builder.create()
3✔
313

314
        self.environment = CompleteEnvironmentVars({})
3✔
315
        self.extra_session_values = extra_session_values or {}
3✔
316
        self.inherent_environment = inherent_environment
3✔
317
        self.max_workunit_verbosity = max_workunit_verbosity
3✔
318

319
        # Change cwd and add sentinel file (BUILDROOT) so NativeOptionParser can find build_root.
320
        with self.pushd():
3✔
321
            Path("BUILDROOT").touch()
3✔
322
            self.options_bootstrapper = self.create_options_bootstrapper(
3✔
323
                args=bootstrap_args, env=None
324
            )
325
            options = self.options_bootstrapper.full_options(
3✔
326
                known_scope_infos=self.build_config.known_scope_infos,
327
                union_membership=UnionMembership.from_rules(
328
                    rule for rule in self.rules if isinstance(rule, UnionRule)
329
                ),
330
                allow_unknown_options=self.build_config.allow_unknown_options,
331
            )
332
            global_options = self.options_bootstrapper.bootstrap_options.for_global_scope()
3✔
333

334
        dynamic_remote_options, _ = DynamicRemoteOptions.from_options(options, self.environment)
3✔
335
        local_store_options = LocalStoreOptions.from_options(global_options)
3✔
336
        if isolated_local_store:
3✔
337
            if root_dir:
×
338
                lmdb_store_dir = root_dir / "lmdb_store"
×
339
                lmdb_store_dir.mkdir()
×
340
                store_dir = str(lmdb_store_dir)
×
341
            else:
342
                store_dir = safe_mkdtemp(prefix="lmdb_store.")
×
343
            local_store_options = dataclasses.replace(local_store_options, store_dir=store_dir)
×
344

345
        local_execution_root_dir = global_options.local_execution_root_dir
3✔
346
        named_caches_dir = global_options.named_caches_dir
3✔
347

348
        self._set_new_session(
3✔
349
            EngineInitializer.setup_graph_extended(
350
                pants_ignore_patterns=GlobalOptions.compute_pants_ignore(
351
                    self.build_root, global_options
352
                ),
353
                use_gitignore=False,
354
                local_store_options=local_store_options,
355
                local_execution_root_dir=local_execution_root_dir,
356
                named_caches_dir=named_caches_dir,
357
                pants_workdir=self.pants_workdir,
358
                build_root=self.build_root,
359
                build_configuration=self.build_config,
360
                # Each Scheduler that is created borrows the global executor, which is shut down `atexit`.
361
                executor=EXECUTOR.to_borrowed(),
362
                execution_options=ExecutionOptions.from_options(
363
                    global_options, dynamic_remote_options
364
                ),
365
                ca_certs_path=ca_certs_path,
366
                engine_visualize_to=None,
367
                is_bootstrap=is_bootstrap,
368
            ).scheduler
369
        )
370

371
    def __repr__(self) -> str:
3✔
372
        return f"RuleRunner(build_root={self.build_root})"
×
373

374
    def _set_new_session(self, scheduler: Scheduler) -> None:
3✔
375
        self.scheduler = scheduler.new_session(
3✔
376
            build_id="buildid_for_test",
377
            session_values=SessionValues(
378
                {
379
                    OptionsBootstrapper: self.options_bootstrapper,
380
                    CompleteEnvironmentVars: self.environment,
381
                    CurrentExecutingGoals: CurrentExecutingGoals(),
382
                    **self.extra_session_values,
383
                }
384
            ),
385
            max_workunit_level=self.max_workunit_verbosity,
386
        )
387

388
    @property
3✔
389
    def pants_workdir(self) -> str:
3✔
390
        return os.path.join(self.build_root, ".pants.d", "workdir")
3✔
391

392
    @property
3✔
393
    def target_types(self) -> tuple[type[Target], ...]:
3✔
394
        return self.build_config.target_types
×
395

396
    @property
3✔
397
    def union_membership(self) -> UnionMembership:
3✔
398
        """An instance of `UnionMembership` with all the test's registered `UnionRule`s."""
399
        return self.request(UnionMembership, [])
3✔
400

401
    def new_session(self, build_id: str) -> None:
3✔
402
        """Mutates this RuleRunner to begin a new Session with the same Scheduler."""
403
        self.scheduler = self.scheduler.scheduler.new_session(build_id)
×
404

405
    def request(self, output_type: type[_O], inputs: Iterable[Any]) -> _O:
3✔
406
        params = (
3✔
407
            Params(*inputs, self.inherent_environment)
408
            if self.inherent_environment
409
            else Params(*inputs)
410
        )
411
        with self.pushd():
3✔
412
            result = assert_single_element(self.scheduler.product_request(output_type, [params]))
3✔
413
        return cast(_O, result)
3✔
414

415
    def run_goal_rule(
3✔
416
        self,
417
        goal: type[Goal],
418
        *,
419
        global_args: Iterable[str] | None = None,
420
        args: Iterable[str] | None = None,
421
        env: Mapping[str, str] | None = None,
422
        env_inherit: set[str] | None = None,
423
    ) -> GoalRuleResult:
424
        merged_args = (*(global_args or []), goal.name, *(args or []))
3✔
425
        self.set_options(merged_args, env=env, env_inherit=env_inherit)
3✔
426

427
        with self.pushd():
3✔
428
            raw_specs = self.options_bootstrapper.full_options_for_scopes(
3✔
429
                [GlobalOptions.get_scope_info(), goal.subsystem_cls.get_scope_info()],
430
                self.union_membership,
431
            ).specs
432
        specs = SpecsParser(root_dir=self.build_root).parse_specs(
3✔
433
            raw_specs, description_of_origin="RuleRunner.run_goal_rule()"
434
        )
435

436
        stdout, stderr = StringIO(), StringIO()
3✔
437
        console = Console(stdout=stdout, stderr=stderr, use_colors=False, session=self.scheduler)
3✔
438

439
        with self.pushd():
3✔
440
            exit_code = self.scheduler.run_goal_rule(
3✔
441
                goal,
442
                Params(
443
                    specs,
444
                    console,
445
                    Workspace(self.scheduler),
446
                    *([self.inherent_environment] if self.inherent_environment else []),
447
                ),
448
            )
449

450
        console.flush()
3✔
451
        return GoalRuleResult(exit_code, stdout.getvalue(), stderr.getvalue())
3✔
452

453
    @contextmanager
3✔
454
    def pushd(self):
3✔
455
        with pushd(self.build_root):
3✔
456
            yield
3✔
457

458
    def create_options_bootstrapper(
3✔
459
        self, args: Iterable[str], env: Mapping[str, str] | None
460
    ) -> OptionsBootstrapper:
461
        return create_options_bootstrapper(args=args, env=env)
3✔
462

463
    def set_options(
3✔
464
        self,
465
        args: Iterable[str],
466
        *,
467
        env: Mapping[str, str] | None = None,
468
        env_inherit: set[str] | None = None,
469
    ) -> None:
470
        """Update the engine session with new options and/or environment variables.
471

472
        The environment variables will be used to set the `CompleteEnvironmentVars`, which is the
473
        environment variables captured by the parent Pants process. Some rules use this to be able
474
        to read arbitrary env vars. Any options that start with `PANTS_` will also be used to set
475
        options.
476

477
        Environment variables listed in `env_inherit` and not in `env` will be inherited from the test
478
        runner's environment (os.environ)
479

480
        This will override any previously configured values.
481
        """
482
        env = {
3✔
483
            **{k: os.environ[k] for k in (env_inherit or set()) if k in os.environ},
484
            **(env or {}),
485
        }
486
        with self.pushd():
3✔
487
            self.options_bootstrapper = self.create_options_bootstrapper(args=args, env=env)
3✔
488
        self.environment = CompleteEnvironmentVars(env)
3✔
489
        self._set_new_session(self.scheduler.scheduler)
3✔
490

491
    def set_session_values(
3✔
492
        self,
493
        extra_session_values: dict[Any, Any],
494
    ) -> None:
495
        """Update the engine Session with new session_values."""
496
        self.extra_session_values = extra_session_values
×
497
        self._set_new_session(self.scheduler.scheduler)
×
498

499
    def _invalidate_for(self, *relpaths: str):
3✔
500
        """Invalidates all files from the relpath, recursively up to the root.
501

502
        Many python operations implicitly create parent directories, so we assume that touching a
503
        file located below directories that do not currently exist will result in their creation.
504
        """
505
        files = {f for relpath in relpaths for f in recursive_dirname(relpath)}
3✔
506
        return self.scheduler.invalidate_files(files)
3✔
507

508
    def chmod(self, relpath: str | PurePath, mode: int) -> None:
3✔
509
        """Change the file mode and permissions.
510

511
        relpath: The relative path to the file or directory from the build root.
512
        mode: The file mode to set, preferable in octal representation, e.g. `mode=0o750`.
513
        """
514
        Path(self.build_root, relpath).chmod(mode)
×
515
        self._invalidate_for(str(relpath))
×
516

517
    def create_dir(self, relpath: str) -> str:
3✔
518
        """Creates a directory under the buildroot.
519

520
        :API: public
521

522
        relpath: The relative path to the directory from the build root.
523
        """
524
        path = os.path.join(self.build_root, relpath)
×
525
        safe_mkdir(path)
×
526
        self._invalidate_for(relpath)
×
527
        return path
×
528

529
    def _create_file(
3✔
530
        self, relpath: str | PurePath, contents: bytes | str = "", mode: str = "w"
531
    ) -> str:
532
        """Writes to a file under the buildroot.
533

534
        relpath: The relative path to the file from the build root.
535
        contents: A string containing the contents of the file - '' by default..
536
        mode: The mode to write to the file in - over-write by default.
537
        """
538
        path = os.path.join(self.build_root, relpath)
3✔
539
        with safe_open(path, mode=mode) as fp:
3✔
540
            fp.write(contents)
3✔
541
        self._invalidate_for(str(relpath))
3✔
542
        return path
3✔
543

544
    @overload
545
    def write_files(self, files: Mapping[str, str | bytes]) -> tuple[str, ...]: ...
546

547
    @overload
548
    def write_files(self, files: Mapping[PurePath, str | bytes]) -> tuple[str, ...]: ...
549

550
    def write_files(
3✔
551
        self, files: Mapping[PurePath, str | bytes] | Mapping[str, str | bytes]
552
    ) -> tuple[str, ...]:
553
        """Write the files to the build root.
554

555
        :API: public
556

557
        files: A mapping of file names to contents.
558
        returns: A tuple of absolute file paths created.
559
        """
560
        paths = []
3✔
561
        for path, content in files.items():
3✔
562
            paths.append(
3✔
563
                self._create_file(path, content, mode="wb" if isinstance(content, bytes) else "w")
564
            )
565
        return tuple(paths)
3✔
566

567
    def read_file(self, file: str | PurePath, mode: str = "r") -> str | bytes:
3✔
568
        """Read a file that was written to the build root, useful for testing."""
569
        path = os.path.join(self.build_root, file)
×
570
        with safe_open(path, mode=mode) as fp:
×
571
            if "b" in mode:
×
572
                return bytes(fp.read())
×
573
            return str(fp.read())
×
574

575
    def make_snapshot(self, files: Mapping[str, str | bytes]) -> Snapshot:
3✔
576
        """Makes a snapshot from a map of file name to file content.
577

578
        :API: public
579
        """
580
        file_contents = [
3✔
581
            FileContent(path, content.encode() if isinstance(content, str) else content)
582
            for path, content in files.items()
583
        ]
584
        digest = self.request(Digest, [CreateDigest(file_contents)])
3✔
585
        return self.request(Snapshot, [digest])
3✔
586

587
    def make_snapshot_of_empty_files(self, files: Iterable[str]) -> Snapshot:
3✔
588
        """Makes a snapshot with empty content for each file.
589

590
        This is a convenience around `TestBase.make_snapshot`, which allows specifying the content
591
        for each file.
592

593
        :API: public
594
        """
595
        return self.make_snapshot(dict.fromkeys(files, ""))
×
596

597
    def get_target(self, address: Address) -> Target:
3✔
598
        """Find the target for a given address.
599

600
        This requires that the target actually exists, i.e. that you set up its BUILD file.
601

602
        :API: public
603
        """
604
        return self.request(
3✔
605
            WrappedTarget,
606
            [WrappedTargetRequest(address, description_of_origin="RuleRunner.get_target()")],
607
        ).target
608

609
    def write_digest(
3✔
610
        self, digest: Digest, *, path_prefix: str | None = None, clear_paths: Sequence[str] = ()
611
    ) -> None:
612
        """Write a digest to disk, relative to the test's build root.
613

614
        Access the written files by using `os.path.join(rule_runner.build_root, <relpath>)`.
615
        """
616
        native_engine.write_digest(
3✔
617
            self.scheduler.py_scheduler,
618
            self.scheduler.py_session,
619
            digest,
620
            path_prefix or "",
621
            clear_paths,
622
        )
623

624
    def run_interactive_process(self, request: InteractiveProcess) -> InteractiveProcessResult:
3✔
625
        with self.pushd():
3✔
626
            return native_engine.session_run_interactive_process(
3✔
627
                self.scheduler.py_session,
628
                request,
629
                ProcessExecutionEnvironment(
630
                    environment_name=None,
631
                    platform=Platform.create_for_localhost().value,
632
                    docker_image=None,
633
                    remote_execution=False,
634
                    remote_execution_extra_platform_properties=[],
635
                    execute_in_workspace=False,
636
                    keep_sandboxes="never",
637
                ),
638
            )
639

640
    def do_not_use_mock(self, output_type: type[Any], input_types: Iterable[type[Any]]) -> MockGet:
3✔
641
        """Returns a `MockGet` whose behavior is to run the actual rule using this `RuleRunner`"""
642
        return MockGet(
×
643
            output_type=output_type,
644
            input_types=tuple(input_types),
645
            mock=lambda *input_values: self.request(output_type, input_values),
646
        )
647

648

649
# -----------------------------------------------------------------------------------------------
650
# `run_rule_with_mocks()`
651
# -----------------------------------------------------------------------------------------------
652

653

654
@dataclass(frozen=True)
3✔
655
class MockEffect(Generic[_O]):
3✔
656
    output_type: type[_O]
3✔
657
    input_types: tuple[type, ...]
3✔
658
    mock: Callable[..., _O]
3✔
659

660

661
@dataclass(frozen=True)
3✔
662
class MockGet(Generic[_O]):
3✔
663
    output_type: type[_O]
3✔
664
    input_types: tuple[type, ...]
3✔
665
    mock: Callable[..., _O]
3✔
666

667

668
def run_rule_with_mocks(
3✔
669
    rule: Callable[..., Coroutine[Any, Any, _O]],
670
    *,
671
    rule_args: Sequence[Any] = (),
672
    mock_gets: Sequence[MockGet | MockEffect] = (),
673
    mock_calls: Mapping[str, Callable] | None = None,
674
    union_membership: UnionMembership | None = None,
675
    show_warnings: bool = True,
676
) -> _O:
677
    """A test helper that runs an @rule with a set of args and mocked underlying @rule invocations.
678

679
    An @rule named `my_rule` that takes one argument and invokes no other @rules (by-name  or via
680
    `Get` requests) can be invoked like so:
681

682
    ```
683
    return_value = run_rule_with_mocks(my_rule, rule_args=[arg1])
684
    ```
685

686
    In the case of an @rule that invokes other @rules, either by name or via `Get` requests, things
687
    get more interesting: either or both of the `mock_calls` and `mock_gets` arguments must be
688
    provided.
689

690
    - `mock_calls` is a mapping of fully-qualified rule name to the function that mocks that rule,
691
      and mocks out calls by name to the corresponding rules.
692
    - `mock_gets` is a sequence of `MockGet`s and `MockEffect`s. Each MockGet takes the Product and
693
      Subject type, along with a one-argument function that takes a subject value and returns a
694
      product value.
695

696
    So in the case of an @rule named `my_co_rule` that takes one argument and calls the @rule
697
    `path.to.module.list_dir` by name to produce a `Listing` from a `Dir`, the invoke might look
698
    like:
699

700
    ```
701
    return_value = run_rule_with_mocks(
702
      my_co_rule,
703
      rule_args=[arg1],
704
      mock_calls={
705
        "path.to.module.list_dir": lambda dir_subject: Listing(..),
706
      },
707
    )
708
    ```
709

710
    And if that same rule uses a Get request for a product type `Listing` with subject type `Dir`,
711
    the invoke might look like:
712

713
    ```
714
    return_value = run_rule_with_mocks(
715
      my_co_rule,
716
      rule_args=[arg1],
717
      mock_gets=[
718
        MockGet(
719
          output_type=Listing,
720
          input_type=Dir,
721
          mock=lambda dir_subject: Listing(..),
722
        ),
723
      ],
724
    )
725
    ```
726

727
    If any of the @rule's Get requests involve union members, you should pass a `UnionMembership`
728
    mapping the union base to any union members you'd like to test. For example, if your rule has
729
    `await Get(TestResult, TargetAdaptor, target_adaptor)`, you may pass
730
    `UnionMembership({TargetAdaptor: PythonTestsTargetAdaptor})` to this function.
731

732
    :returns: The return value of the completed @rule.
733
    """
734
    mock_calls = mock_calls or {}
3✔
735

736
    task_rule = getattr(rule, "rule", None)
3✔
737

738
    func: Callable[..., Coroutine[Any, Any, _O]] | Callable[..., _O]
739

740
    # Perform additional validation on `@rule` that the correct args are provided. We don't have
741
    # an easy way to do this for async helper calls yet.
742
    if task_rule:
3✔
743
        if len(rule_args) != len(task_rule.parameters):
3✔
744
            raise ValueError(
×
745
                "Error running rule with mocks:\n"
746
                f"Rule {task_rule.func.__qualname__} expected to receive arguments of the "
747
                f"form: {task_rule.parameters}; got: {rule_args}"
748
            )
749

750
        # Access the original function, rather than the trampoline that we would get by calling
751
        # it directly.
752
        func = task_rule.func
3✔
753
    else:
754
        func = rule
×
755

756
    res = func(*(rule_args or ()))
3✔
757
    if not isinstance(res, (Coroutine, Generator)):
3✔
758
        return res
×
759

760
    unconsumed_mock_calls = set(mock_calls.keys())
3✔
761
    unconsumed_mock_gets = set(mock_gets)
3✔
762

763
    def get(res: Get | Effect | Call | Coroutine):
3✔
764
        if isinstance(res, Coroutine):
3✔
765
            # A call-by-name element in a concurrently() is a Coroutine whose frame is
766
            # the trampoline wrapper that creates and immediately awaits the Call.
767
            locals = inspect.getcoroutinelocals(res)
×
768
            assert locals is not None
×
769
            rule_id = locals["rule_id"]
×
770
            args = locals["args"]
×
771
            kwargs = dict(locals["kwargs"])
×
772
            __implicitly = locals.get("__implicitly")
×
773
            if __implicitly:
×
774
                kwargs["__implicitly"] = __implicitly
×
775
            mock_call = mock_calls.get(rule_id)
×
776
            if mock_call:
×
777
                unconsumed_mock_calls.discard(rule_id)
×
778
                # Close the original, unmocked, coroutine, to prevent the "was never awaited"
779
                # warning polluting stderr data that the test may examine.
780
                res.close()
×
781
                return mock_call(*args, **kwargs)
×
782
            raise AssertionError(f"No mock_call provided for {rule_id}.")
×
783
        elif isinstance(res, Call):
3✔
784
            mock_call = mock_calls.get(res.rule_id)
3✔
785
            if mock_call:
3✔
786
                unconsumed_mock_calls.discard(res.rule_id)
3✔
787
                return mock_call(*res.inputs)
3✔
788
            # For now we fall through, to allow an old-style MockGet to mock a call-by-name, for
789
            # legacy reasons. But we will deprecate and then remove this in the future, at which
790
            # point we should AssertionError error here as well.
791
            # Note that this fallthrough only works for single call-by-names. When wrapped in a
792
            # concurrently() call, mock_calls *must* be used, hence the error above.
793
            if show_warnings:
×
794
                # Note that we used `warnings` instead of `logger.warning` because the latter may
795
                # get captured or swallowed by the test framework. These warnings will go away
796
                # once we're fully on call-by-name, anyway.
797
                warnings.warn(
×
798
                    f"No mock_call provided for {res.rule_id}, attempting to find a MockGet to "
799
                    "satisfy it. Note that this will soon be deprecated, so we recommend switching "
800
                    "to mock_call ASAP."
801
                )
802

803
        provider = next(
×
804
            (
805
                mock_get
806
                for mock_get in mock_gets
807
                if mock_get.output_type == res.output_type
808
                and all(
809
                    # Either the input type is directly provided.
810
                    input_type in mock_get.input_types
811
                    or (
812
                        # Or the input type is a union and the mock has an input whose
813
                        # type is one of the union members.
814
                        union_membership
815
                        and input_type in union_membership
816
                        and any(
817
                            union_membership.is_member(input_type, t) for t in mock_get.input_types
818
                        )
819
                    )
820
                    for input_type in res.input_types
821
                )
822
            ),
823
            None,
824
        )
825
        if provider is None:
×
826
            raise AssertionError(f"Rule requested: {res}, which cannot be satisfied.")
×
827
        unconsumed_mock_gets.discard(provider)
×
828
        return provider.mock(*res.inputs)
×
829

830
    rule_coroutine = res
3✔
831
    rule_input = None
3✔
832

833
    def warn_on_unconsumed_mocks():
3✔
834
        # Note that we used `warnings` instead of `logger.warning` because the latter may
835
        # get captured or swallowed by the test framework.
836
        if show_warnings:
3✔
837
            if unconsumed_mock_calls:
3✔
838
                warnings.warn(f"Unconsumed mock_calls: {unconsumed_mock_calls}")
3✔
839
            if unconsumed_mock_gets:
3✔
840
                warnings.warn(f"Unconsumed mock_gets: {unconsumed_mock_gets}")
×
841

842
    while True:
3✔
843
        try:
3✔
844
            res = rule_coroutine.send(rule_input)
3✔
845
            if isinstance(res, (Get, Effect, Call)):
3✔
846
                rule_input = get(res)
3✔
847
            elif type(res) in (tuple, list):
×
848
                rule_input = [get(g) for g in res]  # type: ignore[union-attr]
×
849
            else:
850
                warn_on_unconsumed_mocks()
×
851
                return res  # type: ignore[return-value]
×
852
        except StopIteration as e:
3✔
853
            warn_on_unconsumed_mocks()
3✔
854
            return e.value  # type: ignore[no-any-return]
3✔
855

856

857
@contextmanager
3✔
858
def stdin_context(content: bytes | str | None = None):
3✔
859
    if content is None:
3✔
860
        yield open("/dev/null")
3✔
861
    else:
862
        with temporary_file(binary_mode=isinstance(content, bytes)) as stdin_file:
×
863
            stdin_file.write(content)
×
864
            stdin_file.close()
×
865
            yield open(stdin_file.name)
×
866

867

868
@contextmanager
3✔
869
def mock_console(
3✔
870
    options_bootstrapper: OptionsBootstrapper,
871
    *,
872
    stdin_content: bytes | str | None = None,
873
) -> Iterator[tuple[Console, StdioReader]]:
874
    with pushd(get_buildroot()):
3✔
875
        global_bootstrap_options = options_bootstrapper.bootstrap_options.for_global_scope()
3✔
876
        colors = (
3✔
877
            options_bootstrapper.full_options_for_scopes(
878
                [GlobalOptions.get_scope_info()],
879
                UnionMembership.empty(),
880
                allow_unknown_options=True,
881
            )
882
            .for_global_scope()
883
            .colors
884
        )
885

886
    with (
3✔
887
        initialize_stdio(global_bootstrap_options),
888
        stdin_context(stdin_content) as stdin,
889
        temporary_file(binary_mode=False) as stdout,
890
        temporary_file(binary_mode=False) as stderr,
891
        stdio_destination(
892
            stdin_fileno=stdin.fileno(),
893
            stdout_fileno=stdout.fileno(),
894
            stderr_fileno=stderr.fileno(),
895
        ),
896
    ):
897
        # NB: We yield a Console without overriding the destination argument, because we have
898
        # already done a sys.std* level replacement. The replacement is necessary in order for
899
        # InteractiveProcess to have native file handles to interact with.
900
        yield (
3✔
901
            Console(use_colors=colors),
902
            StdioReader(_stdout=Path(stdout.name), _stderr=Path(stderr.name)),
903
        )
904

905

906
@dataclass
3✔
907
class StdioReader:
3✔
908
    _stdout: Path
3✔
909
    _stderr: Path
3✔
910

911
    def get_stdout(self) -> str:
3✔
912
        """Return all data that has been flushed to stdout so far."""
913
        return self._stdout.read_text()
3✔
914

915
    def get_stderr(self) -> str:
3✔
916
        """Return all data that has been flushed to stderr so far."""
917
        return self._stderr.read_text()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc