• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

California-Planet-Search / radvel / #505

16 Oct 2023 04:51PM UTC coverage: 90.038% (-0.03%) from 90.071%
#505

push

coveralls-python

web-flow
Merge pull request #372 from California-Planet-Search/multipanel_remove_whitespace

Multipanel remove whitespace

4 of 4 new or added lines in 1 file covered. (100.0%)

3299 of 3664 relevant lines covered (90.04%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.35
/radvel/plot/orbit_plots.py
1
import numpy as np
2
import pandas as pd
1✔
3
from matplotlib import rcParams, gridspec
1✔
4
from matplotlib import pyplot as pl
1✔
5
from matplotlib.ticker import MaxNLocator
1✔
6
from astropy.time import Time
1✔
7
import copy
1✔
8

9
import radvel
1✔
10
from radvel import plot
1✔
11
from radvel.plot import mcmc_plots
1✔
12
from radvel.utils import t_to_phase, fastbin, sigfig
1✔
13

14

15
class MultipanelPlot(object):
1✔
16
    """
17
    Class to handle the creation of RV multipanel plots.
18

19
    Args:
20
        post (radvel.Posterior): radvel.Posterior object. The model
21
            plotted will be generated from `post.params`
22
        epoch (int, optional): epoch to subtract off of all time measurements
23
        yscale_auto (bool, optional): Use matplotlib auto y-axis
24
             scaling (default: False)
25
        yscale_sigma (float, optional): Scale y-axis limits for all panels to be +/-
26
             yscale_sigma*(RMS of data plotted) if yscale_auto==False
27
        phase_nrows (int, optional): number of columns in the phase
28
            folded plots. Default is nplanets.
29
        phase_ncols (int, optional): number of columns in the phase
30
            folded plots. Default is 1.
31
        uparams (dict, optional): parameter uncertainties, must
32
           contain 'per', 'k', and 'e' keys.
33
        telfmts (dict, optional): dictionary of dictionaries mapping
34
            instrument suffix to plotting format code. Example:
35
                telfmts = {
36
                     'hires': dict(fmt='o',label='HIRES'),
37
                     'harps-n' dict(fmt='s')
38
                }
39
        legend (bool, optional): include legend on plot? Default: True.
40
        phase_limits (list, optional): two element list specifying 
41
            pyplot.xlim bounds for phase-folded plots. Useful for
42
            partial orbits.
43
        nobin (bool, optional): If True do not show binned data on
44
            phase plots. Will default to True if total number of
45
            measurements is less then 20.
46
        phasetext_size (string, optional): fontsize for text in phase plots.
47
            Choice of {'xx-small', 'x-small', 'small', 'medium', 'large', 
48
            'x-large', 'xx-large'}. Default: 'x-small'.
49
        rv_phase_space (float, optional): amount of space to leave between orbit/residual plot
50
            and phase plots.
51
        figwidth (float, optional): width of the figures to be produced. 
52
            Default: 7.5 (spans a page with 0.5 in margins)
53
        fit_linewidth (float, optional): linewidth to use for orbit model lines in phase-folded
54
            plots and residuals plots.
55
        set_xlim (list of float): limits to use for x-axes of the timeseries and residuals plots, in
56
            JD - `epoch`. Ex: [7000., 70005.]
57
        text_size (int): set matplotlib.rcParams['font.size'] (default: 9)
58
        highlight_last (bool): make the most recent measurement much larger in all panels
59
        show_rms (bool): show RMS of the residuals by instrument in the legend
60
        legend_kwargs (dict): dict of options to pass to legend (plotted in top panel)
61
        status (ConfigParser): (optional) result of radvel.driver.load_status on the .stat status file
62
    """
63
    def __init__(self, post, saveplot=None, epoch=2450000, yscale_auto=False, yscale_sigma=3.0,
1✔
64
                 phase_nrows=None, phase_ncols=None, uparams=None, telfmts={}, legend=True,
65
                 phase_limits=[], nobin=False, phasetext_size='large', rv_phase_space=0.08,
66
                 figwidth=7.5, fit_linewidth=2.0, set_xlim=None, text_size=9, highlight_last=False,
67
                 show_rms=False, legend_kwargs=dict(loc='best'), status=None):
68

69
        self.post = copy.deepcopy(post)
1✔
70
        self.saveplot = saveplot
1✔
71
        self.epoch = epoch
1✔
72
        self.yscale_auto = yscale_auto
1✔
73
        self.yscale_sigma = yscale_sigma
1✔
74
        if phase_ncols is None:
1✔
75
            self.phase_ncols = 1
1✔
76
        else:
77
            self.phase_ncols = phase_ncols
×
78
        if phase_nrows is None:
1✔
79
            self.phase_nrows = int(np.ceil(self.post.likelihood.model.num_planets / self.phase_ncols))
1✔
80
        else:
81
            self.phase_nrows = phase_nrows
×
82
        self.uparams = uparams
1✔
83
        self.rv_phase_space = rv_phase_space
1✔
84
        self.telfmts = telfmts
1✔
85
        self.legend = legend
1✔
86
        self.phase_limits = phase_limits
1✔
87
        self.nobin = nobin
1✔
88
        self.phasetext_size = phasetext_size
1✔
89
        self.figwidth = figwidth
1✔
90
        self.fit_linewidth = fit_linewidth
1✔
91
        self.set_xlim = set_xlim
1✔
92
        self.highlight_last = highlight_last
1✔
93
        self.show_rms = show_rms
1✔
94
        self.legend_kwargs = legend_kwargs
1✔
95
        rcParams['font.size'] = text_size
1✔
96

97
        if status is not None:
1✔
98
            self.status = status
1✔
99

100
        if isinstance(self.post.likelihood, radvel.likelihood.CompositeLikelihood):
1✔
101
            self.like_list = self.post.likelihood.like_list
1✔
102
        else:
103
            self.like_list = [self.post.likelihood]
×
104

105
        # FIGURE PROVISIONING
106
        self.ax_rv_height = self.figwidth * 0.6
1✔
107
        self.ax_phase_height = self.ax_rv_height / 1.4
1✔
108

109
        # convert params to synth basis
110
        synthparams = self.post.params.basis.to_synth(self.post.params)
1✔
111
        self.post.params = synthparams
1✔
112
        self.post.vector.dict_to_vector()
1✔
113

114
        self.model = self.post.likelihood.model
1✔
115
        self.rvtimes = self.post.likelihood.x
1✔
116
        self.rverr = self.post.likelihood.errorbars()
1✔
117
        self.num_planets = self.model.num_planets
1✔
118

119
        self.rawresid = self.post.likelihood.residuals()
1✔
120

121
        self.resid = (
1✔
122
            self.rawresid + self.post.params['dvdt'].value*(self.rvtimes-self.model.time_base)
123
            + self.post.params['curv'].value*(self.rvtimes-self.model.time_base)**2
124
        )
125

126
        if self.saveplot is not None:
1✔
127
            resolution = 10000
1✔
128
        else: 
129
            resolution = 2000
×
130

131
        periods = []
1✔
132
        for i in range(self.num_planets):
1✔
133
            periods.append(synthparams['per%d' % (i+1)].value)            
1✔
134
        if len(periods) > 0:
1✔
135
            longp = max(periods)
1✔
136
        else:
137
            longp = max(self.post.likelihood.x) - min(self.post.likelihood.x)
×
138

139
        if self.set_xlim is not None:
1✔
140
            self.dt = self.set_xlim[1] - self.set_xlim[0]
×
141
            self.rvmodt = np.linspace(
×
142
                (self.set_xlim[0]+self.epoch) - 0.05 * self.dt, (self.set_xlim[1]+self.epoch) + 0.05 * self.dt + longp,
143
                int(resolution)
144
            )
145
        else:
146
            self.dt = max(self.rvtimes) - min(self.rvtimes)
1✔
147
            self.rvmodt = np.linspace(
1✔
148
                min(self.rvtimes) - 0.05 * self.dt, max(self.rvtimes) + 0.05 * self.dt + longp,
149
                int(resolution)
150
            )
151
        
152
        self.orbit_model = self.model(self.rvmodt)
1✔
153
        self.rvmod = self.model(self.rvtimes)
1✔
154

155
        if ((self.rvtimes - self.epoch) < -2.4e6).any():
1✔
156
            self.plttimes = self.rvtimes
1✔
157
            self.mplttimes = self.rvmodt
1✔
158
        elif self.epoch == 0:
1✔
159
            self.epoch = 2450000
1✔
160
            self.plttimes = self.rvtimes - self.epoch
1✔
161
            self.mplttimes = self.rvmodt - self.epoch
1✔
162
        else:
163
            self.plttimes = self.rvtimes - self.epoch
1✔
164
            self.mplttimes = self.rvmodt - self.epoch
1✔
165

166
        self.slope = (
1✔
167
            self.post.params['dvdt'].value * (self.rvmodt-self.model.time_base)
168
            + self.post.params['curv'].value * (self.rvmodt-self.model.time_base)**2
169
        )
170
        self.slope_low = (
1✔
171
            self.post.params['dvdt'].value * (self.rvtimes-self.model.time_base)
172
            + self.post.params['curv'].value * (self.rvtimes-self.model.time_base)**2
173
        )
174

175
        # list for Axes objects
176
        self.ax_list = []
1✔
177

178
    def plot_timeseries(self):
1✔
179
        """
180
        Make a plot of the RV data and model in the current Axes.
181
        """
182

183
        ax = pl.gca()
1✔
184

185
        ax.axhline(0, color='0.5', linestyle='--')
1✔
186

187
        if self.show_rms:
1✔
188
            rms_values = dict()
1✔
189
            for like in self.like_list:
1✔
190
                inst = like.suffix
1✔
191
                rms = np.std(like.residuals())
1✔
192
                rms_values[inst] = rms
1✔
193
        else:
194
            rms_values = False
1✔
195

196
        # plot orbit model
197
        ax.plot(self.mplttimes, self.orbit_model, 'b-', rasterized=False, lw=self.fit_linewidth)
1✔
198

199
        # plot data
200
        vels = self.rawresid+self.rvmod
1✔
201
        plot.mtelplot(
1✔
202
            # data = residuals + model
203
            self.plttimes, vels, self.rverr, self.post.likelihood.telvec, ax, telfmts=self.telfmts,
204
            rms_values=rms_values
205
        )
206

207
        if self.set_xlim is not None:
1✔
208
            ax.set_xlim(self.set_xlim)
×
209
        else:
210
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
1✔
211
        pl.setp(ax.get_xticklabels(), visible=False)
1✔
212

213
        if self.highlight_last:
1✔
214
            ind = np.argmax(self.plttimes)
1✔
215
            pl.plot(self.plttimes[ind], vels[ind], **plot.highlight_format)
1✔
216

217
        # legend
218
        if self.legend:
1✔
219
            ax.legend(numpoints=1, **self.legend_kwargs)
1✔
220

221
        # years on upper axis
222
        axyrs = ax.twiny()
1✔
223
        xl = np.array(list(ax.get_xlim())) + self.epoch
1✔
224
        decimalyear = Time(xl, format='jd', scale='utc').decimalyear
1✔
225
#        axyrs.plot(decimalyear, decimalyear)
226
        axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
1✔
227
        axyrs.set_xlim(*decimalyear)
1✔
228
        axyrs.set_xlabel('Year', fontweight='bold')
1✔
229
        pl.locator_params(axis='x', nbins=5)
1✔
230

231
        if not self.yscale_auto: 
1✔
232
            scale = np.std(self.rawresid+self.rvmod)
1✔
233
            ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale)
1✔
234

235
        ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold')
1✔
236
        ticks = ax.yaxis.get_majorticklocs()
1✔
237
        ax.yaxis.set_ticks(ticks[1:])
1✔
238

239
    def plot_residuals(self):
1✔
240
        """
241
        Make a plot of residuals in the current Axes.
242
        """
243
        
244
        ax = pl.gca()
1✔
245

246
        ax.plot(self.mplttimes, self.slope-self.slope, 'b-', lw=self.fit_linewidth)
1✔
247

248
        plot.mtelplot(self.plttimes, self.rawresid, self.rverr, self.post.likelihood.telvec, ax, telfmts=self.telfmts)
1✔
249
        if not self.yscale_auto: 
1✔
250
            scale = np.std(self.rawresid)
1✔
251
            ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale)
1✔
252

253
        if self.highlight_last:
1✔
254
            ind = np.argmax(self.plttimes)
1✔
255
            pl.plot(self.plttimes[ind], self.rawresid[ind], **plot.highlight_format)
1✔
256

257
        if self.set_xlim is not None:
1✔
258
            ax.set_xlim(self.set_xlim)
×
259
        else:
260
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)
1✔
261
        ticks = ax.yaxis.get_majorticklocs()
1✔
262
        ax.yaxis.set_ticks([ticks[0], 0.0, ticks[-1]])
1✔
263
        pl.xlabel('JD - {:d}'.format(int(np.round(self.epoch))), weight='bold')
1✔
264
        ax.set_ylabel('Residuals', weight='bold')
1✔
265
        ax.yaxis.set_major_locator(MaxNLocator(5, prune='both'))
1✔
266

267
    def plot_phasefold(self, pltletter, pnum):
1✔
268
        """
269
        Plot phased orbit plots for each planet in the fit.
270

271
        Args:
272
            pltletter (int): integer representation of 
273
                letter to be printed in the corner of the first
274
                phase plot.
275
                Ex: ord("a") gives 97, so the input should be 97.
276
            pnum (int): the number of the planet to be plotted. Must be
277
                the same as the number used to define a planet's 
278
                Parameter objects (e.g. 'per1' is for planet #1)
279

280
        """
281

282
        ax = pl.gca()
1✔
283

284
        if len(self.post.likelihood.x) < 20: 
1✔
285
            self.nobin = True
×
286

287
        bin_fac = 1.75
1✔
288
        bin_markersize = bin_fac * rcParams['lines.markersize']
1✔
289
        bin_markeredgewidth = bin_fac * rcParams['lines.markeredgewidth']
1✔
290

291
        rvmod2 = self.model(self.rvmodt, planet_num=pnum) - self.slope
1✔
292
        modph = t_to_phase(self.post.params, self.rvmodt, pnum, cat=True) - 1
1✔
293
        rvdat = self.rawresid + self.model(self.rvtimes, planet_num=pnum) - self.slope_low
1✔
294
        phase = t_to_phase(self.post.params, self.rvtimes, pnum, cat=True) - 1
1✔
295
        rvdatcat = np.concatenate((rvdat, rvdat))
1✔
296
        rverrcat = np.concatenate((self.rverr, self.rverr))
1✔
297
        rvmod2cat = np.concatenate((rvmod2, rvmod2))
1✔
298
        bint, bindat, binerr = fastbin(phase+1, rvdatcat, nbins=25)
1✔
299
        bint -= 1.0
1✔
300

301
        ax.axhline(0, color='0.5', linestyle='--', )
1✔
302
        ax.plot(sorted(modph), rvmod2cat[np.argsort(modph)], 'b-', linewidth=self.fit_linewidth)
1✔
303
        plot.labelfig(pltletter)
1✔
304

305
        telcat = np.concatenate((self.post.likelihood.telvec, self.post.likelihood.telvec))
1✔
306

307
        if self.highlight_last:
1✔
308
            ind = np.argmax(self.rvtimes)
1✔
309
            hphase = t_to_phase(self.post.params, self.rvtimes[ind], pnum, cat=False)
1✔
310
            if hphase > 0.5:
1✔
311
                hphase -= 1
1✔
312
            pl.plot(hphase, rvdatcat[ind], **plot.highlight_format)
1✔
313

314
        plot.mtelplot(phase, rvdatcat, rverrcat, telcat, ax, telfmts=self.telfmts)
1✔
315
        if not self.nobin and len(rvdat) > 10: 
1✔
316
            ax.errorbar(
1✔
317
                bint, bindat, yerr=binerr, fmt='ro', mec='w', ms=bin_markersize,
318
                mew=bin_markeredgewidth
319
            )
320

321
        if self.phase_limits:
1✔
322
            ax.set_xlim(self.phase_limits[0], self.phase_limits[1])
×
323
        else:
324
            ax.set_xlim(-0.5, 0.5)
1✔
325

326
        if not self.yscale_auto: 
1✔
327
            scale = np.std(rvdatcat)
1✔
328
            ax.set_ylim(-self.yscale_sigma*scale, self.yscale_sigma*scale)
1✔
329
        
330
        keys = [p+str(pnum) for p in ['per', 'k', 'e']]
1✔
331

332
        labels = [self.post.params.tex_labels().get(k, k) for k in keys]
1✔
333
        if pnum < self.num_planets:
1✔
334
            ticks = ax.yaxis.get_majorticklocs()
1✔
335
            ax.yaxis.set_ticks(ticks[1:-1])
1✔
336

337
        ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold')
1✔
338
        ax.set_xlabel('Phase', weight='bold')
1✔
339

340
        print_params = ['per', 'k', 'e']
1✔
341
        units = {'per': 'days', 'k': plot.latex['ms'], 'e': ''}
1✔
342

343
        anotext = []
1✔
344
        for l, p in enumerate(print_params):
1✔
345
            val = self.post.params["%s%d" % (print_params[l], pnum)].value
1✔
346
            
347
            if self.uparams is None:
1✔
348
                _anotext = r'$\mathregular{%s}$ = %4.2f %s' % (labels[l].replace("$", ""), val, units[p])
1✔
349
            else:
350
                if hasattr(self.post, 'medparams'):
1✔
351
                    val = self.post.medparams["%s%d" % (print_params[l], pnum)]
1✔
352
                else:
353
                    print("WARNING: medparams attribute not found in " +
×
354
                          "posterior object will annotate with " +
355
                          "max-likelihood values and reported uncertainties " +
356
                          "may not be appropriate.")
357
                err = self.uparams["%s%d" % (print_params[l], pnum)]
1✔
358
                if err > 1e-15:
1✔
359
                    val, err, errlow = sigfig(val, err)
1✔
360
                    _anotext = r'$\mathregular{%s}$ = %s $\mathregular{\pm}$ %s %s' \
1✔
361
                               % (labels[l].replace("$", ""), val, err, units[p])
362
                else:
363
                    _anotext = r'$\mathregular{%s}$ = %4.2f %s' % (labels[l].replace("$", ""), val, units[p])
1✔
364

365
            anotext += [_anotext]
1✔
366

367
        if hasattr(self.post, 'derived'):
1✔
368
            chains = pd.read_csv(self.status['derive']['chainfile'])
1✔
369
            self.post.nplanets = self.num_planets
1✔
370
            dp = mcmc_plots.DerivedPlot(chains, self.post)
1✔
371
            labels = dp.labels
1✔
372
            texlabels = dp.texlabels
1✔
373
            units = dp.units
1✔
374
            derived_params = ['mpsini']
1✔
375
            for l, par in enumerate(derived_params):
1✔
376
                par_label = par + str(pnum)
1✔
377
                if par_label in self.post.derived.columns:
1✔
378
                    index = np.where(np.array(labels) == par_label)[0][0]
1✔
379

380
                    unit = units[index]
1✔
381
                    if unit == "M$_{\\rm Jup}$":
1✔
382
                        conversion_fac = 0.00315
×
383
                    elif unit == "M$_{\\odot}$":
1✔
384
                        conversion_fac = 0.000954265748
×
385
                    else:
386
                        conversion_fac = 1
1✔
387

388
                    val = self.post.derived["%s%d" % (derived_params[l], pnum)].loc[0.500] * conversion_fac
1✔
389
                    low = self.post.derived["%s%d" % (derived_params[l], pnum)].loc[0.159] * conversion_fac
1✔
390
                    high = self.post.derived["%s%d" % (derived_params[l], pnum)].loc[0.841] * conversion_fac
1✔
391
                    err_low = val - low
1✔
392
                    err_high = high - val
1✔
393
                    err = np.mean([err_low, err_high])
1✔
394
                    err = radvel.utils.round_sig(err)
1✔
395
                    if err > 1e-15:
1✔
396
                        val, err, errlow = sigfig(val, err)
1✔
397
                        _anotext = r'$\mathregular{%s}$ = %s $\mathregular{\pm}$ %s %s' \
1✔
398
                                   % (texlabels[index].replace("$", ""), val, err, units[index])
399
                    else:
400
                        _anotext = r'$\mathregular{%s}$ = %4.2f %s' % (texlabels[index].replace("$", ""), val, units[index])
×
401

402
                    anotext += [_anotext]
1✔
403

404
        anotext = '\n'.join(anotext)
1✔
405
        plot.add_anchored(
1✔
406
            anotext, loc=1, frameon=True, prop=dict(size=self.phasetext_size, weight='bold'),
407
            bbox=dict(ec='none', fc='w', alpha=0.8)
408
        )
409

410
    def plot_multipanel(self, nophase=False, letter_labels=True):
1✔
411
        """
412
        Provision and plot an RV multipanel plot
413

414
        Args:
415
            nophase (bool, optional): if True, don't
416
                include phase plots. Default: False.
417
            letter_labels (bool, optional): if True, include 
418
                letter labels on orbit and residual plots.
419
                Default: True.
420

421
        Returns:
422
            tuple containing:
423
                - current matplotlib Figure object
424
                - list of Axes objects
425
        """
426

427
        if nophase:
1✔
428
            scalefactor = 1
×
429
        else:
430
            scalefactor = self.phase_nrows
1✔
431

432
        figheight = self.ax_rv_height + self.ax_phase_height * scalefactor
1✔
433

434
        # provision figure
435
        fig = pl.figure(figsize=(self.figwidth, figheight))
1✔
436
        
437
        fig.subplots_adjust(left=0.12, right=0.95)
1✔
438
        gs_rv = gridspec.GridSpec(2, 1, height_ratios=[1., 0.5])
1✔
439

440
        divide = 1 - self.ax_rv_height / figheight
1✔
441
        gs_rv.update(left=0.12, right=0.93, top=0.93,
1✔
442
                     bottom=divide+self.rv_phase_space*0.5, hspace=0.)
443

444
        # orbit plot
445
        ax_rv = pl.subplot(gs_rv[0, 0])
1✔
446
        self.ax_list += [ax_rv]
1✔
447

448
        pl.sca(ax_rv)
1✔
449
        self.plot_timeseries()
1✔
450
        if letter_labels:
1✔
451
            pltletter = ord('a')
1✔
452
            plot.labelfig(pltletter)
1✔
453
            pltletter += 1
1✔
454

455
        # residuals
456
        ax_resid = pl.subplot(gs_rv[1, 0])
1✔
457
        self.ax_list += [ax_resid]
1✔
458

459
        pl.sca(ax_resid)
1✔
460
        self.plot_residuals()
1✔
461
        if letter_labels:
1✔
462
            plot.labelfig(pltletter)
1✔
463
            pltletter += 1
1✔
464

465
        # phase-folded plots
466
        if not nophase:
1✔
467
            gs_phase = gridspec.GridSpec(max([1,self.phase_nrows]), max([1,self.phase_ncols]))
1✔
468
            
469
            if self.phase_ncols == 1:
1✔
470
                gs_phase.update(left=0.12, right=0.93,
1✔
471
                                top=divide - self.rv_phase_space * 0.5,
472
                                bottom=0.07, hspace=0.003)
473
            else:
474
                gs_phase.update(left=0.12, right=0.93,
×
475
                                top=divide - self.rv_phase_space * 0.5,
476
                                bottom=0.07, hspace=0.25, wspace=0.25)
477

478
            for i in range(self.num_planets):
1✔
479
                i_row = int(i / self.phase_ncols)
1✔
480
                i_col = int(i - i_row * self.phase_ncols)
1✔
481
                ax_phase = pl.subplot(gs_phase[i_row, i_col])
1✔
482
                self.ax_list += [ax_phase]
1✔
483

484
                pl.sca(ax_phase)
1✔
485
                self.plot_phasefold(pltletter, i+1)
1✔
486
                pltletter += 1
1✔
487

488
        fig.tight_layout()
1✔
489
        if self.saveplot is not None:
1✔
490
            pl.savefig(self.saveplot, dpi=150)
1✔
491
            print("RV multi-panel plot saved to %s" % self.saveplot)
1✔
492

493
        return fig, self.ax_list
1✔
494

495

496
class GPMultipanelPlot(MultipanelPlot):
1✔
497
    """
498
    Class to handle the creation of RV multipanel plots for posteriors fitted
499
    using Gaussian Processes. 
500

501
    Takes the same args as MultipanelPlot, with a few additional bells and whistles...
502
    
503
    Args:
504
        subtract_gp_mean_model (bool, optional): if True, subtract the Gaussian
505
            process mean max likelihood model from the data and the
506
            model when plotting the results. Default: False.
507
        plot_likelihoods_separately (bool, optional): if True, plot a separate
508
            panel for each Likelihood object. Default: False
509
        subtract_orbit_model (bool, optional): if True, subtract the best-fit
510
            orbit model from the data and the model when plotting 
511
            the results. Useful for seeing the structure of correlated
512
            noise in the data. Default: False.
513
        status (ConfigParser): (optional) result of radvel.driver.load_status on the .stat status file
514

515
    """
516
    def __init__(self, post, saveplot=None, epoch=2450000, yscale_auto=False, yscale_sigma=3.0,
1✔
517
                 phase_nrows=None, phase_ncols=None, uparams=None, rv_phase_space=0.08, telfmts={},
518
                 legend=True,
519
                 phase_limits=[], nobin=False, phasetext_size='large',  figwidth=7.5, fit_linewidth=2.0,
520
                 set_xlim=None, text_size=9, legend_kwargs=dict(loc='best'), subtract_gp_mean_model=False,
521
                 plot_likelihoods_separately=False, subtract_orbit_model=False, status=None, separate_orbit_gp=False):
522

523
        super(GPMultipanelPlot, self).__init__(
1✔
524
            post, saveplot=saveplot, epoch=epoch, yscale_auto=yscale_auto,
525
            yscale_sigma=yscale_sigma, phase_nrows=phase_nrows, phase_ncols=phase_ncols,
526
            uparams=uparams, rv_phase_space=rv_phase_space, telfmts=telfmts, legend=legend,
527
            phase_limits=phase_limits, nobin=nobin, phasetext_size=phasetext_size, 
528
            figwidth=figwidth, fit_linewidth=fit_linewidth, set_xlim=set_xlim, text_size=text_size,
529
            legend_kwargs=legend_kwargs
530
        )
531

532
        self.subtract_gp_mean_model = subtract_gp_mean_model
1✔
533
        self.plot_likelihoods_separately = plot_likelihoods_separately
1✔
534
        self.subtract_orbit_model = subtract_orbit_model
1✔
535
        self.separate_orbit_gp = separate_orbit_gp
1✔
536
        if status is not None:
1✔
537
            self.status = status
1✔
538

539
        is_gp = False
1✔
540
        for like in self.like_list:
1✔
541
            if isinstance(like, radvel.likelihood.GPLikelihood):
1✔
542
                is_gp = True
1✔
543
                break
1✔
544
            else:
545
                pass
×
546
        assert is_gp, "This class requires at least one GPLikelihood object in the Posterior."
1✔
547

548
    def plot_gp_like(self, like, orbit_model4data, ci):
1✔
549
        """
550
        Plot a single Gaussian Process Likleihood object in the current Axes, 
551
        including Gaussian Process uncertainty bands.
552

553
        Args:
554
            like (radvel.GPLikelihood): radvel.GPLikelihood object. The model
555
                plotted will be generated from `like.params`.
556
            orbit_model4data (numpy array): 
557
            ci (int): index to use when choosing a color to plot from 
558
                radvel.plot.default_colors. This is only used if the
559
                Likelihood object being plotted is not in the list of defaults.
560
                Increments by 1 if it is used.
561

562
        Returns: current (possibly changed) value of the input `ci`
563

564
        """
565
        ax = pl.gca()
1✔
566

567
        if isinstance(like, radvel.likelihood.GPLikelihood):
1✔
568

569
            if self.set_xlim is not None:
1✔
570
                xpred = np.linspace(self.set_xlim[0]+self.epoch, self.set_xlim[1]+self.epoch, num=int(3e3))
×
571
            else:
572
                xpred = np.linspace(np.min(like.x), np.max(like.x), num=int(3e3))
1✔
573

574
            gpmu, stddev = like.predict(xpred)
1✔
575
            if self.subtract_orbit_model:
1✔
576
                gp_orbit_model = np.zeros(xpred.shape)
×
577
            else:
578
                gp_orbit_model = self.model(xpred)
1✔
579

580
            if ((xpred - self.epoch) < -2.4e6).any():
1✔
581
                pass
×
582
            elif self.epoch == 0:
1✔
583
                self.epoch = 2450000
×
584
                xpred = xpred - self.epoch
×
585
            else:
586
                xpred = xpred - self.epoch
1✔
587

588
            if self.subtract_gp_mean_model:
1✔
589
                gpmu = 0.
×
590
            else:
591
                gp_mean4data, _ = like.predict(like.x)
1✔
592
                indx = np.where(self.post.likelihood.telvec == like.suffix)
1✔
593
                orbit_model4data[indx] += gp_mean4data
1✔
594

595
            if like.suffix not in self.telfmts and like.suffix in plot.telfmts_default:
1✔
596
                color = plot.telfmts_default[like.suffix]['color']
1✔
597
            if like.suffix in self.telfmts:
1✔
598
                color = self.telfmts[like.suffix]['color']
×
599
            if like.suffix not in self.telfmts and like.suffix not in plot.telfmts_default:
1✔
600
                color = plot.default_colors[ci]
×
601
                ci += 1
×
602

603
            ax.fill_between(xpred, gpmu+gp_orbit_model-stddev, gpmu+gp_orbit_model+stddev, 
1✔
604
                            color=color, alpha=0.5, lw=0
605
                            )
606
            if self.separate_orbit_gp:
1✔
607
                ax.plot(xpred, gpmu, '-', color='orange', rasterized=False, lw=0.2, label='GP')
×
608
                ax.plot(xpred, gp_orbit_model, 'g-', rasterized=False, lw=0.2, label="Orbit")
×
609
                ax.plot(xpred, gpmu+gp_orbit_model, 'b-', rasterized=False, lw=0.4, label="Orbit+GP")
×
610
            else:
611
                ax.plot(xpred, gpmu+gp_orbit_model, 'b-', rasterized=False, lw=0.4)
1✔
612
        else:
613
            # plot orbit model
614
            ax.plot(self.mplttimes, self.orbit_model, 'b-', rasterized=False, lw=0.1)
×
615

616
        if not self.yscale_auto: 
1✔
617
            scale = np.std(self.rawresid+self.rvmod)
1✔
618
            ax.set_ylim(-self.yscale_sigma * scale, self.yscale_sigma * scale)
1✔
619

620
        ax.set_ylabel('RV [{ms:}]'.format(**plot.latex), weight='bold')
1✔
621
        ticks = ax.yaxis.get_majorticklocs()
1✔
622
        ax.yaxis.set_ticks(ticks[1:])
1✔
623
        ax.xaxis.set_ticks([])
1✔
624

625
        return ci
1✔
626

627
    def plot_timeseries(self):
1✔
628
        """
629
        Make a plot of the RV data and Gaussian Process + orbit model in the current Axes.
630
        """
631

632
        ax = pl.gca()
1✔
633

634
        ax.axhline(0, color='0.5', linestyle='--')
1✔
635

636
        if self.subtract_orbit_model:
1✔
637
            orbit_model4data = np.zeros(self.rvmod.shape)
×
638
        else:
639
            orbit_model4data = self.rvmod
1✔
640

641
        ci = 0
1✔
642
        for like in self.like_list:
1✔
643
            ci = self.plot_gp_like(like, orbit_model4data, ci)
1✔
644

645
        # plot data
646
        plot.mtelplot(
1✔
647
            # data = residuals + model
648
            self.plttimes, self.rawresid+orbit_model4data, self.rverr,
649
            self.post.likelihood.telvec, ax, telfmts=self.telfmts
650
        )
651
        if self.set_xlim is not None:
1✔
652
            ax.set_xlim(self.set_xlim)
×
653
        else:
654
            ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
1✔
655
        pl.setp(ax.get_xticklabels(), visible=False)
1✔
656

657
        # legend
658
        if self.legend:
1✔
659
            ax.legend(numpoints=1, **self.legend_kwargs)
1✔
660

661
        # years on upper axis
662
        axyrs = ax.twiny()
1✔
663
        xl = np.array(list(ax.get_xlim())) + self.epoch
1✔
664
        decimalyear = Time(xl, format='jd', scale='utc').decimalyear
1✔
665
        axyrs.plot(decimalyear, decimalyear)
1✔
666
        axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
1✔
667
        axyrs.set_xlim(*decimalyear)
1✔
668
        pl.locator_params(axis='x', nbins=5)
1✔
669
        axyrs.set_xlabel('Year', fontweight='bold')
1✔
670

671

672
    def plot_multipanel(self, nophase=False):
1✔
673
        """
674
        Provision and plot an RV multipanel plot for a Posterior object containing 
675
        one or more Gaussian Process Likelihood objects. 
676
        
677
        Args:
678
            nophase (bool, optional): if True, don't
679
                include phase plots. Default: False.
680
        Returns:
681
            tuple containing:
682
                - current matplotlib Figure object
683
                - list of Axes objects
684
        """
685

686
        if not self.plot_likelihoods_separately:
1✔
687
            super(GPMultipanelPlot, self).plot_multipanel()
1✔
688
        else:
689

690
            if nophase:
1✔
691
                scalefactor = 1
×
692
            else:
693
                scalefactor = self.phase_nrows
1✔
694

695
            n_likes = len(self.like_list)
1✔
696
            figheight = self.ax_rv_height*(n_likes//self.phase_ncols+1.5) + self.ax_phase_height * scalefactor
1✔
697

698
            # provision figure
699
            fig = pl.figure(figsize=(self.figwidth, figheight))
1✔
700
            
701
            fig.subplots_adjust(left=0.12, right=0.95)
1✔
702

703
            hrs = np.zeros(n_likes+1) + 1.
1✔
704
            hrs[-1] = 0.5
1✔
705
            gs_rv = gridspec.GridSpec(n_likes+1, 1, height_ratios=hrs)
1✔
706

707
            divide = 1 - self.ax_rv_height*len(self.like_list) / figheight
1✔
708
            gs_rv.update(left=0.12, right=0.93, top=0.93,
1✔
709
                         bottom=divide+self.rv_phase_space*0.5, hspace=0.0)
710

711
            # orbit plot for each likelihood
712
            pltletter = ord('a')
1✔
713

714
            i = 0
1✔
715
            ci = 0
1✔
716
            for like in self.like_list:
1✔
717

718
                ax = pl.subplot(gs_rv[i, 0])
1✔
719
                i += 1
1✔
720
                self.ax_list += [ax]
1✔
721
                pl.sca(ax)
1✔
722

723
                ax.axhline(0, color='0.5', linestyle='--')
1✔
724

725
                if self.subtract_orbit_model:
1✔
726
                    orbit_model4data = np.zeros(self.rvmod.shape)
×
727
                else:
728
                    orbit_model4data = self.rvmod
1✔
729

730
                self.plot_gp_like(like, orbit_model4data, ci)
1✔
731

732
                # plot data
733
                plot.mtelplot(
1✔
734
                    # data = residuals + model
735
                    self.plttimes, self.rawresid+orbit_model4data, self.rverr,
736
                    self.post.likelihood.telvec, ax, telfmts=self.telfmts
737
                )
738

739
                ax.set_xlim(min(self.plttimes)-0.01*self.dt, max(self.plttimes)+0.01*self.dt)    
1✔
740
                pl.setp(ax.get_xticklabels(), visible=False)
1✔
741

742
                # legend
743
                if self.legend and i == 1:
1✔
744
                    ax.legend(numpoints=1, **self.legend_kwargs)
1✔
745

746
                # years on upper axis
747
                if i == 1:
1✔
748
                    axyrs = ax.twiny()
1✔
749
                    xl = np.array(list(ax.get_xlim())) + self.epoch
1✔
750
                    decimalyear = Time(xl, format='jd', scale='utc').decimalyear
1✔
751
                    axyrs.plot(decimalyear, decimalyear)
1✔
752
                    axyrs.get_xaxis().get_major_formatter().set_useOffset(False)
1✔
753
                    axyrs.set_xlim(*decimalyear)
1✔
754
                    axyrs.set_xlabel('Year', fontweight='bold')    
1✔
755

756
                plot.labelfig(pltletter)
1✔
757
                pltletter += 1  
1✔
758

759
            # residuals
760
            ax_resid = pl.subplot(gs_rv[-1, 0])
1✔
761
            self.ax_list += [ax_resid]
1✔
762

763
            pl.sca(ax_resid)
1✔
764
            self.plot_residuals()
1✔
765
            plot.labelfig(pltletter)
1✔
766
            pltletter += 1
1✔
767

768
            # phase-folded plots
769
            if not nophase:
1✔
770
                gs_phase = gridspec.GridSpec(self.phase_nrows, self.phase_ncols)
1✔
771

772
                if self.phase_ncols == 1:
1✔
773
                    gs_phase.update(left=0.12, right=0.93,
1✔
774
                                    top=divide - self.rv_phase_space * 0.5,
775
                                    bottom=0.07, hspace=0.003)
776
                else:
777
                    gs_phase.update(left=0.12, right=0.93,
×
778
                                    top=divide - self.rv_phase_space * 0.5,
779
                                    bottom=0.07, hspace=0.25, wspace=0.25)
780

781
                for i in range(self.num_planets):
1✔
782
                    i_row = int(i / self.phase_ncols)
1✔
783
                    i_col = int(i - i_row * self.phase_ncols)
1✔
784
                    ax_phase = pl.subplot(gs_phase[i_row, i_col])
1✔
785
                    self.ax_list += [ax_phase]
1✔
786

787
                    pl.sca(ax_phase)
1✔
788
                    self.plot_phasefold(pltletter, i+1)
1✔
789
                    pltletter += 1
1✔
790

791
            if self.saveplot is not None:
1✔
792
                pl.savefig(self.saveplot, dpi=150)
1✔
793
                print("RV multi-panel plot saved to %s" % self.saveplot)
1✔
794

795
            return fig, self.ax_list
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc