• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / d00d1922-5b74-41a8-9002-6df6f052e0b4

10 Nov 2023 12:12AM UTC coverage: 76.812%. First build
d00d1922-5b74-41a8-9002-6df6f052e0b4

push

circleci

americast
plot as a new column

0 of 27 new or added lines in 2 files covered. (0.0%)

10100 of 13149 relevant lines covered (76.81%)

1.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.31
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import contextlib
2✔
16
import hashlib
2✔
17
import locale
2✔
18
import os
2✔
19
import pickle
2✔
20
import re
2✔
21
from pathlib import Path
2✔
22
from typing import Dict, List
2✔
23

24
import numpy as np
2✔
25
import pandas as pd
2✔
26
from sklearn.metrics import mean_squared_error
2✔
27

28
from evadb.catalog.catalog_utils import get_metadata_properties
2✔
29
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
2✔
30
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
2✔
31
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
2✔
32
from evadb.configuration.constants import (
2✔
33
    DEFAULT_TRAIN_REGRESSION_METRIC,
34
    DEFAULT_TRAIN_TIME_LIMIT,
35
    DEFAULT_XGBOOST_TASK,
36
    EvaDB_INSTALLATION_DIR,
37
)
38
from evadb.database import EvaDBDatabase
2✔
39
from evadb.executor.abstract_executor import AbstractExecutor
2✔
40
from evadb.functions.decorators.utils import load_io_from_function_decorators
2✔
41
from evadb.models.storage.batch import Batch
2✔
42
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
2✔
43
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
2✔
44
from evadb.utils.errors import FunctionIODefinitionError
2✔
45
from evadb.utils.generic_utils import (
2✔
46
    load_function_class_from_file,
47
    string_comparison_case_insensitive,
48
    try_to_import_ludwig,
49
    try_to_import_neuralforecast,
50
    try_to_import_sklearn,
51
    try_to_import_statsforecast,
52
    try_to_import_torch,
53
    try_to_import_ultralytics,
54
    try_to_import_xgboost,
55
)
56
from evadb.utils.logging_manager import logger
2✔
57

58

59
# From https://stackoverflow.com/a/34333710
60
@contextlib.contextmanager
2✔
61
def set_env(**environ):
2✔
62
    """
63
    Temporarily set the process environment variables.
64

65
    >>> with set_env(PLUGINS_DIR='test/plugins'):
66
    ...   "PLUGINS_DIR" in os.environ
67
    True
68

69
    >>> "PLUGINS_DIR" in os.environ
70
    False
71

72
    :type environ: dict[str, unicode]
73
    :param environ: Environment variables to set
74
    """
75
    old_environ = dict(os.environ)
×
76
    os.environ.update(environ)
×
77
    try:
×
78
        yield
×
79
    finally:
80
        os.environ.clear()
×
81
        os.environ.update(old_environ)
×
82

83

84
class CreateFunctionExecutor(AbstractExecutor):
2✔
85
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
2✔
86
        super().__init__(db, node)
2✔
87
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
2✔
88

89
    def handle_huggingface_function(self):
2✔
90
        """Handle HuggingFace functions
91

92
        HuggingFace functions are special functions that are not loaded from a file.
93
        So we do not need to call the setup method on them like we do for other functions.
94
        """
95
        # We need at least one deep learning framework for HuggingFace
96
        # Torch or Tensorflow
97
        try_to_import_torch()
×
98
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
99
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
100
        return (
×
101
            self.node.name,
102
            impl_path,
103
            self.node.function_type,
104
            io_list,
105
            self.node.metadata,
106
        )
107

108
    def handle_ludwig_function(self):
2✔
109
        """Handle ludwig functions
110

111
        Use Ludwig's auto_train engine to train/tune models.
112
        """
113
        try_to_import_ludwig()
×
114
        from ludwig.automl import auto_train
×
115

116
        assert (
×
117
            len(self.children) == 1
118
        ), "Create ludwig function expects 1 child, finds {}.".format(
119
            len(self.children)
120
        )
121

122
        aggregated_batch_list = []
×
123
        child = self.children[0]
×
124
        for batch in child.exec():
×
125
            aggregated_batch_list.append(batch)
×
126
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
127
        aggregated_batch.drop_column_alias()
×
128

129
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
130
        auto_train_results = auto_train(
×
131
            dataset=aggregated_batch.frames,
132
            target=arg_map["predict"],
133
            tune_for_memory=arg_map.get("tune_for_memory", False),
134
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
135
            output_directory=self.db.catalog().get_configuration_catalog_value(
136
                "tmp_dir"
137
            ),
138
        )
139
        model_path = os.path.join(
×
140
            self.db.catalog().get_configuration_catalog_value("model_dir"),
141
            self.node.name,
142
        )
143
        auto_train_results.best_model.save(model_path)
×
144
        self.node.metadata.append(
×
145
            FunctionMetadataCatalogEntry("model_path", model_path)
146
        )
147

148
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
149
        io_list = self._resolve_function_io(None)
×
150
        return (
×
151
            self.node.name,
152
            impl_path,
153
            self.node.function_type,
154
            io_list,
155
            self.node.metadata,
156
        )
157

158
    def handle_sklearn_function(self):
2✔
159
        """Handle sklearn functions
160

161
        Use Sklearn's regression to train models.
162
        """
163
        try_to_import_sklearn()
×
164
        from sklearn.linear_model import LinearRegression
×
165

166
        assert (
×
167
            len(self.children) == 1
168
        ), "Create sklearn function expects 1 child, finds {}.".format(
169
            len(self.children)
170
        )
171

172
        aggregated_batch_list = []
×
173
        child = self.children[0]
×
174
        for batch in child.exec():
×
175
            aggregated_batch_list.append(batch)
×
176
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
177
        aggregated_batch.drop_column_alias()
×
178

179
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
180
        model = LinearRegression()
×
181
        Y = aggregated_batch.frames[arg_map["predict"]]
×
182
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
183
        model.fit(X=aggregated_batch.frames, y=Y)
×
184
        model_path = os.path.join(
×
185
            self.db.catalog().get_configuration_catalog_value("model_dir"),
186
            self.node.name,
187
        )
188
        pickle.dump(model, open(model_path, "wb"))
×
189
        self.node.metadata.append(
×
190
            FunctionMetadataCatalogEntry("model_path", model_path)
191
        )
192
        # Pass the prediction column name to sklearn.py
193
        self.node.metadata.append(
×
194
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
195
        )
196

197
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
198
        io_list = self._resolve_function_io(None)
×
199
        return (
×
200
            self.node.name,
201
            impl_path,
202
            self.node.function_type,
203
            io_list,
204
            self.node.metadata,
205
        )
206

207
    def convert_to_numeric(self, x):
2✔
208
        x = re.sub("[^0-9.,]", "", str(x))
×
209
        locale.setlocale(locale.LC_ALL, "")
×
210
        x = float(locale.atof(x))
×
211
        if x.is_integer():
×
212
            return int(x)
×
213
        else:
214
            return x
×
215

216
    def handle_xgboost_function(self):
2✔
217
        """Handle xgboost functions
218

219
        We use the Flaml AutoML model for training xgboost models.
220
        """
221
        try_to_import_xgboost()
×
222

223
        assert (
×
224
            len(self.children) == 1
225
        ), "Create sklearn function expects 1 child, finds {}.".format(
226
            len(self.children)
227
        )
228

229
        aggregated_batch_list = []
×
230
        child = self.children[0]
×
231
        for batch in child.exec():
×
232
            aggregated_batch_list.append(batch)
×
233
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
234
        aggregated_batch.drop_column_alias()
×
235

236
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
237
        from flaml import AutoML
×
238

239
        model = AutoML()
×
240
        settings = {
×
241
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
242
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
243
            "estimator_list": ["xgboost"],
244
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
245
        }
246
        model.fit(
×
247
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
248
        )
249
        model_path = os.path.join(
×
250
            self.db.catalog().get_configuration_catalog_value("model_dir"),
251
            self.node.name,
252
        )
253
        pickle.dump(model, open(model_path, "wb"))
×
254
        self.node.metadata.append(
×
255
            FunctionMetadataCatalogEntry("model_path", model_path)
256
        )
257
        # Pass the prediction column to xgboost.py.
258
        self.node.metadata.append(
×
259
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
260
        )
261

262
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
×
263
        io_list = self._resolve_function_io(None)
×
264
        best_score = model.best_loss
×
265
        train_time = model.best_config_train_time
×
266
        return (
×
267
            self.node.name,
268
            impl_path,
269
            self.node.function_type,
270
            io_list,
271
            self.node.metadata,
272
            best_score,
273
            train_time,
274
        )
275

276
    def handle_ultralytics_function(self):
2✔
277
        """Handle Ultralytics functions"""
278
        try_to_import_ultralytics()
2✔
279

280
        impl_path = (
2✔
281
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
282
        )
283
        function = self._try_initializing_function(
2✔
284
            impl_path, function_args=get_metadata_properties(self.node)
285
        )
286
        io_list = self._resolve_function_io(function)
2✔
287
        return (
2✔
288
            self.node.name,
289
            impl_path,
290
            self.node.function_type,
291
            io_list,
292
            self.node.metadata,
293
        )
294

295
    def handle_forecasting_function(self):
2✔
296
        """Handle forecasting functions"""
297
        aggregated_batch_list = []
×
298
        child = self.children[0]
×
299
        for batch in child.exec():
×
300
            aggregated_batch_list.append(batch)
×
301
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
302
        aggregated_batch.drop_column_alias()
×
303

304
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
305
        if not self.node.impl_path:
×
306
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
307
        else:
308
            impl_path = self.node.impl_path.absolute().as_posix()
×
309
        library = "statsforecast"
×
310
        supported_libraries = ["statsforecast", "neuralforecast"]
×
311

312
        if "horizon" not in arg_map.keys():
×
313
            raise ValueError(
314
                "Horizon must be provided while creating function of type FORECASTING"
315
            )
316
        try:
×
317
            horizon = int(arg_map["horizon"])
×
318
        except Exception as e:
319
            err_msg = f"{str(e)}. HORIZON must be integral."
320
            logger.error(err_msg)
321
            raise FunctionIODefinitionError(err_msg)
322

323
        if "library" in arg_map.keys():
×
324
            try:
×
325
                assert arg_map["library"].lower() in supported_libraries
×
326
            except Exception:
327
                err_msg = (
328
                    "EvaDB currently supports " + str(supported_libraries) + " only."
329
                )
330
                logger.error(err_msg)
331
                raise FunctionIODefinitionError(err_msg)
332

333
            library = arg_map["library"].lower()
×
334

335
        """
×
336
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
337
        - The unique_id (string, int or category) represents an identifier for the series.
338
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
339
        - The y (numeric) represents the measurement we wish to forecast.
340
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
341
        """
342
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
343
        if "time" in arg_map.keys():
×
344
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
345
        if "id" in arg_map.keys():
×
346
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
347
        if "conf" in arg_map.keys():
×
348
            try:
×
349
                conf = round(arg_map["conf"])
×
350
            except Exception:
351
                err_msg = "Confidence must be a number."
352
                logger.error(err_msg)
353
                raise FunctionIODefinitionError(err_msg)
354
        else:
355
            conf = 90
×
356

357
        if conf > 100:
×
358
            err_msg = "Confidence must <= 100."
×
359
            logger.error(err_msg)
×
360
            raise FunctionIODefinitionError(err_msg)
361

362
        data = aggregated_batch.frames
×
363
        if "unique_id" not in list(data.columns):
×
364
            data["unique_id"] = [1 for x in range(len(data))]
×
365

366
        if "ds" not in list(data.columns):
×
367
            data["ds"] = [x + 1 for x in range(len(data))]
×
368

369
        """
×
370
            Set or infer data frequency
371
        """
372

373
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
374
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
375
        frequency = arg_map["frequency"]
×
376
        if frequency is None:
×
377
            raise RuntimeError(
378
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
379
            )
380

381
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
382
            "H": 24,
383
            "M": 12,
384
            "Q": 4,
385
            "SM": 24,
386
            "BM": 12,
387
            "BMS": 12,
388
            "BQ": 4,
389
            "BH": 24,
390
        }
391

392
        new_freq = (
×
393
            frequency.split("-")[0] if "-" in frequency else frequency
394
        )  # shortens longer frequencies like Q-DEC
395
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
396

397
        """
×
398
            Neuralforecast implementation
399
        """
400
        if library == "neuralforecast":
×
401
            try_to_import_neuralforecast()
×
402
            from neuralforecast import NeuralForecast
×
403
            from neuralforecast.auto import (
×
404
                AutoDeepAR,
405
                AutoFEDformer,
406
                AutoInformer,
407
                AutoNBEATS,
408
                AutoNHITS,
409
                AutoPatchTST,
410
                AutoTFT,
411
                AutoTimesNet,
412
            )
413

414
            # from neuralforecast.auto import AutoAutoformer as AutoAFormer
415
            from neuralforecast.losses.pytorch import MQLoss
×
416
            from neuralforecast.models import (
×
417
                NBEATS,
418
                NHITS,
419
                TFT,
420
                DeepAR,
421
                FEDformer,
422
                Informer,
423
                PatchTST,
424
                TimesNet,
425
            )
426

427
            # from neuralforecast.models import Autoformer as AFormer
428

429
            model_dict = {
×
430
                "AutoNBEATS": AutoNBEATS,
431
                "AutoNHITS": AutoNHITS,
432
                "NBEATS": NBEATS,
433
                "NHITS": NHITS,
434
                "PatchTST": PatchTST,
435
                "AutoPatchTST": AutoPatchTST,
436
                "DeepAR": DeepAR,
437
                "AutoDeepAR": AutoDeepAR,
438
                "FEDformer": FEDformer,
439
                "AutoFEDformer": AutoFEDformer,
440
                # "AFormer": AFormer,
441
                # "AutoAFormer": AutoAFormer,
442
                "Informer": Informer,
443
                "AutoInformer": AutoInformer,
444
                "TimesNet": TimesNet,
445
                "AutoTimesNet": AutoTimesNet,
446
                "TFT": TFT,
447
                "AutoTFT": AutoTFT,
448
            }
449

450
            if "model" not in arg_map.keys():
×
451
                arg_map["model"] = "TFT"
×
452

453
            if "auto" not in arg_map.keys() or (
×
454
                arg_map["auto"].lower()[0] == "t"
455
                and "auto" not in arg_map["model"].lower()
456
            ):
457
                arg_map["model"] = "Auto" + arg_map["model"]
×
458

459
            try:
×
460
                model_here = model_dict[arg_map["model"]]
×
461
            except Exception:
462
                err_msg = "Supported models: " + str(model_dict.keys())
463
                logger.error(err_msg)
464
                raise FunctionIODefinitionError(err_msg)
465
            model_args = {}
×
466

467
            if "auto" not in arg_map["model"].lower():
×
468
                model_args["input_size"] = 2 * horizon
×
469
                model_args["early_stop_patience_steps"] = 20
×
470
            else:
471
                model_args_config = {
×
472
                    "input_size": 2 * horizon,
473
                    "early_stop_patience_steps": 20,
474
                }
475

476
            if len(data.columns) >= 4:
×
477
                exogenous_columns = [
×
478
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
479
                ]
480
                if "auto" not in arg_map["model"].lower():
×
481
                    model_args["hist_exog_list"] = exogenous_columns
×
482
                else:
483
                    model_args_config["hist_exog_list"] = exogenous_columns
×
484

485
            if "auto" in arg_map["model"].lower():
×
486

487
                def get_optuna_config(trial):
×
488
                    return model_args_config
×
489

490
                model_args["config"] = get_optuna_config
×
491
                model_args["backend"] = "optuna"
×
492

493
            model_args["h"] = horizon
×
494
            model_args["loss"] = MQLoss(level=[conf])
×
495

496
            model = NeuralForecast(
×
497
                [model_here(**model_args)],
498
                freq=new_freq,
499
            )
500

501
        # """
502
        #     Statsforecast implementation
503
        # """
504
        else:
505
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
506
                raise RuntimeError(
507
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
508
                )
509
            try_to_import_statsforecast()
×
510
            from statsforecast import StatsForecast
×
511
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
512

513
            model_dict = {
×
514
                "AutoARIMA": AutoARIMA,
515
                "AutoCES": AutoCES,
516
                "AutoETS": AutoETS,
517
                "AutoTheta": AutoTheta,
518
            }
519

520
            if "model" not in arg_map.keys():
×
521
                arg_map["model"] = "ARIMA"
×
522

523
            if "auto" not in arg_map["model"].lower():
×
524
                arg_map["model"] = "Auto" + arg_map["model"]
×
525

526
            try:
×
527
                model_here = model_dict[arg_map["model"]]
×
528
            except Exception:
529
                err_msg = "Supported models: " + str(model_dict.keys())
530
                logger.error(err_msg)
531
                raise FunctionIODefinitionError(err_msg)
532

533
            model = StatsForecast(
×
534
                [model_here(season_length=season_length)], freq=new_freq, n_jobs=-1
535
            )
536

537
        data["ds"] = pd.to_datetime(data["ds"])
×
538

539
        model_save_dir_name = (
×
540
            library + "_" + arg_map["model"] + "_" + new_freq
541
            if "statsforecast" in library
542
            else library + "_" + str(conf) + "_" + arg_map["model"] + "_" + new_freq
543
        )
544
        if len(data.columns) >= 4 and library == "neuralforecast":
×
545
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
546

547
        model_dir = os.path.join(
×
548
            self.db.catalog().get_configuration_catalog_value("model_dir"),
549
            "tsforecasting",
550
            model_save_dir_name,
551
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
552
        )
553
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
554

555
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
556

557
        model_path = os.path.join(model_dir, model_save_name)
×
558

559
        existing_model_files = sorted(
×
560
            os.listdir(model_dir),
561
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
562
        )
563
        existing_model_files = [
×
564
            x
565
            for x in existing_model_files
566
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
567
        ]
568
        if len(existing_model_files) == 0:
×
569
            logger.info("Training, please wait...")
×
570
            for column in data.columns:
×
571
                if column != "ds" and column != "unique_id":
×
572
                    data[column] = data.apply(
×
573
                        lambda x: self.convert_to_numeric(x[column]), axis=1
574
                    )
575
            rmses = []
×
576
            if library == "neuralforecast":
×
577
                cuda_devices_here = "0"
×
578
                if "CUDA_VISIBLE_DEVICES" in os.environ:
×
579
                    cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
×
580

581
                with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
×
582
                    model.fit(df=data, val_size=horizon)
×
583
                    model.save(model_path, overwrite=True)
×
584
                    if "metrics" in arg_map and arg_map["metrics"].lower()[0] == "t":
×
585
                        crossvalidation_df = model.cross_validation(
×
586
                            df=data, val_size=horizon
587
                        )
588
                        for uid in crossvalidation_df.unique_id.unique():
×
589
                            crossvalidation_df_here = crossvalidation_df[
×
590
                                crossvalidation_df.unique_id == uid
591
                            ]
592
                            rmses.append(
×
593
                                mean_squared_error(
594
                                    crossvalidation_df_here.y,
595
                                    crossvalidation_df_here[
596
                                        arg_map["model"] + "-median"
597
                                    ],
598
                                    squared=False,
599
                                )
600
                                / np.mean(crossvalidation_df_here.y)
601
                            )
602
                            mean_rmse = np.mean(rmses)
×
603
                            with open(model_path + "_rmse", "w") as f:
×
604
                                f.write(str(mean_rmse) + "\n")
×
605
            else:
606
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
607
                for col in data["unique_id"].unique():
×
608
                    if len(data[data["unique_id"] == col]) == 1:
×
609
                        data = data._append(
×
610
                            [data[data["unique_id"] == col]], ignore_index=True
611
                        )
612

613
                model.fit(df=data[["ds", "y", "unique_id"]])
×
614
                hypers = ""
×
615
                if "arima" in arg_map["model"].lower():
×
616
                    from statsforecast.arima import arima_string
×
617

618
                    hypers += arima_string(model.fitted_[0, 0].model_)
×
619
                f = open(model_path, "wb")
×
620
                pickle.dump(model, f)
×
621
                f.close()
×
622
                if "metrics" not in arg_map or arg_map["metrics"].lower()[0] == "t":
×
623
                    crossvalidation_df = model.cross_validation(
×
624
                        df=data[["ds", "y", "unique_id"]],
625
                        h=horizon,
626
                        step_size=24,
627
                        n_windows=1,
628
                    ).reset_index()
629
                    for uid in crossvalidation_df.unique_id.unique():
×
630
                        crossvalidation_df_here = crossvalidation_df[
×
631
                            crossvalidation_df.unique_id == uid
632
                        ]
633
                        rmses.append(
×
634
                            mean_squared_error(
635
                                crossvalidation_df_here.y,
636
                                crossvalidation_df_here[arg_map["model"]],
637
                                squared=False,
638
                            )
639
                            / np.mean(crossvalidation_df_here.y)
640
                        )
641
                    mean_rmse = np.mean(rmses)
×
642
                    with open(model_path + "_rmse", "w") as f:
×
643
                        f.write(str(mean_rmse) + "\n")
×
644
                        f.write(hypers + "\n")
×
645
        elif not Path(model_path).exists():
×
646
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
647
        io_list = self._resolve_function_io(None)
×
648
        data["ds"] = data.ds.astype(str)
×
NEW
649
        last_ds = list(data["ds"])[-2 * horizon :]
×
NEW
650
        last_y = list(data["y"])[-2 * horizon :]
×
651
        metadata_here = [
×
652
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
653
            FunctionMetadataCatalogEntry("model_path", model_path),
654
            FunctionMetadataCatalogEntry(
655
                "predict_column_rename", arg_map.get("predict", "y")
656
            ),
657
            FunctionMetadataCatalogEntry(
658
                "time_column_rename", arg_map.get("time", "ds")
659
            ),
660
            FunctionMetadataCatalogEntry(
661
                "id_column_rename", arg_map.get("id", "unique_id")
662
            ),
663
            FunctionMetadataCatalogEntry("horizon", horizon),
664
            FunctionMetadataCatalogEntry("library", library),
665
            FunctionMetadataCatalogEntry("conf", conf),
666
            FunctionMetadataCatalogEntry("last_ds", last_ds),
667
            FunctionMetadataCatalogEntry("last_y", last_y),
668
        ]
669

670
        return (
×
671
            self.node.name,
672
            impl_path,
673
            self.node.function_type,
674
            io_list,
675
            metadata_here,
676
        )
677

678
    def handle_generic_function(self):
2✔
679
        """Handle generic functions
680

681
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
682
        """
683
        impl_path = self.node.impl_path.absolute().as_posix()
2✔
684
        function = self._try_initializing_function(impl_path)
2✔
685
        io_list = self._resolve_function_io(function)
2✔
686

687
        return (
2✔
688
            self.node.name,
689
            impl_path,
690
            self.node.function_type,
691
            io_list,
692
            self.node.metadata,
693
        )
694

695
    def exec(self, *args, **kwargs):
2✔
696
        """Create function executor
697

698
        Calls the catalog to insert a function catalog entry.
699
        """
700
        assert (
2✔
701
            self.node.if_not_exists and self.node.or_replace
702
        ) is False, (
703
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
704
        )
705

706
        overwrite = False
2✔
707
        best_score = False
2✔
708
        train_time = False
2✔
709
        # check catalog if it already has this function entry
710
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
2✔
711
            if self.node.if_not_exists:
1✔
712
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
713
                yield Batch(pd.DataFrame([msg]))
1✔
714
                return
×
715
            elif self.node.or_replace:
1✔
716
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
717
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
718

719
                drop_executor = DropObjectExecutor(self.db, None)
1✔
720
                try:
1✔
721
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
1✔
722
                except RuntimeError:
723
                    pass
724
                else:
725
                    overwrite = True
1✔
726
            else:
727
                msg = f"Function {self.node.name} already exists."
×
728
                logger.error(msg)
×
729
                raise RuntimeError(msg)
730

731
        # if it's a type of HuggingFaceModel, override the impl_path
732
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
2✔
733
            (
×
734
                name,
735
                impl_path,
736
                function_type,
737
                io_list,
738
                metadata,
739
            ) = self.handle_huggingface_function()
740
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
2✔
741
            (
2✔
742
                name,
743
                impl_path,
744
                function_type,
745
                io_list,
746
                metadata,
747
            ) = self.handle_ultralytics_function()
748
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
2✔
749
            (
×
750
                name,
751
                impl_path,
752
                function_type,
753
                io_list,
754
                metadata,
755
            ) = self.handle_ludwig_function()
756
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
2✔
757
            (
×
758
                name,
759
                impl_path,
760
                function_type,
761
                io_list,
762
                metadata,
763
            ) = self.handle_sklearn_function()
764
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
2✔
765
            (
×
766
                name,
767
                impl_path,
768
                function_type,
769
                io_list,
770
                metadata,
771
                best_score,
772
                train_time,
773
            ) = self.handle_xgboost_function()
774
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
2✔
775
            (
×
776
                name,
777
                impl_path,
778
                function_type,
779
                io_list,
780
                metadata,
781
            ) = self.handle_forecasting_function()
782
        else:
783
            (
2✔
784
                name,
785
                impl_path,
786
                function_type,
787
                io_list,
788
                metadata,
789
            ) = self.handle_generic_function()
790

791
        self.catalog().insert_function_catalog_entry(
2✔
792
            name, impl_path, function_type, io_list, metadata
793
        )
794

795
        if overwrite:
2✔
796
            msg = f"Function {self.node.name} overwritten."
1✔
797
        else:
798
            msg = f"Function {self.node.name} added to the database."
2✔
799
        if best_score and train_time:
2✔
800
            yield Batch(
×
801
                pd.DataFrame(
802
                    [
803
                        msg,
804
                        "Validation Score: " + str(best_score),
805
                        "Training time: " + str(train_time),
806
                    ]
807
                )
808
            )
809
        else:
810
            yield Batch(pd.DataFrame([msg]))
2✔
811

812
    def _try_initializing_function(
2✔
813
        self, impl_path: str, function_args: Dict = {}
814
    ) -> FunctionCatalogEntry:
815
        """Attempts to initialize function given the implementation file path and arguments.
816

817
        Args:
818
            impl_path (str): The file path of the function implementation file.
819
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
820

821
        Returns:
822
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
823

824
        Raises:
825
            RuntimeError: If an error occurs while initializing the function.
826
        """
827

828
        # load the function class from the file
829
        try:
2✔
830
            # loading the function class from the file
831
            function = load_function_class_from_file(impl_path, self.node.name)
2✔
832
            # initializing the function class calls the setup method internally
833
            function(**function_args)
2✔
834
        except Exception as e:
835
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
836
            # logger.error(err_msg)
837
            raise RuntimeError(err_msg)
838

839
        return function
2✔
840

841
    def _resolve_function_io(
2✔
842
        self, function: FunctionCatalogEntry
843
    ) -> List[FunctionIOCatalogEntry]:
844
        """Private method that resolves the input/output definitions for a given function.
845
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
846

847
        Args:
848
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
849

850
        Returns:
851
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
852
            output definitions for the function.
853

854
        Raises:
855
            RuntimeError: If an error occurs while resolving the function input/output
856
            definitions.
857
        """
858
        io_list = []
2✔
859
        try:
2✔
860
            if self.node.inputs:
2✔
861
                io_list.extend(self.node.inputs)
2✔
862
            else:
863
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
864
                io_list.extend(
2✔
865
                    load_io_from_function_decorators(function, is_input=True)
866
                )
867

868
            if self.node.outputs:
2✔
869
                io_list.extend(self.node.outputs)
2✔
870
            else:
871
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
872
                io_list.extend(
2✔
873
                    load_io_from_function_decorators(function, is_input=False)
874
                )
875

876
        except FunctionIODefinitionError as e:
877
            err_msg = (
878
                f"Error creating function, input/output definition incorrect: {str(e)}"
879
            )
880
            logger.error(err_msg)
881
            raise RuntimeError(err_msg)
882

883
        return io_list
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc