• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

desy-multimessenger / nuztf / 4084330156

pending completion
4084330156

push

github-actions

simeonreusch
bump ztfquery to fix unittest errors

1846 of 2266 relevant lines covered (81.47%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.52
/nuztf/plot.py
1
#!/usr/bin/env python3
2
# License: BSD-3-Clause
3

4
import gzip
1✔
5
import io
1✔
6
import numpy as np
1✔
7
import pandas as pd
1✔
8

9
import matplotlib.pyplot as plt
1✔
10
from matplotlib.ticker import MultipleLocator
1✔
11
from matplotlib.colors import Normalize
1✔
12
from base64 import b64decode
1✔
13

14
from astropy.time import Time
1✔
15
from astropy import units as u
1✔
16
from astropy.io import fits
1✔
17
from astropy import visualization
1✔
18
from ztfquery.utils.stamps import get_ps_stamp
1✔
19
from nuztf.utils import cosmo
1✔
20

21
from nuztf.cat_match import get_cross_match_info
1✔
22
from nuztf.ampel_api import ensure_cutouts
1✔
23

24

25
def alert_to_pandas(alert):
1✔
26
    candidate = alert[0]["candidate"]
1✔
27
    prv_candid = alert[0]["prv_candidates"]
1✔
28
    combined = [candidate]
1✔
29
    combined.extend(prv_candid)
1✔
30

31
    df_detections_list = []
1✔
32
    df_ulims_list = []
1✔
33

34
    for cand in combined:
1✔
35
        _df = pd.DataFrame().from_dict(cand, orient="index").transpose()
1✔
36
        _df["mjd"] = _df["jd"] - 2400000.5
1✔
37
        if "magpsf" in cand.keys() and "isdiffpos" in cand.keys():
1✔
38
            df_detections_list.append(_df)
1✔
39

40
        else:
41
            df_ulims_list.append(_df)
1✔
42

43
    df_detections = pd.concat(df_detections_list)
1✔
44
    if len(df_ulims_list) > 0:
1✔
45
        df_ulims = pd.concat(df_ulims_list)
1✔
46
    else:
47
        df_ulims = None
×
48

49
    return df_detections, df_ulims
1✔
50

51

52
def lightcurve_from_alert(
1✔
53
    alert: list,
54
    # figsize: list=[6.47, 4],
55
    figsize: list = [8, 5],
56
    title: str = None,
57
    include_ulims: bool = True,
58
    include_cutouts: bool = True,
59
    include_crossmatch: bool = True,
60
    mag_range: list = None,
61
    z: float = None,
62
    legend: bool = False,
63
    grid_interval: int = None,
64
    t_0_mjd: float = None,
65
    logger=None,
66
):
67
    """plot AMPEL alerts as lightcurve"""
68

69
    if logger is None:
1✔
70
        import logging
×
71

72
        logger = logging.getLogger(__name__)
×
73
    else:
74
        logger = logger
1✔
75

76
    if z is not None:
1✔
77
        if np.isnan(z):
×
78
            z = None
×
79
            logger.debug("Redshift is nan, will be ignored")
×
80

81
    # ZTF color and naming scheme
82
    BAND_NAMES = {1: "ZTF g", 2: "ZTF r", 3: "ZTF i"}
1✔
83
    BAND_COLORS = {1: "green", 2: "red", 3: "orange"}
1✔
84

85
    name = alert[0]["objectId"]
1✔
86
    candidate = alert[0]["candidate"]
1✔
87

88
    if include_cutouts:
1✔
89
        if "cutoutScience" in alert[0].keys():
1✔
90
            if "stampData" in alert[0]["cutoutScience"].keys():
×
91
                logger.debug(f"{name}: Cutouts are present.")
×
92
            else:
93
                logger.debug(f"{name}: Cutouts are missing data. Will obtain them")
×
94
                alert = ensure_cutouts(alert, logger=logger)
×
95
        else:
96
            logger.debug(
1✔
97
                "The alert dictionary does not contain cutouts. Will obtain them."
98
            )
99
            alert = ensure_cutouts(alert, logger=logger)
1✔
100

101
    logger.debug(f"Plotting {name}")
1✔
102

103
    df, df_ulims = alert_to_pandas(alert)
1✔
104

105
    fig = plt.figure(figsize=figsize)
1✔
106

107
    if include_cutouts:
1✔
108
        lc_ax1 = fig.add_subplot(5, 4, (9, 19))
1✔
109
        cutoutsci = fig.add_subplot(5, 4, (1, 5))
1✔
110
        cutouttemp = fig.add_subplot(5, 4, (2, 6))
1✔
111
        cutoutdiff = fig.add_subplot(5, 4, (3, 7))
1✔
112
        cutoutps1 = fig.add_subplot(5, 4, (4, 8))
1✔
113
    else:
114
        lc_ax1 = fig.add_subplot(1, 1, 1)
×
115
        fig.subplots_adjust(top=0.8, bottom=0.15)
×
116

117
    plt.subplots_adjust(wspace=0.4, hspace=1.8)
1✔
118

119
    if include_cutouts:
1✔
120
        for cutout_, ax_, type_ in zip(
1✔
121
            [alert[0], alert[0], alert[0]],
122
            [cutoutsci, cutouttemp, cutoutdiff],
123
            ["Science", "Template", "Difference"],
124
        ):
125
            create_stamp_plot(alert=cutout_, ax=ax_, cutout_type=type_)
1✔
126

127
        img = get_ps_stamp(
1✔
128
            candidate["ra"], candidate["dec"], size=240, color=["y", "g", "i"]
129
        )
130
        cutoutps1.imshow(np.asarray(img))
1✔
131
        cutoutps1.set_title("PS1", fontdict={"fontsize": "small"})
1✔
132
        cutoutps1.set_xticks([])
1✔
133
        cutoutps1.set_yticks([])
1✔
134

135
    # If redshift is given, calculate absolute magnitude via luminosity distance
136
    # and plot as right axis
137
    if z is not None:
1✔
138
        dist_l = cosmo.luminosity_distance(z).to(u.pc).value
×
139

140
        def mag_to_absmag(mag):
×
141
            absmag = mag - 5 * (np.log10(dist_l) - 1)
×
142
            return absmag
×
143

144
        def absmag_to_mag(absmag):
×
145
            mag = absmag + 5 * (np.log10(dist_l) - 1)
×
146
            return mag
×
147

148
        lc_ax3 = lc_ax1.secondary_yaxis(
×
149
            "right", functions=(mag_to_absmag, absmag_to_mag)
150
        )
151

152
        if not include_cutouts:
×
153
            lc_ax3.set_ylabel(f"Absolute Magnitude [AB]")
×
154

155
    # Give the figure a title
156
    if not include_cutouts:
1✔
157
        if title is None:
×
158
            fig.suptitle(f"{name}", fontweight="bold")
×
159
        else:
160
            fig.suptitle(title, fontweight="bold")
×
161

162
    if grid_interval is not None:
1✔
163
        lc_ax1.xaxis.set_major_locator(MultipleLocator(grid_interval))
×
164

165
    lc_ax1.grid(visible=True, axis="both", alpha=0.5)
1✔
166
    lc_ax1.set_ylabel("Magnitude [AB]")
1✔
167

168
    if not include_cutouts:
1✔
169
        lc_ax1.set_xlabel("MJD")
×
170

171
    # Determine magnitude limits
172
    if mag_range is None:
1✔
173
        max_mag = np.max(df.magpsf.values) + 0.3
1✔
174
        min_mag = np.min(df.magpsf.values) - 0.3
1✔
175
        lc_ax1.set_ylim([max_mag, min_mag])
1✔
176
    else:
177
        lc_ax1.set_ylim([np.max(mag_range), np.min(mag_range)])
×
178

179
    for fid in BAND_NAMES.keys():
1✔
180
        # Plot older datapoints
181
        df_temp = df.iloc[1:].query("fid == @fid")
1✔
182
        lc_ax1.errorbar(
1✔
183
            df_temp["mjd"],
184
            df_temp["magpsf"],
185
            df_temp["sigmapsf"],
186
            color=BAND_COLORS[fid],
187
            fmt=".",
188
            label=BAND_NAMES[fid],
189
            mec="black",
190
            mew=0.5,
191
        )
192

193
        # Plot upper limits
194
        if df_ulims is not None:
1✔
195
            if include_ulims:
1✔
196
                df_temp2 = df_ulims.query("fid == @fid")
1✔
197
                lc_ax1.scatter(
1✔
198
                    df_temp2["mjd"],
199
                    df_temp2["diffmaglim"],
200
                    c=BAND_COLORS[fid],
201
                    marker="v",
202
                    s=1.3,
203
                    alpha=0.5,
204
                )
205

206
    # Plot datapoint from alert
207
    df_temp = df.iloc[0]
1✔
208
    fid = df_temp["fid"]
1✔
209
    lc_ax1.errorbar(
1✔
210
        df_temp["mjd"],
211
        df_temp["magpsf"],
212
        df_temp["sigmapsf"],
213
        color=BAND_COLORS[fid],
214
        fmt=".",
215
        label=BAND_NAMES[fid],
216
        mec="black",
217
        mew=0.5,
218
        markersize=12,
219
    )
220

221
    if legend:
1✔
222
        plt.legend()
×
223

224
    # Now we create an infobox
225
    if include_cutouts:
1✔
226
        info = []
1✔
227

228
        info.append(name)
1✔
229
        info.append("------------------------")
1✔
230
        info.append(f"RA: {candidate['ra']:.8f}")
1✔
231
        info.append(f"Dec: {candidate['dec']:.8f}")
1✔
232
        if "drb" in candidate.keys():
1✔
233
            info.append(f"drb: {candidate['drb']:.3f}")
1✔
234
        else:
235
            info.append(f"rb: {candidate['rb']:.3f}")
×
236
        info.append("------------------------")
1✔
237

238
        for entry in ["sgscore1", "distpsnr1", "srmag1"]:
1✔
239
            info.append(f"{entry[:-1]}: {candidate[entry]:.3f}")
1✔
240
            # for k in [k for k in candidate.keys() if kk in k]:
241
            #     info.append(f"{k}: {candidate.get(k):.3f}")
242

243
        fig.text(0.77, 0.55, "\n".join(info), va="top", fontsize="medium", alpha=0.5)
1✔
244

245
    if include_crossmatch:
1✔
246
        xmatch_info = get_cross_match_info(alert[0])
1✔
247
        if include_cutouts:
1✔
248
            ypos = 0.975
1✔
249
        else:
250
            ypos = 0.035
×
251

252
        fig.text(
1✔
253
            0.5,
254
            ypos,
255
            xmatch_info,
256
            va="top",
257
            ha="center",
258
            fontsize="medium",
259
            alpha=0.5,
260
        )
261

262
    if t_0_mjd is not None:
1✔
263
        lc_ax1.axvline(t_0_mjd, linestyle=":")
1✔
264
    else:
265
        t_0_mjd = np.mean(df.mjd.values)
×
266

267
    # Ugly hack because secondary_axis does not work with astropy.time.Time datetime conversion
268
    mjd_min = min(np.min(df.mjd.values), t_0_mjd)
1✔
269
    mjd_max = max(np.max(df.mjd.values), t_0_mjd)
1✔
270
    length = mjd_max - mjd_min
1✔
271

272
    lc_ax1.set_xlim([mjd_min - (length / 20), mjd_max + (length / 20)])
1✔
273

274
    lc_ax2 = lc_ax1.twiny()
1✔
275

276
    datetimes = [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]]
1✔
277

278
    lc_ax2.scatter(
1✔
279
        [Time(x, format="mjd").datetime for x in [mjd_min, mjd_max]], [20, 20], alpha=0
280
    )
281
    lc_ax2.tick_params(axis="both", which="major", labelsize=6, rotation=45)
1✔
282
    lc_ax1.tick_params(axis="x", which="major", labelsize=6, rotation=45)
1✔
283
    lc_ax1.ticklabel_format(axis="x", style="plain")
1✔
284
    lc_ax1.tick_params(axis="y", which="major", labelsize=9)
1✔
285

286
    if z is not None:
1✔
287
        lc_ax3.tick_params(axis="both", which="major", labelsize=9)
×
288

289
    if z is not None:
1✔
290
        axes = [lc_ax1, lc_ax2, lc_ax3]
×
291
    else:
292
        axes = [lc_ax1, lc_ax2]
1✔
293

294
    return fig, axes
1✔
295

296

297
def create_stamp_plot(alert: dict, ax, cutout_type: str):
1✔
298
    """Helper function to create cutout subplot"""
299
    v3_cutout_names = {
1✔
300
        "Science": "Cutoutscience",
301
        "Template": "Cutouttemplate",
302
        "Difference": "Cutoutdifference",
303
    }
304

305
    if alert.get(f"cutout{cutout_type}") is None:
1✔
306
        cutout_type = v3_cutout_names[cutout_type]
1✔
307
        data = alert[f"cutout{cutout_type}"]["stampData"]["stampData"]
1✔
308
    else:
309
        data = alert[f"cutout{cutout_type}"]["stampData"]
1✔
310

311
    with gzip.open(io.BytesIO(b64decode(data)), "rb") as f:
1✔
312
        data = fits.open(io.BytesIO(f.read()), ignore_missing_simple=True)[0].data
1✔
313
    vmin, vmax = np.percentile(data[data == data], [0, 100])
1✔
314
    data_ = visualization.AsinhStretch()((data - vmin) / (vmax - vmin))
1✔
315
    ax.imshow(
1✔
316
        data_,
317
        norm=Normalize(*np.percentile(data_[data_ == data_], [0.5, 99.5])),
318
        aspect="auto",
319
    )
320
    ax.set_xticks([])
1✔
321
    ax.set_yticks([])
1✔
322
    ax.set_title(type, fontdict={"fontsize": "small"})
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc