• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / d88c9094-1a92-4944-91ea-5e7e8275cb11

23 Nov 2023 10:19PM UTC coverage: 67.117% (+67.1%) from 0.0%
d88c9094-1a92-4944-91ea-5e7e8275cb11

push

circleci

web-flow
Merge branch 'georgia-tech-db:staging' into cost_batching

342 of 692 new or added lines in 47 files covered. (49.42%)

12 existing lines in 4 files now uncovered.

9189 of 13691 relevant lines covered (67.12%)

0.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

24.92
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import contextlib
1✔
16
import hashlib
1✔
17
import locale
1✔
18
import os
1✔
19
import pickle
1✔
20
import re
1✔
21
import time
1✔
22
from pathlib import Path
1✔
23
from typing import Dict, List
1✔
24

25
import numpy as np
1✔
26
import pandas as pd
1✔
27

28
from evadb.catalog.catalog_utils import get_metadata_properties
1✔
29
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
1✔
30
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
1✔
31
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
1✔
32
from evadb.configuration.constants import (
1✔
33
    DEFAULT_SKLEARN_TRAIN_MODEL,
34
    DEFAULT_TRAIN_REGRESSION_METRIC,
35
    DEFAULT_TRAIN_TIME_LIMIT,
36
    DEFAULT_XGBOOST_TASK,
37
    SKLEARN_SUPPORTED_MODELS,
38
    EvaDB_INSTALLATION_DIR,
39
)
40
from evadb.database import EvaDBDatabase
1✔
41
from evadb.executor.abstract_executor import AbstractExecutor
1✔
42
from evadb.functions.decorators.utils import load_io_from_function_decorators
1✔
43
from evadb.models.storage.batch import Batch
1✔
44
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
45
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
1✔
46
from evadb.utils.errors import FunctionIODefinitionError
1✔
47
from evadb.utils.generic_utils import (
1✔
48
    load_function_class_from_file,
49
    string_comparison_case_insensitive,
50
    try_to_import_flaml_automl,
51
    try_to_import_ludwig,
52
    try_to_import_neuralforecast,
53
    try_to_import_statsforecast,
54
    try_to_import_torch,
55
    try_to_import_ultralytics,
56
)
57
from evadb.utils.logging_manager import logger
1✔
58

59

60
def root_mean_squared_error(y_true, y_pred):
1✔
NEW
61
    return np.sqrt(np.mean(np.square(y_pred - y_true)))
×
62

63

64
# From https://stackoverflow.com/a/34333710
65
@contextlib.contextmanager
1✔
66
def set_env(**environ):
1✔
67
    """
68
    Temporarily set the process environment variables.
69

70
    >>> with set_env(PLUGINS_DIR='test/plugins'):
71
    ...   "PLUGINS_DIR" in os.environ
72
    True
73

74
    >>> "PLUGINS_DIR" in os.environ
75
    False
76

77
    :type environ: dict[str, unicode]
78
    :param environ: Environment variables to set
79
    """
80
    old_environ = dict(os.environ)
×
81
    os.environ.update(environ)
×
82
    try:
×
83
        yield
×
84
    finally:
85
        os.environ.clear()
×
86
        os.environ.update(old_environ)
×
87

88

89
class CreateFunctionExecutor(AbstractExecutor):
1✔
90
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
1✔
91
        super().__init__(db, node)
1✔
92
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
1✔
93

94
    def handle_huggingface_function(self):
1✔
95
        """Handle HuggingFace functions
96

97
        HuggingFace functions are special functions that are not loaded from a file.
98
        So we do not need to call the setup method on them like we do for other functions.
99
        """
100
        # We need at least one deep learning framework for HuggingFace
101
        # Torch or Tensorflow
102
        try_to_import_torch()
×
103
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
104
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
105
        return (
×
106
            self.node.name,
107
            impl_path,
108
            self.node.function_type,
109
            io_list,
110
            self.node.metadata,
111
        )
112

113
    def handle_ludwig_function(self):
1✔
114
        """Handle ludwig functions
115

116
        Use Ludwig's auto_train engine to train/tune models.
117
        """
118
        try_to_import_ludwig()
×
119
        from ludwig.automl import auto_train
×
120

121
        assert (
×
122
            len(self.children) == 1
123
        ), "Create ludwig function expects 1 child, finds {}.".format(
124
            len(self.children)
125
        )
126

127
        aggregated_batch_list = []
×
128
        child = self.children[0]
×
129
        for batch in child.exec():
×
130
            aggregated_batch_list.append(batch)
×
131
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
132
        aggregated_batch.drop_column_alias()
×
133

134
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
NEW
135
        start_time = int(time.time())
×
UNCOV
136
        auto_train_results = auto_train(
×
137
            dataset=aggregated_batch.frames,
138
            target=arg_map["predict"],
139
            tune_for_memory=arg_map.get("tune_for_memory", False),
140
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
141
            output_directory=self.db.catalog().get_configuration_catalog_value(
142
                "tmp_dir"
143
            ),
144
        )
NEW
145
        train_time = int(time.time()) - start_time
×
UNCOV
146
        model_path = os.path.join(
×
147
            self.db.catalog().get_configuration_catalog_value("model_dir"),
148
            self.node.name,
149
        )
150
        auto_train_results.best_model.save(model_path)
×
NEW
151
        best_score = auto_train_results.experiment_analysis.best_result["metric_score"]
×
UNCOV
152
        self.node.metadata.append(
×
153
            FunctionMetadataCatalogEntry("model_path", model_path)
154
        )
155

156
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
157
        io_list = self._resolve_function_io(None)
×
158
        return (
×
159
            self.node.name,
160
            impl_path,
161
            self.node.function_type,
162
            io_list,
163
            self.node.metadata,
164
            best_score,
165
            train_time,
166
        )
167

168
    def handle_sklearn_function(self):
1✔
169
        """Handle sklearn functions
170

171
        Use Sklearn's regression to train models.
172
        """
NEW
173
        try_to_import_flaml_automl()
×
174

175
        assert (
×
176
            len(self.children) == 1
177
        ), "Create sklearn function expects 1 child, finds {}.".format(
178
            len(self.children)
179
        )
180

181
        aggregated_batch_list = []
×
182
        child = self.children[0]
×
183
        for batch in child.exec():
×
184
            aggregated_batch_list.append(batch)
×
185
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
186
        aggregated_batch.drop_column_alias()
×
187

188
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
NEW
189
        from flaml import AutoML
×
190

NEW
191
        model = AutoML()
×
NEW
192
        sklearn_model = arg_map.get("model", DEFAULT_SKLEARN_TRAIN_MODEL)
×
NEW
193
        if sklearn_model not in SKLEARN_SUPPORTED_MODELS:
×
194
            raise ValueError(
195
                f"Sklearn Model {sklearn_model} provided as input is not supported."
196
            )
NEW
197
        settings = {
×
198
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
199
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
200
            "estimator_list": [sklearn_model],
201
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
202
        }
NEW
203
        start_time = int(time.time())
×
NEW
204
        model.fit(
×
205
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
206
        )
NEW
207
        train_time = int(time.time()) - start_time
×
NEW
208
        score = model.best_loss
×
UNCOV
209
        model_path = os.path.join(
×
210
            self.db.catalog().get_configuration_catalog_value("model_dir"),
211
            self.node.name,
212
        )
213
        pickle.dump(model, open(model_path, "wb"))
×
214
        self.node.metadata.append(
×
215
            FunctionMetadataCatalogEntry("model_path", model_path)
216
        )
217
        # Pass the prediction column name to sklearn.py
218
        self.node.metadata.append(
×
219
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
220
        )
221

222
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
223
        io_list = self._resolve_function_io(None)
×
224
        return (
×
225
            self.node.name,
226
            impl_path,
227
            self.node.function_type,
228
            io_list,
229
            self.node.metadata,
230
            score,
231
            train_time,
232
        )
233

234
    def convert_to_numeric(self, x):
1✔
235
        x = re.sub("[^0-9.,]", "", str(x))
×
236
        locale.setlocale(locale.LC_ALL, "")
×
237
        x = float(locale.atof(x))
×
238
        if x.is_integer():
×
239
            return int(x)
×
240
        else:
241
            return x
×
242

243
    def handle_xgboost_function(self):
1✔
244
        """Handle xgboost functions
245

246
        We use the Flaml AutoML model for training xgboost models.
247
        """
NEW
248
        try_to_import_flaml_automl()
×
249

250
        assert (
×
251
            len(self.children) == 1
252
        ), "Create sklearn function expects 1 child, finds {}.".format(
253
            len(self.children)
254
        )
255

256
        aggregated_batch_list = []
×
257
        child = self.children[0]
×
258
        for batch in child.exec():
×
259
            aggregated_batch_list.append(batch)
×
260
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
261
        aggregated_batch.drop_column_alias()
×
262

263
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
264
        from flaml import AutoML
×
265

266
        model = AutoML()
×
267
        settings = {
×
268
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
269
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
270
            "estimator_list": ["xgboost"],
271
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
272
        }
NEW
273
        start_time = int(time.time())
×
UNCOV
274
        model.fit(
×
275
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
276
        )
NEW
277
        train_time = int(time.time()) - start_time
×
UNCOV
278
        model_path = os.path.join(
×
279
            self.db.catalog().get_configuration_catalog_value("model_dir"),
280
            self.node.name,
281
        )
282
        pickle.dump(model, open(model_path, "wb"))
×
283
        self.node.metadata.append(
×
284
            FunctionMetadataCatalogEntry("model_path", model_path)
285
        )
286
        # Pass the prediction column to xgboost.py.
287
        self.node.metadata.append(
×
288
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
289
        )
290

291
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
×
292
        io_list = self._resolve_function_io(None)
×
NEW
293
        best_score = model.best_loss
×
UNCOV
294
        return (
×
295
            self.node.name,
296
            impl_path,
297
            self.node.function_type,
298
            io_list,
299
            self.node.metadata,
300
            best_score,
301
            train_time,
302
        )
303

304
    def handle_ultralytics_function(self):
1✔
305
        """Handle Ultralytics functions"""
306
        try_to_import_ultralytics()
1✔
307

308
        impl_path = (
1✔
309
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
310
        )
311
        function = self._try_initializing_function(
1✔
312
            impl_path, function_args=get_metadata_properties(self.node)
313
        )
314
        io_list = self._resolve_function_io(function)
1✔
315
        return (
1✔
316
            self.node.name,
317
            impl_path,
318
            self.node.function_type,
319
            io_list,
320
            self.node.metadata,
321
        )
322

323
    def handle_forecasting_function(self):
1✔
324
        """Handle forecasting functions"""
325
        aggregated_batch_list = []
×
326
        child = self.children[0]
×
327
        for batch in child.exec():
×
328
            aggregated_batch_list.append(batch)
×
329
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
330
        aggregated_batch.drop_column_alias()
×
331

332
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
333
        if not self.node.impl_path:
×
334
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
335
        else:
336
            impl_path = self.node.impl_path.absolute().as_posix()
×
337
        library = "statsforecast"
×
338
        supported_libraries = ["statsforecast", "neuralforecast"]
×
339

340
        if "horizon" not in arg_map.keys():
×
341
            raise ValueError(
342
                "Horizon must be provided while creating function of type FORECASTING"
343
            )
344
        try:
×
345
            horizon = int(arg_map["horizon"])
×
346
        except Exception as e:
347
            err_msg = f"{str(e)}. HORIZON must be integral."
348
            logger.error(err_msg)
349
            raise FunctionIODefinitionError(err_msg)
350

351
        if "library" in arg_map.keys():
×
352
            try:
×
353
                assert arg_map["library"].lower() in supported_libraries
×
354
            except Exception:
355
                err_msg = (
356
                    "EvaDB currently supports " + str(supported_libraries) + " only."
357
                )
358
                logger.error(err_msg)
359
                raise FunctionIODefinitionError(err_msg)
360

361
            library = arg_map["library"].lower()
×
362

363
        """
×
364
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
365
        - The unique_id (string, int or category) represents an identifier for the series.
366
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
367
        - The y (numeric) represents the measurement we wish to forecast.
368
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
369
        """
370
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
371
        if "time" in arg_map.keys():
×
372
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
373
        if "id" in arg_map.keys():
×
374
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
NEW
375
        if "conf" in arg_map.keys():
×
NEW
376
            try:
×
NEW
377
                conf = round(arg_map["conf"])
×
378
            except Exception:
379
                err_msg = "Confidence must be a number."
380
                logger.error(err_msg)
381
                raise FunctionIODefinitionError(err_msg)
382
        else:
NEW
383
            conf = 90
×
384

NEW
385
        if conf > 100:
×
NEW
386
            err_msg = "Confidence must <= 100."
×
NEW
387
            logger.error(err_msg)
×
388
            raise FunctionIODefinitionError(err_msg)
389

390
        data = aggregated_batch.frames
×
391
        if "unique_id" not in list(data.columns):
×
392
            data["unique_id"] = [1 for x in range(len(data))]
×
393

394
        if "ds" not in list(data.columns):
×
395
            data["ds"] = [x + 1 for x in range(len(data))]
×
396

397
        """
×
398
            Set or infer data frequency
399
        """
400

401
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
402
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
403
        frequency = arg_map["frequency"]
×
404
        if frequency is None:
×
405
            raise RuntimeError(
406
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
407
            )
408

409
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
410
            "H": 24,
411
            "M": 12,
412
            "Q": 4,
413
            "SM": 24,
414
            "BM": 12,
415
            "BMS": 12,
416
            "BQ": 4,
417
            "BH": 24,
418
        }
419

420
        new_freq = (
×
421
            frequency.split("-")[0] if "-" in frequency else frequency
422
        )  # shortens longer frequencies like Q-DEC
423
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
424

425
        """
×
426
            Neuralforecast implementation
427
        """
428
        if library == "neuralforecast":
×
429
            try_to_import_neuralforecast()
×
430
            from neuralforecast import NeuralForecast
×
NEW
431
            from neuralforecast.auto import (
×
432
                AutoDeepAR,
433
                AutoFEDformer,
434
                AutoInformer,
435
                AutoNBEATS,
436
                AutoNHITS,
437
                AutoPatchTST,
438
                AutoTFT,
439
            )
440

441
            # from neuralforecast.auto import AutoAutoformer as AutoAFormer
NEW
442
            from neuralforecast.losses.pytorch import MQLoss
×
NEW
443
            from neuralforecast.models import (
×
444
                NBEATS,
445
                NHITS,
446
                TFT,
447
                DeepAR,
448
                FEDformer,
449
                Informer,
450
                PatchTST,
451
            )
452

453
            # from neuralforecast.models import Autoformer as AFormer
454

455
            model_dict = {
×
456
                "AutoNBEATS": AutoNBEATS,
457
                "AutoNHITS": AutoNHITS,
458
                "NBEATS": NBEATS,
459
                "NHITS": NHITS,
460
                "PatchTST": PatchTST,
461
                "AutoPatchTST": AutoPatchTST,
462
                "DeepAR": DeepAR,
463
                "AutoDeepAR": AutoDeepAR,
464
                "FEDformer": FEDformer,
465
                "AutoFEDformer": AutoFEDformer,
466
                # "AFormer": AFormer,
467
                # "AutoAFormer": AutoAFormer,
468
                "Informer": Informer,
469
                "AutoInformer": AutoInformer,
470
                "TFT": TFT,
471
                "AutoTFT": AutoTFT,
472
            }
473

474
            if "model" not in arg_map.keys():
×
NEW
475
                arg_map["model"] = "TFT"
×
476

477
            if "auto" not in arg_map.keys() or (
×
478
                arg_map["auto"].lower()[0] == "t"
479
                and "auto" not in arg_map["model"].lower()
480
            ):
481
                arg_map["model"] = "Auto" + arg_map["model"]
×
482

483
            try:
×
484
                model_here = model_dict[arg_map["model"]]
×
485
            except Exception:
486
                err_msg = "Supported models: " + str(model_dict.keys())
487
                logger.error(err_msg)
488
                raise FunctionIODefinitionError(err_msg)
489
            model_args = {}
×
490

491
            if "auto" not in arg_map["model"].lower():
×
492
                model_args["input_size"] = 2 * horizon
×
493
                model_args["early_stop_patience_steps"] = 20
×
494
            else:
495
                model_args_config = {
×
496
                    "input_size": 2 * horizon,
497
                    "early_stop_patience_steps": 20,
498
                }
499

500
            if len(data.columns) >= 4:
×
501
                exogenous_columns = [
×
502
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
503
                ]
504
                if "auto" not in arg_map["model"].lower():
×
505
                    model_args["hist_exog_list"] = exogenous_columns
×
506
                else:
507
                    model_args_config["hist_exog_list"] = exogenous_columns
×
508

NEW
509
            if "auto" in arg_map["model"].lower():
×
510

NEW
511
                def get_optuna_config(trial):
×
NEW
512
                    return model_args_config
×
513

NEW
514
                model_args["config"] = get_optuna_config
×
NEW
515
                model_args["backend"] = "optuna"
×
516

517
            model_args["h"] = horizon
×
NEW
518
            model_args["loss"] = MQLoss(level=[conf])
×
519

520
            model = NeuralForecast(
×
521
                [model_here(**model_args)],
522
                freq=new_freq,
523
            )
524

525
        # """
526
        #     Statsforecast implementation
527
        # """
528
        else:
529
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
530
                raise RuntimeError(
531
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
532
                )
533
            try_to_import_statsforecast()
×
534
            from statsforecast import StatsForecast
×
535
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
536

537
            model_dict = {
×
538
                "AutoARIMA": AutoARIMA,
539
                "AutoCES": AutoCES,
540
                "AutoETS": AutoETS,
541
                "AutoTheta": AutoTheta,
542
            }
543

544
            if "model" not in arg_map.keys():
×
545
                arg_map["model"] = "ARIMA"
×
546

547
            if "auto" not in arg_map["model"].lower():
×
548
                arg_map["model"] = "Auto" + arg_map["model"]
×
549

550
            try:
×
551
                model_here = model_dict[arg_map["model"]]
×
552
            except Exception:
553
                err_msg = "Supported models: " + str(model_dict.keys())
554
                logger.error(err_msg)
555
                raise FunctionIODefinitionError(err_msg)
556

557
            model = StatsForecast(
×
558
                [model_here(season_length=season_length)], freq=new_freq
559
            )
560

561
        data["ds"] = pd.to_datetime(data["ds"])
×
562

NEW
563
        model_save_dir_name = (
×
564
            library + "_" + arg_map["model"] + "_" + new_freq
565
            if "statsforecast" in library
566
            else library + "_" + str(conf) + "_" + arg_map["model"] + "_" + new_freq
567
        )
568
        if len(data.columns) >= 4 and library == "neuralforecast":
×
569
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
570

571
        model_dir = os.path.join(
×
572
            self.db.catalog().get_configuration_catalog_value("model_dir"),
573
            "tsforecasting",
574
            model_save_dir_name,
575
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
576
        )
577
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
578

579
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
580

581
        model_path = os.path.join(model_dir, model_save_name)
×
582

583
        existing_model_files = sorted(
×
584
            os.listdir(model_dir),
585
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
586
        )
587
        existing_model_files = [
×
588
            x
589
            for x in existing_model_files
590
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
591
        ]
592
        if len(existing_model_files) == 0:
×
593
            logger.info("Training, please wait...")
×
594
            for column in data.columns:
×
595
                if column != "ds" and column != "unique_id":
×
596
                    data[column] = data.apply(
×
597
                        lambda x: self.convert_to_numeric(x[column]), axis=1
598
                    )
NEW
599
            rmses = []
×
600
            if library == "neuralforecast":
×
601
                cuda_devices_here = "0"
×
602
                if "CUDA_VISIBLE_DEVICES" in os.environ:
×
603
                    cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
×
604

605
                with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
×
606
                    model.fit(df=data, val_size=horizon)
×
607
                    model.save(model_path, overwrite=True)
×
NEW
608
                    if "metrics" in arg_map and arg_map["metrics"].lower()[0] == "t":
×
NEW
609
                        crossvalidation_df = model.cross_validation(
×
610
                            df=data, val_size=horizon
611
                        )
NEW
612
                        for uid in crossvalidation_df.unique_id.unique():
×
NEW
613
                            crossvalidation_df_here = crossvalidation_df[
×
614
                                crossvalidation_df.unique_id == uid
615
                            ]
NEW
616
                            rmses.append(
×
617
                                root_mean_squared_error(
618
                                    crossvalidation_df_here.y,
619
                                    crossvalidation_df_here[
620
                                        arg_map["model"] + "-median"
621
                                    ],
622
                                )
623
                                / np.mean(crossvalidation_df_here.y)
624
                            )
NEW
625
                            mean_rmse = np.mean(rmses)
×
NEW
626
                            with open(model_path + "_rmse", "w") as f:
×
NEW
627
                                f.write(str(mean_rmse) + "\n")
×
628
            else:
629
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
630
                for col in data["unique_id"].unique():
×
631
                    if len(data[data["unique_id"] == col]) == 1:
×
632
                        data = data._append(
×
633
                            [data[data["unique_id"] == col]], ignore_index=True
634
                        )
635

636
                model.fit(df=data[["ds", "y", "unique_id"]])
×
NEW
637
                hypers = ""
×
NEW
638
                if "arima" in arg_map["model"].lower():
×
NEW
639
                    from statsforecast.arima import arima_string
×
640

NEW
641
                    hypers += arima_string(model.fitted_[0, 0].model_)
×
642
                f = open(model_path, "wb")
×
643
                pickle.dump(model, f)
×
644
                f.close()
×
NEW
645
                if "metrics" not in arg_map or arg_map["metrics"].lower()[0] == "t":
×
NEW
646
                    crossvalidation_df = model.cross_validation(
×
647
                        df=data[["ds", "y", "unique_id"]],
648
                        h=horizon,
649
                        step_size=24,
650
                        n_windows=1,
651
                    ).reset_index()
NEW
652
                    for uid in crossvalidation_df.unique_id.unique():
×
NEW
653
                        crossvalidation_df_here = crossvalidation_df[
×
654
                            crossvalidation_df.unique_id == uid
655
                        ]
NEW
656
                        rmses.append(
×
657
                            root_mean_squared_error(
658
                                crossvalidation_df_here.y,
659
                                crossvalidation_df_here[arg_map["model"]],
660
                            )
661
                            / np.mean(crossvalidation_df_here.y)
662
                        )
NEW
663
                    mean_rmse = np.mean(rmses)
×
NEW
664
                    with open(model_path + "_rmse", "w") as f:
×
NEW
665
                        f.write(str(mean_rmse) + "\n")
×
NEW
666
                        f.write(hypers + "\n")
×
667
        elif not Path(model_path).exists():
×
668
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
UNCOV
669
        io_list = self._resolve_function_io(None)
×
NEW
670
        data["ds"] = data.ds.astype(str)
×
671
        metadata_here = [
×
672
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
673
            FunctionMetadataCatalogEntry("model_path", model_path),
674
            FunctionMetadataCatalogEntry(
675
                "predict_column_rename", arg_map.get("predict", "y")
676
            ),
677
            FunctionMetadataCatalogEntry(
678
                "time_column_rename", arg_map.get("time", "ds")
679
            ),
680
            FunctionMetadataCatalogEntry(
681
                "id_column_rename", arg_map.get("id", "unique_id")
682
            ),
683
            FunctionMetadataCatalogEntry("horizon", horizon),
684
            FunctionMetadataCatalogEntry("library", library),
685
            FunctionMetadataCatalogEntry("conf", conf),
686
        ]
687

688
        return (
×
689
            self.node.name,
690
            impl_path,
691
            self.node.function_type,
692
            io_list,
693
            metadata_here,
694
        )
695

696
    def handle_generic_function(self):
1✔
697
        """Handle generic functions
698

699
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
700
        """
701
        impl_path = self.node.impl_path.absolute().as_posix()
1✔
702
        function = self._try_initializing_function(impl_path)
1✔
703
        io_list = self._resolve_function_io(function)
1✔
704

705
        return (
1✔
706
            self.node.name,
707
            impl_path,
708
            self.node.function_type,
709
            io_list,
710
            self.node.metadata,
711
        )
712

713
    def exec(self, *args, **kwargs):
1✔
714
        """Create function executor
715

716
        Calls the catalog to insert a function catalog entry.
717
        """
718
        assert (
1✔
719
            self.node.if_not_exists and self.node.or_replace
720
        ) is False, (
721
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
722
        )
723

724
        overwrite = False
1✔
725
        best_score = False
1✔
726
        train_time = False
1✔
727
        # check catalog if it already has this function entry
728
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
1✔
729
            if self.node.if_not_exists:
×
730
                msg = f"Function {self.node.name} already exists, nothing added."
×
731
                yield Batch(pd.DataFrame([msg]))
×
732
                return
×
733
            elif self.node.or_replace:
×
734
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
735
                from evadb.executor.drop_object_executor import DropObjectExecutor
×
736

737
                drop_executor = DropObjectExecutor(self.db, None)
×
738
                try:
×
739
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
×
740
                except RuntimeError:
741
                    pass
742
                else:
743
                    overwrite = True
×
744
            else:
745
                msg = f"Function {self.node.name} already exists."
×
746
                logger.error(msg)
×
747
                raise RuntimeError(msg)
748

749
        # if it's a type of HuggingFaceModel, override the impl_path
750
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
1✔
751
            (
×
752
                name,
753
                impl_path,
754
                function_type,
755
                io_list,
756
                metadata,
757
            ) = self.handle_huggingface_function()
758
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
1✔
759
            (
1✔
760
                name,
761
                impl_path,
762
                function_type,
763
                io_list,
764
                metadata,
765
            ) = self.handle_ultralytics_function()
766
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
1✔
767
            (
×
768
                name,
769
                impl_path,
770
                function_type,
771
                io_list,
772
                metadata,
773
                best_score,
774
                train_time,
775
            ) = self.handle_ludwig_function()
776
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
1✔
777
            (
×
778
                name,
779
                impl_path,
780
                function_type,
781
                io_list,
782
                metadata,
783
                best_score,
784
                train_time,
785
            ) = self.handle_sklearn_function()
786
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
1✔
787
            (
×
788
                name,
789
                impl_path,
790
                function_type,
791
                io_list,
792
                metadata,
793
                best_score,
794
                train_time,
795
            ) = self.handle_xgboost_function()
796
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
1✔
797
            (
×
798
                name,
799
                impl_path,
800
                function_type,
801
                io_list,
802
                metadata,
803
            ) = self.handle_forecasting_function()
804
        else:
805
            (
1✔
806
                name,
807
                impl_path,
808
                function_type,
809
                io_list,
810
                metadata,
811
            ) = self.handle_generic_function()
812

813
        self.catalog().insert_function_catalog_entry(
1✔
814
            name, impl_path, function_type, io_list, metadata
815
        )
816

817
        if overwrite:
1✔
818
            msg = f"Function {self.node.name} overwritten."
×
819
        else:
820
            msg = f"Function {self.node.name} added to the database."
1✔
821
        if best_score and train_time:
1✔
NEW
822
            yield Batch(
×
823
                pd.DataFrame(
824
                    [
825
                        msg,
826
                        "Validation Score: " + str(best_score),
827
                        "Training time: " + str(train_time) + " secs.",
828
                    ]
829
                )
830
            )
831
        else:
832
            yield Batch(pd.DataFrame([msg]))
1✔
833

834
    def _try_initializing_function(
1✔
835
        self, impl_path: str, function_args: Dict = {}
836
    ) -> FunctionCatalogEntry:
837
        """Attempts to initialize function given the implementation file path and arguments.
838

839
        Args:
840
            impl_path (str): The file path of the function implementation file.
841
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
842

843
        Returns:
844
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
845

846
        Raises:
847
            RuntimeError: If an error occurs while initializing the function.
848
        """
849

850
        # load the function class from the file
851
        try:
1✔
852
            # loading the function class from the file
853
            function = load_function_class_from_file(impl_path, self.node.name)
1✔
854
            # initializing the function class calls the setup method internally
855
            function(**function_args)
1✔
856
        except Exception as e:
857
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
858
            # logger.error(err_msg)
859
            raise RuntimeError(err_msg)
860

861
        return function
1✔
862

863
    def _resolve_function_io(
1✔
864
        self, function: FunctionCatalogEntry
865
    ) -> List[FunctionIOCatalogEntry]:
866
        """Private method that resolves the input/output definitions for a given function.
867
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
868

869
        Args:
870
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
871

872
        Returns:
873
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
874
            output definitions for the function.
875

876
        Raises:
877
            RuntimeError: If an error occurs while resolving the function input/output
878
            definitions.
879
        """
880
        io_list = []
1✔
881
        try:
1✔
882
            if self.node.inputs:
1✔
883
                io_list.extend(self.node.inputs)
1✔
884
            else:
885
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
886
                io_list.extend(
1✔
887
                    load_io_from_function_decorators(function, is_input=True)
888
                )
889

890
            if self.node.outputs:
1✔
891
                io_list.extend(self.node.outputs)
1✔
892
            else:
893
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
894
                io_list.extend(
1✔
895
                    load_io_from_function_decorators(function, is_input=False)
896
                )
897

898
        except FunctionIODefinitionError as e:
899
            err_msg = (
900
                f"Error creating function, input/output definition incorrect: {str(e)}"
901
            )
902
            logger.error(err_msg)
903
            raise RuntimeError(err_msg)
904

905
        return io_list
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc