• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SPF-OST / pytrnsys_process / 13394719474

18 Feb 2025 03:55PM UTC coverage: 96.919% (-1.0%) from 97.93%
13394719474

push

github

sebastian-swob
increased positional arguments to 7

1164 of 1201 relevant lines covered (96.92%)

1.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.64
/pytrnsys_process/plotting/plotters.py
1
import typing as _tp
2✔
2
from abc import abstractmethod
2✔
3
from dataclasses import dataclass
2✔
4

5
import matplotlib.pyplot as _plt
2✔
6
import numpy as _np
2✔
7
import pandas as _pd
2✔
8

9
import pytrnsys_process.constants as const
2✔
10
import pytrnsys_process.headers as h
2✔
11
from pytrnsys_process import settings as sett
2✔
12

13
# TODO: provide A4 and half A4 plots to test sizes in latex # pylint: disable=fixme
14
# TODO: provide height as input for plot?  # pylint: disable=fixme
15
# TODO: deal with legends (curve names, fonts, colors, linestyles) # pylint: disable=fixme
16
# TODO: clean up old stuff by refactoring # pylint: disable=fixme
17
# TODO: make issue for docstrings of plotting # pylint: disable=fixme
18
# TODO: Add colormap support # pylint: disable=fixme
19

20

21
# TODO find a better place for this to live in # pylint : disable=fixme
22
plot_settings = sett.settings.plot
2✔
23

24

25
@dataclass
2✔
26
class ChartBase(h.HeaderValidationMixin):
2✔
27

28
    def plot(
2✔
29
        self,
30
        df: _pd.DataFrame,
31
        columns: list[str],
32
        **kwargs,
33
    ) -> tuple[_plt.Figure, _plt.Axes]:
34
        fig, ax = self._do_plot(df, columns, **kwargs)
2✔
35
        return fig, ax
2✔
36

37
    # TODO: Test validation # pylint: disable=fixme
38
    def plot_with_column_validation(
2✔
39
        self,
40
        df: _pd.DataFrame,
41
        columns: list[str],
42
        headers: h.Headers,
43
        **kwargs,
44
    ) -> tuple[_plt.Figure, _plt.Axes]:
45
        """Base plot method with header validation.
46

47
        Args:
48
            df: DataFrame containing the data to plot
49
            columns: List of column names to plot
50
            headers: Headers instance for validation
51
            **kwargs: Additional plotting arguments
52

53
        Raises:
54
            ValueError: If any columns are missing from the headers index
55
        """
56
        # TODO: Might live somewhere else in the future # pylint: disable=fixme
57
        is_valid, missing = self.validate_headers(headers, columns)
2✔
58
        if not is_valid:
2✔
59
            missing_details = []
2✔
60
            for col in missing:
2✔
61
                missing_details.append(col)
2✔
62
            raise ValueError(
2✔
63
                "The following columns are not available in the headers index:\n"
64
                + "\n".join(missing_details)
65
            )
66

67
        return self._do_plot(df, columns, **kwargs)
2✔
68

69
    @abstractmethod
2✔
70
    def _do_plot(
2✔
71
        self,
72
        df: _pd.DataFrame,
73
        columns: list[str],
74
        use_legend: bool = True,
75
            size: tuple[float, float] = const.PlotSizes.A4.value,
76
        **kwargs: _tp.Any,
77
    ) -> tuple[_plt.Figure, _plt.Axes]:
78
        """Implement actual plotting logic in subclasses"""
79

80

81
class StackedBarChart(ChartBase):
2✔
82

83
    def _do_plot(
2✔
84
        self,
85
        df: _pd.DataFrame,
86
        columns: list[str],
87
        use_legend: bool = True,
88
            size: tuple[float, float] = const.PlotSizes.A4.value,
89
        **kwargs: _tp.Any,
90
    ) -> tuple[_plt.Figure, _plt.Axes]:
91
        fig, ax = _plt.subplots(
2✔
92
            figsize=size,
93
            layout="constrained",
94
        )
95
        plot_kwargs = {
2✔
96
            "stacked": True,
97
            "colormap": plot_settings.color_map,
98
            "legend": use_legend,
99
            "ax": ax,
100
            **kwargs,
101
        }
102
        ax = df[columns].plot.bar(**plot_kwargs)
2✔
103
        ax.set_xticklabels(
2✔
104
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
105
        )
106

107
        return fig, ax
2✔
108

109

110
class BarChart(ChartBase):
2✔
111

112
    def _do_plot(
2✔
113
        self,
114
        df: _pd.DataFrame,
115
        columns: list[str],
116
        use_legend: bool = True,
117
            size: tuple[float, float] = const.PlotSizes.A4.value,
118
        **kwargs: _tp.Any,
119
    ) -> tuple[_plt.Figure, _plt.Axes]:
120
        fig, ax = _plt.subplots(
2✔
121
            figsize=size,
122
            layout="constrained",
123
        )
124
        x = _np.arange(len(df.index))
2✔
125
        width = 0.8 / len(columns)
2✔
126

127
        for i, col in enumerate(columns):
2✔
128
            ax.bar(x + i * width, df[col], width, label=col)
2✔
129

130
        if use_legend:
2✔
131
            ax.legend()
2✔
132

133
        ax.set_xticks(x + width * (len(columns) - 1) / 2)
2✔
134
        ax.set_xticklabels(
2✔
135
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
136
        )
137
        ax.tick_params(axis="x", labelrotation=90)
2✔
138
        return fig, ax
2✔
139

140

141
class LinePlot(ChartBase):
2✔
142

143
    def _do_plot(
2✔
144
        self,
145
        df: _pd.DataFrame,
146
        columns: list[str],
147
        use_legend: bool = True,
148
            size: tuple[float, float] = const.PlotSizes.A4.value,
149
        **kwargs: _tp.Any,
150
    ) -> tuple[_plt.Figure, _plt.Axes]:
151
        fig, ax = _plt.subplots(
2✔
152
            figsize=size,
153
            layout="constrained",
154
        )
155
        plot_kwargs = {
2✔
156
            "colormap": plot_settings.color_map,
157
            "legend": use_legend,
158
            "ax": ax,
159
            **kwargs,
160
        }
161
        df[columns].plot.line(**plot_kwargs)
2✔
162
        return fig, ax
2✔
163

164

165
@dataclass
2✔
166
class Histogram(ChartBase):
2✔
167
    bins: int = 50
2✔
168

169
    def _do_plot(
2✔
170
        self,
171
        df: _pd.DataFrame,
172
        columns: list[str],
173
        use_legend: bool = True,
174
            size: tuple[float, float] = const.PlotSizes.A4.value,
175
        **kwargs: _tp.Any,
176
    ) -> tuple[_plt.Figure, _plt.Axes]:
177
        fig, ax = _plt.subplots(
2✔
178
            figsize=size,
179
            layout="constrained",
180
        )
181
        plot_kwargs = {
2✔
182
            "colormap": plot_settings.color_map,
183
            "legend": use_legend,
184
            "ax": ax,
185
            "bins": self.bins,
186
            **kwargs,
187
        }
188
        df[columns].plot.hist(**plot_kwargs)
2✔
189
        return fig, ax
2✔
190

191

192
@dataclass
2✔
193
class ScatterPlot(ChartBase):
2✔
194
    """Handles comparative scatter plots with dual grouping by color and markers."""
195

196
    # pylint: disable=too-many-arguments,too-many-locals
197
    def _do_plot(
2✔
198
            self,
199
            df: _pd.DataFrame,
200
            columns: list[str],
201
            use_legend: bool = True,
202
            size: tuple[float, float] = const.PlotSizes.A4.value,
203
            group_by_color: str | None = None,
204
            group_by_marker: str | None = None,
205
            **kwargs: _tp.Any,
206
    ) -> tuple[_plt.Figure, _plt.Axes]:
207
        self._validate_inputs(columns)
2✔
208
        x_column, y_column = columns
2✔
209

210
        if not group_by_color and not group_by_marker:
2✔
211
            fig, ax = _plt.subplots(
2✔
212
                figsize=size,
213
                layout="constrained",
214
            )
215
            df.plot.scatter(x=x_column, y=y_column, ax=ax, **kwargs)
2✔
216
            return fig, ax
2✔
217
        # See: https://stackoverflow.com/questions/4700614/
218
        # how-to-put-the-legend-outside-the-plot
219
        # This is required to place the legend in a dedicated subplot
220
        fig, (ax, lax) = _plt.subplots(
2✔
221
            layout="constrained",
222
            figsize=size,
223
            ncols=2,
224
            gridspec_kw={"width_ratios": [4, 1]},
225
        )
226
        df_grouped, group_values = self._prepare_grouping(
2✔
227
            df, group_by_color, group_by_marker
228
        )
229
        color_map, marker_map = self._create_style_mappings(*group_values)
2✔
230

231
        self._plot_groups(
2✔
232
            df_grouped,
233
            x_column,
234
            y_column,
235
            color_map,
236
            marker_map,
237
            ax,
238
        )
239

240
        if use_legend:
2✔
241
            self._create_legends(
2✔
242
                lax, color_map, marker_map, group_by_color, group_by_marker
243
            )
244

245
        return fig, ax
2✔
246

247
    def _validate_inputs(
2✔
248
            self,
249
            columns: list[str],
250
    ) -> None:
251
        if len(columns) != 2:
2✔
252
            raise ValueError(
×
253
                "ScatterComparePlotter requires exactly 2 columns (x and y)"
254
            )
255

256
    def _prepare_grouping(
2✔
257
            self,
258
            df: _pd.DataFrame,
259
            color: str | None,
260
            marker: str | None,
261
    ) -> tuple[
262
        _pd.core.groupby.generic.DataFrameGroupBy, tuple[list[str], list[str]]
263
    ]:
264
        group_by = []
2✔
265
        if color:
2✔
266
            group_by.append(color)
2✔
267
        if marker:
2✔
268
            group_by.append(marker)
2✔
269

270
        df_grouped = df.groupby(group_by)
2✔
271

272
        color_values = sorted(df[color].unique()) if color else []
2✔
273
        marker_values = sorted(df[marker].unique()) if marker else []
2✔
274

275
        return df_grouped, (color_values, marker_values)
2✔
276

277
    def _create_style_mappings(
2✔
278
            self, color_values: list[str], marker_values: list[str]
279
    ) -> tuple[dict[str, _tp.Any], dict[str, str]]:
280
        if color_values:
2✔
281
            cmap = _plt.get_cmap(plot_settings.color_map, len(color_values))
2✔
282
            color_map = {val: cmap(i) for i, val in enumerate(color_values)}
2✔
283
        else:
284
            color_map = {}
×
285
        if marker_values:
2✔
286
            marker_map = dict(zip(marker_values, plot_settings.markers))
2✔
287
        else:
288
            marker_map = {}
×
289

290
        return color_map, marker_map
2✔
291

292
    # pylint: disable=too-many-arguments
293
    def _plot_groups(
2✔
294
            self,
295
            df_grouped: _pd.core.groupby.generic.DataFrameGroupBy,
296
            x_column: str,
297
            y_column: str,
298
            color_map: dict[str, _tp.Any],
299
            marker_map: dict[str, str],
300
            ax: _plt.Axes,
301
    ) -> None:
302
        ax.set_xlabel(x_column, fontsize=plot_settings.label_font_size)
2✔
303
        ax.set_ylabel(y_column, fontsize=plot_settings.label_font_size)
2✔
304
        for val, group in df_grouped:
2✔
305
            sorted_group = group.sort_values(x_column)
2✔
306
            x = sorted_group[x_column]
2✔
307
            y = sorted_group[y_column]
2✔
308
            plot_args = {"color": "black"}
2✔
309
            scatter_args = {"marker": "None", "color": "black", "alpha": 0.5}
2✔
310
            if color_map:
2✔
311
                plot_args["color"] = color_map[val[0]]
2✔
312
            if marker_map:
2✔
313
                scatter_args["marker"] = marker_map[val[-1]]
2✔
314
            ax.plot(x, y, **plot_args)  # type: ignore
2✔
315
            ax.scatter(x, y, **scatter_args)  # type: ignore
2✔
316

317
    def _create_legends(
2✔
318
            self,
319
            lax: _plt.Axes,
320
            color_map: dict[str, _tp.Any],
321
            marker_map: dict[str, str],
322
            color_legend_title: str | None,
323
            marker_legend_title: str | None,
324
    ) -> None:
325
        lax.axis("off")
2✔
326

327
        if color_map:
2✔
328
            self._create_color_legend(
2✔
329
                lax, color_map, color_legend_title, bool(marker_map)
330
            )
331
        if marker_map:
2✔
332
            self._create_marker_legend(
2✔
333
                lax, marker_map, marker_legend_title, bool(color_map)
334
            )
335

336
    def _create_color_legend(
2✔
337
            self,
338
            lax: _plt.Axes,
339
            color_map: dict[str, _tp.Any],
340
            color_legend_title: str | None,
341
            has_markers: bool,
342
    ) -> None:
343
        color_handles = [
2✔
344
            _plt.Line2D([], [], color=color, linestyle="-", label=label)
345
            for label, color in color_map.items()
346
        ]
347

348
        legend = lax.legend(
2✔
349
            handles=color_handles,
350
            title=color_legend_title,
351
            bbox_to_anchor=(0, 0, 1, 1),
352
            loc="upper left",
353
            alignment="left",
354
            fontsize=plot_settings.legend_font_size,
355
            borderaxespad=0,
356
        )
357

358
        if has_markers:
2✔
359
            lax.add_artist(legend)
2✔
360

361
    def _create_marker_legend(
2✔
362
            self,
363
            lax: _plt.Axes,
364
            marker_map: dict[str, str],
365
            marker_legend_title: str | None,
366
            has_colors: bool,
367
    ) -> None:
368
        marker_position = 0.7 if has_colors else 1
2✔
369
        marker_handles = [
2✔
370
            _plt.Line2D(
371
                [],
372
                [],
373
                color="black",
374
                marker=marker,
375
                linestyle="None",
376
                label=label,
377
            )
378
            for label, marker in marker_map.items()
379
            if label is not None
380
        ]
381

382
        lax.legend(
2✔
383
            handles=marker_handles,
384
            title=marker_legend_title,
385
            bbox_to_anchor=(0, 0, 1, marker_position),
386
            loc="upper left",
387
            alignment="left",
388
            fontsize=plot_settings.legend_font_size,
389
            borderaxespad=0,
390
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc