• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SPF-OST / pytrnsys_process / 13560850870

27 Feb 2025 07:08AM UTC coverage: 98.182% (+0.02%) from 98.165%
13560850870

push

github

ahobeost
CI adjustments

7 of 8 new or added lines in 3 files covered. (87.5%)

4 existing lines in 2 files now uncovered.

1188 of 1210 relevant lines covered (98.18%)

1.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.78
/pytrnsys_process/plotting/plotters.py
1
import typing as _tp
2✔
2
from abc import abstractmethod
2✔
3
from dataclasses import dataclass
2✔
4

5
import matplotlib.pyplot as _plt
2✔
6
import numpy as _np
2✔
7
import pandas as _pd
2✔
8

9
import pytrnsys_process.constants as const
2✔
10
import pytrnsys_process.headers as h
2✔
11
from pytrnsys_process import settings as sett
2✔
12

13
# TODO: provide A4 and half A4 plots to test sizes in latex # pylint: disable=fixme
14
# TODO: provide height as input for plot?  # pylint: disable=fixme
15
# TODO: deal with legends (curve names, fonts, colors, linestyles) # pylint: disable=fixme
16
# TODO: clean up old stuff by refactoring # pylint: disable=fixme
17
# TODO: make issue for docstrings of plotting # pylint: disable=fixme
18
# TODO: Add colormap support # pylint: disable=fixme
19

20

21
# TODO find a better place for this to live in # pylint : disable=fixme
22
plot_settings = sett.settings.plot
2✔
23

24

25
class ChartBase(h.HeaderValidationMixin):
2✔
26
    cmap: str | None = None
2✔
27

28
    def plot(
2✔
29
        self,
30
        df: _pd.DataFrame,
31
        columns: list[str],
32
        **kwargs,
33
    ) -> tuple[_plt.Figure, _plt.Axes]:
34
        fig, ax = self._do_plot(df, columns, **kwargs)
2✔
35
        return fig, ax
2✔
36

37
    # TODO: Test validation # pylint: disable=fixme
38
    def plot_with_column_validation(
2✔
39
        self,
40
        df: _pd.DataFrame,
41
        columns: list[str],
42
        headers: h.Headers,
43
        **kwargs,
44
    ) -> tuple[_plt.Figure, _plt.Axes]:
45
        """Base plot method with header validation.
46

47
        Parameters
48
        __________
49
            df:
50
                DataFrame containing the data to plot
51

52
            columns:
53
                List of column names to plot
54

55
            headers:
56
                Headers instance for validation
57

58
            **kwargs:
59
                Additional plotting arguments
60

61

62
        Raises
63
        ______
64
            ValueError: If any columns are missing from the headers index
65
        """
66
        # TODO: Might live somewhere else in the future # pylint: disable=fixme
67
        is_valid, missing = self.validate_headers(headers, columns)
2✔
68
        if not is_valid:
2✔
69
            missing_details = []
2✔
70
            for col in missing:
2✔
71
                missing_details.append(col)
2✔
72
            raise ValueError(
2✔
73
                "The following columns are not available in the headers index:\n"
74
                + "\n".join(missing_details)
75
            )
76

77
        return self._do_plot(df, columns, **kwargs)
2✔
78

79
    @abstractmethod
2✔
80
    def _do_plot(
2✔
81
        self,
82
        df: _pd.DataFrame,
83
        columns: list[str],
84
        use_legend: bool = True,
85
        size: tuple[float, float] = const.PlotSizes.A4.value,
86
        **kwargs: _tp.Any,
87
    ) -> tuple[_plt.Figure, _plt.Axes]:
88
        """Implement actual plotting logic in subclasses"""
89

90
    def check_for_cmap(self, kwargs, plot_kwargs):
2✔
91
        if "cmap" not in kwargs and "colormap" not in kwargs:
2✔
92
            plot_kwargs["cmap"] = self.cmap
2✔
93

94

95
class StackedBarChart(ChartBase):
2✔
96
    cmap = "inferno_r"
2✔
97

98
    def _do_plot(
2✔
99
        self,
100
        df: _pd.DataFrame,
101
        columns: list[str],
102
        use_legend: bool = True,
103
        size: tuple[float, float] = const.PlotSizes.A4.value,
104
        **kwargs: _tp.Any,
105
    ) -> tuple[_plt.Figure, _plt.Axes]:
106
        fig, ax = _plt.subplots(
2✔
107
            figsize=size,
108
            layout="constrained",
109
        )
110
        plot_kwargs = {
2✔
111
            "stacked": True,
112
            "legend": use_legend,
113
            "ax": ax,
114
            **kwargs,
115
        }
116
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
117
        ax = df[columns].plot.bar(**plot_kwargs)
2✔
118
        ax.set_xticklabels(
2✔
119
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
120
        )
121

122
        return fig, ax
2✔
123

124

125
class BarChart(ChartBase):
2✔
126

127
    def _do_plot(
2✔
128
        self,
129
        df: _pd.DataFrame,
130
        columns: list[str],
131
        use_legend: bool = True,
132
        size: tuple[float, float] = const.PlotSizes.A4.value,
133
        **kwargs: _tp.Any,
134
    ) -> tuple[_plt.Figure, _plt.Axes]:
135
        # TODO: deal with cmap  # pylint: disable=fixme
136
        fig, ax = _plt.subplots(
2✔
137
            figsize=size,
138
            layout="constrained",
139
        )
140
        x = _np.arange(len(df.index))
2✔
141
        width = 0.8 / len(columns)
2✔
142

143
        for i, col in enumerate(columns):
2✔
144
            ax.bar(x + i * width, df[col], width, label=col)
2✔
145

146
        if use_legend:
2✔
147
            ax.legend()
2✔
148

149
        ax.set_xticks(x + width * (len(columns) - 1) / 2)
2✔
150
        ax.set_xticklabels(
2✔
151
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
152
        )
153
        ax.tick_params(axis="x", labelrotation=90)
2✔
154
        return fig, ax
2✔
155

156

157
class LinePlot(ChartBase):
2✔
158
    cmap: str | None = None
2✔
159

160
    def _do_plot(
2✔
161
        self,
162
        df: _pd.DataFrame,
163
        columns: list[str],
164
        use_legend: bool = True,
165
        size: tuple[float, float] = const.PlotSizes.A4.value,
166
        **kwargs: _tp.Any,
167
    ) -> tuple[_plt.Figure, _plt.Axes]:
168
        fig, ax = _plt.subplots(
2✔
169
            figsize=size,
170
            layout="constrained",
171
        )
172
        plot_kwargs = {
2✔
173
            "legend": use_legend,
174
            "ax": ax,
175
            **kwargs,
176
        }
177
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
178

179
        df[columns].plot.line(**plot_kwargs)
2✔
180
        return fig, ax
2✔
181

182

183
@dataclass
2✔
184
class Histogram(ChartBase):
2✔
185
    bins: int = 50
2✔
186

187
    def _do_plot(
2✔
188
        self,
189
        df: _pd.DataFrame,
190
        columns: list[str],
191
        use_legend: bool = True,
192
        size: tuple[float, float] = const.PlotSizes.A4.value,
193
        **kwargs: _tp.Any,
194
    ) -> tuple[_plt.Figure, _plt.Axes]:
195
        fig, ax = _plt.subplots(
2✔
196
            figsize=size,
197
            layout="constrained",
198
        )
199
        plot_kwargs = {
2✔
200
            "legend": use_legend,
201
            "ax": ax,
202
            "bins": self.bins,
203
            **kwargs,
204
        }
205
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
206
        df[columns].plot.hist(**plot_kwargs)
2✔
207
        return fig, ax
2✔
208

209

210
@dataclass
2✔
211
class ScatterPlot(ChartBase):
2✔
212
    """Handles comparative scatter plots with dual grouping by color and markers."""
213

214
    # pylint: disable=too-many-arguments,too-many-locals
215
    def _do_plot(
2✔
216
        self,
217
        df: _pd.DataFrame,
218
        columns: list[str],
219
        use_legend: bool = True,
220
        size: tuple[float, float] = const.PlotSizes.A4.value,
221
        group_by_color: str | None = None,
222
        group_by_marker: str | None = None,
223
        **kwargs: _tp.Any,
224
    ) -> tuple[_plt.Figure, _plt.Axes]:
225
        self._validate_inputs(columns)
2✔
226
        x_column, y_column = columns
2✔
227

228
        if not group_by_color and not group_by_marker:
2✔
229
            fig, ax = _plt.subplots(
2✔
230
                figsize=size,
231
                layout="constrained",
232
            )
233
            df.plot.scatter(x=x_column, y=y_column, ax=ax, **kwargs)
2✔
234
            return fig, ax
2✔
235
        # See: https://stackoverflow.com/questions/4700614/
236
        # how-to-put-the-legend-outside-the-plot
237
        # This is required to place the legend in a dedicated subplot
238
        fig, (ax, lax) = _plt.subplots(
2✔
239
            layout="constrained",
240
            figsize=size,
241
            ncols=2,
242
            gridspec_kw={"width_ratios": [4, 1]},
243
        )
244
        df_grouped, group_values = self._prepare_grouping(
2✔
245
            df, group_by_color, group_by_marker
246
        )
247
        # TODO: deal with cmap  # pylint: disable=fixme
248
        color_map, marker_map = self._create_style_mappings(*group_values)
2✔
249

250
        self._plot_groups(
2✔
251
            df_grouped,
252
            x_column,
253
            y_column,
254
            color_map,
255
            marker_map,
256
            ax,
257
        )
258

259
        if use_legend:
2✔
260
            self._create_legends(
2✔
261
                lax, color_map, marker_map, group_by_color, group_by_marker
262
            )
263

264
        return fig, ax
2✔
265

266
    def _validate_inputs(
2✔
267
        self,
268
        columns: list[str],
269
    ) -> None:
270
        if len(columns) != 2:
2✔
UNCOV
271
            raise ValueError(
×
272
                "ScatterComparePlotter requires exactly 2 columns (x and y)"
273
            )
274

275
    def _prepare_grouping(
2✔
276
        self,
277
        df: _pd.DataFrame,
278
        color: str | None,
279
        marker: str | None,
280
    ) -> tuple[
281
        _pd.core.groupby.generic.DataFrameGroupBy, tuple[list[str], list[str]]
282
    ]:
283
        group_by = []
2✔
284
        if color:
2✔
285
            group_by.append(color)
2✔
286
        if marker:
2✔
287
            group_by.append(marker)
2✔
288

289
        df_grouped = df.groupby(group_by)
2✔
290

291
        color_values = sorted(df[color].unique()) if color else []
2✔
292
        marker_values = sorted(df[marker].unique()) if marker else []
2✔
293

294
        return df_grouped, (color_values, marker_values)
2✔
295

296
    def _create_style_mappings(
2✔
297
        self, color_values: list[str], marker_values: list[str]
298
    ) -> tuple[dict[str, _tp.Any], dict[str, str]]:
299
        if color_values:
2✔
300
            cmap = _plt.get_cmap(plot_settings.color_map, len(color_values))
2✔
301
            color_map = {val: cmap(i) for i, val in enumerate(color_values)}
2✔
302
        else:
UNCOV
303
            color_map = {}
×
304
        if marker_values:
2✔
305
            marker_map = dict(zip(marker_values, plot_settings.markers))
2✔
306
        else:
UNCOV
307
            marker_map = {}
×
308

309
        return color_map, marker_map
2✔
310

311
    # pylint: disable=too-many-arguments
312
    def _plot_groups(
2✔
313
        self,
314
        df_grouped: _pd.core.groupby.generic.DataFrameGroupBy,
315
        x_column: str,
316
        y_column: str,
317
        color_map: dict[str, _tp.Any],
318
        marker_map: dict[str, str],
319
        ax: _plt.Axes,
320
    ) -> None:
321
        ax.set_xlabel(x_column, fontsize=plot_settings.label_font_size)
2✔
322
        ax.set_ylabel(y_column, fontsize=plot_settings.label_font_size)
2✔
323
        for val, group in df_grouped:
2✔
324
            sorted_group = group.sort_values(x_column)
2✔
325
            x = sorted_group[x_column]
2✔
326
            y = sorted_group[y_column]
2✔
327
            plot_args = {"color": "black"}
2✔
328
            scatter_args = {"marker": "None", "color": "black", "alpha": 0.5}
2✔
329
            if color_map:
2✔
330
                plot_args["color"] = color_map[val[0]]
2✔
331
            if marker_map:
2✔
332
                scatter_args["marker"] = marker_map[val[-1]]
2✔
333
            ax.plot(x, y, **plot_args)  # type: ignore
2✔
334
            ax.scatter(x, y, **scatter_args)  # type: ignore
2✔
335

336
    def _create_legends(
2✔
337
        self,
338
        lax: _plt.Axes,
339
        color_map: dict[str, _tp.Any],
340
        marker_map: dict[str, str],
341
        color_legend_title: str | None,
342
        marker_legend_title: str | None,
343
    ) -> None:
344
        lax.axis("off")
2✔
345

346
        if color_map:
2✔
347
            self._create_color_legend(
2✔
348
                lax, color_map, color_legend_title, bool(marker_map)
349
            )
350
        if marker_map:
2✔
351
            self._create_marker_legend(
2✔
352
                lax, marker_map, marker_legend_title, bool(color_map)
353
            )
354

355
    def _create_color_legend(
2✔
356
        self,
357
        lax: _plt.Axes,
358
        color_map: dict[str, _tp.Any],
359
        color_legend_title: str | None,
360
        has_markers: bool,
361
    ) -> None:
362
        color_handles = [
2✔
363
            _plt.Line2D([], [], color=color, linestyle="-", label=label)
364
            for label, color in color_map.items()
365
        ]
366

367
        legend = lax.legend(
2✔
368
            handles=color_handles,
369
            title=color_legend_title,
370
            bbox_to_anchor=(0, 0, 1, 1),
371
            loc="upper left",
372
            alignment="left",
373
            fontsize=plot_settings.legend_font_size,
374
            borderaxespad=0,
375
        )
376

377
        if has_markers:
2✔
378
            lax.add_artist(legend)
2✔
379

380
    def _create_marker_legend(
2✔
381
        self,
382
        lax: _plt.Axes,
383
        marker_map: dict[str, str],
384
        marker_legend_title: str | None,
385
        has_colors: bool,
386
    ) -> None:
387
        marker_position = 0.7 if has_colors else 1
2✔
388
        marker_handles = [
2✔
389
            _plt.Line2D(
390
                [],
391
                [],
392
                color="black",
393
                marker=marker,
394
                linestyle="None",
395
                label=label,
396
            )
397
            for label, marker in marker_map.items()
398
            if label is not None
399
        ]
400

401
        lax.legend(
2✔
402
            handles=marker_handles,
403
            title=marker_legend_title,
404
            bbox_to_anchor=(0, 0, 1, marker_position),
405
            loc="upper left",
406
            alignment="left",
407
            fontsize=plot_settings.legend_font_size,
408
            borderaxespad=0,
409
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc