• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WillianFuks / tfcausalimpact / 9453713053

10 Jun 2024 06:34PM UTC coverage: 100.0%. Remained the same
9453713053

Pull #99

github

web-flow
Merge 19c223fba into f3012c361
Pull Request #99: Support pandas 2.1

129 of 129 branches covered (100.0%)

Branch coverage included in aggregate %.

2 of 2 new or added lines in 1 file covered. (100.0%)

452 of 452 relevant lines covered (100.0%)

1.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/causalimpact/plot.py
1
# Copyright WillianFuks
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15

16
"""
1✔
17
Plots the analysis obtained in causal impact algorithm.
18
"""
19

20

21
import numpy as np
1✔
22
import pandas as pd
1✔
23

24

25
def plot(
1✔
26
    inferences: pd.DataFrame,
27
    pre_data: pd.DataFrame,
28
    post_data: pd.DataFrame,
29
    panels=['original', 'pointwise', 'cumulative'],
30
    figsize=(10, 7),
31
    show=True
32
) -> None:
33
    """Plots inferences results related to causal impact analysis.
34

35
    Args
36
    ----
37
      panels: list.
38
        Indicates which plot should be considered in the graphics.
39
      figsize: tuple.
40
        Changes the size of the graphics plotted.
41
      show: bool.
42
        If true, runs plt.show(), i.e., displays the figure.
43
        If false, it gives acess to the axis, i.e., the figure can be saved
44
        and the style of the plot can be modified by getting the axis with
45
        `ax = plt.gca()` or the figure with `fig = plt.gcf()`.
46
        Defaults to True.
47
    Raises
48
    ------
49
      RuntimeError: if inferences were not computed yet.
50
    """
51
    plt = get_plotter()
1✔
52
    plt.figure(figsize=figsize)
1✔
53
    valid_panels = ['original', 'pointwise', 'cumulative']
1✔
54
    for panel in panels:
1✔
55
        if panel not in valid_panels:
1✔
56
            raise ValueError(
1✔
57
                '"{}" is not a valid panel. Valid panels are: {}.'.format(
58
                    panel, ', '.join(['"{}"'.format(e) for e in valid_panels])
59
                )
60
            )
61
    pre_data, post_data, inferences = build_data(pre_data, post_data, inferences)
1✔
62
    pre_post_index = pre_data.index.union(post_data.index)
1✔
63

64
    post_period_init = post_data.index[0]
1✔
65
    intervention_idx = pre_post_index.get_loc(post_period_init)
1✔
66
    n_panels = len(panels)
1✔
67
    ax = plt.subplot(n_panels, 1, 1)
1✔
68
    idx = 1
1✔
69
    color = (1.0, 0.4981, 0.0549)
1✔
70
    # The operation `iloc[1:]` is used mainly to remove the uncertainty associated to the
71
    # predictions of the first points. As the predictions follow
72
    # `P(z[t] | y[1...t-1], z[1...t-1])` the very first point ends up being quite noisy
73
    # as there's no previous point observed.
74
    if 'original' in panels:
1✔
75
        ax.plot(
1✔
76
            pre_post_index,
77
            pd.concat([pre_data.iloc[:, 0], post_data.iloc[:, 0]]),
78
            'k',
79
            label='y'
80
        )
81
        ax.plot(
1✔
82
            pre_post_index[1:],
83
            inferences['complete_preds_means'].iloc[1:],
84
            color='orangered',
85
            ls='dashed',
86
            label='Predicted'
87
        )
88
        ax.axvline(pre_post_index[intervention_idx - 1], c='gray', linestyle='--')
1✔
89
        ax.fill_between(
1✔
90
            pre_post_index[1:],
91
            inferences['complete_preds_lower'].iloc[1:],
92
            inferences['complete_preds_upper'].iloc[1:],
93
            color=color,
94
            alpha=0.4
95
        )
96
        ax.legend()
1✔
97
        ax.grid(True, color='gainsboro')
1✔
98
        if idx != n_panels:
1✔
99
            plt.setp(ax.get_xticklabels(), visible=False)
1✔
100
        idx += 1
1✔
101
    if 'pointwise' in panels:
1✔
102
        ax = plt.subplot(n_panels, 1, idx, sharex=ax)
1✔
103
        ax.plot(
1✔
104
            pre_post_index[1:],
105
            inferences['point_effects_means'].iloc[1:],
106
            ls='dashed',
107
            color='orangered',
108
            label='Point Effects'
109
        )
110
        ax.axvline(pre_post_index[intervention_idx - 1], c='gray', linestyle='--')
1✔
111
        ax.fill_between(
1✔
112
            pre_post_index[1:],
113
            inferences['point_effects_lower'].iloc[1:],
114
            inferences['point_effects_upper'].iloc[1:],
115
            color=color,
116
            alpha=0.4
117
        )
118
        ax.axhline(y=0, color='gray')
1✔
119
        ax.legend()
1✔
120
        ax.grid(True, color='gainsboro')
1✔
121
        if idx != n_panels:
1✔
122
            plt.setp(ax.get_xticklabels(), visible=False)
1✔
123
        idx += 1
1✔
124
    if 'cumulative' in panels:
1✔
125
        ax = plt.subplot(n_panels, 1, idx, sharex=ax)
1✔
126
        ax.plot(
1✔
127
            pre_post_index[1:],
128
            inferences['post_cum_effects_means'].iloc[1:],
129
            ls='dashed',
130
            color='orangered',
131
            label='Cumulative Effect'
132
        )
133
        ax.axvline(pre_post_index[intervention_idx - 1], c='gray', linestyle='--')
1✔
134
        ax.fill_between(
1✔
135
            pre_post_index[1:],
136
            inferences['post_cum_effects_lower'].iloc[1:],
137
            inferences['post_cum_effects_upper'].iloc[1:],
138
            color=color,
139
            alpha=0.4
140
        )
141
        ax.axhline(y=0, color='gray', linestyle='--')
1✔
142
        ax.legend()
1✔
143
        ax.grid(True, color='gainsboro')
1✔
144
    if show:
1✔
145
        plt.show()
1✔
146

147

148
def build_data(
1✔
149
    pre_data: pd.DataFrame,
150
    post_data: pd.DataFrame,
151
    inferences: pd.DataFrame
152
) -> [pd.DataFrame, pd.DataFrame, pd.DataFrame]:
153
    """
154
    Input pre_data may contain NaN points due TFP requirement for a valid frequency set.
155
    As it may break the plotting API, this function removes those points.
156

157
    `post_data` has its potential `NaN` values already removed in main function.
158
    """
159
    if isinstance(inferences.index, pd.RangeIndex):
1✔
160
        pre_data = pre_data.set_index(pd.RangeIndex(start=0, stop=len(pre_data)))
1✔
161
        post_data = post_data.set_index(pd.RangeIndex(start=len(pre_data),
1✔
162
                                        stop=len(pre_data) + len(post_data)))
163
    pre_data_null_index = pre_data[pre_data.iloc[:, 0].isnull()].index
1✔
164
    pre_data = pre_data.drop(pre_data_null_index).astype(np.float64)
1✔
165
    post_data = post_data.astype(np.float64)
1✔
166
    inferences = inferences.drop(pre_data_null_index).astype(np.float64)
1✔
167
    return pre_data, post_data, inferences
1✔
168

169

170
def get_plotter():  # pragma: no cover
171
    """As some environments do not have matplotlib then we import the library through
172
    this method which prevents import exceptions.
173

174
    Returns
175
    -------
176
      plotter: `matplotlib.pyplot`.
177
    """
178
    import matplotlib.pyplot as plt
179
    return plt
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc