• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / f09e3b44-ef34-4638-a136-6dc1a5893c0a

26 Oct 2023 11:49PM UTC coverage: 92.368% (-0.8%) from 93.203%
f09e3b44-ef34-4638-a136-6dc1a5893c0a

push

circle-ci

web-flow
Merge branch 'georgia-tech-db:master' into master

1152 of 1152 new or added lines in 80 files covered. (100.0%)

11836 of 12814 relevant lines covered (92.37%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.81
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
1✔
16
import os
1✔
17
import pickle
1✔
18
from pathlib import Path
1✔
19
from typing import Dict, List
1✔
20

21
import pandas as pd
1✔
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
1✔
24
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
1✔
25
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
1✔
26
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
1✔
27
from evadb.configuration.constants import (
1✔
28
    DEFAULT_TRAIN_REGRESSION_METRIC,
29
    DEFAULT_TRAIN_TIME_LIMIT,
30
    EvaDB_INSTALLATION_DIR,
31
)
32
from evadb.database import EvaDBDatabase
1✔
33
from evadb.executor.abstract_executor import AbstractExecutor
1✔
34
from evadb.functions.decorators.utils import load_io_from_function_decorators
1✔
35
from evadb.models.storage.batch import Batch
1✔
36
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
37
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
1✔
38
from evadb.utils.errors import FunctionIODefinitionError
1✔
39
from evadb.utils.generic_utils import (
1✔
40
    load_function_class_from_file,
41
    string_comparison_case_insensitive,
42
    try_to_import_ludwig,
43
    try_to_import_neuralforecast,
44
    try_to_import_sklearn,
45
    try_to_import_statsforecast,
46
    try_to_import_torch,
47
    try_to_import_ultralytics,
48
    try_to_import_xgboost,
49
)
50
from evadb.utils.logging_manager import logger
1✔
51

52

53
class CreateFunctionExecutor(AbstractExecutor):
1✔
54
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
1✔
55
        super().__init__(db, node)
1✔
56
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
1✔
57

58
    def handle_huggingface_function(self):
1✔
59
        """Handle HuggingFace functions
60

61
        HuggingFace functions are special functions that are not loaded from a file.
62
        So we do not need to call the setup method on them like we do for other functions.
63
        """
64
        # We need at least one deep learning framework for HuggingFace
65
        # Torch or Tensorflow
66
        try_to_import_torch()
1✔
67
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
1✔
68
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
1✔
69
        return (
1✔
70
            self.node.name,
71
            impl_path,
72
            self.node.function_type,
73
            io_list,
74
            self.node.metadata,
75
        )
76

77
    def handle_ludwig_function(self):
1✔
78
        """Handle ludwig functions
79

80
        Use Ludwig's auto_train engine to train/tune models.
81
        """
82
        try_to_import_ludwig()
1✔
83
        from ludwig.automl import auto_train
1✔
84

85
        assert (
1✔
86
            len(self.children) == 1
87
        ), "Create ludwig function expects 1 child, finds {}.".format(
88
            len(self.children)
89
        )
90

91
        aggregated_batch_list = []
1✔
92
        child = self.children[0]
1✔
93
        for batch in child.exec():
1✔
94
            aggregated_batch_list.append(batch)
1✔
95
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
1✔
96
        aggregated_batch.drop_column_alias()
1✔
97

98
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
1✔
99
        auto_train_results = auto_train(
1✔
100
            dataset=aggregated_batch.frames,
101
            target=arg_map["predict"],
102
            tune_for_memory=arg_map.get("tune_for_memory", False),
103
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
104
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
105
        )
106
        model_path = os.path.join(
1✔
107
            self.db.config.get_value("storage", "model_dir"), self.node.name
108
        )
109
        auto_train_results.best_model.save(model_path)
1✔
110
        self.node.metadata.append(
1✔
111
            FunctionMetadataCatalogEntry("model_path", model_path)
112
        )
113

114
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
1✔
115
        io_list = self._resolve_function_io(None)
1✔
116
        return (
1✔
117
            self.node.name,
118
            impl_path,
119
            self.node.function_type,
120
            io_list,
121
            self.node.metadata,
122
        )
123

124
    def handle_sklearn_function(self):
1✔
125
        """Handle sklearn functions
126

127
        Use Sklearn's regression to train models.
128
        """
129
        try_to_import_sklearn()
1✔
130
        from sklearn.linear_model import LinearRegression
1✔
131

132
        assert (
1✔
133
            len(self.children) == 1
134
        ), "Create sklearn function expects 1 child, finds {}.".format(
135
            len(self.children)
136
        )
137

138
        aggregated_batch_list = []
1✔
139
        child = self.children[0]
1✔
140
        for batch in child.exec():
1✔
141
            aggregated_batch_list.append(batch)
1✔
142
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
1✔
143
        aggregated_batch.drop_column_alias()
1✔
144

145
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
1✔
146
        model = LinearRegression()
1✔
147
        Y = aggregated_batch.frames[arg_map["predict"]]
1✔
148
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
1✔
149
        model.fit(X=aggregated_batch.frames, y=Y)
1✔
150
        model_path = os.path.join(
1✔
151
            self.db.config.get_value("storage", "model_dir"), self.node.name
152
        )
153
        pickle.dump(model, open(model_path, "wb"))
1✔
154
        self.node.metadata.append(
1✔
155
            FunctionMetadataCatalogEntry("model_path", model_path)
156
        )
157
        # Pass the prediction column name to sklearn.py
158
        self.node.metadata.append(
1✔
159
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
160
        )
161

162
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
1✔
163
        io_list = self._resolve_function_io(None)
1✔
164
        return (
1✔
165
            self.node.name,
166
            impl_path,
167
            self.node.function_type,
168
            io_list,
169
            self.node.metadata,
170
        )
171

172
    def handle_xgboost_function(self):
1✔
173
        """Handle xgboost functions
174

175
        We use the Flaml AutoML model for training xgboost models.
176
        """
177
        try_to_import_xgboost()
1✔
178

179
        assert (
1✔
180
            len(self.children) == 1
181
        ), "Create sklearn function expects 1 child, finds {}.".format(
182
            len(self.children)
183
        )
184

185
        aggregated_batch_list = []
1✔
186
        child = self.children[0]
1✔
187
        for batch in child.exec():
1✔
188
            aggregated_batch_list.append(batch)
1✔
189
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
1✔
190
        aggregated_batch.drop_column_alias()
1✔
191

192
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
1✔
193
        from flaml import AutoML
1✔
194

195
        model = AutoML()
1✔
196
        settings = {
1✔
197
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
198
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
199
            "estimator_list": ["xgboost"],
200
            "task": "regression",
201
        }
202
        model.fit(
1✔
203
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
204
        )
205
        model_path = os.path.join(
1✔
206
            self.db.config.get_value("storage", "model_dir"), self.node.name
207
        )
208
        pickle.dump(model, open(model_path, "wb"))
1✔
209
        self.node.metadata.append(
1✔
210
            FunctionMetadataCatalogEntry("model_path", model_path)
211
        )
212
        # Pass the prediction column to xgboost.py.
213
        self.node.metadata.append(
1✔
214
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
215
        )
216

217
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
1✔
218
        io_list = self._resolve_function_io(None)
1✔
219
        return (
1✔
220
            self.node.name,
221
            impl_path,
222
            self.node.function_type,
223
            io_list,
224
            self.node.metadata,
225
        )
226

227
    def handle_ultralytics_function(self):
1✔
228
        """Handle Ultralytics functions"""
229
        try_to_import_ultralytics()
1✔
230

231
        impl_path = (
1✔
232
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
233
        )
234
        function = self._try_initializing_function(
1✔
235
            impl_path, function_args=get_metadata_properties(self.node)
236
        )
237
        io_list = self._resolve_function_io(function)
1✔
238
        return (
1✔
239
            self.node.name,
240
            impl_path,
241
            self.node.function_type,
242
            io_list,
243
            self.node.metadata,
244
        )
245

246
    def handle_forecasting_function(self):
1✔
247
        """Handle forecasting functions"""
248
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
1✔
249
        aggregated_batch_list = []
1✔
250
        child = self.children[0]
1✔
251
        for batch in child.exec():
1✔
252
            aggregated_batch_list.append(batch)
1✔
253
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
1✔
254
        aggregated_batch.drop_column_alias()
1✔
255

256
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
1✔
257
        if not self.node.impl_path:
1✔
258
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
1✔
259
        else:
260
            impl_path = self.node.impl_path.absolute().as_posix()
×
261
        library = "statsforecast"
1✔
262
        supported_libraries = ["statsforecast", "neuralforecast"]
1✔
263

264
        if "horizon" not in arg_map.keys():
1✔
265
            raise ValueError(
266
                "Horizon must be provided while creating function of type FORECASTING"
267
            )
268
        try:
1✔
269
            horizon = int(arg_map["horizon"])
1✔
270
        except Exception as e:
271
            err_msg = f"{str(e)}. HORIZON must be integral."
272
            logger.error(err_msg)
273
            raise FunctionIODefinitionError(err_msg)
274

275
        if "library" in arg_map.keys():
1✔
276
            try:
1✔
277
                assert arg_map["library"].lower() in supported_libraries
1✔
278
            except Exception:
279
                err_msg = (
280
                    "EvaDB currently supports " + str(supported_libraries) + " only."
281
                )
282
                logger.error(err_msg)
283
                raise FunctionIODefinitionError(err_msg)
284

285
            library = arg_map["library"].lower()
1✔
286

287
        """
1✔
288
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
289
        - The unique_id (string, int or category) represents an identifier for the series.
290
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
291
        - The y (numeric) represents the measurement we wish to forecast.
292
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
293
        """
294
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
1✔
295
        if "time" in arg_map.keys():
1✔
296
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
1✔
297
        if "id" in arg_map.keys():
1✔
298
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
1✔
299

300
        data = aggregated_batch.frames
1✔
301
        if "unique_id" not in list(data.columns):
1✔
302
            data["unique_id"] = [1 for x in range(len(data))]
×
303

304
        if "ds" not in list(data.columns):
1✔
305
            data["ds"] = [x + 1 for x in range(len(data))]
×
306

307
        """
1✔
308
            Set or infer data frequency
309
        """
310

311
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
1✔
312
            arg_map["frequency"] = pd.infer_freq(data["ds"])
1✔
313
        frequency = arg_map["frequency"]
1✔
314
        if frequency is None:
1✔
315
            raise RuntimeError(
316
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
317
            )
318

319
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
1✔
320
            "H": 24,
321
            "M": 12,
322
            "Q": 4,
323
            "SM": 24,
324
            "BM": 12,
325
            "BMS": 12,
326
            "BQ": 4,
327
            "BH": 24,
328
        }
329

330
        new_freq = (
1✔
331
            frequency.split("-")[0] if "-" in frequency else frequency
332
        )  # shortens longer frequencies like Q-DEC
333
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
1✔
334

335
        """
1✔
336
            Neuralforecast implementation
337
        """
338
        if library == "neuralforecast":
1✔
339
            try_to_import_neuralforecast()
1✔
340
            from neuralforecast import NeuralForecast
1✔
341
            from neuralforecast.auto import AutoNBEATS, AutoNHITS
1✔
342
            from neuralforecast.models import NBEATS, NHITS
1✔
343

344
            model_dict = {
1✔
345
                "AutoNBEATS": AutoNBEATS,
346
                "AutoNHITS": AutoNHITS,
347
                "NBEATS": NBEATS,
348
                "NHITS": NHITS,
349
            }
350

351
            if "model" not in arg_map.keys():
1✔
352
                arg_map["model"] = "NBEATS"
1✔
353

354
            if "auto" not in arg_map.keys() or (
1✔
355
                arg_map["auto"].lower()[0] == "t"
356
                and "auto" not in arg_map["model"].lower()
357
            ):
358
                arg_map["model"] = "Auto" + arg_map["model"]
×
359

360
            try:
1✔
361
                model_here = model_dict[arg_map["model"]]
1✔
362
            except Exception:
363
                err_msg = "Supported models: " + str(model_dict.keys())
364
                logger.error(err_msg)
365
                raise FunctionIODefinitionError(err_msg)
366
            model_args = {}
1✔
367

368
            if "auto" not in arg_map["model"].lower():
1✔
369
                model_args["input_size"] = 2 * horizon
1✔
370
                model_args["early_stop_patience_steps"] = 20
1✔
371
            else:
372
                model_args["config"] = {
×
373
                    "input_size": 2 * horizon,
374
                    "early_stop_patience_steps": 20,
375
                }
376

377
            if len(data.columns) >= 4:
1✔
378
                exogenous_columns = [
1✔
379
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
380
                ]
381
                if "auto" not in arg_map["model"].lower():
1✔
382
                    model_args["hist_exog_list"] = exogenous_columns
1✔
383
                else:
384
                    model_args["config"]["hist_exog_list"] = exogenous_columns
×
385

386
            model_args["h"] = horizon
1✔
387

388
            model = NeuralForecast(
1✔
389
                [model_here(**model_args)],
390
                freq=new_freq,
391
            )
392

393
        # """
394
        #     Statsforecast implementation
395
        # """
396
        else:
397
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
1✔
398
                raise RuntimeError(
399
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
400
                )
401
            try_to_import_statsforecast()
1✔
402
            from statsforecast import StatsForecast
1✔
403
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
1✔
404

405
            model_dict = {
1✔
406
                "AutoARIMA": AutoARIMA,
407
                "AutoCES": AutoCES,
408
                "AutoETS": AutoETS,
409
                "AutoTheta": AutoTheta,
410
            }
411

412
            if "model" not in arg_map.keys():
1✔
413
                arg_map["model"] = "ARIMA"
1✔
414

415
            if "auto" not in arg_map["model"].lower():
1✔
416
                arg_map["model"] = "Auto" + arg_map["model"]
1✔
417

418
            try:
1✔
419
                model_here = model_dict[arg_map["model"]]
1✔
420
            except Exception:
421
                err_msg = "Supported models: " + str(model_dict.keys())
422
                logger.error(err_msg)
423
                raise FunctionIODefinitionError(err_msg)
424

425
            model = StatsForecast(
1✔
426
                [model_here(season_length=season_length)], freq=new_freq
427
            )
428

429
        data["ds"] = pd.to_datetime(data["ds"])
1✔
430

431
        model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
1✔
432
        if len(data.columns) >= 4 and library == "neuralforecast":
1✔
433
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
1✔
434

435
        model_dir = os.path.join(
1✔
436
            self.db.config.get_value("storage", "model_dir"),
437
            "tsforecasting",
438
            model_save_dir_name,
439
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
440
        )
441
        Path(model_dir).mkdir(parents=True, exist_ok=True)
1✔
442

443
        model_save_name = "horizon" + str(horizon) + ".pkl"
1✔
444

445
        model_path = os.path.join(model_dir, model_save_name)
1✔
446

447
        existing_model_files = sorted(
1✔
448
            os.listdir(model_dir),
449
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
450
        )
451
        existing_model_files = [
1✔
452
            x
453
            for x in existing_model_files
454
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
455
        ]
456
        if len(existing_model_files) == 0:
1✔
457
            logger.info("Training, please wait...")
1✔
458
            if library == "neuralforecast":
1✔
459
                model.fit(df=data, val_size=horizon)
1✔
460
            else:
461
                model.fit(df=data[["ds", "y", "unique_id"]])
1✔
462
            f = open(model_path, "wb")
1✔
463
            pickle.dump(model, f)
1✔
464
            f.close()
1✔
465
        elif not Path(model_path).exists():
×
466
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
467

468
        io_list = self._resolve_function_io(None)
1✔
469

470
        metadata_here = [
1✔
471
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
472
            FunctionMetadataCatalogEntry("model_path", model_path),
473
            FunctionMetadataCatalogEntry(
474
                "predict_column_rename", arg_map.get("predict", "y")
475
            ),
476
            FunctionMetadataCatalogEntry(
477
                "time_column_rename", arg_map.get("time", "ds")
478
            ),
479
            FunctionMetadataCatalogEntry(
480
                "id_column_rename", arg_map.get("id", "unique_id")
481
            ),
482
            FunctionMetadataCatalogEntry("horizon", horizon),
483
            FunctionMetadataCatalogEntry("library", library),
484
        ]
485

486
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
1✔
487

488
        return (
1✔
489
            self.node.name,
490
            impl_path,
491
            self.node.function_type,
492
            io_list,
493
            metadata_here,
494
        )
495

496
    def handle_generic_function(self):
1✔
497
        """Handle generic functions
498

499
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
500
        """
501
        impl_path = self.node.impl_path.absolute().as_posix()
1✔
502
        function = self._try_initializing_function(impl_path)
1✔
503
        io_list = self._resolve_function_io(function)
1✔
504

505
        return (
1✔
506
            self.node.name,
507
            impl_path,
508
            self.node.function_type,
509
            io_list,
510
            self.node.metadata,
511
        )
512

513
    def exec(self, *args, **kwargs):
1✔
514
        """Create function executor
515

516
        Calls the catalog to insert a function catalog entry.
517
        """
518
        assert (
1✔
519
            self.node.if_not_exists and self.node.or_replace
520
        ) is False, (
521
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
522
        )
523

524
        overwrite = False
1✔
525
        # check catalog if it already has this function entry
526
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
1✔
527
            if self.node.if_not_exists:
1✔
528
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
529
                yield Batch(pd.DataFrame([msg]))
1✔
530
                return
1✔
531
            elif self.node.or_replace:
1✔
532
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
533
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
534

535
                drop_executor = DropObjectExecutor(self.db, None)
1✔
536
                try:
1✔
537
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
1✔
538
                except RuntimeError:
539
                    pass
540
                else:
541
                    overwrite = True
1✔
542
            else:
543
                msg = f"Function {self.node.name} already exists."
1✔
544
                logger.error(msg)
1✔
545
                raise RuntimeError(msg)
546

547
        # if it's a type of HuggingFaceModel, override the impl_path
548
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
1✔
549
            (
1✔
550
                name,
551
                impl_path,
552
                function_type,
553
                io_list,
554
                metadata,
555
            ) = self.handle_huggingface_function()
556
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
1✔
557
            (
1✔
558
                name,
559
                impl_path,
560
                function_type,
561
                io_list,
562
                metadata,
563
            ) = self.handle_ultralytics_function()
564
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
1✔
565
            (
1✔
566
                name,
567
                impl_path,
568
                function_type,
569
                io_list,
570
                metadata,
571
            ) = self.handle_ludwig_function()
572
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
1✔
573
            (
1✔
574
                name,
575
                impl_path,
576
                function_type,
577
                io_list,
578
                metadata,
579
            ) = self.handle_sklearn_function()
580
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
1✔
581
            (
1✔
582
                name,
583
                impl_path,
584
                function_type,
585
                io_list,
586
                metadata,
587
            ) = self.handle_xgboost_function()
588
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
1✔
589
            (
1✔
590
                name,
591
                impl_path,
592
                function_type,
593
                io_list,
594
                metadata,
595
            ) = self.handle_forecasting_function()
596
        else:
597
            (
1✔
598
                name,
599
                impl_path,
600
                function_type,
601
                io_list,
602
                metadata,
603
            ) = self.handle_generic_function()
604

605
        self.catalog().insert_function_catalog_entry(
1✔
606
            name, impl_path, function_type, io_list, metadata
607
        )
608

609
        if overwrite:
1✔
610
            msg = f"Function {self.node.name} overwritten."
1✔
611
        else:
612
            msg = f"Function {self.node.name} added to the database."
1✔
613
        yield Batch(pd.DataFrame([msg]))
1✔
614

615
    def _try_initializing_function(
1✔
616
        self, impl_path: str, function_args: Dict = {}
617
    ) -> FunctionCatalogEntry:
618
        """Attempts to initialize function given the implementation file path and arguments.
619

620
        Args:
621
            impl_path (str): The file path of the function implementation file.
622
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
623

624
        Returns:
625
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
626

627
        Raises:
628
            RuntimeError: If an error occurs while initializing the function.
629
        """
630

631
        # load the function class from the file
632
        try:
1✔
633
            # loading the function class from the file
634
            function = load_function_class_from_file(impl_path, self.node.name)
1✔
635
            # initializing the function class calls the setup method internally
636
            function(**function_args)
1✔
637
        except Exception as e:
638
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
639
            # logger.error(err_msg)
640
            raise RuntimeError(err_msg)
641

642
        return function
1✔
643

644
    def _resolve_function_io(
1✔
645
        self, function: FunctionCatalogEntry
646
    ) -> List[FunctionIOCatalogEntry]:
647
        """Private method that resolves the input/output definitions for a given function.
648
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
649

650
        Args:
651
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
652

653
        Returns:
654
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
655
            output definitions for the function.
656

657
        Raises:
658
            RuntimeError: If an error occurs while resolving the function input/output
659
            definitions.
660
        """
661
        io_list = []
1✔
662
        try:
1✔
663
            if self.node.inputs:
1✔
664
                io_list.extend(self.node.inputs)
1✔
665
            else:
666
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
667
                io_list.extend(
1✔
668
                    load_io_from_function_decorators(function, is_input=True)
669
                )
670

671
            if self.node.outputs:
1✔
672
                io_list.extend(self.node.outputs)
1✔
673
            else:
674
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
675
                io_list.extend(
1✔
676
                    load_io_from_function_decorators(function, is_input=False)
677
                )
678

679
        except FunctionIODefinitionError as e:
680
            err_msg = (
681
                f"Error creating function, input/output definition incorrect: {str(e)}"
682
            )
683
            logger.error(err_msg)
684
            raise RuntimeError(err_msg)
685

686
        return io_list
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc