• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / eabb8981-5cfa-4600-97fe-7ac70b4899e2

16 Nov 2023 09:54AM UTC coverage: 77.091% (+77.1%) from 0.0%
eabb8981-5cfa-4600-97fe-7ac70b4899e2

Pull #1258

circleci

xzdandy
Fix xgboost
Pull Request #1258: Add feedback for forecasting

3 of 72 new or added lines in 3 files covered. (4.17%)

10408 of 13501 relevant lines covered (77.09%)

1.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.01
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import contextlib
2✔
16
import hashlib
2✔
17
import locale
2✔
18
import os
2✔
19
import pickle
2✔
20
import re
2✔
21
import time
2✔
22
from pathlib import Path
2✔
23
from typing import Dict, List
2✔
24

25
import numpy as np
2✔
26
import pandas as pd
2✔
27

28
from evadb.catalog.catalog_utils import get_metadata_properties
2✔
29
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
2✔
30
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
2✔
31
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
2✔
32
from evadb.configuration.constants import (
2✔
33
    DEFAULT_TRAIN_REGRESSION_METRIC,
34
    DEFAULT_TRAIN_TIME_LIMIT,
35
    DEFAULT_XGBOOST_TASK,
36
    EvaDB_INSTALLATION_DIR,
37
)
38
from evadb.database import EvaDBDatabase
2✔
39
from evadb.executor.abstract_executor import AbstractExecutor
2✔
40
from evadb.functions.decorators.utils import load_io_from_function_decorators
2✔
41
from evadb.models.storage.batch import Batch
2✔
42
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
2✔
43
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
2✔
44
from evadb.utils.errors import FunctionIODefinitionError
2✔
45
from evadb.utils.generic_utils import (
2✔
46
    load_function_class_from_file,
47
    string_comparison_case_insensitive,
48
    try_to_import_ludwig,
49
    try_to_import_neuralforecast,
50
    try_to_import_sklearn,
51
    try_to_import_statsforecast,
52
    try_to_import_torch,
53
    try_to_import_ultralytics,
54
    try_to_import_xgboost,
55
)
56
from evadb.utils.logging_manager import logger
2✔
57

58

59
def root_mean_squared_error(y_true, y_pred):
2✔
NEW
60
    return np.sqrt(np.mean(np.square(y_pred - y_true)))
×
61

62

63
# From https://stackoverflow.com/a/34333710
64
@contextlib.contextmanager
2✔
65
def set_env(**environ):
2✔
66
    """
67
    Temporarily set the process environment variables.
68

69
    >>> with set_env(PLUGINS_DIR='test/plugins'):
70
    ...   "PLUGINS_DIR" in os.environ
71
    True
72

73
    >>> "PLUGINS_DIR" in os.environ
74
    False
75

76
    :type environ: dict[str, unicode]
77
    :param environ: Environment variables to set
78
    """
79
    old_environ = dict(os.environ)
×
80
    os.environ.update(environ)
×
81
    try:
×
82
        yield
×
83
    finally:
84
        os.environ.clear()
×
85
        os.environ.update(old_environ)
×
86

87

88
class CreateFunctionExecutor(AbstractExecutor):
2✔
89
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
2✔
90
        super().__init__(db, node)
2✔
91
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
2✔
92

93
    def handle_huggingface_function(self):
2✔
94
        """Handle HuggingFace functions
95

96
        HuggingFace functions are special functions that are not loaded from a file.
97
        So we do not need to call the setup method on them like we do for other functions.
98
        """
99
        # We need at least one deep learning framework for HuggingFace
100
        # Torch or Tensorflow
101
        try_to_import_torch()
×
102
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
103
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
104
        return (
×
105
            self.node.name,
106
            impl_path,
107
            self.node.function_type,
108
            io_list,
109
            self.node.metadata,
110
        )
111

112
    def handle_ludwig_function(self):
2✔
113
        """Handle ludwig functions
114

115
        Use Ludwig's auto_train engine to train/tune models.
116
        """
117
        try_to_import_ludwig()
×
118
        from ludwig.automl import auto_train
×
119

120
        assert (
×
121
            len(self.children) == 1
122
        ), "Create ludwig function expects 1 child, finds {}.".format(
123
            len(self.children)
124
        )
125

126
        aggregated_batch_list = []
×
127
        child = self.children[0]
×
128
        for batch in child.exec():
×
129
            aggregated_batch_list.append(batch)
×
130
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
131
        aggregated_batch.drop_column_alias()
×
132

133
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
134
        start_time = int(time.time())
×
135
        auto_train_results = auto_train(
×
136
            dataset=aggregated_batch.frames,
137
            target=arg_map["predict"],
138
            tune_for_memory=arg_map.get("tune_for_memory", False),
139
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
140
            output_directory=self.db.catalog().get_configuration_catalog_value(
141
                "tmp_dir"
142
            ),
143
        )
144
        train_time = int(time.time()) - start_time
×
145
        model_path = os.path.join(
×
146
            self.db.catalog().get_configuration_catalog_value("model_dir"),
147
            self.node.name,
148
        )
149
        auto_train_results.best_model.save(model_path)
×
150
        best_score = auto_train_results.experiment_analysis.best_result["metric_score"]
×
151
        self.node.metadata.append(
×
152
            FunctionMetadataCatalogEntry("model_path", model_path)
153
        )
154

155
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
156
        io_list = self._resolve_function_io(None)
×
157
        return (
×
158
            self.node.name,
159
            impl_path,
160
            self.node.function_type,
161
            io_list,
162
            self.node.metadata,
163
            best_score,
164
            train_time,
165
        )
166

167
    def handle_sklearn_function(self):
2✔
168
        """Handle sklearn functions
169

170
        Use Sklearn's regression to train models.
171
        """
172
        try_to_import_sklearn()
×
173
        from sklearn.linear_model import LinearRegression
×
174

175
        assert (
×
176
            len(self.children) == 1
177
        ), "Create sklearn function expects 1 child, finds {}.".format(
178
            len(self.children)
179
        )
180

181
        aggregated_batch_list = []
×
182
        child = self.children[0]
×
183
        for batch in child.exec():
×
184
            aggregated_batch_list.append(batch)
×
185
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
186
        aggregated_batch.drop_column_alias()
×
187

188
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
189
        model = LinearRegression()
×
190
        Y = aggregated_batch.frames[arg_map["predict"]]
×
191
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
192
        start_time = int(time.time())
×
193
        model.fit(X=aggregated_batch.frames, y=Y)
×
194
        train_time = int(time.time()) - start_time
×
195
        score = model.score(X=aggregated_batch.frames, y=Y)
×
196
        model_path = os.path.join(
×
197
            self.db.catalog().get_configuration_catalog_value("model_dir"),
198
            self.node.name,
199
        )
200
        pickle.dump(model, open(model_path, "wb"))
×
201
        self.node.metadata.append(
×
202
            FunctionMetadataCatalogEntry("model_path", model_path)
203
        )
204
        # Pass the prediction column name to sklearn.py
205
        self.node.metadata.append(
×
206
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
207
        )
208

209
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
210
        io_list = self._resolve_function_io(None)
×
211
        return (
×
212
            self.node.name,
213
            impl_path,
214
            self.node.function_type,
215
            io_list,
216
            self.node.metadata,
217
            score,
218
            train_time,
219
        )
220

221
    def convert_to_numeric(self, x):
2✔
222
        x = re.sub("[^0-9.,]", "", str(x))
×
223
        locale.setlocale(locale.LC_ALL, "")
×
224
        x = float(locale.atof(x))
×
225
        if x.is_integer():
×
226
            return int(x)
×
227
        else:
228
            return x
×
229

230
    def handle_xgboost_function(self):
2✔
231
        """Handle xgboost functions
232

233
        We use the Flaml AutoML model for training xgboost models.
234
        """
235
        try_to_import_xgboost()
×
236

237
        assert (
×
238
            len(self.children) == 1
239
        ), "Create sklearn function expects 1 child, finds {}.".format(
240
            len(self.children)
241
        )
242

243
        aggregated_batch_list = []
×
244
        child = self.children[0]
×
245
        for batch in child.exec():
×
246
            aggregated_batch_list.append(batch)
×
247
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
248
        aggregated_batch.drop_column_alias()
×
249

250
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
251
        from flaml import AutoML
×
252

253
        model = AutoML()
×
254
        settings = {
×
255
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
256
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
257
            "estimator_list": ["xgboost"],
258
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
259
        }
260
        start_time = int(time.time())
×
261
        model.fit(
×
262
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
263
        )
264
        train_time = int(time.time()) - start_time
×
265
        model_path = os.path.join(
×
266
            self.db.catalog().get_configuration_catalog_value("model_dir"),
267
            self.node.name,
268
        )
269
        pickle.dump(model, open(model_path, "wb"))
×
270
        self.node.metadata.append(
×
271
            FunctionMetadataCatalogEntry("model_path", model_path)
272
        )
273
        # Pass the prediction column to xgboost.py.
274
        self.node.metadata.append(
×
275
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
276
        )
277

278
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
×
279
        io_list = self._resolve_function_io(None)
×
280
        best_score = model.best_loss
×
281
        return (
×
282
            self.node.name,
283
            impl_path,
284
            self.node.function_type,
285
            io_list,
286
            self.node.metadata,
287
            best_score,
288
            train_time,
289
        )
290

291
    def handle_ultralytics_function(self):
2✔
292
        """Handle Ultralytics functions"""
293
        try_to_import_ultralytics()
2✔
294

295
        impl_path = (
2✔
296
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
297
        )
298
        function = self._try_initializing_function(
2✔
299
            impl_path, function_args=get_metadata_properties(self.node)
300
        )
301
        io_list = self._resolve_function_io(function)
2✔
302
        return (
2✔
303
            self.node.name,
304
            impl_path,
305
            self.node.function_type,
306
            io_list,
307
            self.node.metadata,
308
        )
309

310
    def handle_forecasting_function(self):
2✔
311
        """Handle forecasting functions"""
312
        aggregated_batch_list = []
×
313
        child = self.children[0]
×
314
        for batch in child.exec():
×
315
            aggregated_batch_list.append(batch)
×
316
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
317
        aggregated_batch.drop_column_alias()
×
318

319
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
320
        if not self.node.impl_path:
×
321
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
322
        else:
323
            impl_path = self.node.impl_path.absolute().as_posix()
×
324
        library = "statsforecast"
×
325
        supported_libraries = ["statsforecast", "neuralforecast"]
×
326

327
        if "horizon" not in arg_map.keys():
×
328
            raise ValueError(
329
                "Horizon must be provided while creating function of type FORECASTING"
330
            )
331
        try:
×
332
            horizon = int(arg_map["horizon"])
×
333
        except Exception as e:
334
            err_msg = f"{str(e)}. HORIZON must be integral."
335
            logger.error(err_msg)
336
            raise FunctionIODefinitionError(err_msg)
337

338
        if "library" in arg_map.keys():
×
339
            try:
×
340
                assert arg_map["library"].lower() in supported_libraries
×
341
            except Exception:
342
                err_msg = (
343
                    "EvaDB currently supports " + str(supported_libraries) + " only."
344
                )
345
                logger.error(err_msg)
346
                raise FunctionIODefinitionError(err_msg)
347

348
            library = arg_map["library"].lower()
×
349

350
        """
×
351
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
352
        - The unique_id (string, int or category) represents an identifier for the series.
353
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
354
        - The y (numeric) represents the measurement we wish to forecast.
355
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
356
        """
357
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
358
        if "time" in arg_map.keys():
×
359
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
360
        if "id" in arg_map.keys():
×
361
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
NEW
362
        if "conf" in arg_map.keys():
×
NEW
363
            try:
×
NEW
364
                conf = round(arg_map["conf"])
×
365
            except Exception:
366
                err_msg = "Confidence must be a number."
367
                logger.error(err_msg)
368
                raise FunctionIODefinitionError(err_msg)
369
        else:
NEW
370
            conf = 90
×
371

NEW
372
        if conf > 100:
×
NEW
373
            err_msg = "Confidence must <= 100."
×
NEW
374
            logger.error(err_msg)
×
375
            raise FunctionIODefinitionError(err_msg)
376

377
        data = aggregated_batch.frames
×
378
        if "unique_id" not in list(data.columns):
×
379
            data["unique_id"] = [1 for x in range(len(data))]
×
380

381
        if "ds" not in list(data.columns):
×
382
            data["ds"] = [x + 1 for x in range(len(data))]
×
383

384
        """
×
385
            Set or infer data frequency
386
        """
387

388
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
389
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
390
        frequency = arg_map["frequency"]
×
391
        if frequency is None:
×
392
            raise RuntimeError(
393
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
394
            )
395

396
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
397
            "H": 24,
398
            "M": 12,
399
            "Q": 4,
400
            "SM": 24,
401
            "BM": 12,
402
            "BMS": 12,
403
            "BQ": 4,
404
            "BH": 24,
405
        }
406

407
        new_freq = (
×
408
            frequency.split("-")[0] if "-" in frequency else frequency
409
        )  # shortens longer frequencies like Q-DEC
410
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
411

412
        """
×
413
            Neuralforecast implementation
414
        """
415
        if library == "neuralforecast":
×
416
            try_to_import_neuralforecast()
×
417
            from neuralforecast import NeuralForecast
×
NEW
418
            from neuralforecast.auto import (
×
419
                AutoDeepAR,
420
                AutoFEDformer,
421
                AutoInformer,
422
                AutoNBEATS,
423
                AutoNHITS,
424
                AutoPatchTST,
425
                AutoTFT,
426
            )
427

428
            # from neuralforecast.auto import AutoAutoformer as AutoAFormer
NEW
429
            from neuralforecast.losses.pytorch import MQLoss
×
NEW
430
            from neuralforecast.models import (
×
431
                NBEATS,
432
                NHITS,
433
                TFT,
434
                DeepAR,
435
                FEDformer,
436
                Informer,
437
                PatchTST,
438
            )
439

440
            # from neuralforecast.models import Autoformer as AFormer
441

442
            model_dict = {
×
443
                "AutoNBEATS": AutoNBEATS,
444
                "AutoNHITS": AutoNHITS,
445
                "NBEATS": NBEATS,
446
                "NHITS": NHITS,
447
                "PatchTST": PatchTST,
448
                "AutoPatchTST": AutoPatchTST,
449
                "DeepAR": DeepAR,
450
                "AutoDeepAR": AutoDeepAR,
451
                "FEDformer": FEDformer,
452
                "AutoFEDformer": AutoFEDformer,
453
                # "AFormer": AFormer,
454
                # "AutoAFormer": AutoAFormer,
455
                "Informer": Informer,
456
                "AutoInformer": AutoInformer,
457
                "TFT": TFT,
458
                "AutoTFT": AutoTFT,
459
            }
460

461
            if "model" not in arg_map.keys():
×
NEW
462
                arg_map["model"] = "TFT"
×
463

464
            if "auto" not in arg_map.keys() or (
×
465
                arg_map["auto"].lower()[0] == "t"
466
                and "auto" not in arg_map["model"].lower()
467
            ):
468
                arg_map["model"] = "Auto" + arg_map["model"]
×
469

470
            try:
×
471
                model_here = model_dict[arg_map["model"]]
×
472
            except Exception:
473
                err_msg = "Supported models: " + str(model_dict.keys())
474
                logger.error(err_msg)
475
                raise FunctionIODefinitionError(err_msg)
476
            model_args = {}
×
477

478
            if "auto" not in arg_map["model"].lower():
×
479
                model_args["input_size"] = 2 * horizon
×
480
                model_args["early_stop_patience_steps"] = 20
×
481
            else:
482
                model_args_config = {
×
483
                    "input_size": 2 * horizon,
484
                    "early_stop_patience_steps": 20,
485
                }
486

487
            if len(data.columns) >= 4:
×
488
                exogenous_columns = [
×
489
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
490
                ]
491
                if "auto" not in arg_map["model"].lower():
×
492
                    model_args["hist_exog_list"] = exogenous_columns
×
493
                else:
494
                    model_args_config["hist_exog_list"] = exogenous_columns
×
495

NEW
496
            if "auto" in arg_map["model"].lower():
×
497

NEW
498
                def get_optuna_config(trial):
×
NEW
499
                    return model_args_config
×
500

NEW
501
                model_args["config"] = get_optuna_config
×
NEW
502
                model_args["backend"] = "optuna"
×
503

504
            model_args["h"] = horizon
×
NEW
505
            model_args["loss"] = MQLoss(level=[conf])
×
506

507
            model = NeuralForecast(
×
508
                [model_here(**model_args)],
509
                freq=new_freq,
510
            )
511

512
        # """
513
        #     Statsforecast implementation
514
        # """
515
        else:
516
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
517
                raise RuntimeError(
518
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
519
                )
520
            try_to_import_statsforecast()
×
521
            from statsforecast import StatsForecast
×
522
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
523

524
            model_dict = {
×
525
                "AutoARIMA": AutoARIMA,
526
                "AutoCES": AutoCES,
527
                "AutoETS": AutoETS,
528
                "AutoTheta": AutoTheta,
529
            }
530

531
            if "model" not in arg_map.keys():
×
532
                arg_map["model"] = "ARIMA"
×
533

534
            if "auto" not in arg_map["model"].lower():
×
535
                arg_map["model"] = "Auto" + arg_map["model"]
×
536

537
            try:
×
538
                model_here = model_dict[arg_map["model"]]
×
539
            except Exception:
540
                err_msg = "Supported models: " + str(model_dict.keys())
541
                logger.error(err_msg)
542
                raise FunctionIODefinitionError(err_msg)
543

544
            model = StatsForecast(
×
545
                [model_here(season_length=season_length)], freq=new_freq
546
            )
547

548
        data["ds"] = pd.to_datetime(data["ds"])
×
549

NEW
550
        model_save_dir_name = (
×
551
            library + "_" + arg_map["model"] + "_" + new_freq
552
            if "statsforecast" in library
553
            else library + "_" + str(conf) + "_" + arg_map["model"] + "_" + new_freq
554
        )
555
        if len(data.columns) >= 4 and library == "neuralforecast":
×
556
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
557

558
        model_dir = os.path.join(
×
559
            self.db.catalog().get_configuration_catalog_value("model_dir"),
560
            "tsforecasting",
561
            model_save_dir_name,
562
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
563
        )
564
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
565

566
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
567

568
        model_path = os.path.join(model_dir, model_save_name)
×
569

570
        existing_model_files = sorted(
×
571
            os.listdir(model_dir),
572
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
573
        )
574
        existing_model_files = [
×
575
            x
576
            for x in existing_model_files
577
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
578
        ]
579
        if len(existing_model_files) == 0:
×
580
            logger.info("Training, please wait...")
×
581
            for column in data.columns:
×
582
                if column != "ds" and column != "unique_id":
×
583
                    data[column] = data.apply(
×
584
                        lambda x: self.convert_to_numeric(x[column]), axis=1
585
                    )
NEW
586
            rmses = []
×
587
            if library == "neuralforecast":
×
588
                cuda_devices_here = "0"
×
589
                if "CUDA_VISIBLE_DEVICES" in os.environ:
×
590
                    cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
×
591

592
                with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
×
593
                    model.fit(df=data, val_size=horizon)
×
594
                    model.save(model_path, overwrite=True)
×
NEW
595
                    if "metrics" in arg_map and arg_map["metrics"].lower()[0] == "t":
×
NEW
596
                        crossvalidation_df = model.cross_validation(
×
597
                            df=data, val_size=horizon
598
                        )
NEW
599
                        for uid in crossvalidation_df.unique_id.unique():
×
NEW
600
                            crossvalidation_df_here = crossvalidation_df[
×
601
                                crossvalidation_df.unique_id == uid
602
                            ]
NEW
603
                            rmses.append(
×
604
                                root_mean_squared_error(
605
                                    crossvalidation_df_here.y,
606
                                    crossvalidation_df_here[
607
                                        arg_map["model"] + "-median"
608
                                    ],
609
                                )
610
                                / np.mean(crossvalidation_df_here.y)
611
                            )
NEW
612
                            mean_rmse = np.mean(rmses)
×
NEW
613
                            with open(model_path + "_rmse", "w") as f:
×
NEW
614
                                f.write(str(mean_rmse) + "\n")
×
615
            else:
616
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
617
                for col in data["unique_id"].unique():
×
618
                    if len(data[data["unique_id"] == col]) == 1:
×
619
                        data = data._append(
×
620
                            [data[data["unique_id"] == col]], ignore_index=True
621
                        )
622

623
                model.fit(df=data[["ds", "y", "unique_id"]])
×
NEW
624
                hypers = ""
×
NEW
625
                if "arima" in arg_map["model"].lower():
×
NEW
626
                    from statsforecast.arima import arima_string
×
627

NEW
628
                    hypers += arima_string(model.fitted_[0, 0].model_)
×
629
                f = open(model_path, "wb")
×
630
                pickle.dump(model, f)
×
631
                f.close()
×
NEW
632
                if "metrics" not in arg_map or arg_map["metrics"].lower()[0] == "t":
×
NEW
633
                    crossvalidation_df = model.cross_validation(
×
634
                        df=data[["ds", "y", "unique_id"]],
635
                        h=horizon,
636
                        step_size=24,
637
                        n_windows=1,
638
                    ).reset_index()
NEW
639
                    for uid in crossvalidation_df.unique_id.unique():
×
NEW
640
                        crossvalidation_df_here = crossvalidation_df[
×
641
                            crossvalidation_df.unique_id == uid
642
                        ]
NEW
643
                        rmses.append(
×
644
                            root_mean_squared_error(
645
                                crossvalidation_df_here.y,
646
                                crossvalidation_df_here[arg_map["model"]],
647
                            )
648
                            / np.mean(crossvalidation_df_here.y)
649
                        )
NEW
650
                    mean_rmse = np.mean(rmses)
×
NEW
651
                    with open(model_path + "_rmse", "w") as f:
×
NEW
652
                        f.write(str(mean_rmse) + "\n")
×
NEW
653
                        f.write(hypers + "\n")
×
654
        elif not Path(model_path).exists():
×
655
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
656
        io_list = self._resolve_function_io(None)
×
NEW
657
        data["ds"] = data.ds.astype(str)
×
658
        metadata_here = [
×
659
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
660
            FunctionMetadataCatalogEntry("model_path", model_path),
661
            FunctionMetadataCatalogEntry(
662
                "predict_column_rename", arg_map.get("predict", "y")
663
            ),
664
            FunctionMetadataCatalogEntry(
665
                "time_column_rename", arg_map.get("time", "ds")
666
            ),
667
            FunctionMetadataCatalogEntry(
668
                "id_column_rename", arg_map.get("id", "unique_id")
669
            ),
670
            FunctionMetadataCatalogEntry("horizon", horizon),
671
            FunctionMetadataCatalogEntry("library", library),
672
            FunctionMetadataCatalogEntry("conf", conf),
673
        ]
674

675
        return (
×
676
            self.node.name,
677
            impl_path,
678
            self.node.function_type,
679
            io_list,
680
            metadata_here,
681
        )
682

683
    def handle_generic_function(self):
2✔
684
        """Handle generic functions
685

686
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
687
        """
688
        impl_path = self.node.impl_path.absolute().as_posix()
2✔
689
        function = self._try_initializing_function(impl_path)
2✔
690
        io_list = self._resolve_function_io(function)
2✔
691

692
        return (
2✔
693
            self.node.name,
694
            impl_path,
695
            self.node.function_type,
696
            io_list,
697
            self.node.metadata,
698
        )
699

700
    def exec(self, *args, **kwargs):
2✔
701
        """Create function executor
702

703
        Calls the catalog to insert a function catalog entry.
704
        """
705
        assert (
2✔
706
            self.node.if_not_exists and self.node.or_replace
707
        ) is False, (
708
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
709
        )
710

711
        overwrite = False
2✔
712
        best_score = False
2✔
713
        train_time = False
2✔
714
        # check catalog if it already has this function entry
715
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
2✔
716
            if self.node.if_not_exists:
1✔
717
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
718
                yield Batch(pd.DataFrame([msg]))
1✔
719
                return
×
720
            elif self.node.or_replace:
1✔
721
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
722
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
723

724
                drop_executor = DropObjectExecutor(self.db, None)
1✔
725
                try:
1✔
726
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
1✔
727
                except RuntimeError:
728
                    pass
729
                else:
730
                    overwrite = True
1✔
731
            else:
732
                msg = f"Function {self.node.name} already exists."
×
733
                logger.error(msg)
×
734
                raise RuntimeError(msg)
735

736
        # if it's a type of HuggingFaceModel, override the impl_path
737
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
2✔
738
            (
×
739
                name,
740
                impl_path,
741
                function_type,
742
                io_list,
743
                metadata,
744
            ) = self.handle_huggingface_function()
745
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
2✔
746
            (
2✔
747
                name,
748
                impl_path,
749
                function_type,
750
                io_list,
751
                metadata,
752
            ) = self.handle_ultralytics_function()
753
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
2✔
754
            (
×
755
                name,
756
                impl_path,
757
                function_type,
758
                io_list,
759
                metadata,
760
                best_score,
761
                train_time,
762
            ) = self.handle_ludwig_function()
763
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
2✔
764
            (
×
765
                name,
766
                impl_path,
767
                function_type,
768
                io_list,
769
                metadata,
770
                best_score,
771
                train_time,
772
            ) = self.handle_sklearn_function()
773
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
2✔
774
            (
×
775
                name,
776
                impl_path,
777
                function_type,
778
                io_list,
779
                metadata,
780
                best_score,
781
                train_time,
782
            ) = self.handle_xgboost_function()
783
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
2✔
784
            (
×
785
                name,
786
                impl_path,
787
                function_type,
788
                io_list,
789
                metadata,
790
            ) = self.handle_forecasting_function()
791
        else:
792
            (
2✔
793
                name,
794
                impl_path,
795
                function_type,
796
                io_list,
797
                metadata,
798
            ) = self.handle_generic_function()
799

800
        self.catalog().insert_function_catalog_entry(
2✔
801
            name, impl_path, function_type, io_list, metadata
802
        )
803

804
        if overwrite:
2✔
805
            msg = f"Function {self.node.name} overwritten."
1✔
806
        else:
807
            msg = f"Function {self.node.name} added to the database."
2✔
808
        if best_score and train_time:
2✔
809
            yield Batch(
×
810
                pd.DataFrame(
811
                    [
812
                        msg,
813
                        "Validation Score: " + str(best_score),
814
                        "Training time: " + str(train_time) + " secs.",
815
                    ]
816
                )
817
            )
818
        else:
819
            yield Batch(pd.DataFrame([msg]))
2✔
820

821
    def _try_initializing_function(
2✔
822
        self, impl_path: str, function_args: Dict = {}
823
    ) -> FunctionCatalogEntry:
824
        """Attempts to initialize function given the implementation file path and arguments.
825

826
        Args:
827
            impl_path (str): The file path of the function implementation file.
828
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
829

830
        Returns:
831
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
832

833
        Raises:
834
            RuntimeError: If an error occurs while initializing the function.
835
        """
836

837
        # load the function class from the file
838
        try:
2✔
839
            # loading the function class from the file
840
            function = load_function_class_from_file(impl_path, self.node.name)
2✔
841
            # initializing the function class calls the setup method internally
842
            function(**function_args)
2✔
843
        except Exception as e:
844
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
845
            # logger.error(err_msg)
846
            raise RuntimeError(err_msg)
847

848
        return function
2✔
849

850
    def _resolve_function_io(
2✔
851
        self, function: FunctionCatalogEntry
852
    ) -> List[FunctionIOCatalogEntry]:
853
        """Private method that resolves the input/output definitions for a given function.
854
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
855

856
        Args:
857
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
858

859
        Returns:
860
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
861
            output definitions for the function.
862

863
        Raises:
864
            RuntimeError: If an error occurs while resolving the function input/output
865
            definitions.
866
        """
867
        io_list = []
2✔
868
        try:
2✔
869
            if self.node.inputs:
2✔
870
                io_list.extend(self.node.inputs)
2✔
871
            else:
872
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
873
                io_list.extend(
2✔
874
                    load_io_from_function_decorators(function, is_input=True)
875
                )
876

877
            if self.node.outputs:
2✔
878
                io_list.extend(self.node.outputs)
2✔
879
            else:
880
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
881
                io_list.extend(
2✔
882
                    load_io_from_function_decorators(function, is_input=False)
883
                )
884

885
        except FunctionIODefinitionError as e:
886
            err_msg = (
887
                f"Error creating function, input/output definition incorrect: {str(e)}"
888
            )
889
            logger.error(err_msg)
890
            raise RuntimeError(err_msg)
891

892
        return io_list
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc