• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / 042c0bf0-bd18-4ca0-adb3-9a9fb37193b5

06 Sep 2023 06:49PM UTC coverage: 80.573% (+9.6%) from 70.955%
042c0bf0-bd18-4ca0-adb3-9a9fb37193b5

push

circle-ci

gaurav274
merge + minor fixes

768 of 768 new or added lines in 95 files covered. (100.0%)

9369 of 11628 relevant lines covered (80.57%)

1.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.68
/evadb/executor/create_function_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
2✔
16
import os
2✔
17
import pickle
2✔
18
from pathlib import Path
2✔
19
from typing import Dict, List
2✔
20

21
import pandas as pd
2✔
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
2✔
24
from evadb.catalog.models.function_catalog import FunctionCatalogEntry
2✔
25
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
2✔
26
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
2✔
27
from evadb.configuration.constants import (
2✔
28
    DEFAULT_TRAIN_TIME_LIMIT,
29
    EvaDB_INSTALLATION_DIR,
30
)
31
from evadb.database import EvaDBDatabase
2✔
32
from evadb.executor.abstract_executor import AbstractExecutor
2✔
33
from evadb.functions.decorators.utils import load_io_from_function_decorators
2✔
34
from evadb.models.storage.batch import Batch
2✔
35
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
2✔
36
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
2✔
37
from evadb.utils.errors import FunctionIODefinitionError
2✔
38
from evadb.utils.generic_utils import (
2✔
39
    load_function_class_from_file,
40
    try_to_import_forecast,
41
    try_to_import_ludwig,
42
    try_to_import_torch,
43
    try_to_import_ultralytics,
44
)
45
from evadb.utils.logging_manager import logger
2✔
46

47

48
class CreateFunctionExecutor(AbstractExecutor):
2✔
49
    def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan):
2✔
50
        super().__init__(db, node)
2✔
51
        self.function_dir = Path(EvaDB_INSTALLATION_DIR) / "functions"
2✔
52

53
    def handle_huggingface_function(self):
2✔
54
        """Handle HuggingFace functions
55

56
        HuggingFace functions are special functions that are not loaded from a file.
57
        So we do not need to call the setup method on them like we do for other functions.
58
        """
59
        # We need at least one deep learning framework for HuggingFace
60
        # Torch or Tensorflow
61
        try_to_import_torch()
×
62
        impl_path = f"{self.function_dir}/abstract/hf_abstract_function.py"
×
63
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
64
        return (
×
65
            self.node.name,
66
            impl_path,
67
            self.node.function_type,
68
            io_list,
69
            self.node.metadata,
70
        )
71

72
    def handle_ludwig_function(self):
2✔
73
        """Handle ludwig functions
74

75
        Use Ludwig's auto_train engine to train/tune models.
76
        """
77
        try_to_import_ludwig()
×
78
        from ludwig.automl import auto_train
×
79

80
        assert (
×
81
            len(self.children) == 1
82
        ), "Create ludwig function expects 1 child, finds {}.".format(
83
            len(self.children)
84
        )
85

86
        aggregated_batch_list = []
×
87
        child = self.children[0]
×
88
        for batch in child.exec():
×
89
            aggregated_batch_list.append(batch)
×
90
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
91
        aggregated_batch.drop_column_alias()
×
92

93
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
94
        auto_train_results = auto_train(
×
95
            dataset=aggregated_batch.frames,
96
            target=arg_map["predict"],
97
            tune_for_memory=arg_map.get("tune_for_memory", False),
98
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
99
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
100
        )
101
        model_path = os.path.join(
×
102
            self.db.config.get_value("storage", "model_dir"), self.node.name
103
        )
104
        auto_train_results.best_model.save(model_path)
×
105
        self.node.metadata.append(
×
106
            FunctionMetadataCatalogEntry("model_path", model_path)
107
        )
108

109
        impl_path = Path(f"{self.function_dir}/ludwig.py").absolute().as_posix()
×
110
        io_list = self._resolve_function_io(None)
×
111
        return (
×
112
            self.node.name,
113
            impl_path,
114
            self.node.function_type,
115
            io_list,
116
            self.node.metadata,
117
        )
118

119
    def handle_ultralytics_function(self):
2✔
120
        """Handle Ultralytics functions"""
121
        try_to_import_ultralytics()
2✔
122

123
        impl_path = (
2✔
124
            Path(f"{self.function_dir}/yolo_object_detector.py").absolute().as_posix()
125
        )
126
        function = self._try_initializing_function(
2✔
127
            impl_path, function_args=get_metadata_properties(self.node)
128
        )
129
        io_list = self._resolve_function_io(function)
2✔
130
        return (
2✔
131
            self.node.name,
132
            impl_path,
133
            self.node.function_type,
134
            io_list,
135
            self.node.metadata,
136
        )
137

138
    def handle_forecasting_function(self):
2✔
139
        """Handle forecasting functions"""
140
        aggregated_batch_list = []
×
141
        child = self.children[0]
×
142
        for batch in child.exec():
×
143
            aggregated_batch_list.append(batch)
×
144
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
145
        aggregated_batch.drop_column_alias()
×
146

147
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
148
        if not self.node.impl_path:
×
149
            impl_path = Path(f"{self.function_dir}/forecast.py").absolute().as_posix()
×
150
        else:
151
            impl_path = self.node.impl_path.absolute().as_posix()
×
152
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
153

154
        if "model" not in arg_map.keys():
×
155
            arg_map["model"] = "AutoARIMA"
×
156
        if "frequency" not in arg_map.keys():
×
157
            arg_map["frequency"] = "M"
×
158

159
        model_name = arg_map["model"]
×
160
        frequency = arg_map["frequency"]
×
161

162
        data = aggregated_batch.frames.rename(columns={arg_map["predict"]: "y"})
×
163
        if "time" in arg_map.keys():
×
164
            aggregated_batch.frames.rename(columns={arg_map["time"]: "ds"})
×
165
        if "id" in arg_map.keys():
×
166
            aggregated_batch.frames.rename(columns={arg_map["id"]: "unique_id"})
×
167

168
        if "unique_id" not in list(data.columns):
×
169
            data["unique_id"] = ["test" for x in range(len(data))]
×
170

171
        if "ds" not in list(data.columns):
×
172
            data["ds"] = [x + 1 for x in range(len(data))]
×
173

174
        try_to_import_forecast()
×
175
        from statsforecast import StatsForecast
×
176
        from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
177

178
        model_dict = {
×
179
            "AutoARIMA": AutoARIMA,
180
            "AutoCES": AutoCES,
181
            "AutoETS": AutoETS,
182
            "AutoTheta": AutoTheta,
183
        }
184

185
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
186
            "H": 24,
187
            "M": 12,
188
            "Q": 4,
189
            "SM": 24,
190
            "BM": 12,
191
            "BMS": 12,
192
            "BQ": 4,
193
            "BH": 24,
194
        }
195

196
        new_freq = (
×
197
            frequency.split("-")[0] if "-" in frequency else frequency
198
        )  # shortens longer frequencies like Q-DEC
199
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
200
        model = StatsForecast(
×
201
            [model_dict[model_name](season_length=season_length)], freq=new_freq
202
        )
203

204
        model_dir = os.path.join(
×
205
            self.db.config.get_value("storage", "model_dir"), self.node.name
206
        )
207
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
208
        model_path = os.path.join(
×
209
            self.db.config.get_value("storage", "model_dir"),
210
            self.node.name,
211
            str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl",
212
        )
213

214
        weight_file = Path(model_path)
×
215

216
        if not weight_file.exists():
×
217
            model.fit(data)
×
218
            f = open(model_path, "wb")
×
219
            pickle.dump(model, f)
×
220
            f.close()
×
221

222
        arg_map_here = {"model_name": model_name, "model_path": model_path}
×
223
        function = self._try_initializing_function(impl_path, arg_map_here)
×
224
        io_list = self._resolve_function_io(function)
×
225

226
        metadata_here = [
×
227
            FunctionMetadataCatalogEntry(
228
                key="model_name",
229
                value=model_name,
230
                function_id=None,
231
                function_name=None,
232
                row_id=None,
233
            ),
234
            FunctionMetadataCatalogEntry(
235
                key="model_path",
236
                value=model_path,
237
                function_id=None,
238
                function_name=None,
239
                row_id=None,
240
            ),
241
        ]
242

243
        return (
×
244
            self.node.name,
245
            impl_path,
246
            self.node.function_type,
247
            io_list,
248
            metadata_here,
249
        )
250

251
    def handle_generic_function(self):
2✔
252
        """Handle generic functions
253

254
        Generic functions are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
255
        """
256
        impl_path = self.node.impl_path.absolute().as_posix()
2✔
257
        function = self._try_initializing_function(impl_path)
2✔
258
        io_list = self._resolve_function_io(function)
2✔
259

260
        return (
2✔
261
            self.node.name,
262
            impl_path,
263
            self.node.function_type,
264
            io_list,
265
            self.node.metadata,
266
        )
267

268
    def exec(self, *args, **kwargs):
2✔
269
        """Create function executor
270

271
        Calls the catalog to insert a function catalog entry.
272
        """
273
        # check catalog if it already has this function entry
274
        if self.catalog().get_function_catalog_entry_by_name(self.node.name):
2✔
275
            if self.node.if_not_exists:
×
276
                msg = f"Function {self.node.name} already exists, nothing added."
×
277
                yield Batch(pd.DataFrame([msg]))
×
278
                return
×
279
            else:
280
                msg = f"Function {self.node.name} already exists."
×
281
                logger.error(msg)
×
282
                raise RuntimeError(msg)
283

284
        # if it's a type of HuggingFaceModel, override the impl_path
285
        if self.node.function_type == "HuggingFace":
2✔
286
            (
×
287
                name,
288
                impl_path,
289
                function_type,
290
                io_list,
291
                metadata,
292
            ) = self.handle_huggingface_function()
293
        elif self.node.function_type == "ultralytics":
2✔
294
            (
2✔
295
                name,
296
                impl_path,
297
                function_type,
298
                io_list,
299
                metadata,
300
            ) = self.handle_ultralytics_function()
301
        elif self.node.function_type == "Ludwig":
2✔
302
            (
×
303
                name,
304
                impl_path,
305
                function_type,
306
                io_list,
307
                metadata,
308
            ) = self.handle_ludwig_function()
309
        elif self.node.function_type == "Forecasting":
2✔
310
            (
×
311
                name,
312
                impl_path,
313
                function_type,
314
                io_list,
315
                metadata,
316
            ) = self.handle_forecasting_function()
317
        else:
318
            (
2✔
319
                name,
320
                impl_path,
321
                function_type,
322
                io_list,
323
                metadata,
324
            ) = self.handle_generic_function()
325

326
        self.catalog().insert_function_catalog_entry(
2✔
327
            name, impl_path, function_type, io_list, metadata
328
        )
329
        yield Batch(
2✔
330
            pd.DataFrame(
331
                [f"Function {self.node.name} successfully added to the database."]
332
            )
333
        )
334

335
    def _try_initializing_function(
2✔
336
        self, impl_path: str, function_args: Dict = {}
337
    ) -> FunctionCatalogEntry:
338
        """Attempts to initialize function given the implementation file path and arguments.
339

340
        Args:
341
            impl_path (str): The file path of the function implementation file.
342
            function_args (Dict, optional): Dictionary of arguments to pass to the function. Defaults to {}.
343

344
        Returns:
345
            FunctionCatalogEntry: A FunctionCatalogEntry object that represents the initialized function.
346

347
        Raises:
348
            RuntimeError: If an error occurs while initializing the function.
349
        """
350

351
        # load the function class from the file
352
        try:
2✔
353
            # loading the function class from the file
354
            function = load_function_class_from_file(impl_path, self.node.name)
2✔
355
            # initializing the function class calls the setup method internally
356
            function(**function_args)
2✔
357
        except Exception as e:
358
            err_msg = f"Error creating function: {str(e)}"
359
            # logger.error(err_msg)
360
            raise RuntimeError(err_msg)
361

362
        return function
2✔
363

364
    def _resolve_function_io(
2✔
365
        self, function: FunctionCatalogEntry
366
    ) -> List[FunctionIOCatalogEntry]:
367
        """Private method that resolves the input/output definitions for a given function.
368
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
369

370
        Args:
371
            function (FunctionCatalogEntry): The function for which to resolve input and output definitions.
372

373
        Returns:
374
            A List of FunctionIOCatalogEntry objects that represent the resolved input and
375
            output definitions for the function.
376

377
        Raises:
378
            RuntimeError: If an error occurs while resolving the function input/output
379
            definitions.
380
        """
381
        io_list = []
2✔
382
        try:
2✔
383
            if self.node.inputs:
2✔
384
                io_list.extend(self.node.inputs)
2✔
385
            else:
386
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
387
                io_list.extend(
2✔
388
                    load_io_from_function_decorators(function, is_input=True)
389
                )
390

391
            if self.node.outputs:
2✔
392
                io_list.extend(self.node.outputs)
2✔
393
            else:
394
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
395
                io_list.extend(
2✔
396
                    load_io_from_function_decorators(function, is_input=False)
397
                )
398

399
        except FunctionIODefinitionError as e:
400
            err_msg = (
401
                f"Error creating function, input/output definition incorrect: {str(e)}"
402
            )
403
            logger.error(err_msg)
404
            raise RuntimeError(err_msg)
405

406
        return io_list
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc