• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Open-MSS / MSS / 10831517141

12 Sep 2024 01:10PM UTC coverage: 70.815% (+0.6%) from 70.166%
10831517141

Pull #2525

github

web-flow
Merge c485163aa into ab1a72f1d
Pull Request #2525: merge develop nach GSOC24

871 of 1084 new or added lines in 36 files covered. (80.35%)

24 existing lines in 10 files now uncovered.

14362 of 20281 relevant lines covered (70.82%)

0.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.47
/mslib/mswms/mss_plot_driver.py
1
# -*- coding: utf-8 -*-
2
"""
3

4
    mslib.mswms.mss_plot_driver
5
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~
6

7
    Driver classes to create plots from ECMWF NetCDF data.
8

9
    This file is part of MSS.
10

11
    :copyright: Copyright 2008-2014 Deutsches Zentrum fuer Luft- und Raumfahrt e.V.
12
    :copyright: Copyright 2011-2014 Marc Rautenhaus (mr)
13
    :copyright: Copyright 2016-2024 by the MSS team, see AUTHORS.
14
    :license: APACHE-2.0, see LICENSE for details.
15

16
    Licensed under the Apache License, Version 2.0 (the "License");
17
    you may not use this file except in compliance with the License.
18
    You may obtain a copy of the License at
19

20
       http://www.apache.org/licenses/LICENSE-2.0
21

22
    Unless required by applicable law or agreed to in writing, software
23
    distributed under the License is distributed on an "AS IS" BASIS,
24
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25
    See the License for the specific language governing permissions and
26
    limitations under the License.
27
"""
28

29
from datetime import datetime
1✔
30

31
import logging
1✔
32
import os
1✔
33
from abc import ABCMeta, abstractmethod
1✔
34

35
import numpy as np
1✔
36

37
from mslib.utils import netCDF4tools
1✔
38
import mslib.utils.coordinate as coordinate
1✔
39
from mslib.utils.units import convert_to, units
1✔
40

41

42
class MSSPlotDriver(metaclass=ABCMeta):
1✔
43
    """
44
    Abstract super class for implementing driver classes that provide
45
    access to the MSS data server.
46

47
    The idea of a driver class is to encapsulate all methods related to
48
    loading data fields into memory. A driver can control objects from
49
    plotting classes that provide (a) a list of required variables and
50
    (b) a plotting method that only accepts data fields already loaded into
51
    memory.
52

53
    MSSPlotDriver implements methods that determine, given a list of required
54
    variables from a plotting instance <plot_object> and a forecast time
55
    specified by initialisation and valid time, the corresponding data files.
56
    The files are opened and the NetCDF variable objects are determined.
57

58
    Classes that derive from this class need to implement the two methods
59
    set_plot_parameters() and plot().
60
    """
61

62
    def __init__(self, data_access_object):
1✔
63
        """
64
        Requires an instance of a data access object from the MSS
65
        configuration (i.e. an NWPDataAccess instance).
66
        """
67
        self.data_access = data_access_object
1✔
68
        self.dataset = None
1✔
69
        self.plot_object = None
1✔
70
        self.filenames = []
1✔
71

72
    def __del__(self):
1✔
73
        """
74
        Closes the open NetCDF dataset, if existing.
75
        """
76
        if self.dataset is not None:
×
77
            self.dataset.close()
×
78

79
    def _set_time(self, init_time, fc_time):
1✔
80
        """
81
        Open the dataset that corresponds to a forecast field specified
82
        by an initialisation and a valid time.
83

84
        This method
85
          determines the files that correspond to an init time and forecast step
86
          checks if an open NetCDF dataset exists
87
            if yes, checks whether it contains the requested valid time
88
              if not, closes the dataset and opens the corresponding one
89
          loads dimension data if required.
90
        """
91
        if len(self.plot_object.required_datafields) == 0:
1✔
92
            logging.debug("no datasets required.")
×
93
            self.dataset = None
×
94
            self.filenames = []
×
95
            self.init_time = None
×
96
            self.fc_time = None
×
97
            self.times = np.array([])
×
98
            self.lat_data = np.array([])
×
99
            self.lon_data = np.array([])
×
100
            self.lat_order = 1
×
101
            self.vert_data = None
×
102
            self.vert_order = None
×
103
            self.vert_units = None
×
104
            return
×
105

106
        if self.uses_inittime_dimension():
1✔
107
            logging.debug("\trequested initialisation time %s", init_time)
1✔
108
            if fc_time < init_time:
1✔
109
                msg = "Forecast valid time cannot be earlier than " \
×
110
                      "initialisation time."
111
                logging.error(msg)
×
112
                raise ValueError(msg)
×
113
        self.fc_time = fc_time
1✔
114
        logging.debug("\trequested forecast valid time %s", fc_time)
1✔
115

116
        # Check if a dataset is open and if it contains the requested times.
117
        # (a dataset will only be open if the used layer has not changed,
118
        # i.e. the required variables have not changed as well).
119
        if (self.dataset is not None) and (self.init_time == init_time) and (fc_time in self.times):
1✔
120
            logging.debug("\tinit time correct and forecast valid time contained (%s).", fc_time)
1✔
121
            if not self.data_access.is_reload_required(self.filenames):
1✔
122
                return
1✔
123
            logging.debug("need to re-open input files.")
×
124
            self.dataset.close()
×
125
            self.dataset = None
×
126

127
        # Determine the input files from the required variables and the
128
        # requested time:
129

130
        # Create the names of the files containing the required parameters.
131
        self.filenames = []
1✔
132
        for vartype, var, _ in self.plot_object.required_datafields:
1✔
133
            filename = self.data_access.get_filename(
1✔
134
                var, vartype, init_time, fc_time, fullpath=True)
135
            if filename not in self.filenames:
1✔
136
                self.filenames.append(filename)
1✔
137
            logging.debug("\tvariable '%s' requires input file '%s'",
1✔
138
                          var, os.path.basename(filename))
139

140
        if len(self.filenames) == 0:
1✔
141
            raise ValueError("no files found that correspond to the specified "
×
142
                             "datafields. Aborting..")
143

144
        self.init_time = init_time
1✔
145

146
        # Open NetCDF files as one dataset with common dimensions.
147
        logging.debug("opening datasets.")
1✔
148
        dsKWargs = self.data_access.mfDatasetArgs()
1✔
149
        dataset = netCDF4tools.MFDatasetCommonDims(self.filenames, **dsKWargs)
1✔
150

151
        # Load and check time dimension. self.dataset will remain None
152
        # if an Exception is raised here.
153
        timename, timevar = netCDF4tools.identify_CF_time(dataset)
1✔
154

155
        times = netCDF4tools.num2date(timevar[:], timevar.units)
1✔
156
        # removed after discussion, see
157
        # https://mss-devel.slack.com/archives/emerge/p1486658769000007
158
        # if init_time != netCDF4tools.num2date(0, timevar.units):
159
        #     dataset.close()
160
        #     raise ValueError("wrong initialisation time in input")
161

162
        if fc_time not in times:
1✔
163
            msg = f"Forecast valid time '{fc_time}' is not available."
×
164
            logging.error(msg)
×
165
            dataset.close()
×
166
            raise ValueError(msg)
×
167

168
        # Load lat/lon dimensions.
169
        try:
1✔
170
            lat_data, lon_data, lat_order = netCDF4tools.get_latlon_data(dataset)
1✔
171
        except Exception as ex:
×
172
            logging.error("ERROR: %s %s", type(ex), ex)
×
173
            dataset.close()
×
174
            raise
×
175

176
        _, vert_data, vert_orientation, vert_units, _ = netCDF4tools.identify_vertical_axis(dataset)
1✔
177
        self.vert_data = vert_data[:] if vert_data is not None else None
1✔
178
        self.vert_order = vert_orientation
1✔
179
        self.vert_units = vert_units
1✔
180

181
        self.dataset = dataset
1✔
182
        self.times = times
1✔
183
        self.lat_data = lat_data
1✔
184
        self.lon_data = lon_data
1✔
185
        self.lat_order = lat_order
1✔
186

187
        # Identify the variable objects from the NetCDF file that correspond
188
        # to the data fields required by the plot object.
189
        self._find_data_vars()
1✔
190

191
    def _find_data_vars(self):
1✔
192
        """
193
        Find NetCDF variables of required data fields.
194

195
        A dictionary data_vars is created. Its keys are the CF standard names
196
        of the variables provided by the plot object. The values are pointers
197
        to the NetCDF variable objects.
198

199
        <data_vars> can be accessed as <self.data_vars>.
200
        """
201
        self.data_vars = {}
1✔
202
        self.data_units = {}
1✔
203
        for df_type, df_name, _ in self.plot_object.required_datafields:
1✔
204
            varname, var = netCDF4tools.identify_variable(self.dataset, df_name, check=True)
1✔
205
            logging.debug("\tidentified variable <%s> for field <%s>", varname, df_name)
1✔
206
            self.data_vars[df_name] = var
1✔
207
            self.data_units[df_name] = getattr(var, "units", None)
1✔
208

209
    def have_data(self, plot_object, init_time, valid_time):
1✔
210
        """
211
        Checks if this driver has the required data to do the plot
212

213
        This inquires the contained data access class if data is available for
214
        all required data fields for the specified times.
215
        """
216
        return all(
×
217
            self.data_access.have_data(var, vartype, init_time, valid_time)
218
            for vartype, var in plot_object.required_datafields)
219

220
    @abstractmethod
1✔
221
    def set_plot_parameters(self, plot_object, init_time=None, valid_time=None,
1✔
222
                            style=None, bbox=None, figsize=(800, 600),
223
                            noframe=False, require_reload=False, transparent=False,
224
                            mime_type="image/png"):
225
        """
226
        Set parameters controlling the plot.
227

228
        Parameters not passed as arguments are reset to standard values.
229

230
        THIS METHOD NEEDS TO BE REIMPLEMENTED IN ANY CLASS DERIVING FROM
231
        MSSPlotDriver!
232

233
        Derived methods need to call the super method before all other
234
        statements.
235
        """
236
        logging.debug("using plot object '%s'", plot_object.name)
1✔
237
        logging.debug("\tfigure size %s in pixels", figsize)
1✔
238

239
        # If the plot object has been changed, the dataset needs to be reloaded
240
        # (the required variables could have changed).
241
        if self.plot_object is not None:
1✔
242
            require_reload = require_reload or (self.plot_object != plot_object)
1✔
243
        if require_reload and self.dataset is not None:
1✔
244
            self.dataset.close()
1✔
245
            self.dataset = None
1✔
246

247
        self.plot_object = plot_object
1✔
248
        self.figsize = figsize
1✔
249
        self.noframe = noframe
1✔
250
        self.style = style
1✔
251
        self.bbox = bbox
1✔
252
        self.transparent = transparent
1✔
253
        self.mime_type = mime_type
1✔
254

255
        self._set_time(init_time, valid_time)
1✔
256

257
    @abstractmethod
1✔
258
    def update_plot_parameters(self, plot_object=None, figsize=None, style=None,
1✔
259
                               bbox=None, init_time=None, valid_time=None,
260
                               noframe=None, transparent=None, mime_type=None):
261
        """
262
        Update parameters controlling the plot.
263

264
        Similar to set_plot_parameters(), but keeps all parameters already
265
        set except the ones that are specified.
266

267
        THIS METHOD NEEDS TO BE REIMPLEMENTED IN ANY CLASS DERIVING FROM
268
        MSSPlotDriver!
269

270
        Derived methods need to call the super method before all other
271
        statements.
272
        """
273
        plot_object = plot_object if plot_object is not None else self.plot_object
×
274
        figsize = figsize if figsize is not None else self.figsize
×
275
        noframe = noframe if noframe is not None else self.noframe
×
276
        init_time = init_time if init_time is not None else self.init_time
×
277
        valid_time = valid_time if valid_time is not None else self.fc_time
×
278
        style = style if style is not None else self.style
×
279
        bbox = bbox if bbox is not None else self.bbox
×
280
        transparent = transparent if transparent is not None else self.transparent
×
281
        mime_type = mime_type if mime_type is not None else self.mime_type
×
282
        # Explicitly call MSSPlotDriver's set_plot_parameters(). A "self.--"
283
        # call would call the derived class's method and thus reset
284
        # parameters specific to the derived class.
285
        MSSPlotDriver.set_plot_parameters(self, plot_object,
×
286
                                          init_time=init_time,
287
                                          valid_time=valid_time,
288
                                          figsize=figsize,
289
                                          style=style,
290
                                          bbox=bbox,
291
                                          noframe=noframe,
292
                                          transparent=transparent,
293
                                          mime_type=mime_type)
294

295
    @abstractmethod
1✔
296
    def plot(self):
1✔
297
        """
298
        Plot the figure (i.e. load the data fields and call the
299
        corresponding plotting routines of the plot object).
300

301
        THIS METHOD NEEDS TO BE REIMPLEMENTED IN ANY CLASS DERIVING FROM
302
        MSSPlotDriver!
303
        """
304
        pass
×
305

306
    def get_init_times(self):
1✔
307
        """
308
        Returns a list of available forecast init times (base times).
309
        """
310
        return self.data_access.get_init_times()
1✔
311

312
    def get_elevations(self, vert_type):
1✔
313
        """
314
        See ECMWFDataAccess.get_elevations().
315
        """
316
        return self.data_access.get_elevations(vert_type)
1✔
317

318
    def get_elevation_units(self, vert_type):
1✔
319
        """
320
        See ECMWFDataAccess.get_elevation().
321
        """
322
        return self.data_access.get_elevation_units(vert_type)
1✔
323

324
    def get_all_valid_times(self, variable, vartype):
1✔
325
        """
326
        See ECMWFDataAccess.get_all_valid_times().
327
        """
328
        return self.data_access.get_all_valid_times(variable, vartype)
1✔
329

330
    def get_valid_times(self, variable, vartype, init_time):
1✔
331
        """
332
        See ECMWFDataAccess.get_valid_times().
333
        """
334
        return self.data_access.get_valid_times(variable, vartype, init_time)
×
335

336
    def uses_inittime_dimension(self):
1✔
337
        """
338
        Returns whether this driver uses the WMS inittime dimensions.
339
        """
340
        return self.data_access.uses_inittime_dimension()
1✔
341

342
    def uses_validtime_dimension(self):
1✔
343
        """
344
        Returns whether this layer uses the WMS time dimensions.
345
        """
346
        return self.data_access.uses_validtime_dimension()
1✔
347

348

349
class VerticalSectionDriver(MSSPlotDriver):
1✔
350
    """
351
    The vertical section driver is responsible for loading the data that
352
    is to be plotted and for calling the plotting routines (that have
353
    to be registered).
354
    """
355

356
    def set_plot_parameters(self, plot_object=None, vsec_path=None,
1✔
357
                            vsec_numpoints=101, vsec_path_connection='linear',
358
                            vsec_numlabels=10,
359
                            init_time=None, valid_time=None, style=None,
360
                            bbox=None, figsize=(800, 600), noframe=False, draw_verticals=False,
361
                            show=False, transparent=False,
362
                            mime_type="image/png"):
363
        """
364
        """
365
        MSSPlotDriver.set_plot_parameters(self, plot_object,
1✔
366
                                          init_time=init_time,
367
                                          valid_time=valid_time,
368
                                          style=style,
369
                                          bbox=bbox,
370
                                          figsize=figsize, noframe=noframe,
371
                                          transparent=transparent,
372
                                          mime_type=mime_type)
373
        self._set_vertical_section_path(vsec_path, vsec_numpoints,
1✔
374
                                        vsec_path_connection)
375
        self.show = show
1✔
376
        self.vsec_numlabels = vsec_numlabels
1✔
377
        self.draw_verticals = draw_verticals
1✔
378

379
    def update_plot_parameters(self, plot_object=None, vsec_path=None,
1✔
380
                               vsec_numpoints=None, vsec_path_connection=None,
381
                               vsec_numlabels=None,
382
                               init_time=None, valid_time=None, style=None,
383
                               bbox=None, figsize=None, noframe=None, draw_verticals=None, show=None,
384
                               transparent=None, mime_type=None):
385
        """
386
        """
387
        plot_object = plot_object if plot_object is not None else self.plot_object
×
388
        figsize = figsize if figsize is not None else self.figsize
×
389
        noframe = noframe if noframe is not None else self.noframe
×
390
        draw_verticals = draw_verticals if draw_verticals else self.draw_verticals
×
391
        init_time = init_time if init_time is not None else self.init_time
×
392
        valid_time = valid_time if valid_time is not None else self.fc_time
×
393
        style = style if style is not None else self.style
×
394
        bbox = bbox if bbox is not None else self.bbox
×
395
        vsec_path = vsec_path if vsec_path is not None else self.vsec_path
×
396
        vsec_numpoints = vsec_numpoints if vsec_numpoints is not None else self.vsec_numpoints
×
397
        vsec_numlabels = vsec_numlabels if vsec_numlabels is not None else self.vsec_numlabels
×
398
        if vsec_path_connection is None:
×
399
            vsec_path_connection = self.vsec_path_connection
×
400
        show = show if show else self.show
×
401
        transparent = transparent if transparent is not None else self.transparent
×
402
        mime_type = mime_type if mime_type is not None else self.mime_type
×
403
        self.set_plot_parameters(plot_object=plot_object,
×
404
                                 vsec_path=vsec_path,
405
                                 vsec_numpoints=vsec_numpoints,
406
                                 vsec_path_connection=vsec_path_connection,
407
                                 vsec_numlabels=vsec_numlabels,
408
                                 init_time=init_time,
409
                                 valid_time=valid_time,
410
                                 style=style,
411
                                 bbox=bbox,
412
                                 figsize=figsize,
413
                                 noframe=noframe,
414
                                 draw_verticals=draw_verticals,
415
                                 show=show,
416
                                 transparent=transparent,
417
                                 mime_type=mime_type)
418

419
    def _set_vertical_section_path(self, vsec_path, vsec_numpoints=101,
1✔
420
                                   vsec_path_connection='linear'):
421
        """
422
        """
423
        logging.debug("computing %i interpolation points, connection: %s",
1✔
424
                      vsec_numpoints, vsec_path_connection)
425
        self.lats, self.lons = coordinate.path_points(
1✔
426
            [_x[0] for _x in vsec_path],
427
            [_x[1] for _x in vsec_path],
428
            numpoints=vsec_numpoints, connection=vsec_path_connection)
429
        self.lats, self.lons = np.asarray(self.lats), np.asarray(self.lons)
1✔
430
        self.vsec_path = vsec_path
1✔
431
        self.vsec_numpoints = vsec_numpoints
1✔
432
        self.vsec_path_connection = vsec_path_connection
1✔
433

434
    def _load_interpolate_timestep(self):
1✔
435
        """
436
        Load and interpolate the data fields as required by the vertical
437
        section style instance. Only data of time <fc_time> is processed.
438

439
        Shifts the data fields such that the longitudes are in the range
440
        left_longitude .. left_longitude+360, where left_longitude is the
441
        westmost longitude appearing in the list of waypoints minus one
442
        gridpoint (to include all waypoint longitudes).
443

444
        Necessary to prevent data cut-offs in situations where the requested
445
        cross section crosses the data longitude boundaries (e.g. data is
446
        stored on a 0..360 grid, but the path is in the range -10..+20).
447
        """
448
        if self.dataset is None:
1✔
449
            return {}
×
450
        data = {}
1✔
451

452
        timestep = self.times.searchsorted(self.fc_time)
1✔
453
        logging.debug("loading data for time step %s (%s)", timestep, self.fc_time)
1✔
454

455
        # Determine the westmost longitude in the cross-section path. Subtract
456
        # one gridbox size to obtain "left_longitude".
457
        dlon = self.lon_data[1] - self.lon_data[0]
1✔
458
        left_longitude = np.unwrap(self.lons, period=360).min() - dlon
1✔
459
        logging.debug("shifting data grid to gridpoint west of westmost "
1✔
460
                      "longitude in path: %.2f (path %.2f).",
461
                      left_longitude, self.lons.min())
462

463
        # Shift the longitude field such that the data is in the range
464
        # left_longitude .. left_longitude+360.
465
        # NOTE: This does not overwrite self.lon_data (which is required
466
        # in its original form in case other data is loaded while this
467
        # file is open).
468
        lon_data = ((self.lon_data - left_longitude) % 360) + left_longitude
1✔
469
        lon_indices = lon_data.argsort()
1✔
470
        lon_data = lon_data[lon_indices]
1✔
471
        # Identify jump in longitudes due to non-global dataset
472
        dlon_data = np.diff(lon_data)
1✔
473
        jump = np.where(dlon_data > 2 * dlon)[0]
1✔
474

475
        lons = ((self.lons - left_longitude) % 360) + left_longitude
1✔
476

477
        for name, var in self.data_vars.items():
1✔
478
            if len(var.shape) == 4:
1✔
479
                var_data = var[timestep, ::-self.vert_order, ::self.lat_order, :]
1✔
480
            else:
481
                var_data = var[:][timestep, np.newaxis, ::self.lat_order, :]
×
482
            logging.debug("\tLoaded %.2f Mbytes from data field <%s> at timestep %s.",
1✔
483
                          var_data.nbytes / 1048576., name, timestep)
484
            logging.debug("\tVertical dimension direction is %s.",
1✔
485
                          "up" if self.vert_order == 1 else "down")
486
            logging.debug("\tInterpolating to cross-section path.")
1✔
487
            # Re-arange longitude dimension in the data field.
488
            var_data = var_data[:, :, lon_indices]
1✔
489
            if jump:
1✔
490
                logging.debug("\tsetting jump data to NaN at %s", jump)
1✔
491
                var_data = var_data.copy()
1✔
492
                var_data[:, :, jump] = np.nan
1✔
493
            data[name] = coordinate.interpolate_vertsec(var_data, self.lat_data, lon_data, self.lats, lons)
1✔
494
            # Free memory.
495
            del var_data
1✔
496

497
        return data
1✔
498

499
    def shift_data(self):
1✔
500
        """
501
        Shift the data fields such that the longitudes are in the range
502
        left_longitude .. left_longitude+360, where left_longitude is the
503
        westmost longitude appearing in the list of waypoints minus one
504
        gridpoint (to include all waypoint longitudes).
505

506
        Necessary to prevent data cut-offs in situations where the requested
507
        cross section crosses the data longitude boundaries (e.g. data is
508
        stored on a 0..360 grid, but the path is in the range -10..+20).
509
        """
510
        # Determine the leftmost longitude in the plot.
511
        left_longitude = self.lons.min()
×
512
        logging.debug("shifting data grid to leftmost longitude in path "
×
513
                      "(%.2f)..", left_longitude)
514

515
        # Shift the longitude field such that the data is in the range
516
        # left_longitude .. left_longitude+360.
517
        self.lons = ((self.lons - left_longitude) % 360) + left_longitude
×
518
        lon_indices = self.lons.argsort()
×
519
        self.lons = self.lons[lon_indices]
×
520

521
        # Shift data fields correspondingly.
522
        for key in self.data:
×
523
            self.data[key] = self.data[key][:, lon_indices]
×
524

525
    def plot(self):
1✔
526
        """
527
        """
528
        d1 = datetime.now()
1✔
529

530
        # Load and interpolate the data fields as required by the vertical
531
        # section style instance. <data> is a dictionary containing the
532
        # interpolated curtains of the variables identified through CF
533
        # standard names as specified by <self.vsec_style_instance>.
534
        data = self._load_interpolate_timestep()
1✔
535

536
        d2 = datetime.now()
1✔
537
        logging.debug("Loaded and interpolated data (required time %s).", d2 - d1)
1✔
538
        logging.debug("Plotting interpolated curtain.")
1✔
539

540
        if len(self.lat_data) > 1 and len(self.lon_data) > 1:
1✔
541
            resolution = (self.lon_data[1] - self.lon_data[0],
1✔
542
                          self.lat_data[1] - self.lat_data[0])
543
        else:
544
            resolution = (-1, -1)
×
545

546
        if self.mime_type not in ("image/png", "text/xml"):
1✔
547
            raise RuntimeError(f"Unexpected format for vertical sections '{self.mime_type}'.")
1✔
548

549
        # Call the plotting method of the vertical section style instance.
550
        image = self.plot_object.plot_vsection(data, self.lats, self.lons,
1✔
551
                                               valid_time=self.fc_time,
552
                                               init_time=self.init_time,
553
                                               resolution=resolution,
554
                                               bbox=self.bbox,
555
                                               style=self.style,
556
                                               show=self.show,
557
                                               highlight=self.vsec_path,
558
                                               noframe=self.noframe,
559
                                               figsize=self.figsize,
560
                                               draw_verticals=self.draw_verticals,
561
                                               transparent=self.transparent,
562
                                               numlabels=self.vsec_numlabels,
563
                                               mime_type=self.mime_type)
564
        # Free memory.
565
        del data
1✔
566

567
        d3 = datetime.now()
1✔
568
        logging.debug("Finished plotting (required time %s; total "
1✔
569
                      "time %s).\n", d3 - d2, d3 - d1)
570

571
        return image
1✔
572

573

574
class HorizontalSectionDriver(MSSPlotDriver):
1✔
575
    """
576
    The horizontal section driver is responsible for loading the data that
577
    is to be plotted and for calling the plotting routines (that have
578
    to be registered).
579
    """
580

581
    def set_plot_parameters(self, plot_object=None, bbox=None, level=None, crs=None, init_time=None, valid_time=None,
1✔
582
                            style=None, figsize=(800, 600), noframe=False, show=False, transparent=False,
583
                            mime_type="image/png"):
584
        """
585
        """
586
        MSSPlotDriver.set_plot_parameters(self, plot_object,
1✔
587
                                          init_time=init_time,
588
                                          valid_time=valid_time,
589
                                          style=style,
590
                                          bbox=bbox,
591
                                          figsize=figsize, noframe=noframe,
592
                                          transparent=transparent,
593
                                          mime_type=mime_type)
594
        self.level = level
1✔
595
        self.actual_level = None
1✔
596
        self.crs = crs
1✔
597
        self.show = show
1✔
598

599
    def update_plot_parameters(self, plot_object=None, bbox=None, level=None, crs=None, init_time=None, valid_time=None,
1✔
600
                               style=None, figsize=None, noframe=None, show=None, transparent=None, mime_type=None):
601
        """
602
        """
603
        plot_object = plot_object if plot_object is not None else self.plot_object
×
604
        figsize = figsize if figsize is not None else self.figsize
×
605
        noframe = noframe if noframe is not None else self.noframe
×
606
        init_time = init_time if init_time is not None else self.init_time
×
607
        valid_time = valid_time if valid_time is not None else self.fc_time
×
608
        style = style if style is not None else self.style
×
609
        bbox = bbox if bbox is not None else self.bbox
×
610
        level = level if level is not None else self.level
×
611
        crs = crs if crs is not None else self.crs
×
612
        show = show if show is not None else self.show
×
613
        transparent = transparent if transparent is not None else self.transparent
×
614
        mime_type = mime_type if mime_type is not None else self.mime_type
×
615
        self.set_plot_parameters(plot_object=plot_object, bbox=bbox, level=level, crs=crs, init_time=init_time,
×
616
                                 valid_time=valid_time, style=style, figsize=figsize, noframe=noframe, show=show,
617
                                 transparent=transparent, mime_type=mime_type)
618

619
    def _load_timestep(self):
1✔
620
        """
621
        Load the data fields as required by the horizontal section style
622
        instance at the current timestep.
623
        """
624
        if self.dataset is None:
1✔
625
            return {}
×
626
        data = {}
1✔
627
        timestep = self.times.searchsorted(self.fc_time)
1✔
628
        level = None
1✔
629
        if self.level is not None:
1✔
630
            # select the nearest level available
631
            level = np.abs(self.vert_data - self.level).argmin()
1✔
632
            if abs(self.vert_data[level] - self.level) > 1e-3 * np.abs(np.diff(self.vert_data).mean()):
1✔
633
                raise ValueError("Requested elevation not available.")
1✔
634
            self.actual_level = self.vert_data[level]
1✔
635
        logging.debug("loading data for time step %s (%s), level index %s (level %s)",
1✔
636
                      timestep, self.fc_time, level, self.actual_level)
637
        for name, var in self.data_vars.items():
1✔
638
            if level is None or len(var.shape) == 3:
1✔
639
                # 2D fields: time, lat, lon.
640
                var_data = var[timestep, ::self.lat_order, :]
1✔
641
            else:
642
                # 3D fields: time, level, lat, lon.
643
                var_data = var[timestep, level, ::self.lat_order, :]
1✔
644
            logging.debug("\tLoaded %.2f Mbytes from data field <%s>.",
1✔
645
                          var_data.nbytes / 1048576., name)
646
            data[name] = var_data
1✔
647
            # Free memory.
648
            del var_data
1✔
649

650
        return data
1✔
651

652
    def plot(self):
1✔
653
        """
654
        """
655
        d1 = datetime.now()
1✔
656

657
        # Load and interpolate the data fields as required by the horizontal
658
        # section style instance. <data> is a dictionary containing the
659
        # horizontal sections of the variables identified through CF
660
        # standard names as specified by <self.hsec_style_instance>.
661
        data = self._load_timestep()
1✔
662

663
        d2 = datetime.now()
1✔
664
        logging.debug("Loaded data (required time %s).", (d2 - d1))
1✔
665
        logging.debug("Plotting horizontal section.")
1✔
666

667
        if len(self.lat_data) > 1:
1✔
668
            resolution = (self.lat_data[1] - self.lat_data[0])
1✔
669
        else:
670
            resolution = 0
×
671

672
        if self.mime_type != "image/png":
1✔
673
            raise RuntimeError(f"Unexpected format for horizontal sections '{self.mime_type}'.")
1✔
674

675
        # Call the plotting method of the horizontal section style instance.
676
        image = self.plot_object.plot_hsection(data,
1✔
677
                                               self.lat_data,
678
                                               self.lon_data,
679
                                               self.bbox,
680
                                               level=self.actual_level,
681
                                               valid_time=self.fc_time,
682
                                               init_time=self.init_time,
683
                                               resolution=resolution,
684
                                               show=self.show,
685
                                               crs=self.crs,
686
                                               style=self.style,
687
                                               noframe=self.noframe,
688
                                               figsize=self.figsize,
689
                                               transparent=self.transparent)
690
        # Free memory.
691
        del data
1✔
692

693
        d3 = datetime.now()
1✔
694
        logging.debug("Finished plotting (required time %s; total "
1✔
695
                      "time %s).\n", d3 - d2, d3 - d1)
696

697
        return image
1✔
698

699

700
class LinearSectionDriver(VerticalSectionDriver):
1✔
701
    """
702
        The linear plot driver is responsible for loading the data that
703
        is to be plotted and for calling the plotting routines (that have
704
        to be registered).
705
        """
706

707
    def set_plot_parameters(self, plot_object=None, lsec_path=None,
1✔
708
                            lsec_numpoints=101, lsec_path_connection='linear',
709
                            init_time=None, valid_time=None, bbox=None, mime_type=None):
710
        """
711
        """
712
        MSSPlotDriver.set_plot_parameters(self, plot_object,
1✔
713
                                          init_time=init_time,
714
                                          valid_time=valid_time,
715
                                          bbox=bbox, mime_type=mime_type)
716
        self._set_linear_section_path(lsec_path, lsec_numpoints, lsec_path_connection)
1✔
717

718
    def update_plot_parameters(self, plot_object=None, lsec_path=None,
1✔
719
                               lsec_numpoints=None, lsec_path_connection=None,
720
                               init_time=None, valid_time=None, bbox=None, mime_type=None):
721
        """
722
        """
723
        plot_object = plot_object if plot_object is not None else self.plot_object
×
724
        init_time = init_time if init_time is not None else self.init_time
×
725
        valid_time = valid_time if valid_time is not None else self.fc_time
×
726
        bbox = bbox if bbox is not None else self.bbox
×
727
        lsec_path = lsec_path if lsec_path is not None else self.lsec_path
×
728
        lsec_numpoints = lsec_numpoints if lsec_numpoints is not None else self.lsec_numpoints
×
729
        if lsec_path_connection is None:
×
730
            lsec_path_connection = self.lsec_path_connection
×
731
        mime_type = mime_type if mime_type is not None else self.mime_type
×
732
        self.set_plot_parameters(plot_object=plot_object,
×
733
                                 lsec_path=lsec_path,
734
                                 lsec_numpoints=lsec_numpoints,
735
                                 lsec_path_connection=lsec_path_connection,
736
                                 init_time=init_time,
737
                                 valid_time=valid_time,
738
                                 bbox=bbox,
739
                                 mime_type=mime_type)
740

741
    def _set_linear_section_path(self, lsec_path, lsec_numpoints=101, lsec_path_connection='linear'):
1✔
742
        """
743
        """
744
        logging.debug("computing %i interpolation points, connection: %s",
1✔
745
                      lsec_numpoints, lsec_path_connection)
746
        self.lats, self.lons, self.alts = coordinate.path_points(
1✔
747
            [_x[0] for _x in lsec_path],
748
            [_x[1] for _x in lsec_path],
749
            alts=[_x[2] for _x in lsec_path],
750
            numpoints=lsec_numpoints, connection=lsec_path_connection)
751
        self.lats, self.lons, self.alts = [
1✔
752
            np.asarray(_x) for _x in (self.lats, self.lons, self.alts)]
753
        self.lsec_path = lsec_path
1✔
754
        self.lsec_numpoints = lsec_numpoints
1✔
755
        self.lsec_path_connection = lsec_path_connection
1✔
756

757
    def _load_interpolate_timestep(self):
1✔
758
        """
759
        Load and interpolate the data fields as required by the linear
760
        section style instance. Only data of time <fc_time> is processed.
761

762
        Shifts the data fields such that the longitudes are in the range
763
        left_longitude .. left_longitude+360, where left_longitude is the
764
        westmost longitude appearing in the list of waypoints minus one
765
        gridpoint (to include all waypoint longitudes).
766

767
        Necessary to prevent data cut-offs in situations where the requested
768
        cross section crosses the data longitude boundaries (e.g. data is
769
        stored on a 0..360 grid, but the path is in the range -10..+20).
770
        """
771
        if self.dataset is None:
1✔
772
            return {}
×
773
        data = {}
1✔
774

775
        timestep = self.times.searchsorted(self.fc_time)
1✔
776
        logging.debug("loading data for time step %s (%s)", timestep, self.fc_time)
1✔
777

778
        # Determine the westmost longitude in the cross-section path. Subtract
779
        # one gridbox size to obtain "left_longitude".
780
        dlon = self.lon_data[1] - self.lon_data[0]
1✔
781
        left_longitude = np.unwrap(self.lons, period=360).min() - dlon
1✔
782
        logging.debug("shifting data grid to gridpoint west of westmost "
1✔
783
                      "longitude in path: %.2f (path %.2f).",
784
                      left_longitude, self.lons.min())
785

786
        # Shift the longitude field such that the data is in the range
787
        # left_longitude .. left_longitude+360.
788
        # NOTE: This does not overwrite self.lon_data (which is required
789
        # in its original form in case other data is loaded while this
790
        # file is open).
791
        lon_data = ((self.lon_data - left_longitude) % 360) + left_longitude
1✔
792
        lon_indices = lon_data.argsort()
1✔
793
        lon_data = lon_data[lon_indices]
1✔
794
        # Identify jump in longitudes due to non-global dataset
795
        dlon_data = np.diff(lon_data)
1✔
796
        jump = np.where(dlon_data > 2 * dlon)[0]
1✔
797

798
        lons = ((self.lons - left_longitude) % 360) + left_longitude
1✔
799
        factors = []
1✔
800

801
        pressures = None
1✔
802
        if "air_pressure" not in self.data_vars:
1✔
803
            if units(self.vert_units).check("[pressure]"):
1✔
804
                pressures = np.log(convert_to(
1✔
805
                    self.vert_data[::-self.vert_order, np.newaxis],
806
                    self.vert_units, "Pa").repeat(len(self.lats), axis=1))
807
            else:
NEW
808
                raise ValueError(
×
809
                    "air_pressure must be available for linear plotting layers "
810
                    "with non-pressure axis. Please add to required_datafields.")
811

812
        # Make sure air_pressure is the first to be evaluated if needed
813
        variables = list(self.data_vars)
1✔
814
        if "air_pressure" in self.data_vars:
1✔
815
            if variables[0] != "air_pressure":
1✔
816
                variables.insert(0, variables.pop(variables.index("air_pressure")))
×
817

818
        for name in variables:
1✔
819
            var = self.data_vars[name]
1✔
820
            data[name] = []
1✔
821
            if len(var.shape) == 4:
1✔
822
                var_data = var[:][timestep, ::-self.vert_order, ::self.lat_order, :]
1✔
823
            else:
824
                var_data = var[:][timestep, np.newaxis, ::self.lat_order, :]
×
825
            logging.debug("\tLoaded %.2f Mbytes from data field <%s> at timestep %s.",
1✔
826
                          var_data.nbytes / 1048576., name, timestep)
827
            logging.debug("\tVertical dimension direction is %s.",
1✔
828
                          "up" if self.vert_order == 1 else "down")
829
            logging.debug("\tInterpolating to cross-section path.")
1✔
830
            # Re-arange longitude dimension in the data field.
831
            var_data = var_data[:, :, lon_indices]
1✔
832
            if jump:
1✔
833
                logging.debug("\tsetting jump data to NaN at %s", jump)
1✔
834
                var_data = var_data.copy()
1✔
835
                var_data[:, :, jump] = np.nan
1✔
836

837
            cross_section = coordinate.interpolate_vertsec(var_data, self.lat_data, lon_data, self.lats, lons)
1✔
838
            # Create vertical interpolation factors and indices for subsequent variables
839
            # TODO: Improve performance for this interpolation in general
840
            if len(factors) == 0:
1✔
841
                if name == "air_pressure":
1✔
842
                    pressures = np.log(convert_to(cross_section, self.data_units[name], "Pa"))
1✔
843
                for index_lonlat, alt in enumerate(np.log(self.alts)):
1✔
844
                    pressure = pressures[:, index_lonlat]
1✔
845
                    idx0 = None
1✔
846
                    for index_altitude in range(len(pressures) - 1):
1✔
847
                        if (pressure[index_altitude] <= alt <= pressure[index_altitude + 1]) or \
1✔
848
                           (pressure[index_altitude] >= alt >= pressure[index_altitude + 1]):
849
                            idx0 = index_altitude
1✔
850
                            break
1✔
851
                    if idx0 is None:
1✔
NEW
852
                        factors.append(((0, np.nan), (0, np.nan)))
×
853
                        continue
854

855
                    idx1 = idx0 + 1
1✔
856
                    fac1 = (pressure[idx0] - alt) / (pressure[idx0] - pressure[idx1])
1✔
857
                    fac0 = 1 - fac1
1✔
858
                    assert 0 <= fac0 <= 1, fac0
1✔
859
                    factors.append(((idx0, fac0), (idx1, fac1)))
1✔
860

861
            # Interpolate with the previously calculated pressure indices and factors
862
            for index, ((idx0, w0), (idx1, w1)) in enumerate(factors):
1✔
863
                value = cross_section[idx0, index] * w0 + cross_section[idx1, index] * w1
1✔
864
                data[name].append(value)
1✔
865

866
            # Free memory.
867
            del var_data
1✔
868
            data[name] = np.array(data[name])
1✔
869

870
        return data
1✔
871

872
    def plot(self):
1✔
873
        """
874
        """
875
        d1 = datetime.now()
1✔
876

877
        # Load and interpolate the data fields as required by the linear
878
        # section style instance. <data> is a dictionary containing the
879
        # interpolated curtains of the variables identified through CF
880
        # standard names as specified by <self.lsec_style_instance>.
881
        data = self._load_interpolate_timestep()
1✔
882
        d2 = datetime.now()
1✔
883

884
        if self.mime_type != "text/xml":
1✔
885
            raise RuntimeError(f"Unexpected format for linear sections '{self.mime_type}'.")
1✔
886

887
        # Call the plotting method of the linear section style instance.
888
        image = self.plot_object.plot_lsection(data, self.lats, self.lons,
1✔
889
                                               valid_time=self.fc_time,
890
                                               init_time=self.init_time)
891
        # Free memory.
892
        del data
1✔
893

894
        d3 = datetime.now()
1✔
895
        logging.debug("Finished plotting (required time %s; total "
1✔
896
                      "time %s).\n", d3 - d2, d3 - d1)
897

898
        return image
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc