• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

desy-multimessenger / nuztf / 4628402692

pending completion
4628402692

push

github-actions

Simeon Reusch
comment out irsa light curve test for now

1684 of 2306 relevant lines covered (73.03%)

0.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.95
/nuztf/plot.py
1
#!/usr/bin/env python3
2
# License: BSD-3-Clause
3

4
import gzip
1✔
5
import io
1✔
6
from base64 import b64decode
1✔
7

8
import matplotlib.pyplot as plt
1✔
9
import numpy as np
1✔
10
import pandas as pd
1✔
11
from astropy import units as u
1✔
12
from astropy import visualization
1✔
13
from astropy.io import fits
1✔
14
from astropy.time import Time
1✔
15
from matplotlib.colors import Normalize
1✔
16
from matplotlib.ticker import MultipleLocator
1✔
17
from ztfquery.utils.stamps import get_ps_stamp
1✔
18

19
from nuztf.ampel_api import create_empty_cutout, ensure_cutouts
1✔
20
from nuztf.cat_match import get_cross_match_info
1✔
21
from nuztf.utils import cosmo
1✔
22

23

24
def alert_to_pandas(alert):
1✔
25
    candidate = alert[0]["candidate"]
1✔
26
    prv_candid = alert[0]["prv_candidates"]
1✔
27
    combined = [candidate]
1✔
28
    combined.extend(prv_candid)
1✔
29

30
    df_detections_list = []
1✔
31
    df_ulims_list = []
1✔
32

33
    for cand in combined:
1✔
34
        _df = pd.DataFrame().from_dict(cand, orient="index").transpose()
1✔
35
        _df["mjd"] = _df["jd"] - 2400000.5
1✔
36
        if "magpsf" in cand.keys() and "isdiffpos" in cand.keys():
1✔
37
            df_detections_list.append(_df)
1✔
38

39
        else:
40
            df_ulims_list.append(_df)
1✔
41

42
    df_detections = pd.concat(df_detections_list)
1✔
43
    if len(df_ulims_list) > 0:
1✔
44
        df_ulims = pd.concat(df_ulims_list)
1✔
45
    else:
46
        df_ulims = None
×
47

48
    return df_detections, df_ulims
1✔
49

50

51
def lightcurve_from_alert(
1✔
52
    alert: list,
53
    # figsize: list=[6.47, 4],
54
    figsize: list = [8, 5],
55
    title: str = None,
56
    include_ulims: bool = True,
57
    include_cutouts: bool = True,
58
    include_crossmatch: bool = True,
59
    mag_range: list = None,
60
    z: float = None,
61
    legend: bool = False,
62
    grid_interval: int = None,
63
    t_0_mjd: float = None,
64
    logger=None,
65
):
66
    """plot AMPEL alerts as lightcurve"""
67

68
    if logger is None:
1✔
69
        import logging
×
70

71
        logger = logging.getLogger(__name__)
×
72
    else:
73
        logger = logger
1✔
74

75
    if z is not None:
1✔
76
        if np.isnan(z):
×
77
            z = None
×
78
            logger.debug("Redshift is nan, will be ignored")
×
79

80
    # ZTF color and naming scheme
81
    BAND_NAMES = {1: "ZTF g", 2: "ZTF r", 3: "ZTF i"}
1✔
82
    BAND_COLORS = {1: "green", 2: "red", 3: "orange"}
1✔
83

84
    name = alert[0]["objectId"]
1✔
85
    candidate = alert[0]["candidate"]
1✔
86

87
    if include_cutouts:
1✔
88
        if "cutoutScience" in alert[0].keys():
1✔
89
            if "stampData" in alert[0]["cutoutScience"].keys():
×
90
                logger.debug(f"{name}: Cutouts are present.")
×
91
            else:
92
                logger.debug(f"{name}: Cutouts are missing data. Will obtain them")
×
93
                alert = ensure_cutouts(alert, logger=logger)
×
94
        else:
95
            logger.debug(
1✔
96
                "The alert dictionary does not contain cutouts. Will obtain them."
97
            )
98
            alert = ensure_cutouts(alert, logger=logger)
1✔
99

100
    logger.debug(f"Plotting {name}")
1✔
101

102
    df, df_ulims = alert_to_pandas(alert)
1✔
103

104
    fig = plt.figure(figsize=figsize)
1✔
105

106
    if include_cutouts:
1✔
107
        lc_ax1 = fig.add_subplot(5, 4, (9, 19))
1✔
108
        cutoutsci = fig.add_subplot(5, 4, (1, 5))
1✔
109
        cutouttemp = fig.add_subplot(5, 4, (2, 6))
1✔
110
        cutoutdiff = fig.add_subplot(5, 4, (3, 7))
1✔
111
        cutoutps1 = fig.add_subplot(5, 4, (4, 8))
1✔
112
    else:
113
        lc_ax1 = fig.add_subplot(1, 1, 1)
×
114
        fig.subplots_adjust(top=0.8, bottom=0.15)
×
115

116
    plt.subplots_adjust(wspace=0.4, hspace=1.8)
1✔
117

118
    if include_cutouts:
1✔
119
        for cutout_, ax_, type_ in zip(
1✔
120
            [alert[0], alert[0], alert[0]],
121
            [cutoutsci, cutouttemp, cutoutdiff],
122
            ["Science", "Template", "Difference"],
123
        ):
124
            create_stamp_plot(alert=cutout_, ax=ax_, cutout_type=type_)
1✔
125

126
        img = get_ps_stamp(
1✔
127
            candidate["ra"], candidate["dec"], size=240, color=["y", "g", "i"]
128
        )
129
        cutoutps1.imshow(np.asarray(img))
1✔
130
        cutoutps1.set_title("PS1", fontdict={"fontsize": "small"})
1✔
131
        cutoutps1.set_xticks([])
1✔
132
        cutoutps1.set_yticks([])
1✔
133

134
    # If redshift is given, calculate absolute magnitude via luminosity distance
135
    # and plot as right axis
136
    if z is not None:
1✔
137
        dist_l = cosmo.luminosity_distance(z).to(u.pc).value
×
138

139
        def mag_to_absmag(mag):
×
140
            absmag = mag - 5 * (np.log10(dist_l) - 1)
×
141
            return absmag
×
142

143
        def absmag_to_mag(absmag):
×
144
            mag = absmag + 5 * (np.log10(dist_l) - 1)
×
145
            return mag
×
146

147
        lc_ax3 = lc_ax1.secondary_yaxis(
×
148
            "right", functions=(mag_to_absmag, absmag_to_mag)
149
        )
150

151
        if not include_cutouts:
×
152
            lc_ax3.set_ylabel(f"Absolute Magnitude [AB]")
×
153

154
    # Give the figure a title
155
    if not include_cutouts:
1✔
156
        if title is None:
×
157
            fig.suptitle(f"{name}", fontweight="bold")
×
158
        else:
159
            fig.suptitle(title, fontweight="bold")
×
160

161
    if grid_interval is not None:
1✔
162
        lc_ax1.xaxis.set_major_locator(MultipleLocator(grid_interval))
×
163

164
    lc_ax1.grid(visible=True, axis="both", alpha=0.5)
1✔
165
    lc_ax1.set_ylabel("Magnitude [AB]")
1✔
166

167
    if not include_cutouts:
1✔
168
        lc_ax1.set_xlabel("MJD")
×
169

170
    # Determine magnitude limits
171
    if mag_range is None:
1✔
172
        max_mag = np.max(df.magpsf.values) + 0.3
1✔
173
        min_mag = np.min(df.magpsf.values) - 0.3
1✔
174
        lc_ax1.set_ylim([max_mag, min_mag])
1✔
175
    else:
176
        lc_ax1.set_ylim([np.max(mag_range), np.min(mag_range)])
×
177

178
    for fid in BAND_NAMES.keys():
1✔
179
        # Plot older datapoints
180
        df_temp = df.iloc[1:].query("fid == @fid")
1✔
181
        lc_ax1.errorbar(
1✔
182
            df_temp["mjd"],
183
            df_temp["magpsf"],
184
            df_temp["sigmapsf"],
185
            color=BAND_COLORS[fid],
186
            fmt=".",
187
            label=BAND_NAMES[fid],
188
            mec="black",
189
            mew=0.5,
190
        )
191

192
        # Plot upper limits
193
        if df_ulims is not None:
1✔
194
            if include_ulims:
1✔
195
                df_temp2 = df_ulims.query("fid == @fid")
1✔
196
                lc_ax1.scatter(
1✔
197
                    df_temp2["mjd"],
198
                    df_temp2["diffmaglim"],
199
                    c=BAND_COLORS[fid],
200
                    marker="v",
201
                    s=1.3,
202
                    alpha=0.5,
203
                )
204

205
    # Plot datapoint from alert
206
    df_temp = df.iloc[0]
1✔
207
    fid = df_temp["fid"]
1✔
208
    lc_ax1.errorbar(
1✔
209
        df_temp["mjd"],
210
        df_temp["magpsf"],
211
        df_temp["sigmapsf"],
212
        color=BAND_COLORS[fid],
213
        fmt=".",
214
        label=BAND_NAMES[fid],
215
        mec="black",
216
        mew=0.5,
217
        markersize=12,
218
    )
219

220
    if legend:
1✔
221
        plt.legend()
×
222

223
    # Now we create an infobox
224
    if include_cutouts:
1✔
225
        info = []
1✔
226

227
        info.append(name)
1✔
228
        info.append("------------------------")
1✔
229
        info.append(f"RA: {candidate['ra']:.8f}")
1✔
230
        info.append(f"Dec: {candidate['dec']:.8f}")
1✔
231
        if "drb" in candidate.keys():
1✔
232
            info.append(f"drb: {candidate['drb']:.3f}")
1✔
233
        else:
234
            info.append(f"rb: {candidate['rb']:.3f}")
×
235
        info.append("------------------------")
1✔
236

237
        for entry in ["sgscore1", "distpsnr1", "srmag1"]:
1✔
238
            info.append(f"{entry[:-1]}: {candidate[entry]:.3f}")
1✔
239
            # for k in [k for k in candidate.keys() if kk in k]:
240
            #     info.append(f"{k}: {candidate.get(k):.3f}")
241

242
        fig.text(0.77, 0.55, "\n".join(info), va="top", fontsize="medium", alpha=0.5)
1✔
243

244
    if include_crossmatch:
1✔
245
        xmatch_info = get_cross_match_info(alert[0])
1✔
246
        if include_cutouts:
1✔
247
            ypos = 0.975
1✔
248
        else:
249
            ypos = 0.035
×
250

251
        fig.text(
1✔
252
            0.5,
253
            ypos,
254
            xmatch_info,
255
            va="top",
256
            ha="center",
257
            fontsize="medium",
258
            alpha=0.5,
259
        )
260

261
    if t_0_mjd is not None:
1✔
262
        lc_ax1.axvline(t_0_mjd, linestyle=":")
1✔
263
    else:
264
        t_0_mjd = np.mean(df.mjd.values)
×
265

266
    # Ugly hack because secondary_axis does not work with astropy.time.Time datetime conversion
267
    mjd_min = min(np.min(df.mjd.values), t_0_mjd)
1✔
268
    mjd_max = max(np.max(df.mjd.values), t_0_mjd)
1✔
269
    length = mjd_max - mjd_min
1✔
270

271
    lc_ax1.set_xlim([mjd_min - (length / 20), mjd_max + (length / 20)])
1✔
272

273
    lc_ax2 = lc_ax1.twiny()
1✔
274

275
    datetimes = [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]]
1✔
276

277
    lc_ax2.scatter(
1✔
278
        [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]], [20, 20], alpha=0
279
    )
280
    lc_ax2.tick_params(axis="both", which="major", labelsize=6, rotation=45)
1✔
281
    lc_ax1.tick_params(axis="x", which="major", labelsize=6, rotation=45)
1✔
282
    lc_ax1.ticklabel_format(axis="x", style="plain")
1✔
283
    lc_ax1.tick_params(axis="y", which="major", labelsize=9)
1✔
284

285
    if z is not None:
1✔
286
        lc_ax3.tick_params(axis="both", which="major", labelsize=9)
×
287

288
    if z is not None:
1✔
289
        axes = [lc_ax1, lc_ax2, lc_ax3]
×
290
    else:
291
        axes = [lc_ax1, lc_ax2]
1✔
292

293
    return fig, axes
1✔
294

295

296
def create_stamp_plot(alert: dict, ax, cutout_type: str):
1✔
297
    """Helper function to create cutout subplot"""
298
    v3_cutout_names = {
1✔
299
        "Science": "Cutoutscience",
300
        "Template": "Cutouttemplate",
301
        "Difference": "Cutoutdifference",
302
    }
303

304
    # print(alert.keys())
305

306
    if alert.get(f"cutout{cutout_type}") is None:
1✔
307
        v3_cutout_type = v3_cutout_names[cutout_type]
1✔
308
        _data = alert.get(f"cutout{v3_cutout_type}", {}).get("stampData", {})
1✔
309
        if _data is not None:
1✔
310
            data = _data.get("stampData")
1✔
311
        else:
312
            data = None
×
313
        if data is None:
1✔
314
            data = create_empty_cutout()
×
315
    else:
316
        data = alert[f"cutout{cutout_type}"]["stampData"]
1✔
317

318
    with gzip.open(io.BytesIO(b64decode(data)), "rb") as f:
1✔
319
        data = fits.open(io.BytesIO(f.read()), ignore_missing_simple=True)[0].data
1✔
320
    vmin, vmax = np.percentile(data[data == data], [0, 100])
1✔
321
    data_ = visualization.AsinhStretch()((data - vmin) / (vmax - vmin))
1✔
322
    ax.imshow(
1✔
323
        data_,
324
        norm=Normalize(*np.percentile(data_[data_ == data_], [0.5, 99.5])),
325
        aspect="auto",
326
    )
327
    ax.set_xticks([])
1✔
328
    ax.set_yticks([])
1✔
329
    ax.set_title(cutout_type, fontdict={"fontsize": "small"})
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc