• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

desy-multimessenger / nuztf / 13725459265

07 Mar 2025 04:58PM UTC coverage: 71.901% (-1.7%) from 73.615%
13725459265

push

github

web-flow
Add Kowalski Backend (#490)

471 of 640 new or added lines in 26 files covered. (73.59%)

15 existing lines in 4 files now uncovered.

1978 of 2751 relevant lines covered (71.9%)

0.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.07
/nuztf/plot.py
1
#!/usr/bin/env python3
2
# License: BSD-3-Clause
3

4
import gzip
1✔
5
import io
1✔
6
from base64 import b64decode
1✔
7

8
import matplotlib.pyplot as plt
1✔
9
import numpy as np
1✔
10
import pandas as pd
1✔
11
from astropy import units as u
1✔
12
from astropy import visualization
1✔
13
from astropy.io import fits
1✔
14
from astropy.time import Time
1✔
15
from matplotlib.colors import Normalize
1✔
16
from matplotlib.ticker import MultipleLocator
1✔
17
from ztfquery.utils.stamps import get_ps_stamp
1✔
18

19
from nuztf.ampel.ampel_cutout import create_empty_cutout
1✔
20
from nuztf.api import ensure_cutouts
1✔
21
from nuztf.cat_match import get_cross_match_info
1✔
22
from nuztf.paths import CUTOUT_CACHE_DIR
1✔
23
from nuztf.utils import cosmo
1✔
24

25

26
def alert_to_pandas(alert):
1✔
27
    candidate = alert[0]["candidate"]
1✔
28
    prv_candid = alert[0]["prv_candidates"]
1✔
29
    combined = [candidate]
1✔
30
    combined.extend(prv_candid)
1✔
31

32
    df_detections_list = []
1✔
33
    df_ulims_list = []
1✔
34

35
    for cand in combined:
1✔
36
        _df = pd.DataFrame().from_dict(cand, orient="index").transpose()
1✔
37
        _df["mjd"] = _df["jd"] - 2400000.5
1✔
38
        if "magpsf" in cand.keys() and "isdiffpos" in cand.keys():
1✔
39
            df_detections_list.append(_df)
1✔
40

41
        else:
42
            df_ulims_list.append(_df)
1✔
43

44
    df_detections = pd.concat(df_detections_list)
1✔
45
    if len(df_ulims_list) > 0:
1✔
46
        df_ulims = pd.concat(df_ulims_list)
1✔
47
    else:
48
        df_ulims = None
×
49

50
    return df_detections, df_ulims
1✔
51

52

53
def lightcurve_from_alert(
1✔
54
    alert: list,
55
    figsize: list = [8, 5],
56
    title: str = None,
57
    include_ulims: bool = True,
58
    include_cutouts: bool = True,
59
    include_ps1: bool = True,
60
    include_crossmatch: bool = True,
61
    mag_range: list = None,
62
    z: float = None,
63
    legend: bool = False,
64
    grid_interval: int = None,
65
    t_0_mjd: float = None,
66
    logger=None,
67
):
68
    """plot AMPEL alerts as lightcurve"""
69

70
    if logger is None:
1✔
71
        import logging
×
72

73
        logger = logging.getLogger(__name__)
×
74
    else:
75
        logger = logger
1✔
76

77
    if z is not None:
1✔
78
        if np.isnan(z):
×
79
            z = None
×
80
            logger.debug("Redshift is nan, will be ignored")
×
81

82
    # ZTF color and naming scheme
83
    BAND_NAMES = {1: "ZTF g", 2: "ZTF r", 3: "ZTF i"}
1✔
84
    BAND_COLORS = {1: "green", 2: "red", 3: "orange"}
1✔
85

86
    name = alert[0]["objectId"]
1✔
87
    candidate = alert[0]["candidate"]
1✔
88

89
    if include_cutouts:
1✔
90
        if "cutoutScience" in alert[0].keys():
1✔
91
            if "stampData" in alert[0]["cutoutScience"].keys():
×
92
                logger.debug(f"{name}: Cutouts are present.")
×
93
            else:
94
                logger.debug(f"{name}: Cutouts are missing data. Will obtain them")
×
NEW
95
                alert = ensure_cutouts(alert)
×
96
        else:
97
            logger.debug(
1✔
98
                "The alert dictionary does not contain cutouts. Will obtain them."
99
            )
100
            alert = ensure_cutouts(alert)
1✔
101

102
    logger.debug(f"Plotting {name}")
1✔
103

104
    df, df_ulims = alert_to_pandas(alert)
1✔
105

106
    fig = plt.figure(figsize=figsize)
1✔
107

108
    if include_cutouts:
1✔
109
        lc_ax1 = fig.add_subplot(5, 4, (9, 19))
1✔
110
        cutoutsci = fig.add_subplot(5, 4, (1, 5))
1✔
111
        cutouttemp = fig.add_subplot(5, 4, (2, 6))
1✔
112
        cutoutdiff = fig.add_subplot(5, 4, (3, 7))
1✔
113
        cutoutps1 = fig.add_subplot(5, 4, (4, 8))
1✔
114
    else:
115
        lc_ax1 = fig.add_subplot(1, 1, 1)
×
116
        fig.subplots_adjust(top=0.8, bottom=0.15)
×
117

118
    plt.subplots_adjust(wspace=0.4, hspace=1.8)
1✔
119

120
    if include_cutouts:
1✔
121
        for cutout_, ax_, type_ in zip(
1✔
122
            [alert[0], alert[0], alert[0]],
123
            [cutoutsci, cutouttemp, cutoutdiff],
124
            ["Science", "Template", "Difference"],
125
        ):
126
            create_stamp_plot(alert=cutout_, ax=ax_, cutout_type=type_)
1✔
127

128
        if include_ps1:
1✔
129
            img_cache = CUTOUT_CACHE_DIR.joinpath(f"{name}_PS1.png")
1✔
130

131
            if not img_cache.is_file():
1✔
132
                img = get_ps_stamp(
1✔
133
                    candidate["ra"], candidate["dec"], size=240, color=["y", "g", "i"]
134
                )
135
                img.save(img_cache)
1✔
136

137
            else:
UNCOV
138
                from PIL import Image
×
139

UNCOV
140
                img = Image.open(img_cache)
×
141

142
            cutoutps1.imshow(np.asarray(img))
1✔
143
            cutoutps1.set_title("PS1", fontdict={"fontsize": "small"})
1✔
144
            cutoutps1.set_xticks([])
1✔
145
            cutoutps1.set_yticks([])
1✔
146

147
    # If redshift is given, calculate absolute magnitude via luminosity distance
148
    # and plot as right axis
149
    if z is not None:
1✔
150
        dist_l = cosmo.luminosity_distance(z).to(u.pc).value
×
151

152
        def mag_to_absmag(mag):
×
153
            absmag = mag - 5 * (np.log10(dist_l) - 1)
×
154
            return absmag
×
155

156
        def absmag_to_mag(absmag):
×
157
            mag = absmag + 5 * (np.log10(dist_l) - 1)
×
158
            return mag
×
159

160
        lc_ax3 = lc_ax1.secondary_yaxis(
×
161
            "right", functions=(mag_to_absmag, absmag_to_mag)
162
        )
163

164
        if not include_cutouts:
×
165
            lc_ax3.set_ylabel(f"Absolute Magnitude [AB]")
×
166

167
    # Give the figure a title
168
    if not include_cutouts:
1✔
169
        if title is None:
×
170
            fig.suptitle(f"{name}", fontweight="bold")
×
171
        else:
172
            fig.suptitle(title, fontweight="bold")
×
173

174
    if grid_interval is not None:
1✔
175
        lc_ax1.xaxis.set_major_locator(MultipleLocator(grid_interval))
×
176

177
    lc_ax1.grid(visible=True, axis="both", alpha=0.5)
1✔
178
    lc_ax1.set_ylabel("Magnitude [AB]")
1✔
179

180
    if not include_cutouts:
1✔
181
        lc_ax1.set_xlabel("MJD")
×
182

183
    # Determine magnitude limits
184
    if mag_range is None:
1✔
185
        max_mag = np.max(df.magpsf.values) + 0.3
1✔
186
        min_mag = np.min(df.magpsf.values) - 0.3
1✔
187
        lc_ax1.set_ylim([max_mag, min_mag])
1✔
188
    else:
189
        lc_ax1.set_ylim([np.max(mag_range), np.min(mag_range)])
×
190

191
    for fid in BAND_NAMES.keys():
1✔
192
        # Plot older datapoints
193
        df_temp = df.iloc[1:].query("fid == @fid")
1✔
194
        lc_ax1.errorbar(
1✔
195
            df_temp["mjd"],
196
            df_temp["magpsf"],
197
            df_temp["sigmapsf"],
198
            color=BAND_COLORS[fid],
199
            fmt=".",
200
            label=BAND_NAMES[fid],
201
            mec="black",
202
            mew=0.5,
203
        )
204

205
        # Plot upper limits
206
        if df_ulims is not None:
1✔
207
            if include_ulims:
1✔
208
                df_temp2 = df_ulims.query("fid == @fid")
1✔
209
                lc_ax1.scatter(
1✔
210
                    df_temp2["mjd"],
211
                    df_temp2["diffmaglim"],
212
                    c=BAND_COLORS[fid],
213
                    marker="v",
214
                    s=1.3,
215
                    alpha=0.5,
216
                )
217

218
    # Plot datapoint from alert
219
    df_temp = df.iloc[0]
1✔
220
    fid = df_temp["fid"]
1✔
221
    lc_ax1.errorbar(
1✔
222
        df_temp["mjd"],
223
        df_temp["magpsf"],
224
        df_temp["sigmapsf"],
225
        color=BAND_COLORS[fid],
226
        fmt=".",
227
        label=BAND_NAMES[fid],
228
        mec="black",
229
        mew=0.5,
230
        markersize=12,
231
    )
232

233
    if legend:
1✔
234
        plt.legend()
×
235

236
    # Now we create an infobox
237
    if include_cutouts:
1✔
238
        info = []
1✔
239

240
        info.append(name)
1✔
241
        info.append("------------------------")
1✔
242
        info.append(f"RA: {candidate['ra']:.8f}")
1✔
243
        info.append(f"Dec: {candidate['dec']:.8f}")
1✔
244
        if "drb" in candidate.keys():
1✔
245
            info.append(f"drb: {candidate['drb']:.3f}")
1✔
246
        else:
247
            info.append(f"rb: {candidate['rb']:.3f}")
×
248
        info.append("------------------------")
1✔
249

250
        for entry in ["sgscore1", "distpsnr1", "srmag1"]:
1✔
251
            info.append(f"{entry[:-1]}: {candidate[entry]:.3f}")
1✔
252

253
        if alert[0].get("kilonova_eval") is not None:
1✔
254
            info.append(
×
255
                f"------------------------\nAMPEL KN score: {alert[0]['kilonova_eval']['kilonovaness']}"
256
            )
257

258
        if (redshift := alert[0].get("redshifts", {}).get("ampel_z")) is not None:
1✔
259
            if alert[0]["redshifts"]["group_z_nbr"] in [1, 2]:
×
260
                info.append(f"spec z: {redshift:.3f}")
×
261
            else:
262
                info.append(f"photo z: {redshift:.3f}")
×
263

264
        fig.text(0.77, 0.55, "\n".join(info), va="top", fontsize="medium", alpha=0.5)
1✔
265

266
    # Add annotations
267

268
    lc_ax1.annotate(
1✔
269
        "See On Fritz",
270
        xy=(0.5, 1),
271
        xytext=(0.78, 0.10),
272
        xycoords="figure fraction",
273
        verticalalignment="top",
274
        color="royalblue",
275
        url=f"https://fritz.science/source/{name}",
276
        fontsize=12,
277
        bbox=dict(boxstyle="round", fc="cornflowerblue", ec="royalblue", alpha=0.4),
278
    )
279

280
    if include_crossmatch:
1✔
281
        xmatch_info = get_cross_match_info(
1✔
282
            raw=alert[0],
283
        )
284
        if include_cutouts:
1✔
285
            ypos = 0.975
1✔
286
        else:
287
            ypos = 0.035
×
288

289
        if "[TNS NAME=" in xmatch_info:
1✔
UNCOV
290
            tns_name = (
×
291
                xmatch_info.split("[TNS NAME=")[1].split("]")[0].strip("AT").strip("SN")
292
            )
UNCOV
293
            lc_ax1.annotate(
×
294
                "See On TNS",
295
                xy=(0.5, 1),
296
                xytext=(0.78, 0.05),
297
                xycoords="figure fraction",
298
                verticalalignment="top",
299
                color="royalblue",
300
                url=f"https://www.wis-tns.org/object/{tns_name}",
301
                fontsize=12,
302
                bbox=dict(
303
                    boxstyle="round", fc="cornflowerblue", ec="royalblue", alpha=0.4
304
                ),
305
            )
306

307
        fig.text(
1✔
308
            0.5,
309
            ypos,
310
            xmatch_info,
311
            va="top",
312
            ha="center",
313
            fontsize="medium",
314
            alpha=0.5,
315
        )
316

317
    if t_0_mjd is not None:
1✔
318
        lc_ax1.axvline(t_0_mjd, linestyle=":")
1✔
319
    else:
320
        t_0_mjd = np.mean(df.mjd.values)
×
321

322
    # Ugly hack because secondary_axis does not work with astropy.time.Time datetime conversion
323
    mjd_min = min(np.min(df.mjd.values), t_0_mjd)
1✔
324
    mjd_max = max(np.max(df.mjd.values), t_0_mjd)
1✔
325
    length = mjd_max - mjd_min
1✔
326

327
    lc_ax1.set_xlim([mjd_min - (length / 20), mjd_max + (length / 20)])
1✔
328

329
    lc_ax2 = lc_ax1.twiny()
1✔
330

331
    datetimes = [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]]
1✔
332

333
    lc_ax2.scatter(
1✔
334
        [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]], [20, 20], alpha=0
335
    )
336
    lc_ax2.tick_params(axis="both", which="major", labelsize=6, rotation=45)
1✔
337
    lc_ax1.tick_params(axis="x", which="major", labelsize=6, rotation=45)
1✔
338
    lc_ax1.ticklabel_format(axis="x", style="plain")
1✔
339
    lc_ax1.tick_params(axis="y", which="major", labelsize=9)
1✔
340

341
    if z is not None:
1✔
342
        lc_ax3.tick_params(axis="both", which="major", labelsize=9)
×
343

344
    if z is not None:
1✔
345
        axes = [lc_ax1, lc_ax2, lc_ax3]
×
346
    else:
347
        axes = [lc_ax1, lc_ax2]
1✔
348

349
    return fig, axes
1✔
350

351

352
def create_stamp_plot(alert: dict, ax, cutout_type: str):
1✔
353
    """Helper function to create cutout subplot"""
354
    v3_cutout_names = {
1✔
355
        "Science": "Cutoutscience",
356
        "Template": "Cutouttemplate",
357
        "Difference": "Cutoutdifference",
358
    }
359

360
    if alert.get(f"cutout{cutout_type}") is None:
1✔
UNCOV
361
        v3_cutout_type = v3_cutout_names[cutout_type]
×
UNCOV
362
        _data = alert.get(f"cutout{v3_cutout_type}", {}).get("stampData", {})
×
UNCOV
363
        if _data is not None:
×
UNCOV
364
            data = _data.get("stampData")
×
365
        else:
366
            data = None
×
UNCOV
367
        if data is None:
×
368
            data = create_empty_cutout()
×
369
    else:
370
        data = alert[f"cutout{cutout_type}"]["stampData"]
1✔
371

372
    with gzip.open(io.BytesIO(b64decode(data)), "rb") as f:
1✔
373
        data = fits.open(io.BytesIO(f.read()), ignore_missing_simple=True)[0].data
1✔
374
    vmin, vmax = np.percentile(data[data == data], [0, 100])
1✔
375
    data_ = visualization.AsinhStretch()((data - vmin) / (vmax - vmin))
1✔
376
    ax.imshow(
1✔
377
        data_,
378
        norm=Normalize(*np.percentile(data_[data_ == data_], [0.5, 99.5])),
379
        aspect="auto",
380
    )
381
    ax.set_xticks([])
1✔
382
    ax.set_yticks([])
1✔
383
    ax.set_title(cutout_type, fontdict={"fontsize": "small"})
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc