• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #811

25 Sep 2023 03:38AM UTC coverage: 92.737% (-0.1%) from 92.866%
#811

push

circle-ci

Jiashen Cao
make code more modular

11390 of 12282 relevant lines covered (92.74%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.07
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
1✔
16
import os
1✔
17
import pickle
1✔
18
from pathlib import Path
1✔
19
from typing import Dict, List
1✔
20

21
import pandas as pd
1✔
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
1✔
24
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
1✔
25
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
1✔
26
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
1✔
27
from evadb.configuration.constants import (
1✔
28
    DEFAULT_TRAIN_TIME_LIMIT,
29
    EvaDB_INSTALLATION_DIR,
30
)
31
from evadb.database import EvaDBDatabase
1✔
32
from evadb.executor.abstract_executor import AbstractExecutor
1✔
33
from evadb.functions.decorators.utils import load_io_from_function_decorators
1✔
34
from evadb.models.storage.batch import Batch
1✔
35
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
36
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
1✔
37
from evadb.utils.errors import FunctionIODefinitionError
1✔
38
from evadb.utils.generic_utils import (
1✔
39
    load_function_class_from_file,
40
    string_comparison_case_insensitive,
41
    try_to_import_forecast,
42
    try_to_import_ludwig,
43
    try_to_import_sklearn,
44
    try_to_import_torch,
45
    try_to_import_ultralytics,
46
)
47
from evadb.utils.logging_manager import logger
1✔
48

49

50
class CreateFunctionExecutor(AbstractExecutor):
1✔
51
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
1✔
52
        super().__init__(db, node)
1✔
53
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
1✔
54

55
    def handle_huggingface_function(self):
1✔
56
        """Handle HuggingFace functions
57

58
        HuggingFace functions are special functions that are not loaded from a file.
59
        So we do not need to call the setup method on them like we do for other functions.
60
        """
61
        # We need at least one deep learning framework for HuggingFace
62
        # Torch or Tensorflow
63
        try_to_import_torch()
1✔
64
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
1✔
65
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
1✔
66
        return (
1✔
67
            self.node.name,
68
            impl_path,
69
            self.node.function_type,
70
            io_list,
71
            self.node.metadata,
72
        )
73

74
    def handle_ludwig_function(self):
1✔
75
        """Handle ludwig functions
76

77
        Use Ludwig's auto_train engine to train/tune models.
78
        """
79
        try_to_import_ludwig()
×
80
        from ludwig.automl import auto_train
×
81

82
        assert (
×
83
            len(self.children) == 1
84
        ), "Create ludwig function expects 1 child, finds {}.".format(
85
            len(self.children)
86
        )
87

88
        aggregated_batch_list = []
×
89
        child = self.children[0]
×
90
        for batch in child.exec():
×
91
            aggregated_batch_list.append(batch)
×
92
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
93
        aggregated_batch.drop_column_alias()
×
94

95
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
96
        auto_train_results = auto_train(
×
97
            dataset=aggregated_batch.frames,
98
            target=arg_map["predict"],
99
            tune_for_memory=arg_map.get("tune_for_memory", False),
100
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
101
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
102
        )
103
        model_path = os.path.join(
×
104
            self.db.config.get_value("storage", "model_dir"), self.node.name
105
        )
106
        auto_train_results.best_model.save(model_path)
×
107
        self.node.metadata.append(
×
108
            FunctionMetadataCatalogEntry("model_path", model_path)
109
        )
110

111
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
112
        io_list = self._resolve_function_io(None)
×
113
        return (
×
114
            self.node.name,
115
            impl_path,
116
            self.node.function_type,
117
            io_list,
118
            self.node.metadata,
119
        )
120

121
    def handle_sklearn_function(self):
1✔
122
        """Handle sklearn functions
123

124
        Use Sklearn's regression to train models.
125
        """
126
        try_to_import_sklearn()
1✔
127
        from sklearn.linear_model import LinearRegression
1✔
128

129
        assert (
1✔
130
            len(self.children) == 1
131
        ), "Create sklearn function expects 1 child, finds {}.".format(
132
            len(self.children)
133
        )
134

135
        aggregated_batch_list = []
1✔
136
        child = self.children[0]
1✔
137
        for batch in child.exec():
1✔
138
            aggregated_batch_list.append(batch)
1✔
139
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
1✔
140
        aggregated_batch.drop_column_alias()
1✔
141

142
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
1✔
143
        model = LinearRegression()
1✔
144
        Y = aggregated_batch.frames[arg_map["predict"]]
1✔
145
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
1✔
146
        model.fit(X=aggregated_batch.frames, y=Y)
1✔
147
        model_path = os.path.join(
1✔
148
            self.db.config.get_value("storage", "model_dir"), self.node.name
149
        )
150
        pickle.dump(model, open(model_path, "wb"))
1✔
151
        self.node.metadata.append(
1✔
152
            FunctionMetadataCatalogEntry("model_path", model_path)
153
        )
154

155
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
1✔
156
        io_list = self._resolve_function_io(None)
1✔
157
        return (
1✔
158
            self.node.name,
159
            impl_path,
160
            self.node.function_type,
161
            io_list,
162
            self.node.metadata,
163
        )
164

165
    def handle_ultralytics_function(self):
1✔
166
        """Handle Ultralytics functions"""
167
        try_to_import_ultralytics()
1✔
168

169
        impl_path = (
1✔
170
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
171
        )
172
        function = self._try_initializing_function(
1✔
173
            impl_path, function_args=get_metadata_properties(self.node)
174
        )
175
        io_list = self._resolve_function_io(function)
1✔
176
        return (
1✔
177
            self.node.name,
178
            impl_path,
179
            self.node.function_type,
180
            io_list,
181
            self.node.metadata,
182
        )
183

184
    def handle_forecasting_function(self):
1✔
185
        """Handle forecasting functions"""
186
        aggregated_batch_list = []
1✔
187
        child = self.children[0]
1✔
188
        for batch in child.exec():
1✔
189
            aggregated_batch_list.append(batch)
1✔
190
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
1✔
191
        aggregated_batch.drop_column_alias()
1✔
192

193
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
1✔
194
        if not self.node.impl_path:
1✔
195
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
1✔
196
        else:
197
            impl_path = self.node.impl_path.absolute().as_posix()
×
198

199
        if "model" not in arg_map.keys():
1✔
200
            arg_map["model"] = "AutoARIMA"
1✔
201

202
        model_name = arg_map["model"]
1✔
203

204
        """
1✔
205
        The following rename is needed for statsforecast, which requires the column name to be the following:
206
        - The unique_id (string, int or category) represents an identifier for the series.
207
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
208
        - The y (numeric) represents the measurement we wish to forecast.
209
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
210
        """
211
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
1✔
212
        if "time" in arg_map.keys():
1✔
213
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
1✔
214
        if "id" in arg_map.keys():
1✔
215
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
1✔
216

217
        data = aggregated_batch.frames
1✔
218
        if "unique_id" not in list(data.columns):
1✔
219
            data["unique_id"] = [1 for x in range(len(data))]
×
220

221
        if "ds" not in list(data.columns):
1✔
222
            data["ds"] = [x + 1 for x in range(len(data))]
×
223

224
        if "frequency" not in arg_map.keys():
1✔
225
            arg_map["frequency"] = pd.infer_freq(data["ds"])
1✔
226
        frequency = arg_map["frequency"]
1✔
227
        if frequency is None:
1✔
228
            raise RuntimeError(
229
                f"Can not infer the frequency for {self.node.name}. Please explictly set it."
230
            )
231

232
        try_to_import_forecast()
1✔
233
        from statsforecast import StatsForecast
1✔
234
        from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
1✔
235

236
        model_dict = {
1✔
237
            "AutoARIMA": AutoARIMA,
238
            "AutoCES": AutoCES,
239
            "AutoETS": AutoETS,
240
            "AutoTheta": AutoTheta,
241
        }
242

243
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
1✔
244
            "H": 24,
245
            "M": 12,
246
            "Q": 4,
247
            "SM": 24,
248
            "BM": 12,
249
            "BMS": 12,
250
            "BQ": 4,
251
            "BH": 24,
252
        }
253

254
        new_freq = (
1✔
255
            frequency.split("-")[0] if "-" in frequency else frequency
256
        )  # shortens longer frequencies like Q-DEC
257
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
1✔
258
        model = StatsForecast(
1✔
259
            [model_dict[model_name](season_length=season_length)], freq=new_freq
260
        )
261

262
        model_dir = os.path.join(
1✔
263
            self.db.config.get_value("storage", "model_dir"), self.node.name
264
        )
265
        Path(model_dir).mkdir(parents=True, exist_ok=True)
1✔
266
        model_path = os.path.join(
1✔
267
            self.db.config.get_value("storage", "model_dir"),
268
            self.node.name,
269
            str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl",
270
        )
271

272
        weight_file = Path(model_path)
1✔
273
        data["ds"] = pd.to_datetime(data["ds"])
1✔
274
        if not weight_file.exists():
1✔
275
            model.fit(data)
1✔
276
            f = open(model_path, "wb")
1✔
277
            pickle.dump(model, f)
1✔
278
            f.close()
1✔
279

280
        io_list = self._resolve_function_io(None)
1✔
281

282
        metadata_here = [
1✔
283
            FunctionMetadataCatalogEntry("model_name", model_name),
284
            FunctionMetadataCatalogEntry("model_path", model_path),
285
            FunctionMetadataCatalogEntry(
286
                "predict_column_rename", arg_map.get("predict", "y")
287
            ),
288
            FunctionMetadataCatalogEntry(
289
                "time_column_rename", arg_map.get("time", "ds")
290
            ),
291
            FunctionMetadataCatalogEntry(
292
                "id_column_rename", arg_map.get("id", "unique_id")
293
            ),
294
        ]
295

296
        return (
1✔
297
            self.node.name,
298
            impl_path,
299
            self.node.function_type,
300
            io_list,
301
            metadata_here,
302
        )
303

304
    def handle_generic_function(self):
1✔
305
        """Handle generic functions
306

307
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
308
        """
309
        impl_path = self.node.impl_path.absolute().as_posix()
1✔
310
        function = self._try_initializing_function(impl_path)
1✔
311
        io_list = self._resolve_function_io(function)
1✔
312

313
        return (
1✔
314
            self.node.name,
315
            impl_path,
316
            self.node.function_type,
317
            io_list,
318
            self.node.metadata,
319
        )
320

321
    def exec(self, *args, **kwargs):
1✔
322
        """Create function executor
323

324
        Calls the catalog to insert a function catalog entry.
325
        """
326
        assert (
1✔
327
            self.node.if_not_exists and self.node.or_replace
328
        ) is False, (
329
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
330
        )
331

332
        overwrite = False
1✔
333
        # check catalog if it already has this function entry
334
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
1✔
335
            if self.node.if_not_exists:
1✔
336
                msg = f"Function {self.node.name} already exists, nothing added."
1✔
337
                yield Batch(pd.DataFrame([msg]))
1✔
338
                return
1✔
339
            elif self.node.or_replace:
1✔
340
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
341
                from evadb.executor.drop_object_executor import DropObjectExecutor
1✔
342

343
                drop_exectuor = DropObjectExecutor(self.db, None)
1✔
344
                try:
1✔
345
                    drop_exectuor._handle_drop_function(self.node.name, if_exists=False)
1✔
346
                except RuntimeError:
347
                    pass
348
                else:
349
                    overwrite = True
1✔
350
            else:
351
                msg = f"Function {self.node.name} already exists."
1✔
352
                logger.error(msg)
1✔
353
                raise RuntimeError(msg)
354

355
        # if it's a type of HuggingFaceModel, override the impl_path
356
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
1✔
357
            (
1✔
358
                name,
359
                impl_path,
360
                function_type,
361
                io_list,
362
                metadata,
363
            ) = self.handle_huggingface_function()
364
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
1✔
365
            (
1✔
366
                name,
367
                impl_path,
368
                function_type,
369
                io_list,
370
                metadata,
371
            ) = self.handle_ultralytics_function()
372
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
1✔
373
            (
×
374
                name,
375
                impl_path,
376
                function_type,
377
                io_list,
378
                metadata,
379
            ) = self.handle_ludwig_function()
380
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
1✔
381
            (
1✔
382
                name,
383
                impl_path,
384
                function_type,
385
                io_list,
386
                metadata,
387
            ) = self.handle_sklearn_function()
388
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
1✔
389
            (
1✔
390
                name,
391
                impl_path,
392
                function_type,
393
                io_list,
394
                metadata,
395
            ) = self.handle_forecasting_function()
396
        else:
397
            (
1✔
398
                name,
399
                impl_path,
400
                function_type,
401
                io_list,
402
                metadata,
403
            ) = self.handle_generic_function()
404

405
        self.catalog().insert_function_catalog_entry(
1✔
406
            name, impl_path, function_type, io_list, metadata
407
        )
408

409
        if overwrite:
1✔
410
            msg = f"Function {self.node.name} overwritten."
1✔
411
        else:
412
            msg = f"Function {self.node.name} added to the database."
1✔
413
        yield Batch(pd.DataFrame([msg]))
1✔
414

415
    def _try_initializing_function(
1✔
416
        self, impl_path: str, function_args: Dict = {}
417
    ) -> FunctionCatalogEntry:
418
        """Attempts to initialize function given the implementation file path and arguments.
419

420
        Args:
421
            impl_path (str): The file path of the function implementation file.
422
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
423

424
        Returns:
425
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
426

427
        Raises:
428
            RuntimeError: If an error occurs while initializing the function.
429
        """
430

431
        # load the function class from the file
432
        try:
1✔
433
            # loading the function class from the file
434
            function = load_function_class_from_file(impl_path, self.node.name)
1✔
435
            # initializing the function class calls the setup method internally
436
            function(**function_args)
1✔
437
        except Exception as e:
438
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
439
            # logger.error(err_msg)
440
            raise RuntimeError(err_msg)
441

442
        return function
1✔
443

444
    def _resolve_function_io(
1✔
445
        self, function: FunctionCatalogEntry
446
    ) -> List[FunctionIOCatalogEntry]:
447
        """Private method that resolves the input/output definitions for a given function.
448
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
449

450
        Args:
451
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
452

453
        Returns:
454
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
455
            output definitions for the function.
456

457
        Raises:
458
            RuntimeError: If an error occurs while resolving the function input/output
459
            definitions.
460
        """
461
        io_list = []
1✔
462
        try:
1✔
463
            if self.node.inputs:
1✔
464
                io_list.extend(self.node.inputs)
1✔
465
            else:
466
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
467
                io_list.extend(
1✔
468
                    load_io_from_function_decorators(function, is_input=True)
469
                )
470

471
            if self.node.outputs:
1✔
472
                io_list.extend(self.node.outputs)
1✔
473
            else:
474
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
475
                io_list.extend(
1✔
476
                    load_io_from_function_decorators(function, is_input=False)
477
                )
478

479
        except FunctionIODefinitionError as e:
480
            err_msg = (
481
                f"Error creating function, input/output definition incorrect: {str(e)}"
482
            )
483
            logger.error(err_msg)
484
            raise RuntimeError(err_msg)
485

486
        return io_list
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc