• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 8614163418

09 Apr 2024 10:24AM UTC coverage: 81.03% (+0.2%) from 80.813%
8614163418

Pull #343

github

web-flow
Merge 1fd684f5b into 93062a244
Pull Request #343: Apply SAITS embedding strategy to new added models

79 of 80 new or added lines in 10 files covered. (98.75%)

2 existing lines in 1 file now uncovered.

6847 of 8450 relevant lines covered (81.03%)

4.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pypots/utils/visual/clustering.py
1
"""
2
Utilities for clustering visualization.
3
"""
4

5
# Created by Bhargav Vemuri <vemuri.bhargav@gmail.com> and Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8
from typing import Dict
×
9

10
import matplotlib.lines as mlines
×
11
import matplotlib.pyplot as plt
×
12
import numpy as np
×
13
import pandas as pd
×
14
import scipy.stats as st
×
15

16

17
def get_cluster_members(
×
18
    test_data: np.ndarray, class_predictions: np.ndarray
19
) -> Dict[int, np.ndarray]:
20
    """
21
    Subset time series array using predicted cluster membership.
22

23
    Parameters
24
    __________
25
    test_data :
26
        Time series array that clusterer was run on.
27

28
    class_predictions:
29
        Clustering results returned by a clusterer.
30

31
    Returns
32
    _______
33
    cluster_members :
34
        Dictionary of test data time series organized by predicted cluster membership.
35
    """
36
    cluster_members = {}
×
37
    for i in np.unique(class_predictions):
×
38
        cluster_members[i] = test_data[class_predictions == i]
×
39
    return cluster_members
×
40

41

42
def clusters_for_plotting(
×
43
    cluster_members: Dict[int, np.ndarray],
44
) -> Dict[int, dict]:
45
    """
46
    Organize clustered arrays into format ready for plotting.
47

48
    Parameters
49
    __________
50
    cluster_members :
51
        Output from get_cluster_members function.
52

53
    Returns
54
    _______
55
    dict_to_plot :
56
        Test data organized by predicted cluster and time series variable.
57

58
        structure = { 'cluster0': {'var0': [['ts0'],['ts1'],...,['tsN']],
59
                                   'var1': [['ts0'],['ts1'],...,['tsN']],
60
                                   ...
61
                                   'varY': [['ts0'],['ts1'],...,['tsN']]
62
                                  },
63
                      ...,
64
                      'clusterX': {'var0': [['ts0'],['ts1'],...,['tsM']],
65
                                   'var1': [['ts0'],['ts1'],...,['tsM']],
66
                                   ...
67
                                   'varY': [['ts0'],['ts1'],...,['tsM']]
68
                                  }
69
                    }
70

71
                    where
72
                        clusterX is number of clusters predicted (n_clusters in model)
73
                        varY is number of time series variables recorded
74
                        tsN is number of members in cluster0, tsM is number of members in clusterX, etc.
75

76
    """
77
    dict_to_plot = {}
×
78

79
    for i in cluster_members:  # i iterates clusters
×
80
        dict_to_plot[i] = {}  # one dict per cluster
×
81
        for j in cluster_members[i]:  # j iterates members of each cluster
×
82
            temp = pd.DataFrame(j).to_dict(
×
83
                orient="list"
84
            )  # dict of member's time series as lists (one per var)
85
            for key in temp:  # key is a time series var
×
86
                if key not in dict_to_plot[i]:
×
87
                    dict_to_plot[i][key] = [
×
88
                        temp[key]
89
                    ]  # create entry in cluster dict for each time series var
90
                else:
91
                    dict_to_plot[i][key].append(
×
92
                        temp[key]
93
                    )  # add cluster member's time series by var key
94
    return dict_to_plot
×
95

96

97
def plot_clusters(dict_to_plot: Dict[int, dict]) -> None:
×
98
    """
99
    Generate line plots of all cluster members per time series variable per cluster.
100

101
    Parameters
102
    __________
103
    dict_to_plot :
104
        Output from clusters_for_plotting function.
105
    """
106
    for i in dict_to_plot:  # iterate clusters
×
107
        for j in dict_to_plot[i]:  # iterate time series vars
×
108
            y = dict_to_plot[i][j]
×
109

110
            plt.figure(figsize=(16, 8))
×
111
            for y_values in y:  # iterate members
×
112
                x = np.arange(len(y_values))
×
113

114
                series1 = np.array(y_values).astype(np.double)
×
115
                s1mask = np.isfinite(
×
116
                    series1
117
                )  # connects all points (if >1) in line plot even if some intermediate are missing
118

119
                plt.plot(x[s1mask], series1[s1mask], ".-")
×
120

121
            plt.title("Cluster %i" % i)
×
122
            plt.ylabel("Var %i" % j)
×
123
            plt.xticks(x)
×
124
            plt.text(
×
125
                0.93,
126
                0.93,
127
                "n = %i" % len(y),
128
                horizontalalignment="center",
129
                verticalalignment="center",
130
                transform=plt.gca().transAxes,
131
                fontsize=15,
132
            )
133
            plt.show()
×
134

135

136
def get_cluster_means(dict_to_plot: Dict[int, dict]) -> Dict[int, dict]:
×
137
    """
138
    Get time series variables' mean values and 95% confidence intervals at each time point per cluster.
139

140
    Parameters
141
    __________
142
    dict_to_plot :
143
        Output from clusters_for_plotting function.
144

145
    Returns
146
    _______
147
    cluster_means:
148
        Means and CI lower and upper bounds for each time series variable per cluster.
149

150
        structure = { 'var0': {'cluster0': {'mean': [tp0,tp1,...,tpN],
151
                                            'CI_low': [tp0,tp1,...tpN],
152
                                            'CI_high': [tp0,tp1,...tpN],
153
                                            'n': n0
154
                                            },
155
                               ...
156
                               'clusterX': {'mean': [tp0,tp1,...,tpN],
157
                                            'CI_low': [tp0,tp1,...tpN],
158
                                            'CI_high': [tp0,tp1,...tpN],
159
                                            'n': nX
160
                                            }
161
                              },
162
                      ...,
163
                      'varY': {'cluster0': {'mean': [tp0,tp1,...,tpN],
164
                                            'CI_low': [tp0,tp1,...tpN],
165
                                            'CI_high': [tp0,tp1,...tpN],
166
                                            'n': n0
167
                                            },
168
                               ...
169
                               'clusterX': {'mean': [tp0,tp1,...,tpN],
170
                                            'CI_low': [tp0,tp1,...tpN],
171
                                            'CI_high': [tp0,tp1,...tpN],
172
                                            'n': nX
173
                                            }
174
                              }
175
                    }
176

177
                    where
178
                        varY is number of time series variables recorded
179
                        clusterX is number of clusters predicted (n_clusters in model)
180
                        tpN is number of time points in each time series
181
                        n0 is the size of cluster0, nX is the size of clusterX, etc.
182

183

184
    """
185
    cluster_means = {}
×
186

187
    for i in dict_to_plot:  # iterate clusters
×
188
        for j in dict_to_plot[i]:  # iterate labs
×
189
            if j not in cluster_means:
×
190
                cluster_means[j] = {}
×
191

192
            cluster_means[j][
×
193
                i
194
            ] = (
195
                {}
196
            )  # clusters nested within vars (reverse structure to clusters_for_plotting)
197

198
            cluster_means[j][i]["mean"] = list(
×
199
                pd.DataFrame(dict_to_plot[i][j]).mean(axis=0, skipna=True)
200
            )  # cluster mean array of time series var
201
            # CI calculation, from https://stackoverflow.com/a/34474255
202
            (
×
203
                cluster_means[j][i]["CI_low"],
204
                cluster_means[j][i]["CI_high"],
205
            ) = st.t.interval(
206
                0.95,
207
                len(dict_to_plot[i][j]) - 1,  # degrees of freedom
208
                loc=cluster_means[j][i]["mean"],
209
                scale=pd.DataFrame(dict_to_plot[i][j]).sem(axis=0, skipna=True),
210
            )
211
            cluster_means[j][i]["n"] = len(
×
212
                dict_to_plot[i][j]
213
            )  # save cluster size for downstream tasks/plotting
214

215
    return cluster_means
×
216

217

218
def plot_cluster_means(cluster_means: Dict[int, dict]) -> None:
×
219
    """
220
    Generate line plots of cluster means and 95% confidence intervals for each time series variable.
221

222
    Parameters
223
    __________
224
    cluster_means :
225
        Output from get_cluster_means function.
226
    """
227
    colors = plt.rcParams["axes.prop_cycle"].by_key()[
×
228
        "color"
229
    ]  # to keep cluster colors consistent
230

231
    for i in cluster_means:  # iterate time series vars
×
232
        y = cluster_means[i]
×
233

234
        plt.figure(figsize=(16, 8))
×
235

236
        for y_values in y:  # iterate clusters
×
237
            for val in y[y_values]:  # iterate calculation (mean, CI_low, CI_high)
×
238
                if val == "mean":
×
239
                    x = np.arange(len(y[y_values][val]))
×
240
                    series1 = np.array(y[y_values][val]).astype(np.double)
×
241
                    s1mask = np.isfinite(series1)
×
242
                    plt.plot(
×
243
                        x[s1mask],
244
                        series1[s1mask],
245
                        ".-",  # mean as solid line
246
                        color=colors[y_values],
247
                        label="Cluster %i mean (n = %d)"
248
                        % (
249
                            y_values,
250
                            y[y_values]["n"],
251
                        ),  # legend will include cluster size
252
                    )
253

254
                if val in ("CI_low", "CI_high"):
×
255
                    x = np.arange(len(y[y_values][val]))
×
256
                    series1 = np.array(y[y_values][val]).astype(np.double)
×
257
                    s1mask = np.isfinite(series1)
×
258
                    plt.plot(
×
259
                        x[s1mask],
260
                        series1[s1mask],
261
                        "--",  # CI bounds as dashed lines
262
                        color=colors[y_values],
263
                    )
264

265
        plt.title("Var %d" % i)
×
266
        plt.xlabel("Timepoint")
×
267
        plt.xticks(x)
×
268

269
        # add dashed line label to legend
270
        line_dashed = mlines.Line2D(
×
271
            [], [], color="gray", linestyle="--", linewidth=1.5, label="95% CI"
272
        )
273
        handles, labels = plt.legend().axes.get_legend_handles_labels()
×
274
        handles.append(line_dashed)
×
275
        new_lgd = plt.legend(handles=handles)
×
276
        plt.gca().add_artist(new_lgd)
×
277

278
        plt.show()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc