• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SPF-OST / pytrnsys_process / 13563079957

27 Feb 2025 09:25AM UTC coverage: 98.126% (-0.06%) from 98.182%
13563079957

push

github

ahobeost
CI adjustments

5 of 5 new or added lines in 1 file covered. (100.0%)

4 existing lines in 1 file now uncovered.

1204 of 1227 relevant lines covered (98.13%)

1.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.37
/pytrnsys_process/plotting/plotters.py
1
import typing as _tp
2✔
2
from abc import abstractmethod
2✔
3
from dataclasses import dataclass
2✔
4

5
import matplotlib.pyplot as _plt
2✔
6
import numpy as _np
2✔
7
import pandas as _pd
2✔
8

9
import pytrnsys_process.constants as const
2✔
10
import pytrnsys_process.headers as h
2✔
11
from pytrnsys_process import settings as sett
2✔
12

13
# TODO: provide A4 and half A4 plots to test sizes in latex # pylint: disable=fixme
14
# TODO: provide height as input for plot?  # pylint: disable=fixme
15
# TODO: deal with legends (curve names, fonts, colors, linestyles) # pylint: disable=fixme
16
# TODO: clean up old stuff by refactoring # pylint: disable=fixme
17
# TODO: make issue for docstrings of plotting # pylint: disable=fixme
18
# TODO: Add colormap support # pylint: disable=fixme
19

20

21
# TODO find a better place for this to live in # pylint : disable=fixme
22
plot_settings = sett.settings.plot
2✔
23

24

25
class ChartBase(h.HeaderValidationMixin):
2✔
26
    cmap: str | None = None
2✔
27

28
    def plot(
2✔
29
        self,
30
        df: _pd.DataFrame,
31
        columns: list[str],
32
        **kwargs,
33
    ) -> tuple[_plt.Figure, _plt.Axes]:
34
        fig, ax = self._do_plot(df, columns, **kwargs)
2✔
35
        return fig, ax
2✔
36

37
    # TODO: Test validation # pylint: disable=fixme
38
    def plot_with_column_validation(
2✔
39
        self,
40
        df: _pd.DataFrame,
41
        columns: list[str],
42
        headers: h.Headers,
43
        **kwargs,
44
    ) -> tuple[_plt.Figure, _plt.Axes]:
45
        """Base plot method with header validation.
46

47
        Parameters
48
        __________
49
            df:
50
                DataFrame containing the data to plot
51

52
            columns:
53
                List of column names to plot
54

55
            headers:
56
                Headers instance for validation
57

58
            **kwargs:
59
                Additional plotting arguments
60

61

62
        Raises
63
        ______
64
            ValueError: If any columns are missing from the headers index
65
        """
66
        # TODO: Might live somewhere else in the future # pylint: disable=fixme
67
        is_valid, missing = self.validate_headers(headers, columns)
2✔
68
        if not is_valid:
2✔
69
            missing_details = []
2✔
70
            for col in missing:
2✔
71
                missing_details.append(col)
2✔
72
            raise ValueError(
2✔
73
                "The following columns are not available in the headers index:\n"
74
                + "\n".join(missing_details)
75
            )
76

77
        return self._do_plot(df, columns, **kwargs)
2✔
78

79
    @abstractmethod
2✔
80
    def _do_plot(
2✔
81
        self,
82
        df: _pd.DataFrame,
83
        columns: list[str],
84
        use_legend: bool = True,
85
        size: tuple[float, float] = const.PlotSizes.A4.value,
86
        **kwargs: _tp.Any,
87
    ) -> tuple[_plt.Figure, _plt.Axes]:
88
        """Implement actual plotting logic in subclasses"""
89

90
    def check_for_cmap(self, kwargs, plot_kwargs):
2✔
91
        if "cmap" not in kwargs and "colormap" not in kwargs:
2✔
92
            plot_kwargs["cmap"] = self.cmap
2✔
93
        return plot_kwargs
2✔
94

95
    def get_cmap(self, kwargs) -> str | None:
2✔
96
        if "cmap" not in kwargs and "colormap" not in kwargs:
2✔
97
            return self.cmap
2✔
98

99
        if "cmap" in kwargs:
2✔
100
            return kwargs["cmap"]
2✔
101

102
        if "colormap" in kwargs:
2✔
103
            return kwargs["colormap"]
2✔
104

UNCOV
105
        raise ValueError
×
106

107

108
class StackedBarChart(ChartBase):
2✔
109
    cmap = "inferno_r"
2✔
110

111
    def _do_plot(
2✔
112
        self,
113
        df: _pd.DataFrame,
114
        columns: list[str],
115
        use_legend: bool = True,
116
        size: tuple[float, float] = const.PlotSizes.A4.value,
117
        **kwargs: _tp.Any,
118
    ) -> tuple[_plt.Figure, _plt.Axes]:
119
        fig, ax = _plt.subplots(
2✔
120
            figsize=size,
121
            layout="constrained",
122
        )
123
        plot_kwargs = {
2✔
124
            "stacked": True,
125
            "legend": use_legend,
126
            "ax": ax,
127
            **kwargs,
128
        }
129
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
130
        ax = df[columns].plot.bar(**plot_kwargs)
2✔
131
        ax.set_xticklabels(
2✔
132
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
133
        )
134

135
        return fig, ax
2✔
136

137

138
class BarChart(ChartBase):
2✔
139
    cmap = None
2✔
140

141
    def _do_plot(
2✔
142
        self,
143
        df: _pd.DataFrame,
144
        columns: list[str],
145
        use_legend: bool = True,
146
        size: tuple[float, float] = const.PlotSizes.A4.value,
147
        **kwargs: _tp.Any,
148
    ) -> tuple[_plt.Figure, _plt.Axes]:
149
        # TODO: deal with colors  # pylint: disable=fixme
150
        fig, ax = _plt.subplots(
2✔
151
            figsize=size,
152
            layout="constrained",
153
        )
154
        x = _np.arange(len(df.index))
2✔
155
        width = 0.8 / len(columns)
2✔
156

157
        cmap = self.get_cmap(kwargs)
2✔
158
        if cmap:
2✔
159
            cm = _plt.cm.get_cmap(cmap)
2✔
160
            colors = cm(_np.linspace(0, 1, len(columns)))
2✔
161
        else:
162
            colors = [None] * len(columns)
2✔
163

164
        for i, col in enumerate(columns):
2✔
165
            ax.bar(x + i * width, df[col], width, label=col, color=colors[i])
2✔
166

167
        if use_legend:
2✔
168
            ax.legend()
2✔
169

170
        ax.set_xticks(x + width * (len(columns) - 1) / 2)
2✔
171
        ax.set_xticklabels(
2✔
172
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
173
        )
174
        ax.tick_params(axis="x", labelrotation=90)
2✔
175
        return fig, ax
2✔
176

177

178
class LinePlot(ChartBase):
2✔
179
    cmap: str | None = None
2✔
180

181
    def _do_plot(
2✔
182
        self,
183
        df: _pd.DataFrame,
184
        columns: list[str],
185
        use_legend: bool = True,
186
        size: tuple[float, float] = const.PlotSizes.A4.value,
187
        **kwargs: _tp.Any,
188
    ) -> tuple[_plt.Figure, _plt.Axes]:
189
        fig, ax = _plt.subplots(
2✔
190
            figsize=size,
191
            layout="constrained",
192
        )
193
        plot_kwargs = {
2✔
194
            "legend": use_legend,
195
            "ax": ax,
196
            **kwargs,
197
        }
198
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
199

200
        df[columns].plot.line(**plot_kwargs)
2✔
201
        return fig, ax
2✔
202

203

204
@dataclass
2✔
205
class Histogram(ChartBase):
2✔
206
    bins: int = 50
2✔
207

208
    def _do_plot(
2✔
209
        self,
210
        df: _pd.DataFrame,
211
        columns: list[str],
212
        use_legend: bool = True,
213
        size: tuple[float, float] = const.PlotSizes.A4.value,
214
        **kwargs: _tp.Any,
215
    ) -> tuple[_plt.Figure, _plt.Axes]:
216
        fig, ax = _plt.subplots(
2✔
217
            figsize=size,
218
            layout="constrained",
219
        )
220
        plot_kwargs = {
2✔
221
            "legend": use_legend,
222
            "ax": ax,
223
            "bins": self.bins,
224
            **kwargs,
225
        }
226
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
227
        df[columns].plot.hist(**plot_kwargs)
2✔
228
        return fig, ax
2✔
229

230

231
@dataclass
2✔
232
class ScatterPlot(ChartBase):
2✔
233
    """Handles comparative scatter plots with dual grouping by color and markers."""
234

235
    cmap = "Paired"  # This is ignored when no categorical groupings are used.
2✔
236

237
    # pylint: disable=too-many-arguments,too-many-locals
238
    def _do_plot(
2✔
239
        self,
240
        df: _pd.DataFrame,
241
        columns: list[str],
242
        use_legend: bool = True,
243
        size: tuple[float, float] = const.PlotSizes.A4.value,
244
        group_by_color: str | None = None,
245
        group_by_marker: str | None = None,
246
        **kwargs: _tp.Any,
247
    ) -> tuple[_plt.Figure, _plt.Axes]:
248
        self._validate_inputs(columns)
2✔
249
        x_column, y_column = columns
2✔
250

251
        if not group_by_color and not group_by_marker:
2✔
252
            fig, ax = _plt.subplots(
2✔
253
                figsize=size,
254
                layout="constrained",
255
            )
256
            df.plot.scatter(x=x_column, y=y_column, ax=ax, **kwargs)
2✔
257
            return fig, ax
2✔
258
        # See: https://stackoverflow.com/questions/4700614/
259
        # how-to-put-the-legend-outside-the-plot
260
        # This is required to place the legend in a dedicated subplot
261
        fig, (ax, lax) = _plt.subplots(
2✔
262
            layout="constrained",
263
            figsize=size,
264
            ncols=2,
265
            gridspec_kw={"width_ratios": [4, 1]},
266
        )
267
        df_grouped, group_values = self._prepare_grouping(
2✔
268
            df, group_by_color, group_by_marker
269
        )
270
        cmap = self.get_cmap(kwargs)
2✔
271
        color_map, marker_map = self._create_style_mappings(
2✔
272
            *group_values, cmap=cmap
273
        )
274

275
        self._plot_groups(
2✔
276
            df_grouped,
277
            x_column,
278
            y_column,
279
            color_map,
280
            marker_map,
281
            ax,
282
        )
283

284
        if use_legend:
2✔
285
            self._create_legends(
2✔
286
                lax, color_map, marker_map, group_by_color, group_by_marker
287
            )
288

289
        return fig, ax
2✔
290

291
    def _validate_inputs(
2✔
292
        self,
293
        columns: list[str],
294
    ) -> None:
295
        if len(columns) != 2:
2✔
UNCOV
296
            raise ValueError(
×
297
                "ScatterComparePlotter requires exactly 2 columns (x and y)"
298
            )
299

300
    def _prepare_grouping(
2✔
301
        self,
302
        df: _pd.DataFrame,
303
        color: str | None,
304
        marker: str | None,
305
    ) -> tuple[
306
        _pd.core.groupby.generic.DataFrameGroupBy, tuple[list[str], list[str]]
307
    ]:
308
        group_by = []
2✔
309
        if color:
2✔
310
            group_by.append(color)
2✔
311
        if marker:
2✔
312
            group_by.append(marker)
2✔
313

314
        df_grouped = df.groupby(group_by)
2✔
315

316
        color_values = sorted(df[color].unique()) if color else []
2✔
317
        marker_values = sorted(df[marker].unique()) if marker else []
2✔
318

319
        return df_grouped, (color_values, marker_values)
2✔
320

321
    def _create_style_mappings(
2✔
322
        self,
323
        color_values: list[str],
324
        marker_values: list[str],
325
        cmap: str | None,
326
    ) -> tuple[dict[str, _tp.Any], dict[str, str]]:
327
        if color_values:
2✔
328
            cm = _plt.get_cmap(cmap, len(color_values))
2✔
329
            color_map = {val: cm(i) for i, val in enumerate(color_values)}
2✔
330
        else:
UNCOV
331
            color_map = {}
×
332
        if marker_values:
2✔
333
            marker_map = dict(zip(marker_values, plot_settings.markers))
2✔
334
        else:
UNCOV
335
            marker_map = {}
×
336

337
        return color_map, marker_map
2✔
338

339
    # pylint: disable=too-many-arguments
340
    def _plot_groups(
2✔
341
        self,
342
        df_grouped: _pd.core.groupby.generic.DataFrameGroupBy,
343
        x_column: str,
344
        y_column: str,
345
        color_map: dict[str, _tp.Any],
346
        marker_map: dict[str, str],
347
        ax: _plt.Axes,
348
    ) -> None:
349
        ax.set_xlabel(x_column, fontsize=plot_settings.label_font_size)
2✔
350
        ax.set_ylabel(y_column, fontsize=plot_settings.label_font_size)
2✔
351
        for val, group in df_grouped:
2✔
352
            sorted_group = group.sort_values(x_column)
2✔
353
            x = sorted_group[x_column]
2✔
354
            y = sorted_group[y_column]
2✔
355
            plot_args = {"color": "black"}
2✔
356
            scatter_args = {"marker": "None", "color": "black", "alpha": 0.5}
2✔
357
            if color_map:
2✔
358
                plot_args["color"] = color_map[val[0]]
2✔
359
            if marker_map:
2✔
360
                scatter_args["marker"] = marker_map[val[-1]]
2✔
361
            ax.plot(x, y, **plot_args)  # type: ignore
2✔
362
            ax.scatter(x, y, **scatter_args)  # type: ignore
2✔
363

364
    def _create_legends(
2✔
365
        self,
366
        lax: _plt.Axes,
367
        color_map: dict[str, _tp.Any],
368
        marker_map: dict[str, str],
369
        color_legend_title: str | None,
370
        marker_legend_title: str | None,
371
    ) -> None:
372
        lax.axis("off")
2✔
373

374
        if color_map:
2✔
375
            self._create_color_legend(
2✔
376
                lax, color_map, color_legend_title, bool(marker_map)
377
            )
378
        if marker_map:
2✔
379
            self._create_marker_legend(
2✔
380
                lax, marker_map, marker_legend_title, bool(color_map)
381
            )
382

383
    def _create_color_legend(
2✔
384
        self,
385
        lax: _plt.Axes,
386
        color_map: dict[str, _tp.Any],
387
        color_legend_title: str | None,
388
        has_markers: bool,
389
    ) -> None:
390
        color_handles = [
2✔
391
            _plt.Line2D([], [], color=color, linestyle="-", label=label)
392
            for label, color in color_map.items()
393
        ]
394

395
        legend = lax.legend(
2✔
396
            handles=color_handles,
397
            title=color_legend_title,
398
            bbox_to_anchor=(0, 0, 1, 1),
399
            loc="upper left",
400
            alignment="left",
401
            fontsize=plot_settings.legend_font_size,
402
            borderaxespad=0,
403
        )
404

405
        if has_markers:
2✔
406
            lax.add_artist(legend)
2✔
407

408
    def _create_marker_legend(
2✔
409
        self,
410
        lax: _plt.Axes,
411
        marker_map: dict[str, str],
412
        marker_legend_title: str | None,
413
        has_colors: bool,
414
    ) -> None:
415
        marker_position = 0.7 if has_colors else 1
2✔
416
        marker_handles = [
2✔
417
            _plt.Line2D(
418
                [],
419
                [],
420
                color="black",
421
                marker=marker,
422
                linestyle="None",
423
                label=label,
424
            )
425
            for label, marker in marker_map.items()
426
            if label is not None
427
        ]
428

429
        lax.legend(
2✔
430
            handles=marker_handles,
431
            title=marker_legend_title,
432
            bbox_to_anchor=(0, 0, 1, marker_position),
433
            loc="upper left",
434
            alignment="left",
435
            fontsize=plot_settings.legend_font_size,
436
            borderaxespad=0,
437
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc