• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

abhisrkckl / Vela.jl / 14772600346

01 May 2025 08:41AM UTC coverage: 94.118% (+0.1%) from 94.002%
14772600346

push

github

web-flow
Merge pull request #211 from abhisrkckl/priors

Plot priors in `pyvela-plot`

22 of 22 new or added lines in 2 files covered. (100.0%)

1 existing line in 1 file now uncovered.

1024 of 1088 relevant lines covered (94.12%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.68
/pyvela/pyvela/pyvela_plot_script.py
1
import json
1✔
2
from argparse import ArgumentParser
1✔
3
from typing import Iterable
1✔
4

5
import corner
1✔
6
import matplotlib.pyplot as plt
1✔
7
import numpy as np
1✔
8
from astropy import units as u
1✔
9
from pint import DMconst, dmu
1✔
10

11

12
def parse_args(argv):
1✔
13
    parser = ArgumentParser(
1✔
14
        prog="pyvela-plot",
15
        description="Create a corner plot from pyvela results.",
16
    )
17
    parser.add_argument(
1✔
18
        "result_dir", help="A directory containing the output of the `pyvela` script."
19
    )
20
    parser.add_argument(
1✔
21
        "-I",
22
        "--ignore_params",
23
        nargs="+",
24
        default=[],
25
        help="Parameters to exclude from the corner plot.",
26
    )
27

28
    return parser.parse_args(argv)
1✔
29

30

31
def get_param_plot_mask(
1✔
32
    param_names: Iterable[str], param_prefixes: Iterable[str], args: ArgumentParser
33
) -> np.ndarray:
34
    ignore_params_default = {
1✔
35
        "PHOFF",
36
        "PLREDSIN_",
37
        "PLREDCOS_",
38
        "PLDMSIN_",
39
        "PLDMCOS_",
40
        "PLCHROMSIN_",
41
        "PLCHROMCOS_",
42
        "DMX_",
43
        "WXSIN_",
44
        "WXCOS_",
45
        "DMWXSIN_",
46
        "DMWXCOS_",
47
        "CMWXSIN_",
48
        "CMWXCOS_",
49
    }
50
    ignore_params = ignore_params_default.union(args.ignore_params)
1✔
51

52
    return [
1✔
53
        idx
54
        for idx, (pname, pprefix) in enumerate(zip(param_names, param_prefixes))
55
        if (pname not in ignore_params and pprefix not in ignore_params)
56
    ]
57

58

59
def read_true_values(args):
1✔
60
    with open(f"{args.result_dir}/summary.json", "r") as summary_file:
1✔
61
        summary = json.load(summary_file)
1✔
62

63
    if (
1✔
64
        "truth_par_file" not in summary["input"]
65
        or summary["input"]["truth_par_file"] is None
66
    ):
UNCOV
67
        return None
×
68

69
    true_values_raw = np.genfromtxt(f"{args.result_dir}/param_true_values.txt")
1✔
70
    scale_factors = np.genfromtxt(f"{args.result_dir}/param_scale_factors.txt")
1✔
71

72
    return true_values_raw / scale_factors
1✔
73

74

75
def main(argv=None):
1✔
76
    args = parse_args(argv)
1✔
77

78
    samples = np.load(f"{args.result_dir}/samples.npy")
1✔
79
    param_names = np.genfromtxt(f"{args.result_dir}/param_names.txt", dtype=str)
1✔
80
    param_prefixes = np.genfromtxt(f"{args.result_dir}/param_prefixes.txt", dtype=str)
1✔
81
    with open(f"{args.result_dir}/param_units.txt", "r") as f:
1✔
82
        param_units = np.array([s.strip() for s in f.readlines()])
1✔
83

84
    param_plot_mask = get_param_plot_mask(param_names, param_prefixes, args)
1✔
85

86
    plot_labels = [
1✔
87
        f"{pname}\n{punit if punit != '1' else ''}"
88
        for pname, punit in zip(
89
            param_names[param_plot_mask], param_units[param_plot_mask]
90
        )
91
    ]
92

93
    residuals_data = np.genfromtxt(f"{args.result_dir}/residuals.txt")
1✔
94
    wb = residuals_data.shape[1] == 7
1✔
95
    if wb:
1✔
96
        mjds, tres, tres_w, terr, dres, dres_w, derr = residuals_data.T
1✔
97
        dres = (dres * u.Hz / DMconst).to_value(dmu)
1✔
98
        dres_w = (dres_w * u.Hz / DMconst).to_value(dmu)
1✔
99
        derr = (derr * u.Hz / DMconst).to_value(dmu)
1✔
100
    else:
101
        mjds, tres, tres_w, terr = residuals_data.T
1✔
102

103
    true_values_all = read_true_values(args)
1✔
104
    true_values = (
1✔
105
        true_values_all[param_plot_mask] if true_values_all is not None else None
106
    )
107

108
    fig = corner.corner(
1✔
109
        samples[:, param_plot_mask],
110
        labels=plot_labels,
111
        label_kwargs={"fontsize": 9},
112
        labelpad=0.2,
113
        max_n_ticks=3,
114
        plot_datapoints=False,
115
        hist_kwargs={"density": True},
116
        range=[0.999] * len(param_plot_mask),
117
        truths=true_values,
118
    )
119

120
    for ax in fig.get_axes():
1✔
121
        ax.tick_params(axis="both", labelsize=8)
1✔
122
        ax.yaxis.get_offset_text().set_fontsize(8)
1✔
123
        ax.xaxis.get_offset_text().set_fontsize(8)
1✔
124

125
    # Plot the pre-evaluated priors
126
    prior_evals = np.load(f"{args.result_dir}/prior_evals.npy")
1✔
127
    nplots = len(param_plot_mask)
1✔
128
    for jj, ii in enumerate(param_plot_mask):
1✔
129
        plt.subplot(nplots, nplots, jj * (nplots + 1) + 1)
1✔
130
        xs = prior_evals[:, 2 * ii]
1✔
131
        ys = prior_evals[:, 2 * ii + 1]
1✔
132
        plt.plot(xs, ys)
1✔
133

134
    ax = plt.subplot(5, 2, 2)
1✔
135
    ax.errorbar(
1✔
136
        mjds, tres, terr, marker="+", ls="", alpha=1, color="orange", label="Pre-fit"
137
    )
138
    ax.set_ylabel("Time res (pre) (s)")
1✔
139
    # ax.legend()
140

141
    ax1 = ax.twinx()
1✔
142
    ax1.errorbar([], [], [], ls="", marker="+", color="orange", label="Pre-fit")
1✔
143
    ax1.errorbar(
1✔
144
        mjds,
145
        tres_w,
146
        terr,
147
        marker="+",
148
        ls="",
149
        alpha=0.4,
150
        color="blue",
151
        label="Post fit whitened",
152
    )
153
    ax1.legend()
1✔
154
    ax1.set_ylabel("Time res (post) (s)")
1✔
155
    ax1.axhline(0, ls="dotted", color="k")
1✔
156

157
    if wb:
1✔
158
        plt.xticks([])
1✔
159
        ax = plt.subplot(5, 2, 4)
1✔
160
        ax.errorbar(
1✔
161
            mjds,
162
            dres,
163
            derr,
164
            marker="+",
165
            ls="",
166
            alpha=1,
167
            color="orange",
168
            label="Pre-fit",
169
        )
170

171
        ax.set_ylabel("DM res (pre) (dmu)")
1✔
172

173
        ax1 = ax.twinx()
1✔
174
        ax1.errorbar(
1✔
175
            mjds,
176
            dres_w,
177
            derr,
178
            marker="+",
179
            ls="",
180
            alpha=0.4,
181
            color="blue",
182
            label="Post fit whitened",
183
        )
184
        ax1.set_ylabel("DM res (post) (dmu)")
1✔
185
        ax1.axhline(0, ls="dotted", color="k")
1✔
186

187
    ax.set_xlabel("MJD - PEPOCH")
1✔
188
    plt.show()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc