• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #848

09 Nov 2023 05:17AM UTC coverage: 71.393% (-0.07%) from 71.463%
#848

push

circleci

Jineet Desai
Add changes to training times. Add scores and training times for sklearn LR as well.

9294 of 13018 relevant lines covered (71.39%)

0.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

20.61
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import contextlib
1✔
16
import hashlib
1✔
17
import locale
1✔
18
import os
1✔
19
import pickle
1✔
20
import re
1✔
21
import time
1✔
22
from pathlib import Path
1✔
23
from typing import Dict, List
×
24

25
import pandas as pd
×
26

27
from evadb.catalog.catalog_utils import get_metadata_properties
1✔
28
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
1✔
29
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
1✔
30
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
1✔
31
from evadb.configuration.constants import (
1✔
32
    DEFAULT_TRAIN_REGRESSION_METRIC,
33
    DEFAULT_TRAIN_TIME_LIMIT,
34
    DEFAULT_XGBOOST_TASK,
35
    EvaDB_INSTALLATION_DIR,
36
)
37
from evadb.database import EvaDBDatabase
1✔
38
from evadb.executor.abstract_executor import AbstractExecutor
1✔
39
from evadb.functions.decorators.utils import load_io_from_function_decorators
1✔
40
from evadb.models.storage.batch import Batch
1✔
41
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
42
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
1✔
43
from evadb.utils.errors import FunctionIODefinitionError
1✔
44
from evadb.utils.generic_utils import (
1✔
45
    load_function_class_from_file,
46
    string_comparison_case_insensitive,
47
    try_to_import_ludwig,
48
    try_to_import_neuralforecast,
49
    try_to_import_sklearn,
50
    try_to_import_statsforecast,
51
    try_to_import_torch,
52
    try_to_import_ultralytics,
53
    try_to_import_xgboost,
54
)
55
from evadb.utils.logging_manager import logger
×
56

57

58
# From https://stackoverflow.com/a/34333710
59
@contextlib.contextmanager
1✔
60
def set_env(**environ):
×
61
    """
62
    Temporarily set the process environment variables.
63

64
    >>> with set_env(PLUGINS_DIR='test/plugins'):
65
    ...   "PLUGINS_DIR" in os.environ
66
    True
67

68
    >>> "PLUGINS_DIR" in os.environ
69
    False
70

71
    :type environ: dict[str, unicode]
72
    :param environ: Environment variables to set
73
    """
74
    old_environ = dict(os.environ)
×
75
    os.environ.update(environ)
×
76
    try:
×
77
        yield
×
78
    finally:
79
        os.environ.clear()
×
80
        os.environ.update(old_environ)
×
81

82

83
class CreateFunctionExecutor(AbstractExecutor):
1✔
84
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
1✔
85
        super().__init__(db, node)
1✔
86
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
×
87

88
    def handle_huggingface_function(self):
×
89
        """Handle HuggingFace functions
90

91
        HuggingFace functions are special functions that are not loaded from a file.
92
        So we do not need to call the setup method on them like we do for other functions.
93
        """
94
        # We need at least one deep learning framework for HuggingFace
95
        # Torch or Tensorflow
96
        try_to_import_torch()
×
97
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
98
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
99
        return (
×
100
            self.node.name,
101
            impl_path,
102
            self.node.function_type,
103
            io_list,
104
            self.node.metadata,
105
        )
106

107
    def handle_ludwig_function(self):
×
108
        """Handle ludwig functions
109

110
        Use Ludwig's auto_train engine to train/tune models.
111
        """
112
        try_to_import_ludwig()
×
113
        from ludwig.automl import auto_train
×
114

115
        assert (
×
116
            len(self.children) == 1
117
        ), "Create ludwig function expects 1 child, finds {}.".format(
118
            len(self.children)
119
        )
120

121
        aggregated_batch_list = []
×
122
        child = self.children[0]
×
123
        for batch in child.exec():
×
124
            aggregated_batch_list.append(batch)
×
125
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
126
        aggregated_batch.drop_column_alias()
×
127

128
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
129
        start_time = int(time.time())
×
130
        auto_train_results = auto_train(
×
131
            dataset=aggregated_batch.frames,
132
            target=arg_map["predict"],
133
            tune_for_memory=arg_map.get("tune_for_memory", False),
134
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
135
            output_directory=self.db.catalog().get_configuration_catalog_value(
136
                "tmp_dir"
137
            ),
138
        )
139
        train_time = int(time.time()) - start_time
×
140
        model_path = os.path.join(
×
141
            self.db.catalog().get_configuration_catalog_value("model_dir"),
142
            self.node.name,
143
        )
144
        auto_train_results.best_model.save(model_path)
×
145
        best_score = auto_train_results.experiment_analysis.best_result["metric_score"]
×
146
        self.node.metadata.append(
×
147
            FunctionMetadataCatalogEntry("model_path", model_path)
148
        )
149

150
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
151
        io_list = self._resolve_function_io(None)
×
152
        return (
1✔
153
            self.node.name,
154
            impl_path,
155
            self.node.function_type,
156
            io_list,
157
            self.node.metadata,
158
            best_score,
159
            train_time,
160
        )
161

162
    def handle_sklearn_function(self):
×
163
        """Handle sklearn functions
164

165
        Use Sklearn's regression to train models.
166
        """
167
        try_to_import_sklearn()
×
168
        from sklearn.linear_model import LinearRegression
×
169

170
        assert (
×
171
            len(self.children) == 1
172
        ), "Create sklearn function expects 1 child, finds {}.".format(
173
            len(self.children)
174
        )
175

176
        aggregated_batch_list = []
×
177
        child = self.children[0]
×
178
        for batch in child.exec():
×
179
            aggregated_batch_list.append(batch)
×
180
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
181
        aggregated_batch.drop_column_alias()
×
182

183
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
184
        model = LinearRegression()
×
185
        Y = aggregated_batch.frames[arg_map["predict"]]
×
186
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
187
        start_time = int(time.time())
×
188
        model.fit(X=aggregated_batch.frames, y=Y)
×
189
        train_time = int(time.time()) - start_time
×
190
        score = model.score(X=aggregated_batch.frames, y=Y)
×
191
        model_path = os.path.join(
×
192
            self.db.catalog().get_configuration_catalog_value("model_dir"),
193
            self.node.name,
194
        )
195
        pickle.dump(model, open(model_path, "wb"))
×
196
        self.node.metadata.append(
×
197
            FunctionMetadataCatalogEntry("model_path", model_path)
198
        )
199
        # Pass the prediction column name to sklearn.py
200
        self.node.metadata.append(
×
201
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
202
        )
203

204
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
205
        io_list = self._resolve_function_io(None)
1✔
206
        return (
1✔
207
            self.node.name,
208
            impl_path,
209
            self.node.function_type,
210
            io_list,
211
            self.node.metadata,
212
            score,
213
            train_time,
214
        )
215

216
    def convert_to_numeric(self, x):
×
217
        x = re.sub("[^0-9.,]", "", str(x))
×
218
        locale.setlocale(locale.LC_ALL, "")
×
219
        x = float(locale.atof(x))
×
220
        if x.is_integer():
×
221
            return int(x)
×
222
        else:
223
            return x
×
224

225
    def handle_xgboost_function(self):
×
226
        """Handle xgboost functions
227

228
        We use the Flaml AutoML model for training xgboost models.
229
        """
230
        try_to_import_xgboost()
×
231

232
        assert (
×
233
            len(self.children) == 1
234
        ), "Create sklearn function expects 1 child, finds {}.".format(
235
            len(self.children)
236
        )
237

238
        aggregated_batch_list = []
×
239
        child = self.children[0]
×
240
        for batch in child.exec():
×
241
            aggregated_batch_list.append(batch)
×
242
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
243
        aggregated_batch.drop_column_alias()
×
244

245
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
246
        from flaml import AutoML
×
247

248
        model = AutoML()
×
249
        settings = {
×
250
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
251
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
252
            "estimator_list": ["xgboost"],
253
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
254
        }
255
        start_time = int(time.time())
×
256
        model.fit(
×
257
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
258
        )
259
        train_time = int(time.time()) - start_time
×
260
        model_path = os.path.join(
×
261
            self.db.catalog().get_configuration_catalog_value("model_dir"),
262
            self.node.name,
263
        )
264
        pickle.dump(model, open(model_path, "wb"))
×
265
        self.node.metadata.append(
×
266
            FunctionMetadataCatalogEntry("model_path", model_path)
267
        )
268
        # Pass the prediction column to xgboost.py.
269
        self.node.metadata.append(
×
270
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
271
        )
272

273
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
×
274
        io_list = self._resolve_function_io(None)
1✔
275
        best_score = model.best_loss
×
276
        return (
1✔
277
            self.node.name,
278
            impl_path,
279
            self.node.function_type,
280
            io_list,
281
            self.node.metadata,
282
            best_score,
283
            train_time,
284
        )
285

286
    def handle_ultralytics_function(self):
1✔
287
        """Handle Ultralytics functions"""
288
        try_to_import_ultralytics()
1✔
289

290
        impl_path = (
1✔
291
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
292
        )
293
        function = self._try_initializing_function(
1✔
294
            impl_path, function_args=get_metadata_properties(self.node)
295
        )
296
        io_list = self._resolve_function_io(function)
×
297
        return (
×
298
            self.node.name,
299
            impl_path,
300
            self.node.function_type,
301
            io_list,
302
            self.node.metadata,
303
        )
304

305
    def handle_forecasting_function(self):
×
306
        """Handle forecasting functions"""
307
        aggregated_batch_list = []
×
308
        child = self.children[0]
×
309
        for batch in child.exec():
×
310
            aggregated_batch_list.append(batch)
×
311
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
312
        aggregated_batch.drop_column_alias()
×
313

314
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
315
        if not self.node.impl_path:
×
316
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
317
        else:
318
            impl_path = self.node.impl_path.absolute().as_posix()
×
319
        library = "statsforecast"
×
320
        supported_libraries = ["statsforecast", "neuralforecast"]
×
321

322
        if "horizon" not in arg_map.keys():
×
323
            raise ValueError(
324
                "Horizon must be provided while creating function of type FORECASTING"
325
            )
326
        try:
×
327
            horizon = int(arg_map["horizon"])
×
328
        except Exception as e:
329
            err_msg = f"{str(e)}. HORIZON must be integral."
330
            logger.error(err_msg)
331
            raise FunctionIODefinitionError(err_msg)
332

333
        if "library" in arg_map.keys():
×
334
            try:
×
335
                assert arg_map["library"].lower() in supported_libraries
×
336
            except Exception:
337
                err_msg = (
338
                    "EvaDB currently supports " + str(supported_libraries) + " only."
339
                )
340
                logger.error(err_msg)
341
                raise FunctionIODefinitionError(err_msg)
342

343
            library = arg_map["library"].lower()
×
344

345
        """
×
346
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
347
        - The unique_id (string, int or category) represents an identifier for the series.
348
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
349
        - The y (numeric) represents the measurement we wish to forecast.
350
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
351
        """
352
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
353
        if "time" in arg_map.keys():
×
354
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
355
        if "id" in arg_map.keys():
×
356
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
357

358
        data = aggregated_batch.frames
×
359
        if "unique_id" not in list(data.columns):
×
360
            data["unique_id"] = [1 for x in range(len(data))]
×
361

362
        if "ds" not in list(data.columns):
×
363
            data["ds"] = [x + 1 for x in range(len(data))]
×
364

365
        """
×
366
            Set or infer data frequency
367
        """
368

369
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
370
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
371
        frequency = arg_map["frequency"]
×
372
        if frequency is None:
×
373
            raise RuntimeError(
374
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
375
            )
376

377
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
378
            "H": 24,
379
            "M": 12,
380
            "Q": 4,
381
            "SM": 24,
382
            "BM": 12,
383
            "BMS": 12,
384
            "BQ": 4,
385
            "BH": 24,
386
        }
387

388
        new_freq = (
×
389
            frequency.split("-")[0] if "-" in frequency else frequency
390
        )  # shortens longer frequencies like Q-DEC
391
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
392

393
        """
×
394
            Neuralforecast implementation
395
        """
396
        if library == "neuralforecast":
×
397
            try_to_import_neuralforecast()
×
398
            from neuralforecast import NeuralForecast
×
399
            from neuralforecast.auto import AutoNBEATS, AutoNHITS
×
400
            from neuralforecast.models import NBEATS, NHITS
×
401

402
            model_dict = {
×
403
                "AutoNBEATS": AutoNBEATS,
404
                "AutoNHITS": AutoNHITS,
405
                "NBEATS": NBEATS,
406
                "NHITS": NHITS,
407
            }
408

409
            if "model" not in arg_map.keys():
×
410
                arg_map["model"] = "NBEATS"
×
411

412
            if "auto" not in arg_map.keys() or (
×
413
                arg_map["auto"].lower()[0] == "t"
414
                and "auto" not in arg_map["model"].lower()
415
            ):
416
                arg_map["model"] = "Auto" + arg_map["model"]
×
417

418
            try:
×
419
                model_here = model_dict[arg_map["model"]]
×
420
            except Exception:
421
                err_msg = "Supported models: " + str(model_dict.keys())
422
                logger.error(err_msg)
423
                raise FunctionIODefinitionError(err_msg)
424
            model_args = {}
×
425

426
            if "auto" not in arg_map["model"].lower():
×
427
                model_args["input_size"] = 2 * horizon
×
428
                model_args["early_stop_patience_steps"] = 20
×
429
            else:
430
                model_args_config = {
×
431
                    "input_size": 2 * horizon,
432
                    "early_stop_patience_steps": 20,
433
                }
434

435
            if len(data.columns) >= 4:
×
436
                exogenous_columns = [
×
437
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
438
                ]
439
                if "auto" not in arg_map["model"].lower():
×
440
                    model_args["hist_exog_list"] = exogenous_columns
×
441
                else:
442
                    model_args_config["hist_exog_list"] = exogenous_columns
×
443

444
                    def get_optuna_config(trial):
×
445
                        return model_args_config
×
446

447
                    model_args["config"] = get_optuna_config
×
448
                    model_args["backend"] = "optuna"
×
449

450
            model_args["h"] = horizon
×
451

452
            model = NeuralForecast(
×
453
                [model_here(**model_args)],
454
                freq=new_freq,
455
            )
456

457
        # """
458
        #     Statsforecast implementation
459
        # """
460
        else:
461
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
462
                raise RuntimeError(
463
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
464
                )
465
            try_to_import_statsforecast()
×
466
            from statsforecast import StatsForecast
×
467
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
468

469
            model_dict = {
×
470
                "AutoARIMA": AutoARIMA,
471
                "AutoCES": AutoCES,
472
                "AutoETS": AutoETS,
473
                "AutoTheta": AutoTheta,
474
            }
475

476
            if "model" not in arg_map.keys():
×
477
                arg_map["model"] = "ARIMA"
×
478

479
            if "auto" not in arg_map["model"].lower():
×
480
                arg_map["model"] = "Auto" + arg_map["model"]
×
481

482
            try:
×
483
                model_here = model_dict[arg_map["model"]]
×
484
            except Exception:
485
                err_msg = "Supported models: " + str(model_dict.keys())
486
                logger.error(err_msg)
487
                raise FunctionIODefinitionError(err_msg)
488

489
            model = StatsForecast(
×
490
                [model_here(season_length=season_length)], freq=new_freq
491
            )
492

493
        data["ds"] = pd.to_datetime(data["ds"])
×
494

495
        model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
×
496
        if len(data.columns) >= 4 and library == "neuralforecast":
×
497
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
498

499
        model_dir = os.path.join(
×
500
            self.db.catalog().get_configuration_catalog_value("model_dir"),
501
            "tsforecasting",
502
            model_save_dir_name,
503
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
504
        )
505
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
506

507
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
508

509
        model_path = os.path.join(model_dir, model_save_name)
×
510

511
        existing_model_files = sorted(
×
512
            os.listdir(model_dir),
513
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
514
        )
515
        existing_model_files = [
×
516
            x
517
            for x in existing_model_files
518
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
519
        ]
520
        if len(existing_model_files) == 0:
×
521
            logger.info("Training, please wait...")
×
522
            for column in data.columns:
×
523
                if column != "ds" and column != "unique_id":
×
524
                    data[column] = data.apply(
×
525
                        lambda x: self.convert_to_numeric(x[column]), axis=1
526
                    )
527
            if library == "neuralforecast":
×
528
                cuda_devices_here = "0"
×
529
                if "CUDA_VISIBLE_DEVICES" in os.environ:
×
530
                    cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
×
531

532
                with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
×
533
                    model.fit(df=data, val_size=horizon)
×
534
                    model.save(model_path, overwrite=True)
×
535
            else:
536
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
537
                for col in data["unique_id"].unique():
×
538
                    if len(data[data["unique_id"] == col]) == 1:
×
539
                        data = data._append(
×
540
                            [data[data["unique_id"] == col]], ignore_index=True
541
                        )
542

543
                model.fit(df=data[["ds", "y", "unique_id"]])
×
544
                f = open(model_path, "wb")
×
545
                pickle.dump(model, f)
×
546
                f.close()
×
547
        elif not Path(model_path).exists():
×
548
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
549

550
        io_list = self._resolve_function_io(None)
×
551

552
        metadata_here = [
1✔
553
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
554
            FunctionMetadataCatalogEntry("model_path", model_path),
555
            FunctionMetadataCatalogEntry(
556
                "predict_column_rename", arg_map.get("predict", "y")
557
            ),
558
            FunctionMetadataCatalogEntry(
559
                "time_column_rename", arg_map.get("time", "ds")
560
            ),
561
            FunctionMetadataCatalogEntry(
562
                "id_column_rename", arg_map.get("id", "unique_id")
563
            ),
564
            FunctionMetadataCatalogEntry("horizon", horizon),
565
            FunctionMetadataCatalogEntry("library", library),
566
        ]
567

568
        return (
1✔
569
            self.node.name,
570
            impl_path,
571
            self.node.function_type,
572
            io_list,
573
            metadata_here,
574
        )
575

576
    def handle_generic_function(self):
1✔
577
        """Handle generic functions
578

579
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
580
        """
581
        impl_path = self.node.impl_path.absolute().as_posix()
1✔
582
        function = self._try_initializing_function(impl_path)
×
583
        io_list = self._resolve_function_io(function)
×
584

585
        return (
1✔
586
            self.node.name,
587
            impl_path,
588
            self.node.function_type,
589
            io_list,
590
            self.node.metadata,
591
        )
592

593
    def exec(self, *args, **kwargs):
1✔
594
        """Create function executor
595

596
        Calls the catalog to insert a function catalog entry.
597
        """
598
        assert (
1✔
599
            self.node.if_not_exists and self.node.or_replace
600
        ) is False, (
601
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
602
        )
603

604
        overwrite = False
×
605
        best_score = False
1✔
606
        train_time = False
1✔
607
        # check catalog if it already has this function entry
608
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
×
609
            if self.node.if_not_exists:
×
610
                msg = f"Function {self.node.name} already exists, nothing added."
×
611
                yield Batch(pd.DataFrame([msg]))
1✔
612
                return
×
613
            elif self.node.or_replace:
×
614
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
615
                from evadb.executor.drop_object_executor import DropObjectExecutor
×
616

617
                drop_executor = DropObjectExecutor(self.db, None)
×
618
                try:
1✔
619
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
×
620
                except RuntimeError:
621
                    pass
622
                else:
623
                    overwrite = True
×
624
            else:
625
                msg = f"Function {self.node.name} already exists."
×
626
                logger.error(msg)
1✔
627
                raise RuntimeError(msg)
628

629
        # if it's a type of HuggingFaceModel, override the impl_path
630
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
1✔
631
            (
1✔
632
                name,
633
                impl_path,
634
                function_type,
635
                io_list,
636
                metadata,
637
            ) = self.handle_huggingface_function()
638
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
×
639
            (
1✔
640
                name,
641
                impl_path,
642
                function_type,
643
                io_list,
644
                metadata,
645
            ) = self.handle_ultralytics_function()
646
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
×
647
            (
1✔
648
                name,
649
                impl_path,
650
                function_type,
651
                io_list,
652
                metadata,
653
                best_score,
654
                train_time,
655
            ) = self.handle_ludwig_function()
656
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
×
657
            (
1✔
658
                name,
659
                impl_path,
660
                function_type,
661
                io_list,
662
                metadata,
663
                best_score,
664
                train_time,
665
            ) = self.handle_sklearn_function()
666
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
×
667
            (
1✔
668
                name,
669
                impl_path,
670
                function_type,
671
                io_list,
672
                metadata,
673
                best_score,
674
                train_time,
675
            ) = self.handle_xgboost_function()
676
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
×
677
            (
1✔
678
                name,
679
                impl_path,
680
                function_type,
681
                io_list,
682
                metadata,
683
            ) = self.handle_forecasting_function()
684
        else:
685
            (
1✔
686
                name,
687
                impl_path,
688
                function_type,
689
                io_list,
690
                metadata,
691
            ) = self.handle_generic_function()
692

693
        self.catalog().insert_function_catalog_entry(
×
694
            name, impl_path, function_type, io_list, metadata
695
        )
696

697
        if overwrite:
×
698
            msg = f"Function {self.node.name} overwritten."
1✔
699
        else:
700
            msg = f"Function {self.node.name} added to the database."
1✔
701
        if best_score and train_time:
×
702
            yield Batch(
×
703
                pd.DataFrame(
704
                    [
705
                        msg,
706
                        "Validation Score: " + str(best_score),
707
                        "Training time: " + str(train_time) + " secs.",
708
                    ]
709
                )
710
            )
711
        else:
712
            yield Batch(pd.DataFrame([msg]))
×
713

714
    def _try_initializing_function(
1✔
715
        self, impl_path: str, function_args: Dict = {}
716
    ) -> FunctionCatalogEntry:
717
        """Attempts to initialize function given the implementation file path and arguments.
718

719
        Args:
720
            impl_path (str): The file path of the function implementation file.
721
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
722

723
        Returns:
724
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
725

726
        Raises:
727
            RuntimeError: If an error occurs while initializing the function.
728
        """
729

730
        # load the function class from the file
731
        try:
×
732
            # loading the function class from the file
733
            function = load_function_class_from_file(impl_path, self.node.name)
×
734
            # initializing the function class calls the setup method internally
735
            function(**function_args)
×
736
        except Exception as e:
737
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
738
            # logger.error(err_msg)
739
            raise RuntimeError(err_msg)
740

741
        return function
×
742

743
    def _resolve_function_io(
1✔
744
        self, function: FunctionCatalogEntry
745
    ) -> List[FunctionIOCatalogEntry]:
746
        """Private method that resolves the input/output definitions for a given function.
747
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
748

749
        Args:
750
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
751

752
        Returns:
753
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
754
            output definitions for the function.
755

756
        Raises:
757
            RuntimeError: If an error occurs while resolving the function input/output
758
            definitions.
759
        """
760
        io_list = []
×
761
        try:
×
762
            if self.node.inputs:
1✔
763
                io_list.extend(self.node.inputs)
1✔
764
            else:
765
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
766
                io_list.extend(
1✔
767
                    load_io_from_function_decorators(function, is_input=True)
768
                )
769

770
            if self.node.outputs:
×
771
                io_list.extend(self.node.outputs)
×
772
            else:
773
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
774
                io_list.extend(
×
775
                    load_io_from_function_decorators(function, is_input=False)
776
                )
777

778
        except FunctionIODefinitionError as e:
779
            err_msg = (
780
                f"Error creating function, input/output definition incorrect: {str(e)}"
781
            )
782
            logger.error(err_msg)
783
            raise RuntimeError(err_msg)
784

785
        return io_list
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc