• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #846

02 Nov 2023 03:37AM UTC coverage: 72.419% (+0.05%) from 72.365%
#846

push

circle-ci

Jineet Desai
Convert variables to str

9421 of 13009 relevant lines covered (72.42%)

0.72 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

31.25
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import contextlib
1✔
16
import hashlib
1✔
17
import locale
1✔
18
import os
1✔
19
import pickle
1✔
20
import re
1✔
21
from pathlib import Path
1✔
22
from typing import Dict, List
1✔
23

24
import pandas as pd
1✔
25

26
from evadb.catalog.catalog_utils import get_metadata_properties
1✔
27
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
1✔
28
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
1✔
29
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
1✔
30
from evadb.configuration.constants import (
1✔
31
    DEFAULT_TRAIN_REGRESSION_METRIC,
32
    DEFAULT_TRAIN_TIME_LIMIT,
33
    DEFAULT_XGBOOST_TASK,
34
    EvaDB_INSTALLATION_DIR,
35
)
36
from evadb.database import EvaDBDatabase
1✔
37
from evadb.executor.abstract_executor import AbstractExecutor
1✔
38
from evadb.functions.decorators.utils import load_io_from_function_decorators
1✔
39
from evadb.models.storage.batch import Batch
1✔
40
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
41
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
1✔
42
from evadb.utils.errors import FunctionIODefinitionError
1✔
43
from evadb.utils.generic_utils import (
1✔
44
    load_function_class_from_file,
45
    string_comparison_case_insensitive,
46
    try_to_import_ludwig,
47
    try_to_import_neuralforecast,
48
    try_to_import_sklearn,
49
    try_to_import_statsforecast,
50
    try_to_import_torch,
51
    try_to_import_ultralytics,
52
    try_to_import_xgboost,
53
)
54
from evadb.utils.logging_manager import logger
1✔
55

56

57
# From https://stackoverflow.com/a/34333710
58
@contextlib.contextmanager
1✔
59
def set_env(**environ):
1✔
60
    """
61
    Temporarily set the process environment variables.
62

63
    >>> with set_env(PLUGINS_DIR='test/plugins'):
64
    ...   "PLUGINS_DIR" in os.environ
65
    True
66

67
    >>> "PLUGINS_DIR" in os.environ
68
    False
69

70
    :type environ: dict[str, unicode]
71
    :param environ: Environment variables to set
72
    """
73
    old_environ = dict(os.environ)
×
74
    os.environ.update(environ)
×
75
    try:
×
76
        yield
×
77
    finally:
78
        os.environ.clear()
×
79
        os.environ.update(old_environ)
×
80

81

82
class CreateFunctionExecutor(AbstractExecutor):
1✔
83
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
1✔
84
        super().__init__(db, node)
1✔
85
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
1✔
86

87
    def handle_huggingface_function(self):
1✔
88
        """Handle HuggingFace functions
89

90
        HuggingFace functions are special functions that are not loaded from a file.
91
        So we do not need to call the setup method on them like we do for other functions.
92
        """
93
        # We need at least one deep learning framework for HuggingFace
94
        # Torch or Tensorflow
95
        try_to_import_torch()
×
96
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
97
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
98
        return (
×
99
            self.node.name,
100
            impl_path,
101
            self.node.function_type,
102
            io_list,
103
            self.node.metadata,
104
        )
105

106
    def handle_ludwig_function(self):
1✔
107
        """Handle ludwig functions
108

109
        Use Ludwig's auto_train engine to train/tune models.
110
        """
111
        try_to_import_ludwig()
×
112
        from ludwig.automl import auto_train
×
113

114
        assert (
×
115
            len(self.children) == 1
116
        ), "Create ludwig function expects 1 child, finds {}.".format(
117
            len(self.children)
118
        )
119

120
        aggregated_batch_list = []
×
121
        child = self.children[0]
×
122
        for batch in child.exec():
×
123
            aggregated_batch_list.append(batch)
×
124
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
125
        aggregated_batch.drop_column_alias()
×
126

127
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
128
        auto_train_results = auto_train(
×
129
            dataset=aggregated_batch.frames,
130
            target=arg_map["predict"],
131
            tune_for_memory=arg_map.get("tune_for_memory", False),
132
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
133
            output_directory=self.db.catalog().get_configuration_catalog_value(
134
                "tmp_dir"
135
            ),
136
        )
137
        model_path = os.path.join(
×
138
            self.db.catalog().get_configuration_catalog_value("model_dir"),
139
            self.node.name,
140
        )
141
        auto_train_results.best_model.save(model_path)
×
142
        self.node.metadata.append(
×
143
            FunctionMetadataCatalogEntry("model_path", model_path)
144
        )
145

146
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
147
        io_list = self._resolve_function_io(None)
×
148
        return (
×
149
            self.node.name,
150
            impl_path,
151
            self.node.function_type,
152
            io_list,
153
            self.node.metadata,
154
        )
155

156
    def handle_sklearn_function(self):
1✔
157
        """Handle sklearn functions
158

159
        Use Sklearn's regression to train models.
160
        """
161
        try_to_import_sklearn()
×
162
        from sklearn.linear_model import LinearRegression
×
163

164
        assert (
×
165
            len(self.children) == 1
166
        ), "Create sklearn function expects 1 child, finds {}.".format(
167
            len(self.children)
168
        )
169

170
        aggregated_batch_list = []
×
171
        child = self.children[0]
×
172
        for batch in child.exec():
×
173
            aggregated_batch_list.append(batch)
×
174
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
175
        aggregated_batch.drop_column_alias()
×
176

177
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
178
        model = LinearRegression()
×
179
        Y = aggregated_batch.frames[arg_map["predict"]]
×
180
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
181
        model.fit(X=aggregated_batch.frames, y=Y)
×
182
        model_path = os.path.join(
×
183
            self.db.catalog().get_configuration_catalog_value("model_dir"),
184
            self.node.name,
185
        )
186
        pickle.dump(model, open(model_path, "wb"))
×
187
        self.node.metadata.append(
×
188
            FunctionMetadataCatalogEntry("model_path", model_path)
189
        )
190
        # Pass the prediction column name to sklearn.py
191
        self.node.metadata.append(
×
192
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
193
        )
194

195
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
196
        io_list = self._resolve_function_io(None)
×
197
        return (
×
198
            self.node.name,
199
            impl_path,
200
            self.node.function_type,
201
            io_list,
202
            self.node.metadata,
203
        )
204

205
    def convert_to_numeric(self, x):
1✔
206
        x = re.sub("[^0-9.,]", "", str(x))
×
207
        locale.setlocale(locale.LC_ALL, "")
×
208
        x = float(locale.atof(x))
×
209
        if x.is_integer():
×
210
            return int(x)
×
211
        else:
212
            return x
×
213

214
    def handle_xgboost_function(self):
1✔
215
        """Handle xgboost functions
216

217
        We use the Flaml AutoML model for training xgboost models.
218
        """
219
        try_to_import_xgboost()
×
220

221
        assert (
×
222
            len(self.children) == 1
223
        ), "Create sklearn function expects 1 child, finds {}.".format(
224
            len(self.children)
225
        )
226

227
        aggregated_batch_list = []
×
228
        child = self.children[0]
×
229
        for batch in child.exec():
×
230
            aggregated_batch_list.append(batch)
×
231
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
232
        aggregated_batch.drop_column_alias()
×
233

234
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
235
        from flaml import AutoML
×
236

237
        model = AutoML()
×
238
        settings = {
×
239
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
240
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
241
            "estimator_list": ["xgboost"],
242
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
243
        }
244
        model.fit(
×
245
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
246
        )
247
        model_path = os.path.join(
×
248
            self.db.catalog().get_configuration_catalog_value("model_dir"),
249
            self.node.name,
250
        )
251
        pickle.dump(model, open(model_path, "wb"))
×
252
        self.node.metadata.append(
×
253
            FunctionMetadataCatalogEntry("model_path", model_path)
254
        )
255
        # Pass the prediction column to xgboost.py.
256
        self.node.metadata.append(
×
257
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
258
        )
259

260
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
×
261
        io_list = self._resolve_function_io(None)
×
262
        best_score = model.best_loss
×
263
        train_time = model.best_config_train_time
×
264
        return (
×
265
            self.node.name,
266
            impl_path,
267
            self.node.function_type,
268
            io_list,
269
            self.node.metadata,
270
            best_score,
271
            train_time,
272
        )
273

274
    def handle_ultralytics_function(self):
1✔
275
        """Handle Ultralytics functions"""
276
        try_to_import_ultralytics()
1✔
277

278
        impl_path = (
1✔
279
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
280
        )
281
        function = self._try_initializing_function(
1✔
282
            impl_path, function_args=get_metadata_properties(self.node)
283
        )
284
        io_list = self._resolve_function_io(function)
1✔
285
        return (
1✔
286
            self.node.name,
287
            impl_path,
288
            self.node.function_type,
289
            io_list,
290
            self.node.metadata,
291
        )
292

293
    def handle_forecasting_function(self):
1✔
294
        """Handle forecasting functions"""
295
        aggregated_batch_list = []
×
296
        child = self.children[0]
×
297
        for batch in child.exec():
×
298
            aggregated_batch_list.append(batch)
×
299
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
300
        aggregated_batch.drop_column_alias()
×
301

302
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
303
        if not self.node.impl_path:
×
304
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
305
        else:
306
            impl_path = self.node.impl_path.absolute().as_posix()
×
307
        library = "statsforecast"
×
308
        supported_libraries = ["statsforecast", "neuralforecast"]
×
309

310
        if "horizon" not in arg_map.keys():
×
311
            raise ValueError(
312
                "Horizon must be provided while creating function of type FORECASTING"
313
            )
314
        try:
×
315
            horizon = int(arg_map["horizon"])
×
316
        except Exception as e:
317
            err_msg = f"{str(e)}. HORIZON must be integral."
318
            logger.error(err_msg)
319
            raise FunctionIODefinitionError(err_msg)
320

321
        if "library" in arg_map.keys():
×
322
            try:
×
323
                assert arg_map["library"].lower() in supported_libraries
×
324
            except Exception:
325
                err_msg = (
326
                    "EvaDB currently supports " + str(supported_libraries) + " only."
327
                )
328
                logger.error(err_msg)
329
                raise FunctionIODefinitionError(err_msg)
330

331
            library = arg_map["library"].lower()
×
332

333
        """
×
334
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
335
        - The unique_id (string, int or category) represents an identifier for the series.
336
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
337
        - The y (numeric) represents the measurement we wish to forecast.
338
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
339
        """
340
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
341
        if "time" in arg_map.keys():
×
342
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
343
        if "id" in arg_map.keys():
×
344
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
345

346
        data = aggregated_batch.frames
×
347
        if "unique_id" not in list(data.columns):
×
348
            data["unique_id"] = [1 for x in range(len(data))]
×
349

350
        if "ds" not in list(data.columns):
×
351
            data["ds"] = [x + 1 for x in range(len(data))]
×
352

353
        """
×
354
            Set or infer data frequency
355
        """
356

357
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
358
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
359
        frequency = arg_map["frequency"]
×
360
        if frequency is None:
×
361
            raise RuntimeError(
362
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
363
            )
364

365
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
366
            "H": 24,
367
            "M": 12,
368
            "Q": 4,
369
            "SM": 24,
370
            "BM": 12,
371
            "BMS": 12,
372
            "BQ": 4,
373
            "BH": 24,
374
        }
375

376
        new_freq = (
×
377
            frequency.split("-")[0] if "-" in frequency else frequency
378
        )  # shortens longer frequencies like Q-DEC
379
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
380

381
        """
×
382
            Neuralforecast implementation
383
        """
384
        if library == "neuralforecast":
×
385
            try_to_import_neuralforecast()
×
386
            from neuralforecast import NeuralForecast
×
387
            from neuralforecast.auto import AutoNBEATS, AutoNHITS
×
388
            from neuralforecast.models import NBEATS, NHITS
×
389

390
            model_dict = {
×
391
                "AutoNBEATS": AutoNBEATS,
392
                "AutoNHITS": AutoNHITS,
393
                "NBEATS": NBEATS,
394
                "NHITS": NHITS,
395
            }
396

397
            if "model" not in arg_map.keys():
×
398
                arg_map["model"] = "NBEATS"
×
399

400
            if "auto" not in arg_map.keys() or (
×
401
                arg_map["auto"].lower()[0] == "t"
402
                and "auto" not in arg_map["model"].lower()
403
            ):
404
                arg_map["model"] = "Auto" + arg_map["model"]
×
405

406
            try:
×
407
                model_here = model_dict[arg_map["model"]]
×
408
            except Exception:
409
                err_msg = "Supported models: " + str(model_dict.keys())
410
                logger.error(err_msg)
411
                raise FunctionIODefinitionError(err_msg)
412
            model_args = {}
×
413

414
            if "auto" not in arg_map["model"].lower():
×
415
                model_args["input_size"] = 2 * horizon
×
416
                model_args["early_stop_patience_steps"] = 20
×
417
            else:
418
                model_args_config = {
×
419
                    "input_size": 2 * horizon,
420
                    "early_stop_patience_steps": 20,
421
                }
422

423
            if len(data.columns) >= 4:
×
424
                exogenous_columns = [
×
425
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
426
                ]
427
                if "auto" not in arg_map["model"].lower():
×
428
                    model_args["hist_exog_list"] = exogenous_columns
×
429
                else:
430
                    model_args_config["hist_exog_list"] = exogenous_columns
×
431

432
                    def get_optuna_config(trial):
×
433
                        return model_args_config
×
434

435
                    model_args["config"] = get_optuna_config
×
436
                    model_args["backend"] = "optuna"
×
437

438
            model_args["h"] = horizon
×
439

440
            model = NeuralForecast(
×
441
                [model_here(**model_args)],
442
                freq=new_freq,
443
            )
444

445
        # """
446
        #     Statsforecast implementation
447
        # """
448
        else:
449
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
450
                raise RuntimeError(
451
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
452
                )
453
            try_to_import_statsforecast()
×
454
            from statsforecast import StatsForecast
×
455
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
456

457
            model_dict = {
×
458
                "AutoARIMA": AutoARIMA,
459
                "AutoCES": AutoCES,
460
                "AutoETS": AutoETS,
461
                "AutoTheta": AutoTheta,
462
            }
463

464
            if "model" not in arg_map.keys():
×
465
                arg_map["model"] = "ARIMA"
×
466

467
            if "auto" not in arg_map["model"].lower():
×
468
                arg_map["model"] = "Auto" + arg_map["model"]
×
469

470
            try:
×
471
                model_here = model_dict[arg_map["model"]]
×
472
            except Exception:
473
                err_msg = "Supported models: " + str(model_dict.keys())
474
                logger.error(err_msg)
475
                raise FunctionIODefinitionError(err_msg)
476

477
            model = StatsForecast(
×
478
                [model_here(season_length=season_length)], freq=new_freq
479
            )
480

481
        data["ds"] = pd.to_datetime(data["ds"])
×
482

483
        model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
×
484
        if len(data.columns) >= 4 and library == "neuralforecast":
×
485
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
486

487
        model_dir = os.path.join(
×
488
            self.db.catalog().get_configuration_catalog_value("model_dir"),
489
            "tsforecasting",
490
            model_save_dir_name,
491
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
492
        )
493
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
494

495
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
496

497
        model_path = os.path.join(model_dir, model_save_name)
×
498

499
        existing_model_files = sorted(
×
500
            os.listdir(model_dir),
501
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
502
        )
503
        existing_model_files = [
×
504
            x
505
            for x in existing_model_files
506
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
507
        ]
508
        if len(existing_model_files) == 0:
×
509
            logger.info("Training, please wait...")
×
510
            for column in data.columns:
×
511
                if column != "ds" and column != "unique_id":
×
512
                    data[column] = data.apply(
×
513
                        lambda x: self.convert_to_numeric(x[column]), axis=1
514
                    )
515
            if library == "neuralforecast":
×
516
                cuda_devices_here = "0"
×
517
                if "CUDA_VISIBLE_DEVICES" in os.environ:
×
518
                    cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
×
519

520
                with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
×
521
                    model.fit(df=data, val_size=horizon)
×
522
                    model.save(model_path, overwrite=True)
×
523
            else:
524
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
525
                for col in data["unique_id"].unique():
×
526
                    if len(data[data["unique_id"] == col]) == 1:
×
527
                        data = data._append(
×
528
                            [data[data["unique_id"] == col]], ignore_index=True
529
                        )
530

531
                model.fit(df=data[["ds", "y", "unique_id"]])
×
532
                f = open(model_path, "wb")
×
533
                pickle.dump(model, f)
×
534
                f.close()
×
535
        elif not Path(model_path).exists():
×
536
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
537

538
        io_list = self._resolve_function_io(None)
×
539

540
        metadata_here = [
×
541
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
542
            FunctionMetadataCatalogEntry("model_path", model_path),
543
            FunctionMetadataCatalogEntry(
544
                "predict_column_rename", arg_map.get("predict", "y")
545
            ),
546
            FunctionMetadataCatalogEntry(
547
                "time_column_rename", arg_map.get("time", "ds")
548
            ),
549
            FunctionMetadataCatalogEntry(
550
                "id_column_rename", arg_map.get("id", "unique_id")
551
            ),
552
            FunctionMetadataCatalogEntry("horizon", horizon),
553
            FunctionMetadataCatalogEntry("library", library),
554
        ]
555

556
        return (
×
557
            self.node.name,
558
            impl_path,
559
            self.node.function_type,
560
            io_list,
561
            metadata_here,
562
        )
563

564
    def handle_generic_function(self):
1✔
565
        """Handle generic functions
566

567
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
568
        """
569
        impl_path = self.node.impl_path.absolute().as_posix()
1✔
570
        function = self._try_initializing_function(impl_path)
1✔
571
        io_list = self._resolve_function_io(function)
1✔
572

573
        return (
1✔
574
            self.node.name,
575
            impl_path,
576
            self.node.function_type,
577
            io_list,
578
            self.node.metadata,
579
        )
580

581
    def exec(self, *args, **kwargs):
1✔
582
        """Create function executor
583

584
        Calls the catalog to insert a function catalog entry.
585
        """
586
        assert (
1✔
587
            self.node.if_not_exists and self.node.or_replace
588
        ) is False, (
589
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
590
        )
591

592
        overwrite = False
1✔
593
        best_score = False
1✔
594
        train_time = False
1✔
595
        # check catalog if it already has this function entry
596
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
1✔
597
            if self.node.if_not_exists:
1✔
598
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
599
                yield Batch(pd.DataFrame([msg]))
1✔
600
                return
×
601
            elif self.node.or_replace:
1✔
602
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
603
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
604

605
                drop_executor = DropObjectExecutor(self.db, None)
1✔
606
                try:
1✔
607
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
1✔
608
                except RuntimeError:
609
                    pass
610
                else:
611
                    overwrite = True
1✔
612
            else:
613
                msg = f"Function {self.node.name} already exists."
×
614
                logger.error(msg)
×
615
                raise RuntimeError(msg)
616

617
        # if it's a type of HuggingFaceModel, override the impl_path
618
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
1✔
619
            (
×
620
                name,
621
                impl_path,
622
                function_type,
623
                io_list,
624
                metadata,
625
            ) = self.handle_huggingface_function()
626
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
1✔
627
            (
1✔
628
                name,
629
                impl_path,
630
                function_type,
631
                io_list,
632
                metadata,
633
            ) = self.handle_ultralytics_function()
634
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
1✔
635
            (
×
636
                name,
637
                impl_path,
638
                function_type,
639
                io_list,
640
                metadata,
641
            ) = self.handle_ludwig_function()
642
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
1✔
643
            (
×
644
                name,
645
                impl_path,
646
                function_type,
647
                io_list,
648
                metadata,
649
            ) = self.handle_sklearn_function()
650
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
1✔
651
            (
×
652
                name,
653
                impl_path,
654
                function_type,
655
                io_list,
656
                metadata,
657
                best_score,
658
                train_time,
659
            ) = self.handle_xgboost_function()
660
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
1✔
661
            (
×
662
                name,
663
                impl_path,
664
                function_type,
665
                io_list,
666
                metadata,
667
            ) = self.handle_forecasting_function()
668
        else:
669
            (
1✔
670
                name,
671
                impl_path,
672
                function_type,
673
                io_list,
674
                metadata,
675
            ) = self.handle_generic_function()
676

677
        self.catalog().insert_function_catalog_entry(
1✔
678
            name, impl_path, function_type, io_list, metadata
679
        )
680

681
        if overwrite:
1✔
682
            msg = f"Function {self.node.name} overwritten."
1✔
683
        else:
684
            msg = f"Function {self.node.name} added to the database."
1✔
685
        if best_score and train_time:
1✔
686
            yield Batch(
×
687
                pd.DataFrame(
688
                    [
689
                        msg,
690
                        "Validation Score: " + str(best_score),
691
                        "Training time: " + str(train_time),
692
                    ]
693
                )
694
            )
695
        else:
696
            yield Batch(pd.DataFrame([msg]))
1✔
697

698
    def _try_initializing_function(
1✔
699
        self, impl_path: str, function_args: Dict = {}
700
    ) -> FunctionCatalogEntry:
701
        """Attempts to initialize function given the implementation file path and arguments.
702

703
        Args:
704
            impl_path (str): The file path of the function implementation file.
705
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
706

707
        Returns:
708
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
709

710
        Raises:
711
            RuntimeError: If an error occurs while initializing the function.
712
        """
713

714
        # load the function class from the file
715
        try:
1✔
716
            # loading the function class from the file
717
            function = load_function_class_from_file(impl_path, self.node.name)
1✔
718
            # initializing the function class calls the setup method internally
719
            function(**function_args)
1✔
720
        except Exception as e:
721
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
722
            # logger.error(err_msg)
723
            raise RuntimeError(err_msg)
724

725
        return function
1✔
726

727
    def _resolve_function_io(
1✔
728
        self, function: FunctionCatalogEntry
729
    ) -> List[FunctionIOCatalogEntry]:
730
        """Private method that resolves the input/output definitions for a given function.
731
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
732

733
        Args:
734
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
735

736
        Returns:
737
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
738
            output definitions for the function.
739

740
        Raises:
741
            RuntimeError: If an error occurs while resolving the function input/output
742
            definitions.
743
        """
744
        io_list = []
1✔
745
        try:
1✔
746
            if self.node.inputs:
1✔
747
                io_list.extend(self.node.inputs)
1✔
748
            else:
749
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
750
                io_list.extend(
1✔
751
                    load_io_from_function_decorators(function, is_input=True)
752
                )
753

754
            if self.node.outputs:
1✔
755
                io_list.extend(self.node.outputs)
1✔
756
            else:
757
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
758
                io_list.extend(
1✔
759
                    load_io_from_function_decorators(function, is_input=False)
760
                )
761

762
        except FunctionIODefinitionError as e:
763
            err_msg = (
764
                f"Error creating function, input/output definition incorrect: {str(e)}"
765
            )
766
            logger.error(err_msg)
767
            raise RuntimeError(err_msg)
768

769
        return io_list
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc