• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #788

18 Sep 2023 07:27AM UTC coverage: 67.606% (-12.5%) from 80.067%
#788

push

circle-ci

xzdandy
Add intergration test

8089 of 11965 relevant lines covered (67.61%)

0.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

38.56
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
1✔
16
import os
1✔
17
import pickle
1✔
18
from pathlib import Path
1✔
19
from typing import Dict, List
1✔
20

21
import pandas as pd
1✔
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
1✔
24
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
1✔
25
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
1✔
26
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
1✔
27
from evadb.configuration.constants import (
1✔
28
    DEFAULT_TRAIN_TIME_LIMIT,
29
    EvaDB_INSTALLATION_DIR,
30
)
31
from evadb.database import EvaDBDatabase
1✔
32
from evadb.executor.abstract_executor import AbstractExecutor
1✔
33
from evadb.functions.decorators.utils import load_io_from_function_decorators
1✔
34
from evadb.models.storage.batch import Batch
1✔
35
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
36
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
1✔
37
from evadb.utils.errors import FunctionIODefinitionError
1✔
38
from evadb.utils.generic_utils import (
1✔
39
    load_function_class_from_file,
40
    string_comparison_case_insensitive,
41
    try_to_import_forecast,
42
    try_to_import_ludwig,
43
    try_to_import_torch,
44
    try_to_import_ultralytics,
45
)
46
from evadb.utils.logging_manager import logger
1✔
47

48

49
class CreateFunctionExecutor(AbstractExecutor):
1✔
50
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
1✔
51
        super().__init__(db, node)
1✔
52
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
1✔
53

54
    def handle_huggingface_function(self):
1✔
55
        """Handle HuggingFace functions
56

57
        HuggingFace functions are special functions that are not loaded from a file.
58
        So we do not need to call the setup method on them like we do for other functions.
59
        """
60
        # We need at least one deep learning framework for HuggingFace
61
        # Torch or Tensorflow
62
        try_to_import_torch()
×
63
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
64
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
65
        return (
×
66
            self.node.name,
67
            impl_path,
68
            self.node.function_type,
69
            io_list,
70
            self.node.metadata,
71
        )
72

73
    def handle_ludwig_function(self):
1✔
74
        """Handle ludwig functions
75

76
        Use Ludwig's auto_train engine to train/tune models.
77
        """
78
        try_to_import_ludwig()
×
79
        from ludwig.automl import auto_train
×
80

81
        assert (
×
82
            len(self.children) == 1
83
        ), "Create ludwig function expects 1 child, finds {}.".format(
84
            len(self.children)
85
        )
86

87
        aggregated_batch_list = []
×
88
        child = self.children[0]
×
89
        for batch in child.exec():
×
90
            aggregated_batch_list.append(batch)
×
91
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
92
        aggregated_batch.drop_column_alias()
×
93

94
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
95
        auto_train_results = auto_train(
×
96
            dataset=aggregated_batch.frames,
97
            target=arg_map["predict"],
98
            tune_for_memory=arg_map.get("tune_for_memory", False),
99
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
100
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
101
        )
102
        model_path = os.path.join(
×
103
            self.db.config.get_value("storage", "model_dir"), self.node.name
104
        )
105
        auto_train_results.best_model.save(model_path)
×
106
        self.node.metadata.append(
×
107
            FunctionMetadataCatalogEntry("model_path", model_path)
108
        )
109

110
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
111
        io_list = self._resolve_function_io(None)
×
112
        return (
×
113
            self.node.name,
114
            impl_path,
115
            self.node.function_type,
116
            io_list,
117
            self.node.metadata,
118
        )
119

120
    def handle_ultralytics_function(self):
1✔
121
        """Handle Ultralytics functions"""
122
        try_to_import_ultralytics()
×
123

124
        impl_path = (
×
125
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
126
        )
127
        function = self._try_initializing_function(
×
128
            impl_path, function_args=get_metadata_properties(self.node)
129
        )
130
        io_list = self._resolve_function_io(function)
×
131
        return (
×
132
            self.node.name,
133
            impl_path,
134
            self.node.function_type,
135
            io_list,
136
            self.node.metadata,
137
        )
138

139
    def handle_forecasting_function(self):
1✔
140
        """Handle forecasting functions"""
141
        aggregated_batch_list = []
×
142
        child = self.children[0]
×
143
        for batch in child.exec():
×
144
            aggregated_batch_list.append(batch)
×
145
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
146
        aggregated_batch.drop_column_alias()
×
147

148
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
149
        if not self.node.impl_path:
×
150
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
151
        else:
152
            impl_path = self.node.impl_path.absolute().as_posix()
×
153

154
        if "model" not in arg_map.keys():
×
155
            arg_map["model"] = "AutoARIMA"
×
156

157
        model_name = arg_map["model"]
×
158

159
        """
×
160
        The following rename is needed for statsforecast, which requires the column name to be the following:
161
        - The unique_id (string, int or category) represents an identifier for the series.
162
        - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp.
163
        - The y (numeric) represents the measurement we wish to forecast.
164
        For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html
165
        """
166
        aggregated_batch.rename(columns={arg_map["predict"]: "y"})
×
167
        if "time" in arg_map.keys():
×
168
            aggregated_batch.rename(columns={arg_map["time"]: "ds"})
×
169
        if "id" in arg_map.keys():
×
170
            aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
×
171

172
        data = aggregated_batch.frames
×
173
        if "unique_id" not in list(data.columns):
×
174
            data["unique_id"] = [1 for x in range(len(data))]
×
175

176
        if "ds" not in list(data.columns):
×
177
            data["ds"] = [x + 1 for x in range(len(data))]
×
178

179
        if "frequency" not in arg_map.keys():
×
180
            arg_map["frequency"] = pd.infer_freq(data["ds"])
×
181
        frequency = arg_map["frequency"]
×
182
        if frequency is None:
×
183
            raise RuntimeError(
184
                f"Can not infer the frequency for {self.node.name}. Please explictly set it."
185
            )
186

187
        try_to_import_forecast()
×
188
        from statsforecast import StatsForecast
×
189
        from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
190

191
        model_dict = {
×
192
            "AutoARIMA": AutoARIMA,
193
            "AutoCES": AutoCES,
194
            "AutoETS": AutoETS,
195
            "AutoTheta": AutoTheta,
196
        }
197

198
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
199
            "H": 24,
200
            "M": 12,
201
            "Q": 4,
202
            "SM": 24,
203
            "BM": 12,
204
            "BMS": 12,
205
            "BQ": 4,
206
            "BH": 24,
207
        }
208

209
        new_freq = (
×
210
            frequency.split("-")[0] if "-" in frequency else frequency
211
        )  # shortens longer frequencies like Q-DEC
212
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
213
        model = StatsForecast(
×
214
            [model_dict[model_name](season_length=season_length)], freq=new_freq
215
        )
216

217
        model_dir = os.path.join(
×
218
            self.db.config.get_value("storage", "model_dir"), self.node.name
219
        )
220
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
221
        model_path = os.path.join(
×
222
            self.db.config.get_value("storage", "model_dir"),
223
            self.node.name,
224
            str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl",
225
        )
226

227
        weight_file = Path(model_path)
×
228
        data["ds"] = pd.to_datetime(data["ds"])
×
229
        if not weight_file.exists():
×
230
            model.fit(data)
×
231
            f = open(model_path, "wb")
×
232
            pickle.dump(model, f)
×
233
            f.close()
×
234

235
        io_list = self._resolve_function_io(None)
×
236

237
        metadata_here = [
×
238
            FunctionMetadataCatalogEntry("model_name", model_name),
239
            FunctionMetadataCatalogEntry("model_path", model_path),
240
            FunctionMetadataCatalogEntry(
241
                "predict_column_rename", arg_map.get("predict", "y")
242
            ),
243
            FunctionMetadataCatalogEntry(
244
                "time_column_rename", arg_map.get("time", "ds")
245
            ),
246
            FunctionMetadataCatalogEntry(
247
                "id_column_rename", arg_map.get("id", "unique_id")
248
            ),
249
        ]
250

251
        return (
×
252
            self.node.name,
253
            impl_path,
254
            self.node.function_type,
255
            io_list,
256
            metadata_here,
257
        )
258

259
    def handle_generic_function(self):
1✔
260
        """Handle generic functions
261

262
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
263
        """
264
        impl_path = self.node.impl_path.absolute().as_posix()
1✔
265
        function = self._try_initializing_function(impl_path)
1✔
266
        io_list = self._resolve_function_io(function)
1✔
267

268
        return (
1✔
269
            self.node.name,
270
            impl_path,
271
            self.node.function_type,
272
            io_list,
273
            self.node.metadata,
274
        )
275

276
    def exec(self, *args, **kwargs):
1✔
277
        """Create function executor
278

279
        Calls the catalog to insert a function catalog entry.
280
        """
281
        assert (
1✔
282
            self.node.if_not_exists and self.node.or_replace
283
        ) is False, (
284
            "OR REPLACE and IF NOT EXISTS can not be both set for CREATE FUNCTION."
285
        )
286

287
        overwrite = False
1✔
288
        # check catalog if it already has this function entry
289
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
1✔
290
            if self.node.if_not_exists:
×
291
                msg = f"Function {self.node.name} already exists, nothing added."
×
292
                yield Batch(pd.DataFrame([msg]))
×
293
                return
×
294
            elif self.node.or_replace:
×
295
                # We use DropObjectExecutor to avoid bookkeeping the code. The drop function should be moved to catalog.
296
                from evadb.executor.drop_object_executor import DropObjectExecutor
×
297

298
                drop_exectuor = DropObjectExecutor(self.db, None)
×
299
                try:
×
300
                    drop_exectuor._handle_drop_function(self.node.name, if_exists=False)
×
301
                except RuntimeError:
302
                    pass
303
                else:
304
                    overwrite = True
×
305
            else:
306
                msg = f"Function {self.node.name} already exists."
×
307
                logger.error(msg)
×
308
                raise RuntimeError(msg)
309

310
        # if it's a type of HuggingFaceModel, override the impl_path
311
        if string_comparison_case_insensitive(self.node.function_type, "HuggingFace"):
1✔
312
            (
×
313
                name,
314
                impl_path,
315
                function_type,
316
                io_list,
317
                metadata,
318
            ) = self.handle_huggingface_function()
319
        elif string_comparison_case_insensitive(self.node.function_type, "ultralytics"):
1✔
320
            (
×
321
                name,
322
                impl_path,
323
                function_type,
324
                io_list,
325
                metadata,
326
            ) = self.handle_ultralytics_function()
327
        elif string_comparison_case_insensitive(self.node.function_type, "Ludwig"):
1✔
328
            (
×
329
                name,
330
                impl_path,
331
                function_type,
332
                io_list,
333
                metadata,
334
            ) = self.handle_ludwig_function()
335
        elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"):
1✔
336
            (
×
337
                name,
338
                impl_path,
339
                function_type,
340
                io_list,
341
                metadata,
342
            ) = self.handle_forecasting_function()
343
        else:
344
            (
1✔
345
                name,
346
                impl_path,
347
                function_type,
348
                io_list,
349
                metadata,
350
            ) = self.handle_generic_function()
351

352
        self.catalog().insert_function_catalog_entry(
1✔
353
            name, impl_path, function_type, io_list, metadata
354
        )
355

356
        if overwrite:
1✔
357
            msg = f"Function {self.node.name} overwritten."
×
358
        else:
359
            msg = f"Function {self.node.name} added to the database."
1✔
360
        yield Batch(pd.DataFrame([msg]))
1✔
361

362
    def _try_initializing_function(
1✔
363
        self, impl_path: str, function_args: Dict = {}
364
    ) -> FunctionCatalogEntry:
365
        """Attempts to initialize function given the implementation file path and arguments.
366

367
        Args:
368
            impl_path (str): The file path of the function implementation file.
369
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
370

371
        Returns:
372
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
373

374
        Raises:
375
            RuntimeError: If an error occurs while initializing the function.
376
        """
377

378
        # load the function class from the file
379
        try:
1✔
380
            # loading the function class from the file
381
            function = load_function_class_from_file(impl_path, self.node.name)
1✔
382
            # initializing the function class calls the setup method internally
383
            function(**function_args)
1✔
384
        except Exception as e:
385
            err_msg = f"Error creating function: {str(e)}"
386
            # logger.error(err_msg)
387
            raise RuntimeError(err_msg)
388

389
        return function
1✔
390

391
    def _resolve_function_io(
1✔
392
        self, function: FunctionCatalogEntry
393
    ) -> List[FunctionIOCatalogEntry]:
394
        """Private method that resolves the input/output definitions for a given function.
395
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
396

397
        Args:
398
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
399

400
        Returns:
401
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
402
            output definitions for the function.
403

404
        Raises:
405
            RuntimeError: If an error occurs while resolving the function input/output
406
            definitions.
407
        """
408
        io_list = []
1✔
409
        try:
1✔
410
            if self.node.inputs:
1✔
411
                io_list.extend(self.node.inputs)
1✔
412
            else:
413
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
414
                io_list.extend(
×
415
                    load_io_from_function_decorators(function, is_input=True)
416
                )
417

418
            if self.node.outputs:
1✔
419
                io_list.extend(self.node.outputs)
1✔
420
            else:
421
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
422
                io_list.extend(
×
423
                    load_io_from_function_decorators(function, is_input=False)
424
                )
425

426
        except FunctionIODefinitionError as e:
427
            err_msg = (
428
                f"Error creating function, input/output definition incorrect: {str(e)}"
429
            )
430
            logger.error(err_msg)
431
            raise RuntimeError(err_msg)
432

433
        return io_list
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc