• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #820

01 Oct 2023 02:22AM UTC coverage: 0.0% (-73.7%) from 73.748%
#820

push

circle-ci

Jiashen Cao
fix lint error

0 of 12361 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
×
16
import os
×
17
import pickle
×
18
from pathlib import Path
×
19
from typing import Dict, List
×
20

21
import pandas as pd
×
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
×
24
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
×
25
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
×
26
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
×
27
from evadb.configuration.constants import (
×
28
    DEFAULT_TRAIN_TIME_LIMIT,
29
    EvaDB_INSTALLATION_DIR,
30
)
31
from evadb.database import EvaDBDatabase
×
32
from evadb.executor.abstract_executor import AbstractExecutor
×
33
from evadb.functions.decorators.utils import load_io_from_function_decorators
×
34
from evadb.models.storage.batch import Batch
×
35
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
×
36
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
×
37
from evadb.utils.errors import FunctionIODefinitionError
×
38
from evadb.utils.generic_utils import (
×
39
    load_function_class_from_file,
40
    string_comparison_case_insensitive,
41
    try_to_import_forecast,
42
    try_to_import_ludwig,
43
    try_to_import_sklearn,
44
    try_to_import_torch,
45
    try_to_import_ultralytics,
46
)
47
from evadb.utils.logging_manager import logger
×
48

49

50
class CreateFunctionExecutor(AbstractExecutor):
×
51
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
×
52
        super().__init__(db, node)
×
53
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
×
54

55
    def handle_huggingface_function(self):
×
56
        """Handle HuggingFace functions
57

58
        HuggingFace functions are special functions that are not loaded from a file.
59
        So we do not need to call the setup method on them like we do for other functions.
60
        """
61
        # We need at least one deep learning framework for HuggingFace
62
        # Torch or Tensorflow
63
        try_to_import_torch()
×
64
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
65
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
66
        return (
×
67
            self.node.name,
68
            impl_path,
69
            self.node.function_type,
70
            io_list,
71
            self.node.metadata,
72
        )
73

74
    def handle_ludwig_function(self):
×
75
        """Handle ludwig functions
76

77
        Use Ludwig's auto_train engine to train/tune models.
78
        """
79
        try_to_import_ludwig()
×
80
        from ludwig.automl import auto_train
×
81

82
        assert (
×
83
            len(self.children) == 1
84
        ), "Create ludwig function expects 1 child, finds {}.".format(
85
            len(self.children)
86
        )
87

88
        aggregated_batch_list = []
×
89
        child = self.children[0]
×
90
        for batch in child.exec():
×
91
            aggregated_batch_list.append(batch)
×
92
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
93
        aggregated_batch.drop_column_alias()
×
94

95
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
96
        auto_train_results = auto_train(
×
97
            dataset=aggregated_batch.frames,
98
            target=arg_map["predict"],
99
            tune_for_memory=arg_map.get("tune_for_memory", False),
100
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
101
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
102
        )
103
        model_path = os.path.join(
×
104
            self.db.config.get_value("storage", "model_dir"), self.node.name
105
        )
106
        auto_train_results.best_model.save(model_path)
×
107
        self.node.metadata.append(
×
108
            FunctionMetadataCatalogEntry("model_path", model_path)
109
        )
110

111
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
112
        io_list = self._resolve_function_io(None)
×
113
        return (
×
114
            self.node.name,
115
            impl_path,
116
            self.node.function_type,
117
            io_list,
118
            self.node.metadata,
119
        )
120

121
    def handle_sklearn_function(self):
×
122
        """Handle sklearn functions
123

124
        Use Sklearn's regression to train models.
125
        """
126
        try_to_import_sklearn()
×
127
        from sklearn.linear_model import LinearRegression
×
128

129
        assert (
×
130
            len(self.children) == 1
131
        ), "Create sklearn function expects 1 child, finds {}.".format(
132
            len(self.children)
133
        )
134

135
        aggregated_batch_list = []
×
136
        child = self.children[0]
×
137
        for batch in child.exec():
×
138
            aggregated_batch_list.append(batch)
×
139
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
140
        aggregated_batch.drop_column_alias()
×
141

142
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
143
        model = LinearRegression()
×
144
        Y = aggregated_batch.frames[arg_map["predict"]]
×
145
        aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
×
146
        model.fit(X=aggregated_batch.frames, y=Y)
×
147
        model_path = os.path.join(
×
148
            self.db.config.get_value("storage", "model_dir"), self.node.name
149
        )
150
        pickle.dump(model, open(model_path, "wb"))
×
151
        self.node.metadata.append(
×
152
            FunctionMetadataCatalogEntry("model_path", model_path)
153
        )
154

155
        impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
×
156
        io_list = self._resolve_function_io(None)
×
157
        return (
×
158
            self.node.name,
159
            impl_path,
160
            self.node.function_type,
161
            io_list,
162
            self.node.metadata,
163
        )
164

165
    def handle_ultralytics_function(self):
×
166
        """Handle Ultralytics functions"""
167
        try_to_import_ultralytics()
×
168

169
        impl_path = (
×
170
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
171
        )
172
        function = self._try_initializing_function(
×
173
            impl_path, function_args=get_metadata_properties(self.node)
174
        )
175
        io_list = self._resolve_function_io(function)
×
176
        return (
×
177
            self.node.name,
178
            impl_path,
179
            self.node.function_type,
180
            io_list,
181
            self.node.metadata,
182
        )
183

184
    def handle_forecasting_function(self):
×
185
        """Handle forecasting functions"""
186
        aggregated_batch_list = []
×
187
        child = self.children[0]
×
188
        for batch in child.exec():
×
189
            aggregated_batch_list.append(batch)
×
190
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
191
        aggregated_batch.drop_column_alias()
×
192

193
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
194
        if not self.node.impl_path:
×
195
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
196
        else:
197
            impl_path = self.node.impl_path.absolute().as_posix()
×
198

199
        if "model" not in arg_map.keys():
×
200
            arg_map["model"] = "AutoARIMA"
×
201

202
        model_name = arg_map["model"]
×
203

204
        """
×
205
        The following rename is needed for statsforecast, which requires the column name to be the following:
206
        - The unique_id (string, int or category) represents an identifier for the series.
207
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
208
        - The y (numeric) represents the measurement we wish to forecast.
209
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
210
        """
211
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
212
        if "time" in arg_map.keys():
×
213
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
214
        if "id" in arg_map.keys():
×
215
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
216

217
        data = aggregated_batch.frames
×
218
        if "unique_id" not in list(data.columns):
×
219
            data["unique_id"] = [1 for x in range(len(data))]
×
220

221
        if "ds" not in list(data.columns):
×
222
            data["ds"] = [x + 1 for x in range(len(data))]
×
223

224
        if "frequency" not in arg_map.keys():
×
225
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
226
        frequency = arg_map["frequency"]
×
227
        if frequency is None:
×
228
            raise RuntimeError(
229
                f"Can not infer the frequency for {self.node.name}. Please explictly set it."
230
            )
231

232
        try_to_import_forecast()
×
233
        from statsforecast import StatsForecast
×
234
        from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
235

236
        model_dict = {
×
237
            "AutoARIMA": AutoARIMA,
238
            "AutoCES": AutoCES,
239
            "AutoETS": AutoETS,
240
            "AutoTheta": AutoTheta,
241
        }
242

243
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
244
            "H": 24,
245
            "M": 12,
246
            "Q": 4,
247
            "SM": 24,
248
            "BM": 12,
249
            "BMS": 12,
250
            "BQ": 4,
251
            "BH": 24,
252
        }
253

254
        new_freq = (
×
255
            frequency.split("-")[0] if "-" in frequency else frequency
256
        )  # shortens longer frequencies like Q-DEC
257
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
258
        model = StatsForecast(
×
259
            [model_dict[model_name](season_length=season_length)], freq=new_freq
260
        )
261

262
        model_dir = os.path.join(
×
263
            self.db.config.get_value("storage", "model_dir"), self.node.name
264
        )
265
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
266
        model_path = os.path.join(
×
267
            self.db.config.get_value("storage", "model_dir"),
268
            self.node.name,
269
            str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl",
270
        )
271

272
        weight_file = Path(model_path)
×
273
        data["ds"] = pd.to_datetime(data["ds"])
×
274
        if not weight_file.exists():
×
275
            model.fit(data)
×
276
            f = open(model_path, "wb")
×
277
            pickle.dump(model, f)
×
278
            f.close()
×
279

280
        io_list = self._resolve_function_io(None)
×
281

282
        metadata_here = [
×
283
            FunctionMetadataCatalogEntry("model_name", model_name),
284
            FunctionMetadataCatalogEntry("model_path", model_path),
285
            FunctionMetadataCatalogEntry(
286
                "predict_column_rename", arg_map.get("predict", "y")
287
            ),
288
            FunctionMetadataCatalogEntry(
289
                "time_column_rename", arg_map.get("time", "ds")
290
            ),
291
            FunctionMetadataCatalogEntry(
292
                "id_column_rename", arg_map.get("id", "unique_id")
293
            ),
294
        ]
295

296
        return (
×
297
            self.node.name,
298
            impl_path,
299
            self.node.function_type,
300
            io_list,
301
            metadata_here,
302
        )
303

304
    def handle_generic_function(self):
×
305
        """Handle generic functions
306

307
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
308
        """
309
        impl_path = self.node.impl_path.absolute().as_posix()
×
310
        function = self._try_initializing_function(impl_path)
×
311
        io_list = self._resolve_function_io(function)
×
312

313
        return (
×
314
            self.node.name,
315
            impl_path,
316
            self.node.function_type,
317
            io_list,
318
            self.node.metadata,
319
        )
320

321
    def exec(self, *args, **kwargs):
×
322
        """Create function executor
323

324
        Calls the catalog to insert a function catalog entry.
325
        """
326
        assert (
×
327
            self.node.if_not_exists and self.node.or_replace
328
        ) is False, (
329
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
330
        )
331

332
        overwrite = False
×
333
        # check catalog if it already has this function entry
334
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
×
335
            if self.node.if_not_exists:
×
336
                msg = f"Function {self.node.name} already exists, nothing added."
×
337
                yield Batch(pd.DataFrame([msg]))
×
338
                return
×
339
            elif self.node.or_replace:
×
340
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
341
                from evadb.executor.drop_object_executor import DropObjectExecutor
×
342

343
                drop_exectuor = DropObjectExecutor(self.db, None)
×
344
                try:
×
345
                    drop_exectuor._handle_drop_function(self.node.name, if_exists=False)
×
346
                except RuntimeError:
347
                    pass
348
                else:
349
                    overwrite = True
×
350
            else:
351
                msg = f"Function {self.node.name} already exists."
×
352
                logger.error(msg)
×
353
                raise RuntimeError(msg)
354

355
        # if it's a type of HuggingFaceModel, override the impl_path
356
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
×
357
            (
×
358
                name,
359
                impl_path,
360
                function_type,
361
                io_list,
362
                metadata,
363
            ) = self.handle_huggingface_function()
364
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
×
365
            (
×
366
                name,
367
                impl_path,
368
                function_type,
369
                io_list,
370
                metadata,
371
            ) = self.handle_ultralytics_function()
372
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
×
373
            (
×
374
                name,
375
                impl_path,
376
                function_type,
377
                io_list,
378
                metadata,
379
            ) = self.handle_ludwig_function()
380
        elif string_comparison_case_insensitive(self.node.function_type, "Sklearn"):
×
381
            (
×
382
                name,
383
                impl_path,
384
                function_type,
385
                io_list,
386
                metadata,
387
            ) = self.handle_sklearn_function()
388
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
×
389
            (
×
390
                name,
391
                impl_path,
392
                function_type,
393
                io_list,
394
                metadata,
395
            ) = self.handle_forecasting_function()
396
        else:
397
            (
×
398
                name,
399
                impl_path,
400
                function_type,
401
                io_list,
402
                metadata,
403
            ) = self.handle_generic_function()
404

405
        self.catalog().insert_function_catalog_entry(
×
406
            name, impl_path, function_type, io_list, metadata
407
        )
408

409
        if overwrite:
×
410
            msg = f"Function {self.node.name} overwritten."
×
411
        else:
412
            msg = f"Function {self.node.name} added to the database."
×
413
        yield Batch(pd.DataFrame([msg]))
×
414

415
    def _try_initializing_function(
×
416
        self, impl_path: str, function_args: Dict = {}
417
    ) -> FunctionCatalogEntry:
418
        """Attempts to initialize function given the implementation file path and arguments.
419

420
        Args:
421
            impl_path (str): The file path of the function implementation file.
422
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
423

424
        Returns:
425
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
426

427
        Raises:
428
            RuntimeError: If an error occurs while initializing the function.
429
        """
430

431
        # load the function class from the file
432
        try:
×
433
            # loading the function class from the file
434
            function = load_function_class_from_file(impl_path, self.node.name)
×
435
            # initializing the function class calls the setup method internally
436
            function(**function_args)
×
437
        except Exception as e:
438
            err_msg = f"Error creating function {self.node.name}: {str(e)}"
439
            # logger.error(err_msg)
440
            raise RuntimeError(err_msg)
441

442
        return function
×
443

444
    def _resolve_function_io(
×
445
        self, function: FunctionCatalogEntry
446
    ) -> List[FunctionIOCatalogEntry]:
447
        """Private method that resolves the input/output definitions for a given function.
448
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
449

450
        Args:
451
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
452

453
        Returns:
454
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
455
            output definitions for the function.
456

457
        Raises:
458
            RuntimeError: If an error occurs while resolving the function input/output
459
            definitions.
460
        """
461
        io_list = []
×
462
        try:
×
463
            if self.node.inputs:
×
464
                io_list.extend(self.node.inputs)
×
465
            else:
466
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
467
                io_list.extend(
×
468
                    load_io_from_function_decorators(function, is_input=True)
469
                )
470

471
            if self.node.outputs:
×
472
                io_list.extend(self.node.outputs)
×
473
            else:
474
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
475
                io_list.extend(
×
476
                    load_io_from_function_decorators(function, is_input=False)
477
                )
478

479
        except FunctionIODefinitionError as e:
480
            err_msg = (
481
                f"Error creating function, input/output definition incorrect: {str(e)}"
482
            )
483
            logger.error(err_msg)
484
            raise RuntimeError(err_msg)
485

486
        return io_list
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc