• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WenjieDu / PyPOTS / 12015168636

25 Nov 2024 05:10PM UTC coverage: 84.286% (-0.02%) from 84.307%
12015168636

push

github

web-flow
Update docs for CSAI (#549)

12047 of 14293 relevant lines covered (84.29%)

4.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pypots/utils/visual/clustering.py
1
"""
2
Utilities for clustering visualization.
3
"""
4

5
# Created by Bhargav Vemuri <vemuri.bhargav@gmail.com> and Wenjie Du <wenjay.du@gmail.com>
6
# License: BSD-3-Clause
7

8
from typing import Dict
×
9

10
import matplotlib.lines as mlines
×
11
import matplotlib.pyplot as plt
×
12
import numpy as np
×
13
import pandas as pd
×
14
import scipy.stats as st
×
15

16

17
def get_cluster_members(test_data: np.ndarray, class_predictions: np.ndarray) -> Dict[int, np.ndarray]:
×
18
    """
19
    Subset time series array using predicted cluster membership.
20

21
    Parameters
22
    __________
23
    test_data :
24
        Time series array that clusterer was run on.
25

26
    class_predictions:
27
        Clustering results returned by a clusterer.
28

29
    Returns
30
    _______
31
    cluster_members :
32
        Dictionary of test data time series organized by predicted cluster membership.
33
    """
34
    cluster_members = {}
×
35
    for i in np.unique(class_predictions):
×
36
        cluster_members[i] = test_data[class_predictions == i]
×
37
    return cluster_members
×
38

39

40
def clusters_for_plotting(
×
41
    cluster_members: Dict[int, np.ndarray],
42
) -> Dict[int, dict]:
43
    """
44
    Organize clustered arrays into format ready for plotting.
45

46
    Parameters
47
    __________
48
    cluster_members :
49
        Output from get_cluster_members function.
50

51
    Returns
52
    _______
53
    dict_to_plot :
54
        Test data organized by predicted cluster and time series variable.
55

56
        structure = { 'cluster0': {'var0': [['ts0'],['ts1'],...,['tsN']],
57
                                   'var1': [['ts0'],['ts1'],...,['tsN']],
58
                                   ...
59
                                   'varY': [['ts0'],['ts1'],...,['tsN']]
60
                                  },
61
                      ...,
62
                      'clusterX': {'var0': [['ts0'],['ts1'],...,['tsM']],
63
                                   'var1': [['ts0'],['ts1'],...,['tsM']],
64
                                   ...
65
                                   'varY': [['ts0'],['ts1'],...,['tsM']]
66
                                  }
67
                    }
68

69
                    where
70
                        clusterX is number of clusters predicted (n_clusters in model)
71
                        varY is number of time series variables recorded
72
                        tsN is number of members in cluster0, tsM is number of members in clusterX, etc.
73

74
    """
75
    dict_to_plot = {}
×
76

77
    for i in cluster_members:  # i iterates clusters
×
78
        dict_to_plot[i] = {}  # one dict per cluster
×
79
        for j in cluster_members[i]:  # j iterates members of each cluster
×
80
            temp = pd.DataFrame(j).to_dict(orient="list")  # dict of member's time series as lists (one per var)
×
81
            for key in temp:  # key is a time series var
×
82
                if key not in dict_to_plot[i]:
×
83
                    dict_to_plot[i][key] = [temp[key]]  # create entry in cluster dict for each time series var
×
84
                else:
85
                    dict_to_plot[i][key].append(temp[key])  # add cluster member's time series by var key
×
86
    return dict_to_plot
×
87

88

89
def plot_clusters(dict_to_plot: Dict[int, dict]) -> None:
×
90
    """
91
    Generate line plots of all cluster members per time series variable per cluster.
92

93
    Parameters
94
    __________
95
    dict_to_plot :
96
        Output from clusters_for_plotting function.
97
    """
98
    for i in dict_to_plot:  # iterate clusters
×
99
        for j in dict_to_plot[i]:  # iterate time series vars
×
100
            y = dict_to_plot[i][j]
×
101

102
            plt.figure(figsize=(16, 8))
×
103
            for y_values in y:  # iterate members
×
104
                x = np.arange(len(y_values))
×
105

106
                series1 = np.array(y_values).astype(np.double)
×
107
                s1mask = np.isfinite(
×
108
                    series1
109
                )  # connects all points (if >1) in line plot even if some intermediate are missing
110

111
                plt.plot(x[s1mask], series1[s1mask], ".-")
×
112

113
            plt.title("Cluster %i" % i)
×
114
            plt.ylabel("Var %i" % j)
×
115
            plt.xticks(x)
×
116
            plt.text(
×
117
                0.93,
118
                0.93,
119
                "n = %i" % len(y),
120
                horizontalalignment="center",
121
                verticalalignment="center",
122
                transform=plt.gca().transAxes,
123
                fontsize=15,
124
            )
125
            plt.show()
×
126

127

128
def get_cluster_means(dict_to_plot: Dict[int, dict]) -> Dict[int, dict]:
×
129
    """
130
    Get time series variables' mean values and 95% confidence intervals at each time point per cluster.
131

132
    Parameters
133
    __________
134
    dict_to_plot :
135
        Output from clusters_for_plotting function.
136

137
    Returns
138
    _______
139
    cluster_means:
140
        Means and CI lower and upper bounds for each time series variable per cluster.
141

142
        structure = { 'var0': {'cluster0': {'mean': [tp0,tp1,...,tpN],
143
                                            'CI_low': [tp0,tp1,...tpN],
144
                                            'CI_high': [tp0,tp1,...tpN],
145
                                            'n': n0
146
                                            },
147
                               ...
148
                               'clusterX': {'mean': [tp0,tp1,...,tpN],
149
                                            'CI_low': [tp0,tp1,...tpN],
150
                                            'CI_high': [tp0,tp1,...tpN],
151
                                            'n': nX
152
                                            }
153
                              },
154
                      ...,
155
                      'varY': {'cluster0': {'mean': [tp0,tp1,...,tpN],
156
                                            'CI_low': [tp0,tp1,...tpN],
157
                                            'CI_high': [tp0,tp1,...tpN],
158
                                            'n': n0
159
                                            },
160
                               ...
161
                               'clusterX': {'mean': [tp0,tp1,...,tpN],
162
                                            'CI_low': [tp0,tp1,...tpN],
163
                                            'CI_high': [tp0,tp1,...tpN],
164
                                            'n': nX
165
                                            }
166
                              }
167
                    }
168

169
                    where
170
                        varY is number of time series variables recorded
171
                        clusterX is number of clusters predicted (n_clusters in model)
172
                        tpN is number of time points in each time series
173
                        n0 is the size of cluster0, nX is the size of clusterX, etc.
174

175

176
    """
177
    cluster_means = {}
×
178

179
    for i in dict_to_plot:  # iterate clusters
×
180
        for j in dict_to_plot[i]:  # iterate labs
×
181
            if j not in cluster_means:
×
182
                cluster_means[j] = {}
×
183

184
            cluster_means[j][i] = {}  # clusters nested within vars (reverse structure to clusters_for_plotting)
×
185

186
            cluster_means[j][i]["mean"] = list(
×
187
                pd.DataFrame(dict_to_plot[i][j]).mean(axis=0, skipna=True)
188
            )  # cluster mean array of time series var
189
            # CI calculation, from https://stackoverflow.com/a/34474255
190
            (cluster_means[j][i]["CI_low"], cluster_means[j][i]["CI_high"]) = st.t.interval(
×
191
                0.95,
192
                len(dict_to_plot[i][j]) - 1,  # degrees of freedom
193
                loc=cluster_means[j][i]["mean"],
194
                scale=pd.DataFrame(dict_to_plot[i][j]).sem(axis=0, skipna=True),
195
            )
196
            cluster_means[j][i]["n"] = len(dict_to_plot[i][j])  # save cluster size for downstream tasks/plotting
×
197

198
    return cluster_means
×
199

200

201
def plot_cluster_means(cluster_means: Dict[int, dict]) -> None:
×
202
    """
203
    Generate line plots of cluster means and 95% confidence intervals for each time series variable.
204

205
    Parameters
206
    __________
207
    cluster_means :
208
        Output from get_cluster_means function.
209
    """
210
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]  # to keep cluster colors consistent
×
211

212
    for i in cluster_means:  # iterate time series vars
×
213
        y = cluster_means[i]
×
214

215
        plt.figure(figsize=(16, 8))
×
216

217
        for y_values in y:  # iterate clusters
×
218
            for val in y[y_values]:  # iterate calculation (mean, CI_low, CI_high)
×
219
                if val == "mean":
×
220
                    x = np.arange(len(y[y_values][val]))
×
221
                    series1 = np.array(y[y_values][val]).astype(np.double)
×
222
                    s1mask = np.isfinite(series1)
×
223
                    plt.plot(
×
224
                        x[s1mask],
225
                        series1[s1mask],
226
                        ".-",  # mean as solid line
227
                        color=colors[y_values],
228
                        label="Cluster %i mean (n = %d)"
229
                        % (
230
                            y_values,
231
                            y[y_values]["n"],
232
                        ),  # legend will include cluster size
233
                    )
234

235
                if val in ("CI_low", "CI_high"):
×
236
                    x = np.arange(len(y[y_values][val]))
×
237
                    series1 = np.array(y[y_values][val]).astype(np.double)
×
238
                    s1mask = np.isfinite(series1)
×
239
                    plt.plot(
×
240
                        x[s1mask],
241
                        series1[s1mask],
242
                        "--",  # CI bounds as dashed lines
243
                        color=colors[y_values],
244
                    )
245

246
        plt.title("Var %d" % i)
×
247
        plt.xlabel("Timepoint")
×
248
        plt.xticks(x)
×
249

250
        # add dashed line label to legend
251
        line_dashed = mlines.Line2D([], [], color="gray", linestyle="--", linewidth=1.5, label="95% CI")
×
252
        handles, labels = plt.legend().axes.get_legend_handles_labels()
×
253
        handles.append(line_dashed)
×
254
        new_lgd = plt.legend(handles=handles)
×
255
        plt.gca().add_artist(new_lgd)
×
256

257
        plt.show()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc