• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 4244003410

pending completion
4244003410

Pull #30

github

GitHub
Merge 8f5dfabb6 into 455fe61b5
Pull Request #30: Initial steps towards a plotting API

433 of 433 new or added lines in 11 files covered. (100.0%)

813 of 1609 relevant lines covered (50.53%)

0.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv'
×
2

3

4
import copy
×
5
import numpy as np
×
6
import matplotlib.pyplot as plt
×
7
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
8

9
from coolest.api.analysis import Analysis
×
10
from coolest.api.light_model import CompositeLightModel
×
11
from coolest.api.util import read_json_param
×
12
from coolest.api.plot_util import nice_colorbar
×
13

14
# matplotlib global settings
15
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
16

17

18
class ModelPlotter(object):
×
19
    """
20
    Creates pyplot panels from a lens model stored in the COOLEST format
21
    """
22

23
    def __init__(self, coolest_object, coolest_directory=None):
×
24
        self.coolest = coolest_object
×
25
        self.analysis = Analysis(self.coolest)
×
26
        cmap_flux = copy.copy(plt.get_cmap('magma'))
×
27
        cmap_flux.set_bad('black')
×
28
        self.cmap_flux = cmap_flux
×
29
        self._directory = coolest_directory
×
30

31
    def plot_surface_brightness(self, ax, title=None, coordinates=None, 
×
32
                                norm=None, cmap=None,
33
                                **kwargs_selection):
34
        light_model = CompositeLightModel(self.coolest, self._directory, **kwargs_selection)
×
35
        if cmap is None:
×
36
            cmap = self.cmap_flux
×
37
        if coordinates is not None:
×
38
            x, y = coordinates.pixel_coordinates
×
39
            image = light_model.evaluate_surface_brightness(x, y)
×
40
            extent = coordinates.extent
×
41
            im = self._plot_regular_image(ax, image, extent=extent, 
×
42
                                          cmap=self.cmap_flux, 
43
                                          norm=norm)
44
        else:
45
            values, extent = light_model.surface_brightness(return_extent=True)
×
46
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
47
                image = values
×
48
                im = self._plot_regular_image(ax, image, extent=extent, 
×
49
                                              cmap=self.cmap_flux, 
50
                                              norm=norm)
51
            else: # irregular grid
52
                values = light_model.surface_brightness()
×
53
                im = self._plot_voronoi_image(values)
×
54
                image = None
×
55
        if title is not None:
×
56
            ax.set_title(title)
×
57
        return image
×
58
        
59
    @staticmethod
×
60
    def _plot_regular_image(ax, image, **imshow_kwargs):
61
        im = ax.imshow(image, **imshow_kwargs)
×
62
        nice_colorbar(im)
×
63
        return im
×
64

65
    @staticmethod
×
66
    def _plot_voronoi_image(self, points):
67
        # TODO: incorporate Giorgos' code here
68
        raise NotImplementedError()
×
69

70

71

72
class MultiModelPlotter(object):
×
73
    """
74
    Creates pyplot panels from several lens model
75
    """
76

77
    def __init__(self, coolest_objects, coolest_directories=None):
×
78
        self.num_models = len(coolest_objects)
×
79
        if coolest_directories is None:
×
80
            coolest_directories = self.num_models * [None]
×
81
        self.plotter_list = []
×
82
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
83
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir))
×
84

85
    def plot_surface_brightness(self, axes, titles=None, 
×
86
                                coordinates=None, norm=None, cmap=None,
87
                                **kwargs_selection_list):
88
        if kwargs_selection_list is None:
×
89
            kwargs_selection_list = self.num_models * [{}]
×
90
        if titles is None:
×
91
            titles = self.num_models * [None]
×
92
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
93
        image_list = []
×
94
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
95
            kw_select = {key: val[i] for key, val in kwargs_selection_list.items()}
×
96
            image = plotter.plot_surface_brightness(ax, coordinates=coordinates, 
×
97
                                                    title=titles[i],
98
                                                    norm=norm, cmap=cmap, **kw_select)
99
            image_list.append(image)
×
100
        return image_list
×
101

102

103

104

105
class Comparison_analytical(object):
×
106
    """
107
    Handles plot of analytical models in a comparative way
108
    """
109
    def __init__(self,coolest_file_list, nickname_file_list, posterior_bool_list):
×
110
        self.file_names = nickname_file_list
×
111
        self.posterior_bool_list = posterior_bool_list
×
112
        self.param_lens, self.param_source = read_json_param(coolest_file_list,self.file_names, lens_light=False)
×
113

114
    def plotting_routine(self,param_dict,idx_file=0):
×
115
        """
116
        plot the parameters
117

118
        INPUT
119
        -----
120
        param_dict: dict, organized dictonnary with all parameters results of the different files
121
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
122
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
123
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
124
         sersic and shapelets parameters when available)
125
        """
126

127
        #find the numer of parameters to plot and define a nice looking figure
128
        number_param = len(param_dict[self.file_names[idx_file]])
×
129
        unused_figs = []
×
130
        if number_param <= 4:
×
131
            print('so few parameters not implemented yet')
×
132
        else:
133
            if number_param % 4 == 0:
×
134
                num_lines = int(number_param / 4.)
×
135
            else:
136
                num_lines = int(number_param / 4.) + 1
×
137

138
                for idx in range(3):
×
139
                    if (number_param + idx) % 4 != 0:
×
140
                        unused_figs.append(-idx - 1)
×
141
                    else:
142
                        break
×
143

144
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
145
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
146
        #may find a better way to define markers but right now, it is sufficient
147

148
        for j, file_name in enumerate(self.file_names):
×
149
            i = 0
×
150
            result = param_dict[file_name]
×
151
            for key in result.keys():
×
152
                idx_line = int(i / 4.)
×
153
                idx_col = i % 4
×
154
                p = result[key]
×
155
                m = markers[j]
×
156
                if self.posterior_bool_list[j]:
×
157
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
158
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
159
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
160
                    #                 i+=1
161
                    #                 continue
162

163
                    #trick to plot correct error bars if close to the +180/-180 edge
164
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
165
                        if p['percentile_16th'] > p['median']:
×
166
                            p['percentile_16th'] -= 180.
×
167
                        if p['percentile_84th'] < p['median']:
×
168
                            p['percentile_84th'] += 180.
×
169
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
170
                                                                    [p['percentile_84th'] - p['median']]],
171
                                                   marker=m, ls='', label=file_name)
172
                else:
173
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
174

175
                if j == 0:
×
176
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
177
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
178
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
179
                i += 1
×
180

181
        ax[0, 0].legend()
×
182
        for idx in unused_figs:
×
183
            ax[-1, idx].axis('off')
×
184
        plt.tight_layout()
×
185
        plt.show()
×
186
        return f,ax
×
187
    def plot_source(self,idx_file=0):
×
188
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
189
        return f,ax
×
190
    def plot_lens(self,idx_file=0):
×
191
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
192
        return f,ax
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc