• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / c60bc8c4-3c84-48a0-8de6-9eddb48baa31

17 Oct 2023 03:06PM UTC coverage: 78.624% (+78.6%) from 0.0%
c60bc8c4-3c84-48a0-8de6-9eddb48baa31

Pull #1283

circle-ci

americast
update with contextmanager
Pull Request #1283: Fix current issues with forecasting

37 of 37 new or added lines in 2 files covered. (100.0%)

9894 of 12584 relevant lines covered (78.62%)

1.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.53
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
2✔
16
import os
2✔
17
import pickle
2✔
18
import re
2✔
19
from pathlib import Path
2✔
20
from typing import Dict, List
2✔
21

22
import pandas as pd
2✔
23

24
from evadb.catalog.catalog_utils import get_metadata_properties
2✔
25
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
2✔
26
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
2✔
27
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
2✔
28
from evadb.configuration.constants import (
2✔
29
    DEFAULT_TRAIN_TIME_LIMIT,
30
    EvaDB_INSTALLATION_DIR,
31
)
32
from evadb.database import EvaDBDatabase
2✔
33
from evadb.executor.abstract_executor import AbstractExecutor
2✔
34
from evadb.functions.decorators.utils import load_io_from_function_decorators
2✔
35
from evadb.models.storage.batch import Batch
2✔
36
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
2✔
37
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
2✔
38
from evadb.utils.errors import FunctionIODefinitionError
2✔
39
from evadb.utils.generic_utils import (
2✔
40
    load_function_class_from_file,
41
    string_comparison_case_insensitive,
42
    try_to_import_ludwig,
43
    try_to_import_neuralforecast,
44
    try_to_import_sklearn,
45
    try_to_import_statsforecast,
46
    try_to_import_torch,
47
    try_to_import_ultralytics,
48
)
49
from evadb.utils.logging_manager import logger
2✔
50

51

52
class CreateFunctionExecutor(AbstractExecutor):
2✔
53
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
2✔
54
        super().__init__(db, node)
2✔
55
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
2✔
56

57
    def handle_huggingface_function(self):
2✔
58
        """Handle HuggingFace functions
59

60
        HuggingFace functions are special functions that are not loaded from a file.
61
        So we do not need to call the setup method on them like we do for other functions.
62
        """
63
        # We need at least one deep learning framework for HuggingFace
64
        # Torch or Tensorflow
65
        try_to_import_torch()
×
66
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
67
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
68
        return (
×
69
            self.node.name,
70
            impl_path,
71
            self.node.function_type,
72
            io_list,
73
            self.node.metadata,
74
        )
75

76
    def handle_ludwig_function(self):
2✔
77
        """Handle ludwig functions
78

79
        Use Ludwig's auto_train engine to train/tune models.
80
        """
81
        try_to_import_ludwig()
×
82
        from ludwig.automl import auto_train
×
83

84
        assert (
×
85
            len(self.children) == 1
86
        ), "Create ludwig function expects 1 child, finds {}.".format(
87
            len(self.children)
88
        )
89

90
        aggregated_batch_list = []
×
91
        child = self.children[0]
×
92
        for batch in child.exec():
×
93
            aggregated_batch_list.append(batch)
×
94
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
95
        aggregated_batch.drop_column_alias()
×
96

97
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
98
        auto_train_results = auto_train(
×
99
            dataset=aggregated_batch.frames,
100
            target=arg_map["predict"],
101
            tune_for_memory=arg_map.get("tune_for_memory", False),
102
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
103
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
104
        )
105
        model_path = os.path.join(
×
106
            self.db.config.get_value("storage", "model_dir"), self.node.name
107
        )
108
        auto_train_results.best_model.save(model_path)
×
109
        self.node.metadata.append(
×
110
            FunctionMetadataCatalogEntry("model_path", model_path)
111
        )
112

113
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
114
        io_list = self._resolve_function_io(None)
×
115
        return (
×
116
            self.node.name,
117
            impl_path,
118
            self.node.function_type,
119
            io_list,
120
            self.node.metadata,
121
        )
122

123
    def handle_sklearn_function(self):
2✔
124
        """Handle sklearn functions
125

126
        Use Sklearn's regression to train models.
127
        """
128
        try_to_import_sklearn()
×
129
        from sklearn.linear_model import LinearRegression
×
130

131
        assert (
×
132
            len(self.children) == 1
133
        ), "Create sklearn function expects 1 child, finds {}.".format(
134
            len(self.children)
135
        )
136

137
        aggregated_batch_list = []
×
138
        child = self.children[0]
×
139
        for batch in child.exec():
×
140
            aggregated_batch_list.append(batch)
×
141
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
142
        aggregated_batch.drop_column_alias()
×
143

144
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
145
        model = LinearRegression()
×
146
        Y = aggregated_batch.frames[arg_map["predict"]]
×
147
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
148
        model.fit(X=aggregated_batch.frames, y=Y)
×
149
        model_path = os.path.join(
×
150
            self.db.config.get_value("storage", "model_dir"), self.node.name
151
        )
152
        pickle.dump(model, open(model_path, "wb"))
×
153
        self.node.metadata.append(
×
154
            FunctionMetadataCatalogEntry("model_path", model_path)
155
        )
156

157
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
158
        io_list = self._resolve_function_io(None)
×
159
        return (
×
160
            self.node.name,
161
            impl_path,
162
            self.node.function_type,
163
            io_list,
164
            self.node.metadata,
165
        )
166

167
    def convert_to_numeric(self, x):
2✔
168
        x = re.sub("[^0-9.]", "", str(x))
×
169
        try:
×
170
            return int(x)
×
171
        except ValueError:
172
            try:
173
                return float(x)
174
            except ValueError:
175
                return x
176

177
    def handle_ultralytics_function(self):
2✔
178
        """Handle Ultralytics functions"""
179
        try_to_import_ultralytics()
2✔
180

181
        impl_path = (
2✔
182
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
183
        )
184
        function = self._try_initializing_function(
2✔
185
            impl_path, function_args=get_metadata_properties(self.node)
186
        )
187
        io_list = self._resolve_function_io(function)
2✔
188
        return (
2✔
189
            self.node.name,
190
            impl_path,
191
            self.node.function_type,
192
            io_list,
193
            self.node.metadata,
194
        )
195

196
    def handle_forecasting_function(self):
2✔
197
        """Handle forecasting functions"""
198
        save_old_cuda_env = None
×
199
        if "CUDA_VISIBLE_DEVICES" in os.environ:
×
200
            save_old_cuda_env = os.environ["CUDA_VISIBLE_DEVICES"]
×
201
            os.environ["CUDA_VISIBLE_DEVICES"] = save_old_cuda_env.split(",")[0]
×
202
        else:
203
            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
×
204
        aggregated_batch_list = []
×
205
        child = self.children[0]
×
206
        for batch in child.exec():
×
207
            aggregated_batch_list.append(batch)
×
208
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
209
        aggregated_batch.drop_column_alias()
×
210

211
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
212
        if not self.node.impl_path:
×
213
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
214
        else:
215
            impl_path = self.node.impl_path.absolute().as_posix()
×
216
        library = "statsforecast"
×
217
        supported_libraries = ["statsforecast", "neuralforecast"]
×
218

219
        if "horizon" not in arg_map.keys():
×
220
            raise ValueError(
221
                "Horizon must be provided while creating function of type FORECASTING"
222
            )
223
        try:
×
224
            horizon = int(arg_map["horizon"])
×
225
        except Exception as e:
226
            err_msg = f"{str(e)}. HORIZON must be integral."
227
            logger.error(err_msg)
228
            raise FunctionIODefinitionError(err_msg)
229

230
        if "library" in arg_map.keys():
×
231
            try:
×
232
                assert arg_map["library"].lower() in supported_libraries
×
233
            except Exception:
234
                err_msg = (
235
                    "EvaDB currently supports " + str(supported_libraries) + " only."
236
                )
237
                logger.error(err_msg)
238
                raise FunctionIODefinitionError(err_msg)
239

240
            library = arg_map["library"].lower()
×
241

242
        """
×
243
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
244
        - The unique_id (string, int or category) represents an identifier for the series.
245
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
246
        - The y (numeric) represents the measurement we wish to forecast.
247
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
248
        """
249
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
250
        if "time" in arg_map.keys():
×
251
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
252
        if "id" in arg_map.keys():
×
253
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
254

255
        data = aggregated_batch.frames
×
256
        if "unique_id" not in list(data.columns):
×
257
            data["unique_id"] = [1 for x in range(len(data))]
×
258

259
        if "ds" not in list(data.columns):
×
260
            data["ds"] = [x + 1 for x in range(len(data))]
×
261

262
        """
×
263
            Set or infer data frequency
264
        """
265

266
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
267
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
268
        frequency = arg_map["frequency"]
×
269
        if frequency is None:
×
270
            raise RuntimeError(
271
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
272
            )
273

274
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
275
            "H": 24,
276
            "M": 12,
277
            "Q": 4,
278
            "SM": 24,
279
            "BM": 12,
280
            "BMS": 12,
281
            "BQ": 4,
282
            "BH": 24,
283
        }
284

285
        new_freq = (
×
286
            frequency.split("-")[0] if "-" in frequency else frequency
287
        )  # shortens longer frequencies like Q-DEC
288
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
289

290
        """
×
291
            Neuralforecast implementation
292
        """
293
        if library == "neuralforecast":
×
294
            try_to_import_neuralforecast()
×
295
            from neuralforecast import NeuralForecast
×
296
            from neuralforecast.auto import AutoNBEATS, AutoNHITS
×
297
            from neuralforecast.models import NBEATS, NHITS
×
298

299
            model_dict = {
×
300
                "AutoNBEATS": AutoNBEATS,
301
                "AutoNHITS": AutoNHITS,
302
                "NBEATS": NBEATS,
303
                "NHITS": NHITS,
304
            }
305

306
            if "model" not in arg_map.keys():
×
307
                arg_map["model"] = "NBEATS"
×
308

309
            if "auto" not in arg_map.keys() or (
×
310
                arg_map["auto"].lower()[0] == "t"
311
                and "auto" not in arg_map["model"].lower()
312
            ):
313
                arg_map["model"] = "Auto" + arg_map["model"]
×
314

315
            try:
×
316
                model_here = model_dict[arg_map["model"]]
×
317
            except Exception:
318
                err_msg = "Supported models: " + str(model_dict.keys())
319
                logger.error(err_msg)
320
                raise FunctionIODefinitionError(err_msg)
321
            model_args = {}
×
322

323
            if "auto" not in arg_map["model"].lower():
×
324
                model_args["input_size"] = 2 * horizon
×
325
                model_args["early_stop_patience_steps"] = 20
×
326
            else:
327
                model_args_config = {
×
328
                    "input_size": 2 * horizon,
329
                    "early_stop_patience_steps": 20,
330
                }
331

332
            if len(data.columns) >= 4:
×
333
                exogenous_columns = [
×
334
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
335
                ]
336
                if "auto" not in arg_map["model"].lower():
×
337
                    model_args["hist_exog_list"] = exogenous_columns
×
338
                else:
339
                    model_args_config["hist_exog_list"] = exogenous_columns
×
340

341
                    def get_optuna_config(trial):
×
342
                        return model_args_config
×
343

344
                    model_args["config"] = get_optuna_config
×
345
                    model_args["backend"] = "optuna"
×
346

347
            model_args["h"] = horizon
×
348

349
            model = NeuralForecast(
×
350
                [model_here(**model_args)],
351
                freq=new_freq,
352
            )
353

354
        # """
355
        #     Statsforecast implementation
356
        # """
357
        else:
358
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
359
                raise RuntimeError(
360
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
361
                )
362
            try_to_import_statsforecast()
×
363
            from statsforecast import StatsForecast
×
364
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
365

366
            model_dict = {
×
367
                "AutoARIMA": AutoARIMA,
368
                "AutoCES": AutoCES,
369
                "AutoETS": AutoETS,
370
                "AutoTheta": AutoTheta,
371
            }
372

373
            if "model" not in arg_map.keys():
×
374
                arg_map["model"] = "ARIMA"
×
375

376
            if "auto" not in arg_map["model"].lower():
×
377
                arg_map["model"] = "Auto" + arg_map["model"]
×
378

379
            try:
×
380
                model_here = model_dict[arg_map["model"]]
×
381
            except Exception:
382
                err_msg = "Supported models: " + str(model_dict.keys())
383
                logger.error(err_msg)
384
                raise FunctionIODefinitionError(err_msg)
385

386
            model = StatsForecast(
×
387
                [model_here(season_length=season_length)], freq=new_freq
388
            )
389

390
        data["ds"] = pd.to_datetime(data["ds"])
×
391

392
        model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
×
393
        if len(data.columns) >= 4 and library == "neuralforecast":
×
394
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
395

396
        model_dir = os.path.join(
×
397
            self.db.config.get_value("storage", "model_dir"),
398
            "tsforecasting",
399
            model_save_dir_name,
400
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
401
        )
402
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
403

404
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
405

406
        model_path = os.path.join(model_dir, model_save_name)
×
407

408
        existing_model_files = sorted(
×
409
            os.listdir(model_dir),
410
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
411
        )
412
        existing_model_files = [
×
413
            x
414
            for x in existing_model_files
415
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
416
        ]
417
        if len(existing_model_files) == 0:
×
418
            logger.info("Training, please wait...")
×
419
            for column in data.columns:
×
420
                if column != "ds" and column != "unique_id":
×
421
                    data[column] = data.apply(
×
422
                        lambda x: self.convert_to_numeric(x[column]), axis=1
423
                    )
424
            if library == "neuralforecast":
×
425
                model.fit(df=data, val_size=horizon)
×
426
                model.save(model_path, overwrite=True)
×
427
            else:
428
                # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
429
                for col in data["unique_id"].unique():
×
430
                    if len(data[data["unique_id"] == col]) == 1:
×
431
                        data = data._append(
×
432
                            [data[data["unique_id"] == col]], ignore_index=True
433
                        )
434

435
                model.fit(df=data[["ds", "y", "unique_id"]])
×
436
                f = open(model_path, "wb")
×
437
                pickle.dump(model, f)
×
438
                f.close()
×
439
        elif not Path(model_path).exists():
×
440
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
441

442
        io_list = self._resolve_function_io(None)
×
443

444
        metadata_here = [
×
445
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
446
            FunctionMetadataCatalogEntry("model_path", model_path),
447
            FunctionMetadataCatalogEntry(
448
                "predict_column_rename", arg_map.get("predict", "y")
449
            ),
450
            FunctionMetadataCatalogEntry(
451
                "time_column_rename", arg_map.get("time", "ds")
452
            ),
453
            FunctionMetadataCatalogEntry(
454
                "id_column_rename", arg_map.get("id", "unique_id")
455
            ),
456
            FunctionMetadataCatalogEntry("horizon", horizon),
457
            FunctionMetadataCatalogEntry("library", library),
458
        ]
459

460
        if save_old_cuda_env is not None:
×
461
            os.environ["CUDA_VISIBLE_DEVICES"] = save_old_cuda_env
×
462
        else:
463
            os.environ.pop("CUDA_VISIBLE_DEVICES", None)
×
464

465
        return (
×
466
            self.node.name,
467
            impl_path,
468
            self.node.function_type,
469
            io_list,
470
            metadata_here,
471
        )
472

473
    def handle_generic_function(self):
2✔
474
        """Handle generic functions
475

476
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
477
        """
478
        impl_path = self.node.impl_path.absolute().as_posix()
2✔
479
        function = self._try_initializing_function(impl_path)
2✔
480
        io_list = self._resolve_function_io(function)
2✔
481

482
        return (
2✔
483
            self.node.name,
484
            impl_path,
485
            self.node.function_type,
486
            io_list,
487
            self.node.metadata,
488
        )
489

490
    def exec(self, *args, **kwargs):
2✔
491
        """Create function executor
492

493
        Calls the catalog to insert a function catalog entry.
494
        """
495
        assert (
2✔
496
            self.node.if_not_exists and self.node.or_replace
497
        ) is False, (
498
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
499
        )
500

501
        overwrite = False
2✔
502
        # check catalog if it already has this function entry
503
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
2✔
504
            if self.node.if_not_exists:
1✔
505
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
506
                yield Batch(pd.DataFrame([msg]))
1✔
507
                return
×
508
            elif self.node.or_replace:
1✔
509
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
510
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
511

512
                drop_executor = DropObjectExecutor(self.db, None)
1✔
513
                try:
1✔
514
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
1✔
515
                except RuntimeError:
516
                    pass
517
                else:
518
                    overwrite = True
1✔
519
            else:
520
                msg = f"Function {self.node.name} already exists."
×
521
                logger.error(msg)
×
522
                raise RuntimeError(msg)
523

524
        # if it's a type of HuggingFaceModel, override the impl_path
525
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
2✔
526
            (
×
527
                name,
528
                impl_path,
529
                function_type,
530
                io_list,
531
                metadata,
532
            ) = self.handle_huggingface_function()
533
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
2✔
534
            (
2✔
535
                name,
536
                impl_path,
537
                function_type,
538
                io_list,
539
                metadata,
540
            ) = self.handle_ultralytics_function()
541
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
2✔
542
            (
×
543
                name,
544
                impl_path,
545
                function_type,
546
                io_list,
547
                metadata,
548
            ) = self.handle_ludwig_function()
549
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
2✔
550
            (
×
551
                name,
552
                impl_path,
553
                function_type,
554
                io_list,
555
                metadata,
556
            ) = self.handle_sklearn_function()
557
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
2✔
558
            (
×
559
                name,
560
                impl_path,
561
                function_type,
562
                io_list,
563
                metadata,
564
            ) = self.handle_forecasting_function()
565
        else:
566
            (
2✔
567
                name,
568
                impl_path,
569
                function_type,
570
                io_list,
571
                metadata,
572
            ) = self.handle_generic_function()
573

574
        self.catalog().insert_function_catalog_entry(
2✔
575
            name, impl_path, function_type, io_list, metadata
576
        )
577

578
        if overwrite:
2✔
579
            msg = f"Function {self.node.name} overwritten."
1✔
580
        else:
581
            msg = f"Function {self.node.name} added to the database."
2✔
582
        yield Batch(pd.DataFrame([msg]))
2✔
583

584
    def _try_initializing_function(
2✔
585
        self, impl_path: str, function_args: Dict = {}
586
    ) -> FunctionCatalogEntry:
587
        """Attempts to initialize function given the implementation file path and arguments.
588

589
        Args:
590
            impl_path (str): The file path of the function implementation file.
591
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
592

593
        Returns:
594
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
595

596
        Raises:
597
            RuntimeError: If an error occurs while initializing the function.
598
        """
599

600
        # load the function class from the file
601
        try:
2✔
602
            # loading the function class from the file
603
            function = load_function_class_from_file(impl_path, self.node.name)
2✔
604
            # initializing the function class calls the setup method internally
605
            function(**function_args)
2✔
606
        except Exception as e:
607
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
608
            # logger.error(err_msg)
609
            raise RuntimeError(err_msg)
610

611
        return function
2✔
612

613
    def _resolve_function_io(
2✔
614
        self, function: FunctionCatalogEntry
615
    ) -> List[FunctionIOCatalogEntry]:
616
        """Private method that resolves the input/output definitions for a given function.
617
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
618

619
        Args:
620
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
621

622
        Returns:
623
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
624
            output definitions for the function.
625

626
        Raises:
627
            RuntimeError: If an error occurs while resolving the function input/output
628
            definitions.
629
        """
630
        io_list = []
2✔
631
        try:
2✔
632
            if self.node.inputs:
2✔
633
                io_list.extend(self.node.inputs)
2✔
634
            else:
635
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
636
                io_list.extend(
2✔
637
                    load_io_from_function_decorators(function, is_input=True)
638
                )
639

640
            if self.node.outputs:
2✔
641
                io_list.extend(self.node.outputs)
2✔
642
            else:
643
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
644
                io_list.extend(
2✔
645
                    load_io_from_function_decorators(function, is_input=False)
646
                )
647

648
        except FunctionIODefinitionError as e:
649
            err_msg = (
650
                f"Error creating function, input/output definition incorrect: {str(e)}"
651
            )
652
            logger.error(err_msg)
653
            raise RuntimeError(err_msg)
654

655
        return io_list
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc