• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nuclear-multimessenger-astronomy / nmma / 21710444803

05 Feb 2026 11:53AM UTC coverage: 48.565% (+16.3%) from 32.286%
21710444803

push

github

web-flow
Merge pull request #411 from nuclear-multimessenger-astronomy/direct_eos_sampling

ENH: Merge to reach nmma 1.0

3464 of 6955 new or added lines in 48 files covered. (49.81%)

9 existing lines in 5 files now uncovered.

4280 of 8813 relevant lines covered (48.56%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

10.4
/nmma/post_processing/marginalisation.py
1
import os
2✔
2
import numpy as np
2✔
3
import h5py
2✔
4
import matplotlib.pyplot as plt
2✔
5

6

7
from ..eos.eos_processing import load_tabulated_macro_eos_set_to_dict, EoSConverter
2✔
8
from gwpy.table import Table
2✔
9

10
from ..em.lightcurve_generation import create_light_curve_data
2✔
11
from ..em import io, model, utils, em_parsing as emp
2✔
12
from ..em.plotting_utils import lc_plot_with_histogram
2✔
13

14
from ..core import conversion as conv 
2✔
15
from ..core.utils import read_trigger_time
2✔
16

17
def marginalised_lightcurve_expectation_from_gw_samples(args=None):
2✔
18
    """Routine to generate a marginalized set of light curves from a set of GW samples. These need to be parsed as template-files, h5-file or coincidence files."""
NEW
19
    args = emp.parsing_and_logging(emp.lc_marginalisation_parser, args)
×
20

NEW
21
    rng = np.random.default_rng(args.generation_seed)
×
NEW
22
    args.mag_error_scale = 0
×
NEW
23
    filters = utils.set_filters(args)
×
NEW
24
    if filters is None:
×
NEW
25
        filters = 'u,g,r,i,z,y,J,H,K'
×
NEW
26
    light_curve_model = model.create_light_curve_model_from_args(args.em_model, args, filters)
×
NEW
27
    conversion = conv.MultimessengerConversion.basic_bns(EoSConverter(args, method = 'tabulated'),
×
28
        light_curve_model.parameter_conversion)
29

30
    ## read eos and gw data
NEW
31
    EOS_data, weights, Neos = load_tabulated_macro_eos_set_to_dict(args.eos_data, args.eos_weights)
×
NEW
32
    args.Neos = Neos
×
33

NEW
34
    if args.template_file is not None:
×
NEW
35
        try:
×
NEW
36
            names = ["SNRdiff", "erf", "weight", "m1", "m2", "a1", "a2", "dist"]
×
NEW
37
            data_out = Table.read(args.template_file, names=names, format="ascii")
×
NEW
38
        except Exception:
×
NEW
39
            names = ["SNRdiff", "erf", "weight", "m1", "m2", "dist"]
×
NEW
40
            data_out = Table.read(args.template_file, names=names, format="ascii")
×
41

NEW
42
    elif args.hdf5_file is not None:
×
NEW
43
        f = h5py.File(args.hdf5_file, "r")
×
NEW
44
        posterior = f["lalinference"]["lalinference_mcmc"]["posterior_samples"][()]
×
NEW
45
        data_out = Table(posterior)
×
NEW
46
        args.gps = np.median(data_out["t0"])
×
47

48

NEW
49
    elif args.coinc_file is not None:
×
NEW
50
        from ligo.skymap import bayestar, distance, io as lio
×
NEW
51
        data_out = Table.read(args.coinc_file, format="ligolw", tablename="sngl_inspiral")
×
NEW
52
        data_out["m1"], data_out["m2"] = data_out["mass1"], data_out["mass2"]
×
NEW
53
        skymap = lio.fits.read_sky_map(args.skymap, moc=True, distances=True)
×
NEW
54
        skymap = bayestar.rasterize(skymap, order=9)
×
NEW
55
        dist_mean, dist_std = distance.parameters_to_marginal_moments(
×
56
            skymap["PROB"], skymap["DISTMU"], skymap["DISTSIGMA"]
57
        )
NEW
58
        data_out["dist"] = dist_mean + rng.standard_normal(len(data_out["m1"])) * dist_std
×
59

60
    else:
NEW
61
        print("Needs template_file, hdf5_file, or coinc_file")
×
NEW
62
        exit(1)
×
63
    
NEW
64
    data_out = get_all_gw_quantities(data_out)
×
65

66

NEW
67
    idxs = rng.choice(np.arange(len(weights)), args.Nmarg, p=weights)
×
NEW
68
    idys = rng.choice(np.arange(len(data_out["m1"])), args.Nmarg,
×
69
        p=data_out["weight"] / np.sum(data_out["weight"]) )
70

NEW
71
    mag_ds, matter = [], []
×
NEW
72
    for ii in range(args.Nmarg):
×
73

NEW
74
        outdir = os.path.join(args.outdir, "%d" % ii)
×
NEW
75
        os.makedirs(outdir, exist_ok=True)
×
76

NEW
77
        lightcurve_outfile = os.path.join(outdir, "lc.dat")
×
NEW
78
        matter_outfile = os.path.join(outdir, "matter.dat")
×
NEW
79
        if os.path.isfile(lightcurve_outfile) and os.path.isfile(matter_outfile):
×
NEW
80
            mag_ds.append(io.read_lc_from_csv(lightcurve_outfile, args, format='model'))
×
NEW
81
            matter.append(np.loadtxt(matter_outfile))
×
NEW
82
            continue
×
83

NEW
84
        idx, idy = int(idxs[ii]), int(idys[ii])
×
NEW
85
        m1, m2 = data_out["m1"][idy], data_out["m2"][idy]
×
86

NEW
87
        mMax = np.max(EOS_data[idx]["M"]) 
×
88

NEW
89
        params = {
×
90
            "luminosity_distance": data_out["dist"][idy],
91
            "chirp_mass": data_out["mchirp"][idy],
92
            "ratio_epsilon": 1e-20,
93
            "theta_jn": data_out["theta_jn"][idy],
94
            "a_1": data_out["a1"][idy],
95
            "a_2": data_out["m2"][idy],
96
            "mass_1": m1,
97
            "mass_2": m2,
98
            "EOS": idx,
99
            "cos_tilt_1": np.cos(data_out["tilt1"][idy]),
100
            "cos_tilt_2": np.cos(data_out["tilt2"][idy]),
101
            "KNphi": 30,
102
        }
103

NEW
104
        log10zeta_min, log10zeta_max = -3, 0
×
NEW
105
        zeta = 10 ** rng.uniform(log10zeta_min, log10zeta_max)
×
106
        
NEW
107
        if (m1 < mMax) and (m2 < mMax):
×
NEW
108
            alpha_min, alpha_max = 1e-2, 2e-2
×
NEW
109
            alpha = rng.uniform(alpha_min, alpha_max)
×
NEW
110
            log10_alpha = np.log10(alpha)
×
NEW
111
        elif (m1 > mMax) and (m2 < mMax):
×
NEW
112
            log10_alpha_min, log10_alpha_max = -3, -1
×
NEW
113
            log10_alpha = rng.uniform(log10_alpha_min, log10_alpha_max)
×
NEW
114
            alpha = 10 ** log10_alpha
×
NEW
115
        params.update({ "alpha"         : alpha,
×
116
                        "ratio_zeta"    : zeta})
117
        
NEW
118
        complete_parameters = conversion.convert_to_multimessenger_parameters(params)
×
119

NEW
120
        log10_mej_dyn = complete_parameters["log10_mej_dyn"].item()
×
NEW
121
        log10_mej_wind = complete_parameters["log10_mej_wind"].item()
×
NEW
122
        with open(matter_outfile, "w") as fid:
×
NEW
123
            if np.isfinite(log10_mej_dyn):
×
NEW
124
                fid.write(f"1 {log10_mej_dyn:.5f} {log10_mej_wind:.5f}\n")
×
125
            else:
NEW
126
                fid.write("0 0 0\n")
×
127
        # initialize light curve model
NEW
128
        complete_parameters['trigger_time'] = read_trigger_time(complete_parameters, args)
×
NEW
129
        sample_times = utils.setup_sample_times(args)
×
NEW
130
        data = create_light_curve_data(
×
131
            complete_parameters, args, light_curve_model, sample_times, rng=rng,
132
        )
133

NEW
134
        filters = list(data.keys())
×
135

NEW
136
        io.write_lc_to_csv(lightcurve_outfile, data, 'model')
×
137

NEW
138
        mag_ds.append(io.read_lc_from_csv(lightcurve_outfile, args, format='model'))
×
NEW
139
        matter.append( np.loadtxt(matter_outfile))
×
140

NEW
141
    if args.plot:
×
NEW
142
        NS, dyn, wind = [], [], []
×
NEW
143
        for matter_data in  matter:
×
NEW
144
            NS.append(matter_data[0])
×
NEW
145
            dyn.append(matter_data[1])
×
NEW
146
            wind.append(matter_data[2])
×
147

NEW
148
        print("Fraction of samples with NS: %.5f" % (np.sum(NS) / len(NS)))
×
149

NEW
150
        bins = np.linspace(-3, 0, 25)
×
NEW
151
        dyn_hist, bin_edges = np.histogram(dyn, bins=bins, density=True)
×
NEW
152
        wind_hist, bin_edges = np.histogram(wind, bins=bins, density=True)
×
NEW
153
        bins = (bin_edges[1:] + bin_edges[:-1]) / 2.0
×
154

NEW
155
        plotName = os.path.join(args.outdir, "matter.pdf")
×
NEW
156
        fig = plt.figure(figsize=(10, 6))
×
NEW
157
        plt.step(bins, dyn_hist, "k--", label="Dynamical")
×
NEW
158
        plt.step(bins, wind_hist, "b-", label="Wind")
×
NEW
159
        plt.xlabel(r"log10(Ejecta Mass / $M_\odot$)")
×
NEW
160
        plt.ylabel("Probability Density Function")
×
NEW
161
        plt.legend()
×
NEW
162
        plt.tight_layout()
×
NEW
163
        plt.savefig(plotName, bbox_inches="tight")
×
NEW
164
        plt.close()
×
165

NEW
166
        plotpath= os.path.join(args.outdir, "lc.pdf")
×
NEW
167
        plot_dict = {filt: np.vstack(
×
168
                        [lc_data[filt]['mag'] for lc_data in mag_ds])
169
                        for filt in filters}
NEW
170
        times = next(iter(mag_ds[0].values()))['time']
×
NEW
171
        if args.absolute:
×
NEW
172
            ylim = getattr(args, 'ylim', [-12, -18])
×
173
        else:
NEW
174
            ylim = getattr(args, 'ylim', [24, 15])
×
NEW
175
        lc_plot_with_histogram(filters, plot_dict, times, plotpath, ylim=ylim, fontsize=30)
×
176

177
def get_all_gw_quantities(data_out):
2✔
NEW
178
    try:
×
NEW
179
        data_out["mchirp"], data_out["eta"], data_out["q"] = conv.component_masses_to_mass_quantities(
×
180
            data_out["m1"], data_out["m2"]
181
        )
182
    except KeyError:
183
        data_out["eta"] = conv.mass_ratio_to_eta(data_out["q"])
184
        data_out["mchirp"] = data_out["mc"]
185
        data_out["m1"], data_out["m2"] = conv.chirp_mass_and_eta_to_component_masses(data_out["mchirp"], data_out["eta"])
186
    
187
    
NEW
188
    data_out["weight"] = 1.0 / len(data_out["m1"])
×
189

190

NEW
191
    data_out["chi_eff"] = (
×
192
        data_out["m1"] * data_out["a1"] + data_out["m2"] * data_out["a2"]
193
    ) / (data_out["m1"] + data_out["m2"])
194

NEW
195
    for key in ["a1", "a2", "theta_jn", "tilt1", "tilt2"]:
×
NEW
196
        if key not in data_out.keys():
×
NEW
197
            data_out[key] = 0.
×
NEW
198
    try:
×
NEW
199
        data_out["a1"], data_out["a2"] = data_out["spin1z"], data_out["spin2z"]
×
200
    except KeyError:
201
        pass
NEW
202
    return data_out
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc