• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

kazewong / flowMC / 13655281414

04 Mar 2025 01:55PM UTC coverage: 67.835%. Remained the same
13655281414

push

github

kazewong
format doc strings

1 of 1 new or added line in 1 file covered. (100.0%)

162 existing lines in 15 files now uncovered.

987 of 1455 relevant lines covered (67.84%)

1.36 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/flowMC/utils/postprocessing.py
1
import jax.numpy as jnp
×
2
import matplotlib.pyplot as plt
×
3

4
from flowMC import Sampler
×
5

6

7
def plot_summary(sampler: Sampler, **plotkwargs) -> None:
×
8
    """Create plots of the most important quantities in the summary.
9

10
    Args:
11
        training (bool, optional): If True, plot training quantities. If False, plot production quantities. Defaults to False.
12
    """
UNCOV
13
    keys = ["local_accs", "global_accs", "log_prob"]
×
14

15
    # Check if outdir is property of sampler
UNCOV
16
    if hasattr(sampler, "outdir"):
×
17
        outdir = sampler.outdir
×
18
    else:
UNCOV
19
        outdir = "./outdir/"
×
20

UNCOV
21
    if outdir[-1] != "/":
×
22
        outdir += "/"
×
23

UNCOV
24
    training_sampler_state = sampler.get_sampler_state(training=True)
×
25

UNCOV
26
    _loss_val_plot(training_sampler_state["loss_vals"], outdir=outdir, **plotkwargs)
×
27

UNCOV
28
    production_sampler_state = sampler.get_sampler_state(training=False)
×
29

UNCOV
30
    for key in keys:
×
31
        training_data = training_sampler_state[key]
×
32
        production_data = production_sampler_state[key]
×
33
        _stacked_plot(
×
34
            training_data,
35
            production_data,
36
            key,
37
            outdir=outdir,
38
            **plotkwargs,
39
        )
40

41

UNCOV
42
def _stacked_plot(
×
43
    training_data: dict,
44
    production_data: dict,
45
    name: str,
46
    outdir: str = "./outdir/",
47
    **plotkwargs,
48
):
UNCOV
49
    training_data_mean = jnp.mean(training_data, axis=0)
×
50
    production_data_mean = jnp.mean(production_data, axis=0)
×
51
    x_training = list(range(1, len(training_data_mean) + 1))
×
52
    x_production = list(range(1, len(production_data_mean) + 1))
×
53

UNCOV
54
    figsize = plotkwargs.get("figsize", (15, 10))
×
55
    alpha = plotkwargs.get("alpha", 1)
×
56
    eps = 1e-3
×
57

UNCOV
58
    fig, ax = plt.subplots(2, 1, figsize=figsize, sharex=True, sharey=True)
×
59
    ax[0].plot(
×
60
        x_training, training_data_mean, linestyle="-", color="#3498DB", alpha=alpha
61
    )
UNCOV
62
    ax[1].plot(
×
63
        x_production, production_data_mean, linestyle="-", color="#3498DB", alpha=alpha
64
    )
UNCOV
65
    ax[0].set_ylabel(f"{name} (training)")
×
66
    ax[1].set_ylabel(f"{name} (production)")
×
67
    plt.xlabel("Iteration")
×
68
    if "acc" in name:
×
69
        plt.ylim(0 - eps, 1 + eps)
×
70
    plt.savefig(f"{outdir}{name}.png", bbox_inches="tight")
×
71

72

UNCOV
73
def _loss_val_plot(
×
74
    data,
75
    outdir: str = "./outdir/",
76
    **plotkwargs,
77
):
78
    # Get plot kwargs
UNCOV
79
    figsize = plotkwargs["figsize"] if "figsize" in plotkwargs else (12, 8)
×
80
    alpha = plotkwargs["alpha"] if "alpha" in plotkwargs else 1
×
81

UNCOV
82
    data_to_plot = data.reshape(-1)
×
83
    x = list(range(1, len(data_to_plot) + 1))
×
84

85
    # Plot
UNCOV
86
    plt.figure(figsize=figsize)
×
87
    plt.plot(x, data_to_plot, linestyle="-", color="#3498DB", alpha=alpha)
×
88
    plt.xlabel("Iteration")
×
89
    plt.ylabel("loss")
×
90
    # Extras for some variables:
UNCOV
91
    plt.savefig(f"{outdir}loss.png", bbox_inches="tight")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc