• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #836

18 Oct 2023 06:28PM UTC coverage: 0.0% (-78.6%) from 78.602%
#836

push

circle-ci

Andy Xu
Skip the limit test

0 of 12321 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
×
16
import os
×
17
import pickle
×
18
from pathlib import Path
×
19
from typing import Dict, List
×
20

21
import pandas as pd
×
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
×
24
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
×
25
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
×
26
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
×
27
from evadb.configuration.constants import (
×
28
    DEFAULT_TRAIN_TIME_LIMIT,
29
    EvaDB_INSTALLATION_DIR,
30
)
31
from evadb.database import EvaDBDatabase
×
32
from evadb.executor.abstract_executor import AbstractExecutor
×
33
from evadb.functions.decorators.utils import load_io_from_function_decorators
×
34
from evadb.models.storage.batch import Batch
×
35
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
×
36
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
×
37
from evadb.utils.errors import FunctionIODefinitionError
×
38
from evadb.utils.generic_utils import (
×
39
    load_function_class_from_file,
40
    string_comparison_case_insensitive,
41
    try_to_import_ludwig,
42
    try_to_import_neuralforecast,
43
    try_to_import_sklearn,
44
    try_to_import_statsforecast,
45
    try_to_import_torch,
46
    try_to_import_ultralytics,
47
)
48
from evadb.utils.logging_manager import logger
×
49

50

51
class CreateFunctionExecutor(AbstractExecutor):
×
52
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
×
53
        super().__init__(db, node)
×
54
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
×
55

56
    def handle_huggingface_function(self):
×
57
        """Handle HuggingFace functions
58

59
        HuggingFace functions are special functions that are not loaded from a file.
60
        So we do not need to call the setup method on them like we do for other functions.
61
        """
62
        # We need at least one deep learning framework for HuggingFace
63
        # Torch or Tensorflow
64
        try_to_import_torch()
×
65
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
66
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
67
        return (
×
68
            self.node.name,
69
            impl_path,
70
            self.node.function_type,
71
            io_list,
72
            self.node.metadata,
73
        )
74

75
    def handle_ludwig_function(self):
×
76
        """Handle ludwig functions
77

78
        Use Ludwig's auto_train engine to train/tune models.
79
        """
80
        try_to_import_ludwig()
×
81
        from ludwig.automl import auto_train
×
82

83
        assert (
×
84
            len(self.children) == 1
85
        ), "Create ludwig function expects 1 child, finds {}.".format(
86
            len(self.children)
87
        )
88

89
        aggregated_batch_list = []
×
90
        child = self.children[0]
×
91
        for batch in child.exec():
×
92
            aggregated_batch_list.append(batch)
×
93
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
94
        aggregated_batch.drop_column_alias()
×
95

96
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
97
        auto_train_results = auto_train(
×
98
            dataset=aggregated_batch.frames,
99
            target=arg_map["predict"],
100
            tune_for_memory=arg_map.get("tune_for_memory", False),
101
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
102
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
103
        )
104
        model_path = os.path.join(
×
105
            self.db.config.get_value("storage", "model_dir"), self.node.name
106
        )
107
        auto_train_results.best_model.save(model_path)
×
108
        self.node.metadata.append(
×
109
            FunctionMetadataCatalogEntry("model_path", model_path)
110
        )
111

112
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
113
        io_list = self._resolve_function_io(None)
×
114
        return (
×
115
            self.node.name,
116
            impl_path,
117
            self.node.function_type,
118
            io_list,
119
            self.node.metadata,
120
        )
121

122
    def handle_sklearn_function(self):
×
123
        """Handle sklearn functions
124

125
        Use Sklearn's regression to train models.
126
        """
127
        try_to_import_sklearn()
×
128
        from sklearn.linear_model import LinearRegression
×
129

130
        assert (
×
131
            len(self.children) == 1
132
        ), "Create sklearn function expects 1 child, finds {}.".format(
133
            len(self.children)
134
        )
135

136
        aggregated_batch_list = []
×
137
        child = self.children[0]
×
138
        for batch in child.exec():
×
139
            aggregated_batch_list.append(batch)
×
140
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
141
        aggregated_batch.drop_column_alias()
×
142

143
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
144
        model = LinearRegression()
×
145
        Y = aggregated_batch.frames[arg_map["predict"]]
×
146
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
147
        model.fit(X=aggregated_batch.frames, y=Y)
×
148
        model_path = os.path.join(
×
149
            self.db.config.get_value("storage", "model_dir"), self.node.name
150
        )
151
        pickle.dump(model, open(model_path, "wb"))
×
152
        self.node.metadata.append(
×
153
            FunctionMetadataCatalogEntry("model_path", model_path)
154
        )
155

156
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
157
        io_list = self._resolve_function_io(None)
×
158
        return (
×
159
            self.node.name,
160
            impl_path,
161
            self.node.function_type,
162
            io_list,
163
            self.node.metadata,
164
        )
165

166
    def handle_ultralytics_function(self):
×
167
        """Handle Ultralytics functions"""
168
        try_to_import_ultralytics()
×
169

170
        impl_path = (
×
171
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
172
        )
173
        function = self._try_initializing_function(
×
174
            impl_path, function_args=get_metadata_properties(self.node)
175
        )
176
        io_list = self._resolve_function_io(function)
×
177
        return (
×
178
            self.node.name,
179
            impl_path,
180
            self.node.function_type,
181
            io_list,
182
            self.node.metadata,
183
        )
184

185
    def handle_forecasting_function(self):
×
186
        """Handle forecasting functions"""
187
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
×
188
        aggregated_batch_list = []
×
189
        child = self.children[0]
×
190
        for batch in child.exec():
×
191
            aggregated_batch_list.append(batch)
×
192
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
193
        aggregated_batch.drop_column_alias()
×
194

195
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
196
        if not self.node.impl_path:
×
197
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
198
        else:
199
            impl_path = self.node.impl_path.absolute().as_posix()
×
200
        library = "statsforecast"
×
201
        supported_libraries = ["statsforecast", "neuralforecast"]
×
202

203
        if "horizon" not in arg_map.keys():
×
204
            raise ValueError(
205
                "Horizon must be provided while creating function of type FORECASTING"
206
            )
207
        try:
×
208
            horizon = int(arg_map["horizon"])
×
209
        except Exception as e:
210
            err_msg = f"{str(e)}. HORIZON must be integral."
211
            logger.error(err_msg)
212
            raise FunctionIODefinitionError(err_msg)
213

214
        if "library" in arg_map.keys():
×
215
            try:
×
216
                assert arg_map["library"].lower() in supported_libraries
×
217
            except Exception:
218
                err_msg = (
219
                    "EvaDB currently supports " + str(supported_libraries) + " only."
220
                )
221
                logger.error(err_msg)
222
                raise FunctionIODefinitionError(err_msg)
223

224
            library = arg_map["library"].lower()
×
225

226
        """
×
227
        The following rename is needed for statsforecast/neuralforecast, which requires the column name to be the following:
228
        - The unique_id (string, int or category) represents an identifier for the series.
229
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
230
        - The y (numeric) represents the measurement we wish to forecast.
231
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
232
        """
233
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
234
        if "time" in arg_map.keys():
×
235
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
236
        if "id" in arg_map.keys():
×
237
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
238

239
        data = aggregated_batch.frames
×
240
        if "unique_id" not in list(data.columns):
×
241
            data["unique_id"] = [1 for x in range(len(data))]
×
242

243
        if "ds" not in list(data.columns):
×
244
            data["ds"] = [x + 1 for x in range(len(data))]
×
245

246
        """
×
247
            Set or infer data frequency
248
        """
249

250
        if "frequency" not in arg_map.keys() or arg_map["frequency"] == "auto":
×
251
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
252
        frequency = arg_map["frequency"]
×
253
        if frequency is None:
×
254
            raise RuntimeError(
255
                f"Can not infer the frequency for {self.node.name}. Please explicitly set it."
256
            )
257

258
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
259
            "H": 24,
260
            "M": 12,
261
            "Q": 4,
262
            "SM": 24,
263
            "BM": 12,
264
            "BMS": 12,
265
            "BQ": 4,
266
            "BH": 24,
267
        }
268

269
        new_freq = (
×
270
            frequency.split("-")[0] if "-" in frequency else frequency
271
        )  # shortens longer frequencies like Q-DEC
272
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
273

274
        """
×
275
            Neuralforecast implementation
276
        """
277
        if library == "neuralforecast":
×
278
            try_to_import_neuralforecast()
×
279
            from neuralforecast import NeuralForecast
×
280
            from neuralforecast.auto import AutoNBEATS, AutoNHITS
×
281
            from neuralforecast.models import NBEATS, NHITS
×
282

283
            model_dict = {
×
284
                "AutoNBEATS": AutoNBEATS,
285
                "AutoNHITS": AutoNHITS,
286
                "NBEATS": NBEATS,
287
                "NHITS": NHITS,
288
            }
289

290
            if "model" not in arg_map.keys():
×
291
                arg_map["model"] = "NBEATS"
×
292

293
            if "auto" not in arg_map.keys() or (
×
294
                arg_map["auto"].lower()[0] == "t"
295
                and "auto" not in arg_map["model"].lower()
296
            ):
297
                arg_map["model"] = "Auto" + arg_map["model"]
×
298

299
            try:
×
300
                model_here = model_dict[arg_map["model"]]
×
301
            except Exception:
302
                err_msg = "Supported models: " + str(model_dict.keys())
303
                logger.error(err_msg)
304
                raise FunctionIODefinitionError(err_msg)
305
            model_args = {}
×
306

307
            if "auto" not in arg_map["model"].lower():
×
308
                model_args["input_size"] = 2 * horizon
×
309
                model_args["early_stop_patience_steps"] = 20
×
310
            else:
311
                model_args["config"] = {
×
312
                    "input_size": 2 * horizon,
313
                    "early_stop_patience_steps": 20,
314
                }
315

316
            if len(data.columns) >= 4:
×
317
                exogenous_columns = [
×
318
                    x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
319
                ]
320
                if "auto" not in arg_map["model"].lower():
×
321
                    model_args["hist_exog_list"] = exogenous_columns
×
322
                else:
323
                    model_args["config"]["hist_exog_list"] = exogenous_columns
×
324

325
            model_args["h"] = horizon
×
326

327
            model = NeuralForecast(
×
328
                [model_here(**model_args)],
329
                freq=new_freq,
330
            )
331

332
        # """
333
        #     Statsforecast implementation
334
        # """
335
        else:
336
            if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
×
337
                raise RuntimeError(
338
                    "Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
339
                )
340
            try_to_import_statsforecast()
×
341
            from statsforecast import StatsForecast
×
342
            from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
343

344
            model_dict = {
×
345
                "AutoARIMA": AutoARIMA,
346
                "AutoCES": AutoCES,
347
                "AutoETS": AutoETS,
348
                "AutoTheta": AutoTheta,
349
            }
350

351
            if "model" not in arg_map.keys():
×
352
                arg_map["model"] = "ARIMA"
×
353

354
            if "auto" not in arg_map["model"].lower():
×
355
                arg_map["model"] = "Auto" + arg_map["model"]
×
356

357
            try:
×
358
                model_here = model_dict[arg_map["model"]]
×
359
            except Exception:
360
                err_msg = "Supported models: " + str(model_dict.keys())
361
                logger.error(err_msg)
362
                raise FunctionIODefinitionError(err_msg)
363

364
            model = StatsForecast(
×
365
                [model_here(season_length=season_length)], freq=new_freq
366
            )
367

368
        data["ds"] = pd.to_datetime(data["ds"])
×
369

370
        model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
×
371
        if len(data.columns) >= 4 and library == "neuralforecast":
×
372
            model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))
×
373

374
        model_dir = os.path.join(
×
375
            self.db.config.get_value("storage", "model_dir"),
376
            "tsforecasting",
377
            model_save_dir_name,
378
            str(hashlib.sha256(data.to_string().encode()).hexdigest()),
379
        )
380
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
381

382
        model_save_name = "horizon" + str(horizon) + ".pkl"
×
383

384
        model_path = os.path.join(model_dir, model_save_name)
×
385

386
        existing_model_files = sorted(
×
387
            os.listdir(model_dir),
388
            key=lambda x: int(x.split("horizon")[1].split(".pkl")[0]),
389
        )
390
        existing_model_files = [
×
391
            x
392
            for x in existing_model_files
393
            if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
394
        ]
395
        if len(existing_model_files) == 0:
×
396
            logger.info("Training, please wait...")
×
397
            if library == "neuralforecast":
×
398
                model.fit(df=data, val_size=horizon)
×
399
            else:
400
                model.fit(df=data[["ds", "y", "unique_id"]])
×
401
            f = open(model_path, "wb")
×
402
            pickle.dump(model, f)
×
403
            f.close()
×
404
        elif not Path(model_path).exists():
×
405
            model_path = os.path.join(model_dir, existing_model_files[-1])
×
406

407
        io_list = self._resolve_function_io(None)
×
408

409
        metadata_here = [
×
410
            FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
411
            FunctionMetadataCatalogEntry("model_path", model_path),
412
            FunctionMetadataCatalogEntry(
413
                "predict_column_rename", arg_map.get("predict", "y")
414
            ),
415
            FunctionMetadataCatalogEntry(
416
                "time_column_rename", arg_map.get("time", "ds")
417
            ),
418
            FunctionMetadataCatalogEntry(
419
                "id_column_rename", arg_map.get("id", "unique_id")
420
            ),
421
            FunctionMetadataCatalogEntry("horizon", horizon),
422
            FunctionMetadataCatalogEntry("library", library),
423
        ]
424

425
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
×
426

427
        return (
×
428
            self.node.name,
429
            impl_path,
430
            self.node.function_type,
431
            io_list,
432
            metadata_here,
433
        )
434

435
    def handle_generic_function(self):
×
436
        """Handle generic functions
437

438
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
439
        """
440
        impl_path = self.node.impl_path.absolute().as_posix()
×
441
        function = self._try_initializing_function(impl_path)
×
442
        io_list = self._resolve_function_io(function)
×
443

444
        return (
×
445
            self.node.name,
446
            impl_path,
447
            self.node.function_type,
448
            io_list,
449
            self.node.metadata,
450
        )
451

452
    def exec(self, *args, **kwargs):
×
453
        """Create function executor
454

455
        Calls the catalog to insert a function catalog entry.
456
        """
457
        assert (
×
458
            self.node.if_not_exists and self.node.or_replace
459
        ) is False, (
460
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
461
        )
462

463
        overwrite = False
×
464
        # check catalog if it already has this function entry
465
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
×
466
            if self.node.if_not_exists:
×
467
                msg = f"Function {self.node.name} already exists, nothing added."
×
468
                yield Batch(pd.DataFrame([msg]))
×
469
                return
×
470
            elif self.node.or_replace:
×
471
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
472
                from evadb.executor.drop_object_executor import DropObjectExecutor
×
473

474
                drop_executor = DropObjectExecutor(self.db, None)
×
475
                try:
×
476
                    drop_executor._handle_drop_function(self.node.name, if_exists=False)
×
477
                except RuntimeError:
478
                    pass
479
                else:
480
                    overwrite = True
×
481
            else:
482
                msg = f"Function {self.node.name} already exists."
×
483
                logger.error(msg)
×
484
                raise RuntimeError(msg)
485

486
        # if it's a type of HuggingFaceModel, override the impl_path
487
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
×
488
            (
×
489
                name,
490
                impl_path,
491
                function_type,
492
                io_list,
493
                metadata,
494
            ) = self.handle_huggingface_function()
495
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
×
496
            (
×
497
                name,
498
                impl_path,
499
                function_type,
500
                io_list,
501
                metadata,
502
            ) = self.handle_ultralytics_function()
503
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
×
504
            (
×
505
                name,
506
                impl_path,
507
                function_type,
508
                io_list,
509
                metadata,
510
            ) = self.handle_ludwig_function()
511
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
×
512
            (
×
513
                name,
514
                impl_path,
515
                function_type,
516
                io_list,
517
                metadata,
518
            ) = self.handle_sklearn_function()
519
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
×
520
            (
×
521
                name,
522
                impl_path,
523
                function_type,
524
                io_list,
525
                metadata,
526
            ) = self.handle_forecasting_function()
527
        else:
528
            (
×
529
                name,
530
                impl_path,
531
                function_type,
532
                io_list,
533
                metadata,
534
            ) = self.handle_generic_function()
535

536
        self.catalog().insert_function_catalog_entry(
×
537
            name, impl_path, function_type, io_list, metadata
538
        )
539

540
        if overwrite:
×
541
            msg = f"Function {self.node.name} overwritten."
×
542
        else:
543
            msg = f"Function {self.node.name} added to the database."
×
544
        yield Batch(pd.DataFrame([msg]))
×
545

546
    def _try_initializing_function(
×
547
        self, impl_path: str, function_args: Dict = {}
548
    ) -> FunctionCatalogEntry:
549
        """Attempts to initialize function given the implementation file path and arguments.
550

551
        Args:
552
            impl_path (str): The file path of the function implementation file.
553
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
554

555
        Returns:
556
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
557

558
        Raises:
559
            RuntimeError: If an error occurs while initializing the function.
560
        """
561

562
        # load the function class from the file
563
        try:
×
564
            # loading the function class from the file
565
            function = load_function_class_from_file(impl_path, self.node.name)
×
566
            # initializing the function class calls the setup method internally
567
            function(**function_args)
×
568
        except Exception as e:
569
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
570
            # logger.error(err_msg)
571
            raise RuntimeError(err_msg)
572

573
        return function
×
574

575
    def _resolve_function_io(
×
576
        self, function: FunctionCatalogEntry
577
    ) -> List[FunctionIOCatalogEntry]:
578
        """Private method that resolves the input/output definitions for a given function.
579
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
580

581
        Args:
582
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
583

584
        Returns:
585
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
586
            output definitions for the function.
587

588
        Raises:
589
            RuntimeError: If an error occurs while resolving the function input/output
590
            definitions.
591
        """
592
        io_list = []
×
593
        try:
×
594
            if self.node.inputs:
×
595
                io_list.extend(self.node.inputs)
×
596
            else:
597
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
598
                io_list.extend(
×
599
                    load_io_from_function_decorators(function, is_input=True)
600
                )
601

602
            if self.node.outputs:
×
603
                io_list.extend(self.node.outputs)
×
604
            else:
605
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
606
                io_list.extend(
×
607
                    load_io_from_function_decorators(function, is_input=False)
608
                )
609

610
        except FunctionIODefinitionError as e:
611
            err_msg = (
612
                f"Error creating function, input/output definition incorrect: {str(e)}"
613
            )
614
            logger.error(err_msg)
615
            raise RuntimeError(err_msg)
616

617
        return io_list
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc