• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / 7c1c5cb7-cb1e-4c69-985c-b642250357b6

28 Oct 2023 10:44PM UTC coverage: 76.956% (-1.7%) from 78.704%
7c1c5cb7-cb1e-4c69-985c-b642250357b6

push

circle-ci

xzdandy
Fix list

8990 of 11682 relevant lines covered (76.96%)

1.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

30.85
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import contextlib
2✔
16
import hashlib
2✔
17
import locale
2✔
18
import os
2✔
19
import pickle
2✔
20
import re
2✔
21
from pathlib import Path
2✔
22
from typing import Dict, List
2✔
23

24
import pandas as pd
2✔
25

26
from evadb.catalog.catalog_utils import get_metadata_properties
2✔
27
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
2✔
28
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
2✔
29
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
2✔
30
from evadb.configuration.constants import (
2✔
31
    DEFAULT_TRAIN_REGRESSION_METRIC,
32
    DEFAULT_TRAIN_TIME_LIMIT,
33
    DEFAULT_XGBOOST_TASK,
34
    EvaDB_INSTALLATION_DIR,
35
)
36
from evadb.database import EvaDBDatabase
2✔
37
from evadb.executor.abstract_executor import AbstractExecutor
2✔
38
from evadb.functions.decorators.utils import load_io_from_function_decorators
2✔
39
from evadb.models.storage.batch import Batch
2✔
40
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
2✔
41
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
2✔
42
from evadb.utils.errors import FunctionIODefinitionError
2✔
43
from evadb.utils.generic_utils import (
2✔
44
    load_function_class_from_file,
45
    string_comparison_case_insensitive,
46
    try_to_import_ludwig,
47
    try_to_import_neuralforecast,
48
    try_to_import_sklearn,
49
    try_to_import_statsforecast,
50
    try_to_import_torch,
51
    try_to_import_ultralytics,
52
    try_to_import_xgboost,
53
)
54
from evadb.utils.logging_manager import logger
2✔
55

56

57
# From https://stackoverflow.com/a/34333710
58
@contextlib.contextmanager
2✔
59
def set_env(**environ):
2✔
60
    """
61
    Temporarily set the process environment variables.
62

63
    >>> with set_env(PLUGINS_DIR='test/plugins'):
64
    ...   "PLUGINS_DIR" in os.environ
65
    True
66

67
    >>> "PLUGINS_DIR" in os.environ
68
    False
69

70
    :type environ: dict[str, unicode]
71
    :param environ: Environment variables to set
72
    """
73
    old_environ = dict(os.environ)
×
74
    os.environ.update(environ)
×
75
    try:
×
76
        yield
×
77
    finally:
78
        os.environ.clear()
×
79
        os.environ.update(old_environ)
×
80

81

82
class CreateFunctionExecutor(AbstractExecutor):
2✔
83
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
2✔
84
        super().__init__(db, node)
2✔
85
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
2✔
86

87
    def handle_huggingface_function(self):
2✔
88
        """Handle HuggingFace functions
89

90
        HuggingFace functions are special functions that are not loaded from a file.
91
        So we do not need to call the setup method on them like we do for other functions.
92
        """
93
        # We need at least one deep learning framework for HuggingFace
94
        # Torch or Tensorflow
95
        try_to_import_torch()
×
96
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
97
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
98
        return (
×
99
            self.node.name,
100
            impl_path,
101
            self.node.function_type,
102
            io_list,
103
            self.node.metadata,
104
        )
105

106
    def handle_ludwig_function(self):
2✔
107
        """Handle ludwig functions
108

109
        Use Ludwig's auto_train engine to train/tune models.
110
        """
111
        try_to_import_ludwig()
×
112
        from ludwig.automl import auto_train
×
113

114
        assert (
×
115
            len(self.children) == 1
116
        ), "Create ludwig function expects 1 child, finds {}.".format(
117
            len(self.children)
118
        )
119

120
        aggregated_batch_list = []
×
121
        child = self.children[0]
×
122
        for batch in child.exec():
×
123
            aggregated_batch_list.append(batch)
×
124
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
125
        aggregated_batch.drop_column_alias()
×
126

127
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
128
        auto_train_results = auto_train(
×
129
            dataset=aggregated_batch.frames,
130
            target=arg_map["predict"],
131
            tune_for_memory=arg_map.get("tune_for_memory", False),
132
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
133
            output_directory=self.db.catalog().get_configuration_catalog_value(
134
                "tmp_dir"
135
            ),
136
        )
137
        model_path = os.path.join(
×
138
            self.db.catalog().get_configuration_catalog_value("model_dir"),
139
            self.node.name,
140
        )
141
        auto_train_results.best_model.save(model_path)
×
142
        self.node.metadata.append(
×
143
            FunctionMetadataCatalogEntry("model_path", model_path)
144
        )
145

146
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
147
        io_list = self._resolve_function_io(None)
×
148
        return (
×
149
            self.node.name,
150
            impl_path,
151
            self.node.function_type,
152
            io_list,
153
            self.node.metadata,
154
        )
155

156
    def handle_sklearn_function(self):
2✔
157
        """Handle sklearn functions
158

159
        Use Sklearn's regression to train models.
160
        """
161
        try_to_import_sklearn()
×
162
        from sklearn.linear_model import LinearRegression
×
163

164
        assert (
×
165
            len(self.children) == 1
166
        ), "Create sklearn function expects 1 child, finds {}.".format(
167
            len(self.children)
168
        )
169

170
        aggregated_batch_list = []
×
171
        child = self.children[0]
×
172
        for batch in child.exec():
×
173
            aggregated_batch_list.append(batch)
×
174
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
175
        aggregated_batch.drop_column_alias()
×
176

177
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
178
        model = LinearRegression()
×
179
        Y = aggregated_batch.frames[arg_map["predict"]]
×
180
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
181
        model.fit(X=aggregated_batch.frames, y=Y)
×
182
        model_path = os.path.join(
×
183
            self.db.catalog().get_configuration_catalog_value("model_dir"),
184
            self.node.name,
185
        )
186
        pickle.dump(model, open(model_path, "wb"))
×
187
        self.node.metadata.append(
×
188
            FunctionMetadataCatalogEntry("model_path", model_path)
189
        )
190
        # Pass the prediction column name to sklearn.py
191
        self.node.metadata.append(
×
192
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
193
        )
194

195
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
196
        io_list = self._resolve_function_io(None)
×
197
        return (
×
198
            self.node.name,
199
            impl_path,
200
            self.node.function_type,
201
            io_list,
202
            self.node.metadata,
203
        )
204

205
    def convert_to_numeric(self, x):
2✔
206
        x = re.sub("[^0-9.,]", "", str(x))
×
207
        locale.setlocale(locale.LC_ALL, "")
×
208
        x = float(locale.atof(x))
×
209
        if x.is_integer():
×
210
            return int(x)
×
211
        else:
212
            return x
×
213

214
    def handle_xgboost_function(self):
2✔
215
        """Handle xgboost functions
216

217
        We use the Flaml AutoML model for training xgboost models.
218
        """
219
        try_to_import_xgboost()
×
220

221
        assert (
×
222
            len(self.children) == 1
223
        ), "Create sklearn function expects 1 child, finds {}.".format(
224
            len(self.children)
225
        )
226

227
        aggregated_batch_list = []
×
228
        child = self.children[0]
×
229
        for batch in child.exec():
×
230
            aggregated_batch_list.append(batch)
×
231
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
232
        aggregated_batch.drop_column_alias()
×
233

234
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
235
        from flaml import AutoML
×
236

237
        model = AutoML()
×
238
        settings = {
×
239
            "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
240
            "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
241
            "estimator_list": ["xgboost"],
242
            "task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
243
        }
244
        model.fit(
×
245
            dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
246
        )
247
        model_path = os.path.join(
×
248
            self.db.catalog().get_configuration_catalog_value("model_dir"),
249
            self.node.name,
250
        )
251
        pickle.dump(model, open(model_path, "wb"))
×
252
        self.node.metadata.append(
×
253
            FunctionMetadataCatalogEntry("model_path", model_path)
254
        )
255
        # Pass the prediction column to xgboost.py.
256
        self.node.metadata.append(
×
257
            FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
258
        )
259

260
        impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
×
261
        io_list = self._resolve_function_io(None)
×
262
        return (
×
263
            self.node.name,
264
            impl_path,
265
            self.node.function_type,
266
            io_list,
267
            self.node.metadata,
268
        )
269

270
    def handle_ultralytics_function(self):
2✔
271
        """Handle Ultralytics functions"""
272
        try_to_import_ultralytics()
2✔
273

274
        impl_path = (
2✔
275
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
276
        )
277
        function = self._try_initializing_function(
2✔
278
            impl_path, function_args=get_metadata_properties(self.node)
279
        )
280
        io_list = self._resolve_function_io(function)
2✔
281
        return (
2✔
282
            self.node.name,
283
            impl_path,
284
            self.node.function_type,
285
            io_list,
286
            self.node.metadata,
287
        )
288

289
    def handle_forecasting_function(self):
2✔
290
        """Handle forecasting functions"""
291
        aggregated_batch_list = []
×
292
        child = self.children[0]
×
293
        for batch in child.exec():
×
294
            aggregated_batch_list.append(batch)
×
295
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
296
        aggregated_batch.drop_column_alias()
×
297

298
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
299
        if not self.node.impl_path:
×
300
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
301
        else:
302
            impl_path = self.node.impl_path.absolute().as_posix()
×
303
        library = "statsforecast"
×
304
        supported_libraries = ["statsforecast", "neuralforecast"]
×
305

306
        if "horizon" not in arg_map.keys():
×
307
            raise ValueError(
308
                "Horizon must be provided while creating function of type FORECASTING"
309
            )
310
        try:
×
311
            horizon = int(arg_map["horizon"])
×
312
        except Exception as e:
313
            err_msg = f"{str(e)}. HORIZON must be integral."
314
            logger.error(err_msg)
315
            raise FunctionIODefinitionError(err_msg)
316

317
        if "library" in arg_map.keys():
×
318
            try:
×
319
                assert arg_map["library"].lower() in supported_libraries
×
320
            except Exception:
321
                err_msg = (
322
                    "EvaDB currently supports " + str(supported_libraries) + " only."
323
                )
324
                logger.error(err_msg)
325
                raise FunctionIODefinitionError(err_msg)
326

327
            library = arg_map["library"].lower()
×
328

329
        """
×
330
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
331
        - The unique_id (string, int or category) represents an identifier for the series.
332
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
333
        - The y (numeric) represents the measurement we wish to forecast.
334
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
335
        """
336
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
337
        if "time" in arg_map.keys():
×
338
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
339
        if "id" in arg_map.keys():
×
340
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
341

342
        data = aggregated_batch.frames
×
343
        if "unique_id" not in list(data.columns):
×
344
            data["unique_id"] = [1 for x in range(len(data))]
×
345

346
        if "ds" not in list(data.columns):
×
347
            data["ds"] = [x + 1 for x in range(len(data))]
×
348

349
        """
×
350
            Set or infer data frequency
351
        """
352

353
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
354
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
355
        frequency = arg_map["frequency"]
×
356
        if frequency is None:
×
357
            raise RuntimeError(
358
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
359
            )
360

361
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
362
            "H": 24,
363
            "M": 12,
364
            "Q": 4,
365
            "SM": 24,
366
            "BM": 12,
367
            "BMS": 12,
368
            "BQ": 4,
369
            "BH": 24,
370
        }
371

372
        new_freq = (
×
373
            frequency.split("-")[0] if "-" in frequency else frequency
374
        )  # shortens longer frequencies like Q-DEC
375
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
376

377
        """
×
378
            Neuralforecast implementation
379
        """
380
        if library == "neuralforecast":
×
381
            try_to_import_neuralforecast()
×
382
            from neuralforecast import NeuralForecast
×
383
            from neuralforecast.auto import AutoNBEATS, AutoNHITS
×
384
            from neuralforecast.models import NBEATS, NHITS
×
385

386
            model_dict = {
×
387
                "AutoNBEATS": AutoNBEATS,
388
                "AutoNHITS": AutoNHITS,
389
                "NBEATS": NBEATS,
390
                "NHITS": NHITS,
391
            }
392

393
            if "model" not in arg_map.keys():
×
394
                arg_map["model"] = "NBEATS"
×
395

396
            if "auto" not in arg_map.keys() or (
×
397
                arg_map["auto"].lower()[0] == "t"
398
                and "auto" not in arg_map["model"].lower()
399
            ):
400
                arg_map["model"] = "Auto" + arg_map["model"]
×
401

402
            try:
×
403
                model_here = model_dict[arg_map["model"]]
×
404
            except Exception:
405
                err_msg = "Supported models: " + str(model_dict.keys())
406
                logger.error(err_msg)
407
                raise FunctionIODefinitionError(err_msg)
408
            model_args = {}
×
409

410
            if "auto" not in arg_map["model"].lower():
×
411
                model_args["input_size"] = 2 * horizon
×
412
                model_args["early_stop_patience_steps"] = 20
×
413
            else:
414
                model_args_config = {
×
415
                    "input_size": 2 * horizon,
416
                    "early_stop_patience_steps": 20,
417
                }
418

419
            if len(data.columns) >= 4:
×
420
                exogenous_columns = [
×
421
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
422
                ]
423
                if "auto" not in arg_map["model"].lower():
×
424
                    model_args["hist_exog_list"] = exogenous_columns
×
425
                else:
426
                    model_args_config["hist_exog_list"] = exogenous_columns
×
427

428
                    def get_optuna_config(trial):
×
429
                        return model_args_config
×
430

431
                    model_args["config"] = get_optuna_config
×
432
                    model_args["backend"] = "optuna"
×
433

434
            model_args["h"] = horizon
×
435

436
            model = NeuralForecast(
×
437
                [model_here(**model_args)],
438
                freq=new_freq,
439
            )
440

441
        # """
442
        #     Statsforecast implementation
443
        # """
444
        else:
445
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
446
                raise RuntimeError(
447
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
448
                )
449
            try_to_import_statsforecast()
×
450
            from statsforecast import StatsForecast
×
451
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
452

453
            model_dict = {
×
454
                "AutoARIMA": AutoARIMA,
455
                "AutoCES": AutoCES,
456
                "AutoETS": AutoETS,
457
                "AutoTheta": AutoTheta,
458
            }
459

460
            if "model" not in arg_map.keys():
×
461
                arg_map["model"] = "ARIMA"
×
462

463
            if "auto" not in arg_map["model"].lower():
×
464
                arg_map["model"] = "Auto" + arg_map["model"]
×
465

466
            try:
×
467
                model_here = model_dict[arg_map["model"]]
×
468
            except Exception:
469
                err_msg = "Supported models: " + str(model_dict.keys())
470
                logger.error(err_msg)
471
                raise FunctionIODefinitionError(err_msg)
472

473
            model = StatsForecast(
×
474
                [model_here(season_length=season_length)], freq=new_freq
475
            )
476

477
        data["ds"] = pd.to_datetime(data["ds"])
×
478

479
        model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
×
480
        if len(data.columns) >= 4 and library == "neuralforecast":
×
481
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
482

483
        model_dir = os.path.join(
×
484
            self.db.catalog().get_configuration_catalog_value("model_dir"),
485
            "tsforecasting",
486
            model_save_dir_name,
487
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
488
        )
489
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
490

491
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
492

493
        model_path = os.path.join(model_dir, model_save_name)
×
494

495
        existing_model_files = sorted(
×
496
            os.listdir(model_dir),
497
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
498
        )
499
        existing_model_files = [
×
500
            x
501
            for x in existing_model_files
502
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
503
        ]
504
        if len(existing_model_files) == 0:
×
505
            logger.info("Training, please wait...")
×
506
            for column in data.columns:
×
507
                if column != "ds" and column != "unique_id":
×
508
                    data[column] = data.apply(
×
509
                        lambda x: self.convert_to_numeric(x[column]), axis=1
510
                    )
511
            if library == "neuralforecast":
×
512
                cuda_devices_here = "0"
×
513
                if "CUDA_VISIBLE_DEVICES" in os.environ:
×
514
                    cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]
×
515

516
                with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
×
517
                    model.fit(df=data, val_size=horizon)
×
518
                    model.save(model_path, overwrite=True)
×
519
            else:
520
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
521
                for col in data["unique_id"].unique():
×
522
                    if len(data[data["unique_id"] == col]) == 1:
×
523
                        data = data._append(
×
524
                            [data[data["unique_id"] == col]], ignore_index=True
525
                        )
526

527
                model.fit(df=data[["ds", "y", "unique_id"]])
×
528
                f = open(model_path, "wb")
×
529
                pickle.dump(model, f)
×
530
                f.close()
×
531
        elif not Path(model_path).exists():
×
532
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
533

534
        io_list = self._resolve_function_io(None)
×
535

536
        metadata_here = [
×
537
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
538
            FunctionMetadataCatalogEntry("model_path", model_path),
539
            FunctionMetadataCatalogEntry(
540
                "predict_column_rename", arg_map.get("predict", "y")
541
            ),
542
            FunctionMetadataCatalogEntry(
543
                "time_column_rename", arg_map.get("time", "ds")
544
            ),
545
            FunctionMetadataCatalogEntry(
546
                "id_column_rename", arg_map.get("id", "unique_id")
547
            ),
548
            FunctionMetadataCatalogEntry("horizon", horizon),
549
            FunctionMetadataCatalogEntry("library", library),
550
        ]
551

552
        return (
×
553
            self.node.name,
554
            impl_path,
555
            self.node.function_type,
556
            io_list,
557
            metadata_here,
558
        )
559

560
    def handle_generic_function(self):
2✔
561
        """Handle generic functions
562

563
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
564
        """
565
        impl_path = self.node.impl_path.absolute().as_posix()
2✔
566
        function = self._try_initializing_function(impl_path)
2✔
567
        io_list = self._resolve_function_io(function)
2✔
568

569
        return (
2✔
570
            self.node.name,
571
            impl_path,
572
            self.node.function_type,
573
            io_list,
574
            self.node.metadata,
575
        )
576

577
    def exec(self, *args, **kwargs):
2✔
578
        """Create function executor
579

580
        Calls the catalog to insert a function catalog entry.
581
        """
582
        assert (
2✔
583
            self.node.if_not_exists and self.node.or_replace
584
        ) is False, (
585
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
586
        )
587

588
        overwrite = False
2✔
589
        # check catalog if it already has this function entry
590
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
2✔
591
            if self.node.if_not_exists:
1✔
592
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
593
                yield Batch(pd.DataFrame([msg]))
1✔
594
                return
×
595
            elif self.node.or_replace:
1✔
596
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
597
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
598

599
                drop_executor = DropObjectExecutor(self.db, None)
1✔
600
                try:
1✔
601
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
1✔
602
                except RuntimeError:
603
                    pass
604
                else:
605
                    overwrite = True
1✔
606
            else:
607
                msg = f"Function {self.node.name} already exists."
×
608
                logger.error(msg)
×
609
                raise RuntimeError(msg)
610

611
        # if it's a type of HuggingFaceModel, override the impl_path
612
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
2✔
613
            (
×
614
                name,
615
                impl_path,
616
                function_type,
617
                io_list,
618
                metadata,
619
            ) = self.handle_huggingface_function()
620
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
2✔
621
            (
2✔
622
                name,
623
                impl_path,
624
                function_type,
625
                io_list,
626
                metadata,
627
            ) = self.handle_ultralytics_function()
628
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
2✔
629
            (
×
630
                name,
631
                impl_path,
632
                function_type,
633
                io_list,
634
                metadata,
635
            ) = self.handle_ludwig_function()
636
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
2✔
637
            (
×
638
                name,
639
                impl_path,
640
                function_type,
641
                io_list,
642
                metadata,
643
            ) = self.handle_sklearn_function()
644
        elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"):
2✔
645
            (
×
646
                name,
647
                impl_path,
648
                function_type,
649
                io_list,
650
                metadata,
651
            ) = self.handle_xgboost_function()
652
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
2✔
653
            (
×
654
                name,
655
                impl_path,
656
                function_type,
657
                io_list,
658
                metadata,
659
            ) = self.handle_forecasting_function()
660
        else:
661
            (
2✔
662
                name,
663
                impl_path,
664
                function_type,
665
                io_list,
666
                metadata,
667
            ) = self.handle_generic_function()
668

669
        self.catalog().insert_function_catalog_entry(
2✔
670
            name, impl_path, function_type, io_list, metadata
671
        )
672

673
        if overwrite:
2✔
674
            msg = f"Function {self.node.name} overwritten."
1✔
675
        else:
676
            msg = f"Function {self.node.name} added to the database."
2✔
677
        yield Batch(pd.DataFrame([msg]))
2✔
678

679
    def _try_initializing_function(
2✔
680
        self, impl_path: str, function_args: Dict = {}
681
    ) -> FunctionCatalogEntry:
682
        """Attempts to initialize function given the implementation file path and arguments.
683

684
        Args:
685
            impl_path (str): The file path of the function implementation file.
686
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
687

688
        Returns:
689
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
690

691
        Raises:
692
            RuntimeError: If an error occurs while initializing the function.
693
        """
694

695
        # load the function class from the file
696
        try:
2✔
697
            # loading the function class from the file
698
            function = load_function_class_from_file(impl_path, self.node.name)
2✔
699
            # initializing the function class calls the setup method internally
700
            function(**function_args)
2✔
701
        except Exception as e:
702
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
703
            # logger.error(err_msg)
704
            raise RuntimeError(err_msg)
705

706
        return function
2✔
707

708
    def _resolve_function_io(
2✔
709
        self, function: FunctionCatalogEntry
710
    ) -> List[FunctionIOCatalogEntry]:
711
        """Private method that resolves the input/output definitions for a given function.
712
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
713

714
        Args:
715
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
716

717
        Returns:
718
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
719
            output definitions for the function.
720

721
        Raises:
722
            RuntimeError: If an error occurs while resolving the function input/output
723
            definitions.
724
        """
725
        io_list = []
2✔
726
        try:
2✔
727
            if self.node.inputs:
2✔
728
                io_list.extend(self.node.inputs)
2✔
729
            else:
730
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
731
                io_list.extend(
2✔
732
                    load_io_from_function_decorators(function, is_input=True)
733
                )
734

735
            if self.node.outputs:
2✔
736
                io_list.extend(self.node.outputs)
2✔
737
            else:
738
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
739
                io_list.extend(
2✔
740
                    load_io_from_function_decorators(function, is_input=False)
741
                )
742

743
        except FunctionIODefinitionError as e:
744
            err_msg = (
745
                f"Error creating function, input/output definition incorrect: {str(e)}"
746
            )
747
            logger.error(err_msg)
748
            raise RuntimeError(err_msg)
749

750
        return io_list
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc