• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / b7e09b55-9333-4c49-b273-87a69e5c463f

05 Sep 2023 11:42PM UTC coverage: 74.515% (-19.0%) from 93.55%
b7e09b55-9333-4c49-b273-87a69e5c463f

Pull #1050

circle-ci

jiashenC
fix: add missing needed file (#1046)
Pull Request #1050: feat: sync master staging

768 of 768 new or added lines in 96 files covered. (100.0%)

8757 of 11752 relevant lines covered (74.51%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/evadb/executor/create_udf_executor.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import hashlib
×
16
import os
×
17
import pickle
×
18
from pathlib import Path
×
19
from typing import Dict, List
×
20

21
import pandas as pd
×
22

23
from evadb.catalog.catalog_utils import get_metadata_properties
×
24
from evadb.catalog.models.udf_catalog import UdfCatalogEntry
×
25
from evadb.catalog.models.udf_io_catalog import UdfIOCatalogEntry
×
26
from evadb.catalog.models.udf_metadata_catalog import UdfMetadataCatalogEntry
×
27
from evadb.configuration.constants import (
×
28
    DEFAULT_TRAIN_TIME_LIMIT,
29
    EvaDB_INSTALLATION_DIR,
30
)
31
from evadb.database import EvaDBDatabase
×
32
from evadb.executor.abstract_executor import AbstractExecutor
×
33
from evadb.models.storage.batch import Batch
×
34
from evadb.plan_nodes.create_udf_plan import CreateUDFPlan
×
35
from evadb.third_party.huggingface.create import gen_hf_io_catalog_entries
×
36
from evadb.udfs.decorators.utils import load_io_from_udf_decorators
×
37
from evadb.utils.errors import UDFIODefinitionError
×
38
from evadb.utils.generic_utils import (
×
39
    load_udf_class_from_file,
40
    try_to_import_forecast,
41
    try_to_import_ludwig,
42
    try_to_import_torch,
43
    try_to_import_ultralytics,
44
)
45
from evadb.utils.logging_manager import logger
×
46

47

48
class CreateUDFExecutor(AbstractExecutor):
×
49
    def __init__(self, db: EvaDBDatabase, node: CreateUDFPlan):
×
50
        super().__init__(db, node)
×
51
        self.udf_dir = Path(EvaDB_INSTALLATION_DIR) / "udfs"
×
52

53
    def handle_huggingface_udf(self):
×
54
        """Handle HuggingFace UDFs
55

56
        HuggingFace UDFs are special UDFs that are not loaded from a file.
57
        So we do not need to call the setup method on them like we do for other UDFs.
58
        """
59
        # We need at least one deep learning framework for HuggingFace
60
        # Torch or Tensorflow
61
        try_to_import_torch()
×
62
        impl_path = f"{self.udf_dir}/abstract/hf_abstract_udf.py"
×
63
        io_list = gen_hf_io_catalog_entries(self.node.name, self.node.metadata)
×
64
        return (
×
65
            self.node.name,
66
            impl_path,
67
            self.node.udf_type,
68
            io_list,
69
            self.node.metadata,
70
        )
71

72
    def handle_ludwig_udf(self):
×
73
        """Handle ludwig UDFs
74

75
        Use Ludwig's auto_train engine to train/tune models.
76
        """
77
        try_to_import_ludwig()
×
78
        from ludwig.automl import auto_train
×
79

80
        assert (
×
81
            len(self.children) == 1
82
        ), "Create ludwig UDF expects 1 child, finds {}.".format(len(self.children))
83

84
        aggregated_batch_list = []
×
85
        child = self.children[0]
×
86
        for batch in child.exec():
×
87
            aggregated_batch_list.append(batch)
×
88
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
89
        aggregated_batch.drop_column_alias()
×
90

91
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
92
        auto_train_results = auto_train(
×
93
            dataset=aggregated_batch.frames,
94
            target=arg_map["predict"],
95
            tune_for_memory=arg_map.get("tune_for_memory", False),
96
            time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
97
            output_directory=self.db.config.get_value("storage", "tmp_dir"),
98
        )
99
        model_path = os.path.join(
×
100
            self.db.config.get_value("storage", "model_dir"), self.node.name
101
        )
102
        auto_train_results.best_model.save(model_path)
×
103
        self.node.metadata.append(UdfMetadataCatalogEntry("model_path", model_path))
×
104

105
        impl_path = Path(f"{self.udf_dir}/ludwig.py").absolute().as_posix()
×
106
        io_list = self._resolve_udf_io(None)
×
107
        return (
×
108
            self.node.name,
109
            impl_path,
110
            self.node.udf_type,
111
            io_list,
112
            self.node.metadata,
113
        )
114

115
    def handle_ultralytics_udf(self):
×
116
        """Handle Ultralytics UDFs"""
117
        try_to_import_ultralytics()
×
118

119
        impl_path = (
×
120
            Path(f"{self.udf_dir}/yolo_object_detector.py").absolute().as_posix()
121
        )
122
        udf = self._try_initializing_udf(
×
123
            impl_path, udf_args=get_metadata_properties(self.node)
124
        )
125
        io_list = self._resolve_udf_io(udf)
×
126
        return (
×
127
            self.node.name,
128
            impl_path,
129
            self.node.udf_type,
130
            io_list,
131
            self.node.metadata,
132
        )
133

134
    def handle_forecasting_udf(self):
×
135
        """Handle forecasting UDFs"""
136
        aggregated_batch_list = []
×
137
        child = self.children[0]
×
138
        for batch in child.exec():
×
139
            aggregated_batch_list.append(batch)
×
140
        aggregated_batch = Batch.concat(aggregated_batch_list, copy=False)
×
141
        aggregated_batch.drop_column_alias()
×
142

143
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
144
        if not self.node.impl_path:
×
145
            impl_path = Path(f"{self.udf_dir}/forecast.py").absolute().as_posix()
×
146
        else:
147
            impl_path = self.node.impl_path.absolute().as_posix()
×
148
        arg_map = {arg.key: arg.value for arg in self.node.metadata}
×
149

150
        if "model" not in arg_map.keys():
×
151
            arg_map["model"] = "AutoARIMA"
×
152
        if "frequency" not in arg_map.keys():
×
153
            arg_map["frequency"] = "M"
×
154

155
        model_name = arg_map["model"]
×
156
        frequency = arg_map["frequency"]
×
157

158
        data = aggregated_batch.frames.rename(columns={arg_map["predict"]: "y"})
×
159
        if "time" in arg_map.keys():
×
160
            aggregated_batch.frames.rename(columns={arg_map["time"]: "ds"})
×
161
        if "id" in arg_map.keys():
×
162
            aggregated_batch.frames.rename(columns={arg_map["id"]: "unique_id"})
×
163

164
        if "unique_id" not in list(data.columns):
×
165
            data["unique_id"] = ["test" for x in range(len(data))]
×
166

167
        if "ds" not in list(data.columns):
×
168
            data["ds"] = [x + 1 for x in range(len(data))]
×
169

170
        try_to_import_forecast()
×
171
        from statsforecast import StatsForecast
×
172
        from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
×
173

174
        model_dict = {
×
175
            "AutoARIMA": AutoARIMA,
176
            "AutoCES": AutoCES,
177
            "AutoETS": AutoETS,
178
            "AutoTheta": AutoTheta,
179
        }
180

181
        season_dict = {  # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases
×
182
            "H": 24,
183
            "M": 12,
184
            "Q": 4,
185
            "SM": 24,
186
            "BM": 12,
187
            "BMS": 12,
188
            "BQ": 4,
189
            "BH": 24,
190
        }
191

192
        new_freq = (
×
193
            frequency.split("-")[0] if "-" in frequency else frequency
194
        )  # shortens longer frequencies like Q-DEC
195
        season_length = season_dict[new_freq] if new_freq in season_dict else 1
×
196
        model = StatsForecast(
×
197
            [model_dict[model_name](season_length=season_length)], freq=new_freq
198
        )
199

200
        model_dir = os.path.join(
×
201
            self.db.config.get_value("storage", "model_dir"), self.node.name
202
        )
203
        Path(model_dir).mkdir(parents=True, exist_ok=True)
×
204
        model_path = os.path.join(
×
205
            self.db.config.get_value("storage", "model_dir"),
206
            self.node.name,
207
            str(hashlib.sha256(data.to_string().encode()).hexdigest()) + ".pkl",
208
        )
209

210
        weight_file = Path(model_path)
×
211

212
        if not weight_file.exists():
×
213
            model.fit(data)
×
214
            f = open(model_path, "wb")
×
215
            pickle.dump(model, f)
×
216
            f.close()
×
217

218
        arg_map_here = {"model_name": model_name, "model_path": model_path}
×
219
        udf = self._try_initializing_udf(impl_path, arg_map_here)
×
220
        io_list = self._resolve_udf_io(udf)
×
221

222
        metadata_here = [
×
223
            UdfMetadataCatalogEntry(
224
                key="model_name",
225
                value=model_name,
226
                udf_id=None,
227
                udf_name=None,
228
                row_id=None,
229
            ),
230
            UdfMetadataCatalogEntry(
231
                key="model_path",
232
                value=model_path,
233
                udf_id=None,
234
                udf_name=None,
235
                row_id=None,
236
            ),
237
        ]
238

239
        return (
×
240
            self.node.name,
241
            impl_path,
242
            self.node.udf_type,
243
            io_list,
244
            metadata_here,
245
        )
246

247
    def handle_generic_udf(self):
×
248
        """Handle generic UDFs
249

250
        Generic UDFs are loaded from a file. We check for inputs passed by the user during CREATE or try to load io from decorators.
251
        """
252
        impl_path = self.node.impl_path.absolute().as_posix()
×
253
        udf = self._try_initializing_udf(impl_path)
×
254
        io_list = self._resolve_udf_io(udf)
×
255

256
        return (
×
257
            self.node.name,
258
            impl_path,
259
            self.node.udf_type,
260
            io_list,
261
            self.node.metadata,
262
        )
263

264
    def exec(self, *args, **kwargs):
×
265
        """Create udf executor
266

267
        Calls the catalog to insert a udf catalog entry.
268
        """
269
        # check catalog if it already has this udf entry
270
        if self.catalog().get_udf_catalog_entry_by_name(self.node.name):
×
271
            if self.node.if_not_exists:
×
272
                msg = f"UDF {self.node.name} already exists, nothing added."
×
273
                yield Batch(pd.DataFrame([msg]))
×
274
                return
×
275
            else:
276
                msg = f"UDF {self.node.name} already exists."
×
277
                logger.error(msg)
×
278
                raise RuntimeError(msg)
279

280
        # if it's a type of HuggingFaceModel, override the impl_path
281
        if self.node.udf_type == "HuggingFace":
×
282
            name, impl_path, udf_type, io_list, metadata = self.handle_huggingface_udf()
×
283
        elif self.node.udf_type == "ultralytics":
×
284
            name, impl_path, udf_type, io_list, metadata = self.handle_ultralytics_udf()
×
285
        elif self.node.udf_type == "Ludwig":
×
286
            name, impl_path, udf_type, io_list, metadata = self.handle_ludwig_udf()
×
287
        elif self.node.udf_type == "Forecasting":
×
288
            name, impl_path, udf_type, io_list, metadata = self.handle_forecasting_udf()
×
289
        else:
290
            name, impl_path, udf_type, io_list, metadata = self.handle_generic_udf()
×
291

292
        self.catalog().insert_udf_catalog_entry(
×
293
            name, impl_path, udf_type, io_list, metadata
294
        )
295
        yield Batch(
×
296
            pd.DataFrame([f"UDF {self.node.name} successfully added to the database."])
297
        )
298

299
    def _try_initializing_udf(
×
300
        self, impl_path: str, udf_args: Dict = {}
301
    ) -> UdfCatalogEntry:
302
        """Attempts to initialize UDF given the implementation file path and arguments.
303

304
        Args:
305
            impl_path (str): The file path of the UDF implementation file.
306
            udf_args (Dict, optional): Dictionary of arguments to pass to the UDF. Defaults to {}.
307

308
        Returns:
309
            UdfCatalogEntry: A UdfCatalogEntry object that represents the initialized UDF.
310

311
        Raises:
312
            RuntimeError: If an error occurs while initializing the UDF.
313
        """
314

315
        # load the udf class from the file
316
        try:
×
317
            # loading the udf class from the file
318
            udf = load_udf_class_from_file(impl_path, self.node.name)
×
319
            # initializing the udf class calls the setup method internally
320
            udf(**udf_args)
×
321
        except Exception as e:
322
            err_msg = f"Error creating UDF: {str(e)}"
323
            # logger.error(err_msg)
324
            raise RuntimeError(err_msg)
325

326
        return udf
×
327

328
    def _resolve_udf_io(self, udf: UdfCatalogEntry) -> List[UdfIOCatalogEntry]:
×
329
        """Private method that resolves the input/output definitions for a given UDF.
330
        It first searches for the input/outputs in the CREATE statement. If not found, it resolves them using decorators. If not found there as well, it raises an error.
331

332
        Args:
333
            udf (UdfCatalogEntry): The UDF for which to resolve input and output definitions.
334

335
        Returns:
336
            A List of UdfIOCatalogEntry objects that represent the resolved input and
337
            output definitions for the UDF.
338

339
        Raises:
340
            RuntimeError: If an error occurs while resolving the UDF input/output
341
            definitions.
342
        """
343
        io_list = []
×
344
        try:
×
345
            if self.node.inputs:
×
346
                io_list.extend(self.node.inputs)
×
347
            else:
348
                # try to load the inputs from decorators, the inputs from CREATE statement take precedence
349
                io_list.extend(load_io_from_udf_decorators(udf, is_input=True))
×
350

351
            if self.node.outputs:
×
352
                io_list.extend(self.node.outputs)
×
353
            else:
354
                # try to load the outputs from decorators, the outputs from CREATE statement take precedence
355
                io_list.extend(load_io_from_udf_decorators(udf, is_input=False))
×
356

357
        except UDFIODefinitionError as e:
358
            err_msg = f"Error creating UDF, input/output definition incorrect: {str(e)}"
359
            logger.error(err_msg)
360
            raise RuntimeError(err_msg)
361

362
        return io_list
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc