• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SPF-OST / pytrnsys_process / 13658028785

04 Mar 2025 04:13PM UTC coverage: 97.522% (+0.002%) from 97.52%
13658028785

push

github

sebastian-swob
fixing doc build

1102 of 1130 relevant lines covered (97.52%)

1.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.18
/pytrnsys_process/plot/plotters.py
1
import typing as _tp
2✔
2
from abc import abstractmethod
2✔
3
from dataclasses import dataclass
2✔
4

5
import matplotlib.pyplot as _plt
2✔
6
import numpy as _np
2✔
7
import pandas as _pd
2✔
8

9
from pytrnsys_process import config as conf
2✔
10

11
# TODO: provide A4 and half A4 plots to test sizes in latex # pylint: disable=fixme
12
# TODO: provide height as input for plot?  # pylint: disable=fixme
13
# TODO: deal with legends (curve names, fonts, colors, linestyles) # pylint: disable=fixme
14
# TODO: clean up old stuff by refactoring # pylint: disable=fixme
15
# TODO: make issue for docstrings of plotting # pylint: disable=fixme
16
# TODO: Add colormap support # pylint: disable=fixme
17

18

19
# TODO find a better place for this to live in # pylint : disable=fixme
20
plot_settings = conf.global_settings.plot
2✔
21
"Settings shared by all plots"
2✔
22

23

24
class ChartBase:
2✔
25
    cmap: str | None = None
2✔
26

27
    def plot(
2✔
28
        self,
29
        df: _pd.DataFrame,
30
        columns: list[str],
31
        **kwargs,
32
    ) -> tuple[_plt.Figure, _plt.Axes]:
33
        fig, ax = self._do_plot(df, columns, **kwargs)
2✔
34
        return fig, ax
2✔
35

36
    @abstractmethod
2✔
37
    def _do_plot(
2✔
38
        self,
39
        df: _pd.DataFrame,
40
        columns: list[str],
41
        use_legend: bool = True,
42
        size: tuple[float, float] = conf.PlotSizes.A4.value,
43
        **kwargs: _tp.Any,
44
    ) -> tuple[_plt.Figure, _plt.Axes]:
45
        """Implement actual plotting logic in subclasses"""
46

47
    def check_for_cmap(self, kwargs, plot_kwargs):
2✔
48
        if "cmap" not in kwargs and "colormap" not in kwargs:
2✔
49
            plot_kwargs["cmap"] = self.cmap
2✔
50
        return plot_kwargs
2✔
51

52
    def get_cmap(self, kwargs) -> str | None:
2✔
53
        if "cmap" not in kwargs and "colormap" not in kwargs:
2✔
54
            return self.cmap
2✔
55

56
        if "cmap" in kwargs:
2✔
57
            return kwargs["cmap"]
2✔
58

59
        if "colormap" in kwargs:
2✔
60
            return kwargs["colormap"]
2✔
61

62
        raise ValueError
×
63

64

65
class StackedBarChart(ChartBase):
2✔
66
    cmap: str | None = "inferno_r"
2✔
67

68
    def _do_plot(
2✔
69
        self,
70
        df: _pd.DataFrame,
71
        columns: list[str],
72
        use_legend: bool = True,
73
        size: tuple[float, float] = conf.PlotSizes.A4.value,
74
        **kwargs: _tp.Any,
75
    ) -> tuple[_plt.Figure, _plt.Axes]:
76
        fig, ax = _plt.subplots(
2✔
77
            figsize=size,
78
            layout="constrained",
79
        )
80
        plot_kwargs = {
2✔
81
            "stacked": True,
82
            "legend": use_legend,
83
            "ax": ax,
84
            **kwargs,
85
        }
86
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
87
        ax = df[columns].plot.bar(**plot_kwargs)
2✔
88
        ax.set_xticklabels(
2✔
89
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
90
        )
91

92
        return fig, ax
2✔
93

94

95
class BarChart(ChartBase):
2✔
96
    cmap = None
2✔
97

98
    def _do_plot(
2✔
99
        self,
100
        df: _pd.DataFrame,
101
        columns: list[str],
102
        use_legend: bool = True,
103
        size: tuple[float, float] = conf.PlotSizes.A4.value,
104
        **kwargs: _tp.Any,
105
    ) -> tuple[_plt.Figure, _plt.Axes]:
106
        # TODO: deal with colors  # pylint: disable=fixme
107
        fig, ax = _plt.subplots(
2✔
108
            figsize=size,
109
            layout="constrained",
110
        )
111
        x = _np.arange(len(df.index))
2✔
112
        width = 0.8 / len(columns)
2✔
113

114
        cmap = self.get_cmap(kwargs)
2✔
115
        if cmap:
2✔
116
            cm = _plt.cm.get_cmap(cmap)
2✔
117
            colors = cm(_np.linspace(0, 1, len(columns)))
2✔
118
        else:
119
            colors = [None] * len(columns)
2✔
120

121
        for i, col in enumerate(columns):
2✔
122
            ax.bar(x + i * width, df[col], width, label=col, color=colors[i])
2✔
123

124
        if use_legend:
2✔
125
            ax.legend()
2✔
126

127
        ax.set_xticks(x + width * (len(columns) - 1) / 2)
2✔
128
        ax.set_xticklabels(
2✔
129
            _pd.to_datetime(df.index).strftime(plot_settings.date_format)
130
        )
131
        ax.tick_params(axis="x", labelrotation=90)
2✔
132
        return fig, ax
2✔
133

134

135
class LinePlot(ChartBase):
2✔
136
    cmap: str | None = None
2✔
137

138
    def _do_plot(
2✔
139
        self,
140
        df: _pd.DataFrame,
141
        columns: list[str],
142
        use_legend: bool = True,
143
        size: tuple[float, float] = conf.PlotSizes.A4.value,
144
        **kwargs: _tp.Any,
145
    ) -> tuple[_plt.Figure, _plt.Axes]:
146
        fig, ax = _plt.subplots(
2✔
147
            figsize=size,
148
            layout="constrained",
149
        )
150
        plot_kwargs = {
2✔
151
            "legend": use_legend,
152
            "ax": ax,
153
            **kwargs,
154
        }
155
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
156

157
        df[columns].plot.line(**plot_kwargs)
2✔
158
        return fig, ax
2✔
159

160

161
@dataclass()
2✔
162
class Histogram(ChartBase):
2✔
163
    bins: int = 50
2✔
164

165
    def _do_plot(
2✔
166
        self,
167
        df: _pd.DataFrame,
168
        columns: list[str],
169
        use_legend: bool = True,
170
        size: tuple[float, float] = conf.PlotSizes.A4.value,
171
        **kwargs: _tp.Any,
172
    ) -> tuple[_plt.Figure, _plt.Axes]:
173
        fig, ax = _plt.subplots(
2✔
174
            figsize=size,
175
            layout="constrained",
176
        )
177
        plot_kwargs = {
2✔
178
            "legend": use_legend,
179
            "ax": ax,
180
            "bins": self.bins,
181
            **kwargs,
182
        }
183
        self.check_for_cmap(kwargs, plot_kwargs)
2✔
184
        df[columns].plot.hist(**plot_kwargs)
2✔
185
        return fig, ax
2✔
186

187

188
class ScatterPlot(ChartBase):
2✔
189
    """Handles comparative scatter plots with dual grouping by color and markers."""
190

191
    cmap = "Paired"  # This is ignored when no categorical groupings are used.
2✔
192

193
    # pylint: disable=too-many-arguments,too-many-locals
194
    def _do_plot(
2✔
195
        self,
196
        df: _pd.DataFrame,
197
        columns: list[str],
198
        use_legend: bool = True,
199
        size: tuple[float, float] = conf.PlotSizes.A4.value,
200
        group_by_color: str | None = None,
201
        group_by_marker: str | None = None,
202
        **kwargs: _tp.Any,
203
    ) -> tuple[_plt.Figure, _plt.Axes]:
204
        self._validate_inputs(columns)
2✔
205
        x_column, y_column = columns
2✔
206

207
        if not group_by_color and not group_by_marker:
2✔
208
            fig, ax = _plt.subplots(
2✔
209
                figsize=size,
210
                layout="constrained",
211
            )
212
            df.plot.scatter(x=x_column, y=y_column, ax=ax, **kwargs)
2✔
213
            return fig, ax
2✔
214
        # See: https://stackoverflow.com/questions/4700614/
215
        # how-to-put-the-legend-outside-the-plot
216
        # This is required to place the legend in a dedicated subplot
217
        fig, (ax, lax) = _plt.subplots(
2✔
218
            layout="constrained",
219
            figsize=size,
220
            ncols=2,
221
            gridspec_kw={"width_ratios": [4, 1]},
222
        )
223
        df_grouped, group_values = self._prepare_grouping(
2✔
224
            df, group_by_color, group_by_marker
225
        )
226
        cmap = self.get_cmap(kwargs)
2✔
227
        color_map, marker_map = self._create_style_mappings(
2✔
228
            *group_values, cmap=cmap
229
        )
230

231
        self._plot_groups(
2✔
232
            df_grouped,
233
            x_column,
234
            y_column,
235
            color_map,
236
            marker_map,
237
            ax,
238
        )
239

240
        if use_legend:
2✔
241
            self._create_legends(
2✔
242
                lax, color_map, marker_map, group_by_color, group_by_marker
243
            )
244

245
        return fig, ax
2✔
246

247
    def _validate_inputs(
2✔
248
        self,
249
        columns: list[str],
250
    ) -> None:
251
        if len(columns) != 2:
2✔
252
            raise ValueError(
×
253
                "ScatterComparePlotter requires exactly 2 columns (x and y)"
254
            )
255

256
    def _prepare_grouping(
2✔
257
        self,
258
        df: _pd.DataFrame,
259
        color: str | None,
260
        marker: str | None,
261
    ) -> tuple[
262
        _pd.core.groupby.generic.DataFrameGroupBy, tuple[list[str], list[str]]
263
    ]:
264
        group_by = []
2✔
265
        if color:
2✔
266
            group_by.append(color)
2✔
267
        if marker:
2✔
268
            group_by.append(marker)
2✔
269

270
        df_grouped = df.groupby(group_by)
2✔
271

272
        color_values = sorted(df[color].unique()) if color else []
2✔
273
        marker_values = sorted(df[marker].unique()) if marker else []
2✔
274

275
        return df_grouped, (color_values, marker_values)
2✔
276

277
    def _create_style_mappings(
2✔
278
        self,
279
        color_values: list[str],
280
        marker_values: list[str],
281
        cmap: str | None,
282
    ) -> tuple[dict[str, _tp.Any], dict[str, str]]:
283
        if color_values:
2✔
284
            cm = _plt.get_cmap(cmap, len(color_values))
2✔
285
            color_map = {val: cm(i) for i, val in enumerate(color_values)}
2✔
286
        else:
287
            color_map = {}
×
288
        if marker_values:
2✔
289
            marker_map = dict(zip(marker_values, plot_settings.markers))
2✔
290
        else:
291
            marker_map = {}
×
292

293
        return color_map, marker_map
2✔
294

295
    # pylint: disable=too-many-arguments
296
    def _plot_groups(
2✔
297
        self,
298
        df_grouped: _pd.core.groupby.generic.DataFrameGroupBy,
299
        x_column: str,
300
        y_column: str,
301
        color_map: dict[str, _tp.Any],
302
        marker_map: dict[str, str],
303
        ax: _plt.Axes,
304
    ) -> None:
305
        ax.set_xlabel(x_column, fontsize=plot_settings.label_font_size)
2✔
306
        ax.set_ylabel(y_column, fontsize=plot_settings.label_font_size)
2✔
307
        for val, group in df_grouped:
2✔
308
            sorted_group = group.sort_values(x_column)
2✔
309
            x = sorted_group[x_column]
2✔
310
            y = sorted_group[y_column]
2✔
311
            plot_args = {"color": "black"}
2✔
312
            scatter_args = {"marker": "None", "color": "black", "alpha": 0.5}
2✔
313
            if color_map:
2✔
314
                plot_args["color"] = color_map[val[0]]
2✔
315
            if marker_map:
2✔
316
                scatter_args["marker"] = marker_map[val[-1]]
2✔
317
            ax.plot(x, y, **plot_args)  # type: ignore
2✔
318
            ax.scatter(x, y, **scatter_args)  # type: ignore
2✔
319

320
    def _create_legends(
2✔
321
        self,
322
        lax: _plt.Axes,
323
        color_map: dict[str, _tp.Any],
324
        marker_map: dict[str, str],
325
        color_legend_title: str | None,
326
        marker_legend_title: str | None,
327
    ) -> None:
328
        lax.axis("off")
2✔
329

330
        if color_map:
2✔
331
            self._create_color_legend(
2✔
332
                lax, color_map, color_legend_title, bool(marker_map)
333
            )
334
        if marker_map:
2✔
335
            self._create_marker_legend(
2✔
336
                lax, marker_map, marker_legend_title, bool(color_map)
337
            )
338

339
    def _create_color_legend(
2✔
340
        self,
341
        lax: _plt.Axes,
342
        color_map: dict[str, _tp.Any],
343
        color_legend_title: str | None,
344
        has_markers: bool,
345
    ) -> None:
346
        color_handles = [
2✔
347
            _plt.Line2D([], [], color=color, linestyle="-", label=label)
348
            for label, color in color_map.items()
349
        ]
350

351
        legend = lax.legend(
2✔
352
            handles=color_handles,
353
            title=color_legend_title,
354
            bbox_to_anchor=(0, 0, 1, 1),
355
            loc="upper left",
356
            alignment="left",
357
            fontsize=plot_settings.legend_font_size,
358
            borderaxespad=0,
359
        )
360

361
        if has_markers:
2✔
362
            lax.add_artist(legend)
2✔
363

364
    def _create_marker_legend(
2✔
365
        self,
366
        lax: _plt.Axes,
367
        marker_map: dict[str, str],
368
        marker_legend_title: str | None,
369
        has_colors: bool,
370
    ) -> None:
371
        marker_position = 0.7 if has_colors else 1
2✔
372
        marker_handles = [
2✔
373
            _plt.Line2D(
374
                [],
375
                [],
376
                color="black",
377
                marker=marker,
378
                linestyle="None",
379
                label=label,
380
            )
381
            for label, marker in marker_map.items()
382
            if label is not None
383
        ]
384

385
        lax.legend(
2✔
386
            handles=marker_handles,
387
            title=marker_legend_title,
388
            bbox_to_anchor=(0, 0, 1, marker_position),
389
            loc="upper left",
390
            alignment="left",
391
            fontsize=plot_settings.legend_font_size,
392
            borderaxespad=0,
393
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc