• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

desy-multimessenger / nuztf / 3748198822

pending completion
3748198822

push

github-actions

simeonreusch
repair test (error due to prettier date in GCN candidate details)

1832 of 2223 relevant lines covered (82.41%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.31
/nuztf/plot.py
1
#!/usr/bin/env python3
2
# License: BSD-3-Clause
3

4
import gzip
1✔
5
import io
1✔
6
import numpy as np
1✔
7
import pandas as pd
1✔
8

9
import matplotlib.pyplot as plt
1✔
10
from matplotlib.ticker import MultipleLocator
1✔
11
from matplotlib.colors import Normalize
1✔
12
from base64 import b64decode
1✔
13

14
from astropy.time import Time
1✔
15
from astropy import units as u
1✔
16
from astropy.io import fits
1✔
17
from astropy import visualization
1✔
18
from ztfquery.utils.stamps import get_ps_stamp
1✔
19
from nuztf.utils import cosmo
1✔
20

21
from nuztf.cat_match import get_cross_match_info
1✔
22
from nuztf.ampel_api import ensure_cutouts
1✔
23

24

25
def alert_to_pandas(alert):
1✔
26

27
    candidate = alert[0]["candidate"]
1✔
28
    prv_candid = alert[0]["prv_candidates"]
1✔
29
    combined = [candidate]
1✔
30
    combined.extend(prv_candid)
1✔
31

32
    df_detections_list = []
1✔
33
    df_ulims_list = []
1✔
34

35
    for cand in combined:
1✔
36
        _df = pd.DataFrame().from_dict(cand, orient="index").transpose()
1✔
37
        _df["mjd"] = _df["jd"] - 2400000.5
1✔
38
        if "magpsf" in cand.keys() and "isdiffpos" in cand.keys():
1✔
39
            df_detections_list.append(_df)
1✔
40

41
        else:
42
            df_ulims_list.append(_df)
1✔
43

44
    df_detections = pd.concat(df_detections_list)
1✔
45
    if len(df_ulims_list) > 0:
1✔
46
        df_ulims = pd.concat(df_ulims_list)
1✔
47
    else:
48
        df_ulims = None
×
49

50
    return df_detections, df_ulims
1✔
51

52

53
def lightcurve_from_alert(
1✔
54
    alert: list,
55
    # figsize: list=[6.47, 4],
56
    figsize: list = [8, 5],
57
    title: str = None,
58
    include_ulims: bool = True,
59
    include_cutouts: bool = True,
60
    include_crossmatch: bool = True,
61
    mag_range: list = None,
62
    z: float = None,
63
    legend: bool = False,
64
    grid_interval: int = None,
65
    t_0_mjd: float = None,
66
    logger=None,
67
):
68
    """plot AMPEL alerts as lightcurve"""
69

70
    if logger is None:
1✔
71
        import logging
×
72

73
        logger = logging.getLogger(__name__)
×
74
    else:
75
        logger = logger
1✔
76

77
    if z is not None:
1✔
78
        if np.isnan(z):
×
79
            z = None
×
80
            logger.debug("Redshift is nan, will be ignored")
×
81

82
    # ZTF color and naming scheme
83
    BAND_NAMES = {1: "ZTF g", 2: "ZTF r", 3: "ZTF i"}
1✔
84
    BAND_COLORS = {1: "green", 2: "red", 3: "orange"}
1✔
85

86
    name = alert[0]["objectId"]
1✔
87
    candidate = alert[0]["candidate"]
1✔
88

89
    if include_cutouts:
1✔
90
        if "cutoutScience" in alert[0].keys():
1✔
91
            if "stampData" in alert[0]["cutoutScience"].keys():
×
92
                logger.debug(f"{name}: Cutouts are present.")
×
93
            else:
94
                logger.debug(f"{name}: Cutouts are missing data. Will obtain them")
×
95
                alert = ensure_cutouts(alert, logger=logger)
×
96
        else:
97
            logger.debug(
1✔
98
                "The alert dictionary does not contain cutouts. Will obtain them."
99
            )
100
            alert = ensure_cutouts(alert, logger=logger)
1✔
101

102
    logger.debug(f"Plotting {name}")
1✔
103

104
    df, df_ulims = alert_to_pandas(alert)
1✔
105

106
    fig = plt.figure(figsize=figsize)
1✔
107

108
    if include_cutouts:
1✔
109
        lc_ax1 = fig.add_subplot(5, 4, (9, 19))
1✔
110
        cutoutsci = fig.add_subplot(5, 4, (1, 5))
1✔
111
        cutouttemp = fig.add_subplot(5, 4, (2, 6))
1✔
112
        cutoutdiff = fig.add_subplot(5, 4, (3, 7))
1✔
113
        cutoutps1 = fig.add_subplot(5, 4, (4, 8))
1✔
114
    else:
115
        lc_ax1 = fig.add_subplot(1, 1, 1)
×
116
        fig.subplots_adjust(top=0.8, bottom=0.15)
×
117

118
    plt.subplots_adjust(wspace=0.4, hspace=1.8)
1✔
119

120
    if include_cutouts:
1✔
121
        for cutout_, ax_, type_ in zip(
1✔
122
            [alert[0], alert[0], alert[0]],
123
            [cutoutsci, cutouttemp, cutoutdiff],
124
            ["Science", "Template", "Difference"],
125
        ):
126
            create_stamp_plot(alert=cutout_, ax=ax_, cutout_type=type_)
1✔
127

128
        img = get_ps_stamp(
1✔
129
            candidate["ra"], candidate["dec"], size=240, color=["y", "g", "i"]
130
        )
131
        cutoutps1.imshow(np.asarray(img))
1✔
132
        cutoutps1.set_title("PS1", fontdict={"fontsize": "small"})
1✔
133
        cutoutps1.set_xticks([])
1✔
134
        cutoutps1.set_yticks([])
1✔
135

136
    # If redshift is given, calculate absolute magnitude via luminosity distance
137
    # and plot as right axis
138
    if z is not None:
1✔
139

140
        dist_l = cosmo.luminosity_distance(z).to(u.pc).value
×
141

142
        def mag_to_absmag(mag):
×
143
            absmag = mag - 5 * (np.log10(dist_l) - 1)
×
144
            return absmag
×
145

146
        def absmag_to_mag(absmag):
×
147
            mag = absmag + 5 * (np.log10(dist_l) - 1)
×
148
            return mag
×
149

150
        lc_ax3 = lc_ax1.secondary_yaxis(
×
151
            "right", functions=(mag_to_absmag, absmag_to_mag)
152
        )
153

154
        if not include_cutouts:
×
155
            lc_ax3.set_ylabel(f"Absolute Magnitude [AB]")
×
156

157
    # Give the figure a title
158
    if not include_cutouts:
1✔
159
        if title is None:
×
160
            fig.suptitle(f"{name}", fontweight="bold")
×
161
        else:
162
            fig.suptitle(title, fontweight="bold")
×
163

164
    if grid_interval is not None:
1✔
165
        lc_ax1.xaxis.set_major_locator(MultipleLocator(grid_interval))
×
166

167
    lc_ax1.grid(visible=True, axis="both", alpha=0.5)
1✔
168
    lc_ax1.set_ylabel("Magnitude [AB]")
1✔
169

170
    if not include_cutouts:
1✔
171
        lc_ax1.set_xlabel("MJD")
×
172

173
    # Determine magnitude limits
174
    if mag_range is None:
1✔
175
        max_mag = np.max(df.magpsf.values) + 0.3
1✔
176
        min_mag = np.min(df.magpsf.values) - 0.3
1✔
177
        lc_ax1.set_ylim([max_mag, min_mag])
1✔
178
    else:
179
        lc_ax1.set_ylim([np.max(mag_range), np.min(mag_range)])
×
180

181
    for fid in BAND_NAMES.keys():
1✔
182

183
        # Plot older datapoints
184
        df_temp = df.iloc[1:].query("fid == @fid")
1✔
185
        lc_ax1.errorbar(
1✔
186
            df_temp["mjd"],
187
            df_temp["magpsf"],
188
            df_temp["sigmapsf"],
189
            color=BAND_COLORS[fid],
190
            fmt=".",
191
            label=BAND_NAMES[fid],
192
            mec="black",
193
            mew=0.5,
194
        )
195

196
        # Plot upper limits
197
        if df_ulims is not None:
1✔
198
            if include_ulims:
1✔
199
                df_temp2 = df_ulims.query("fid == @fid")
1✔
200
                lc_ax1.scatter(
1✔
201
                    df_temp2["mjd"],
202
                    df_temp2["diffmaglim"],
203
                    c=BAND_COLORS[fid],
204
                    marker="v",
205
                    s=1.3,
206
                    alpha=0.5,
207
                )
208

209
    # Plot datapoint from alert
210
    df_temp = df.iloc[0]
1✔
211
    fid = df_temp["fid"]
1✔
212
    lc_ax1.errorbar(
1✔
213
        df_temp["mjd"],
214
        df_temp["magpsf"],
215
        df_temp["sigmapsf"],
216
        color=BAND_COLORS[fid],
217
        fmt=".",
218
        label=BAND_NAMES[fid],
219
        mec="black",
220
        mew=0.5,
221
        markersize=12,
222
    )
223

224
    if legend:
1✔
225
        plt.legend()
×
226

227
    # Now we create an infobox
228
    if include_cutouts:
1✔
229
        info = []
1✔
230

231
        info.append(name)
1✔
232
        info.append("------------------------")
1✔
233
        info.append(f"RA: {candidate['ra']:.8f}")
1✔
234
        info.append(f"Dec: {candidate['dec']:.8f}")
1✔
235
        if "drb" in candidate.keys():
1✔
236
            info.append(f"drb: {candidate['drb']:.3f}")
1✔
237
        else:
238
            info.append(f"rb: {candidate['rb']:.3f}")
×
239
        info.append("------------------------")
1✔
240

241
        for entry in ["sgscore1", "distpsnr1", "srmag1"]:
1✔
242
            info.append(f"{entry[:-1]}: {candidate[entry]:.3f}")
1✔
243
            # for k in [k for k in candidate.keys() if kk in k]:
244
            #     info.append(f"{k}: {candidate.get(k):.3f}")
245

246
        fig.text(0.77, 0.55, "\n".join(info), va="top", fontsize="medium", alpha=0.5)
1✔
247

248
    if include_crossmatch:
1✔
249
        xmatch_info = get_cross_match_info(alert[0])
1✔
250
        if include_cutouts:
1✔
251
            ypos = 0.975
1✔
252
        else:
253
            ypos = 0.035
×
254

255
        fig.text(
1✔
256
            0.5,
257
            ypos,
258
            xmatch_info,
259
            va="top",
260
            ha="center",
261
            fontsize="medium",
262
            alpha=0.5,
263
        )
264

265
    if t_0_mjd is not None:
1✔
266
        lc_ax1.axvline(t_0_mjd, linestyle=":")
1✔
267
    else:
268
        t_0_mjd = np.mean(df.mjd.values)
×
269

270
    # Ugly hack because secondary_axis does not work with astropy.time.Time datetime conversion
271
    mjd_min = min(np.min(df.mjd.values), t_0_mjd)
1✔
272
    mjd_max = max(np.max(df.mjd.values), t_0_mjd)
1✔
273
    length = mjd_max - mjd_min
1✔
274

275
    lc_ax1.set_xlim([mjd_min - (length / 20), mjd_max + (length / 20)])
1✔
276

277
    lc_ax2 = lc_ax1.twiny()
1✔
278

279
    datetimes = [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]]
1✔
280

281
    lc_ax2.scatter(
1✔
282
        [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]], [20, 20], alpha=0
283
    )
284
    lc_ax2.tick_params(axis="both", which="major", labelsize=6, rotation=45)
1✔
285
    lc_ax1.tick_params(axis="x", which="major", labelsize=6, rotation=45)
1✔
286
    lc_ax1.ticklabel_format(axis="x", style="plain")
1✔
287
    lc_ax1.tick_params(axis="y", which="major", labelsize=9)
1✔
288

289
    if z is not None:
1✔
290
        lc_ax3.tick_params(axis="both", which="major", labelsize=9)
×
291

292
    if z is not None:
1✔
293
        axes = [lc_ax1, lc_ax2, lc_ax3]
×
294
    else:
295
        axes = [lc_ax1, lc_ax2]
1✔
296

297
    return fig, axes
1✔
298

299

300
def create_stamp_plot(alert: dict, ax, cutout_type: str):
1✔
301
    """Helper function to create cutout subplot"""
302
    v3_cutout_names = {
1✔
303
        "Science": "Cutoutscience",
304
        "Template": "Cutouttemplate",
305
        "Difference": "Cutoutdifference",
306
    }
307

308
    if alert.get(f"cutout{cutout_type}") is None:
1✔
309
        cutout_type = v3_cutout_names[cutout_type]
×
310
        data = alert[f"cutout{cutout_type}"]["stampData"]["stampData"]
×
311
    else:
312
        data = alert[f"cutout{cutout_type}"]["stampData"]
1✔
313

314
    with gzip.open(io.BytesIO(b64decode(data)), "rb") as f:
1✔
315
        data = fits.open(io.BytesIO(f.read()), ignore_missing_simple=True)[0].data
1✔
316
    vmin, vmax = np.percentile(data[data == data], [0, 100])
1✔
317
    data_ = visualization.AsinhStretch()((data - vmin) / (vmax - vmin))
1✔
318
    ax.imshow(
1✔
319
        data_,
320
        norm=Normalize(*np.percentile(data_[data_ == data_], [0.5, 99.5])),
321
        aspect="auto",
322
    )
323
    ax.set_xticks([])
1✔
324
    ax.set_yticks([])
1✔
325
    ax.set_title(type, fontdict={"fontsize": "small"})
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc