• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / c60bc8c4-3c84-48a0-8de6-9eddb48baa31

17 Oct 2023 03:06PM UTC coverage: 78.624% (+78.6%) from 0.0%
c60bc8c4-3c84-48a0-8de6-9eddb48baa31

Pull #1283

circle-ci

americast
update with contextmanager
Pull Request #1283: Fix current issues with forecasting

37 of 37 new or added lines in 2 files covered. (100.0%)

9894 of 12584 relevant lines covered (78.62%)

1.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/evadb/functions/forecast.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
import pickle
×
18

19
import pandas as pd
×
20

21
from evadb.functions.abstract.abstract_function import AbstractFunction
×
22
from evadb.functions.decorators.decorators import setup
×
23

24

25
class ForecastModel(AbstractFunction):
×
26
    @property
×
27
    def name(self) -> str:
×
28
        return "ForecastModel"
×
29

30
    @setup(cacheable=False, function_type="Forecasting", batchable=True)
×
31
    def setup(
×
32
        self,
33
        model_name: str,
34
        model_path: str,
35
        predict_column_rename: str,
36
        time_column_rename: str,
37
        id_column_rename: str,
38
        horizon: int,
39
        library: str,
40
    ):
41
        self.library = library
×
42
        if "neuralforecast" in self.library:
×
43
            from neuralforecast import NeuralForecast
×
44

45
            loaded_model = NeuralForecast.load(path=model_path)
×
46
            self.model_name = model_name[4:] if "Auto" in model_name else model_name
×
47
        else:
48
            with open(model_path, "rb") as f:
×
49
                loaded_model = pickle.load(f)
×
50
            self.model_name = model_name
×
51
        self.model = loaded_model
×
52
        self.predict_column_rename = predict_column_rename
×
53
        self.time_column_rename = time_column_rename
×
54
        self.id_column_rename = id_column_rename
×
55
        self.horizon = int(horizon)
×
56

57
    def forward(self, data) -> pd.DataFrame:
×
58
        if self.library == "statsforecast":
×
59
            forecast_df = self.model.predict(h=self.horizon)
×
60
        else:
61
            forecast_df = self.model.predict()
×
62
        forecast_df.reset_index(inplace=True)
×
63
        forecast_df = forecast_df.rename(
×
64
            columns={
65
                "unique_id": self.id_column_rename,
66
                "ds": self.time_column_rename,
67
                self.model_name: self.predict_column_rename,
68
            }
69
        )[: self.horizon * forecast_df["unique_id"].nunique()]
70
        return forecast_df
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc