• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #758

04 Sep 2023 08:37PM UTC coverage: 0.0% (-78.3%) from 78.333%
#758

push

circle-ci

hershd23
Increased underline length in at line 75 in text_summarization.rst
	modified:   docs/source/benchmarks/text_summarization.rst

0 of 11303 relevant lines covered (0.0%)

0.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/evadb/binder/statement_binder.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from functools import singledispatchmethod
×
16
from pathlib import Path
×
17
from typing import Callable
×
18

19
from evadb.binder.binder_utils import (
×
20
    BinderError,
21
    bind_table_info,
22
    check_column_name_is_string,
23
    check_groupby_pattern,
24
    check_table_object_is_groupable,
25
    drop_row_id_from_target_list,
26
    extend_star,
27
    get_column_definition_from_select_target_list,
28
    handle_bind_extract_object_function,
29
    resolve_alias_table_value_expression,
30
)
31
from evadb.binder.statement_binder_context import StatementBinderContext
×
32
from evadb.catalog.catalog_type import NdArrayType, TableType, VideoColumnName
×
33
from evadb.catalog.catalog_utils import get_metadata_properties, is_document_table
×
34
from evadb.configuration.constants import EvaDB_INSTALLATION_DIR
×
35
from evadb.expression.abstract_expression import AbstractExpression, ExpressionType
×
36
from evadb.expression.function_expression import FunctionExpression
×
37
from evadb.expression.tuple_value_expression import TupleValueExpression
×
38
from evadb.parser.create_index_statement import CreateIndexStatement
×
39
from evadb.parser.create_statement import CreateTableStatement
×
40
from evadb.parser.create_udf_statement import CreateUDFStatement
×
41
from evadb.parser.delete_statement import DeleteTableStatement
×
42
from evadb.parser.explain_statement import ExplainStatement
×
43
from evadb.parser.rename_statement import RenameTableStatement
×
44
from evadb.parser.select_statement import SelectStatement
×
45
from evadb.parser.statement import AbstractStatement
×
46
from evadb.parser.table_ref import TableRef
×
47
from evadb.parser.types import UDFType
×
48
from evadb.third_party.huggingface.binder import assign_hf_udf
×
49
from evadb.utils.generic_utils import load_udf_class_from_file
×
50
from evadb.utils.logging_manager import logger
×
51

52

53
class StatementBinder:
×
54
    def __init__(self, binder_context: StatementBinderContext):
×
55
        self._binder_context = binder_context
×
56
        self._catalog: Callable = binder_context._catalog
×
57

58
    @singledispatchmethod
×
59
    def bind(self, node):
×
60
        raise NotImplementedError(f"Cannot bind {type(node)}")
61

62
    @bind.register(AbstractStatement)
×
63
    def _bind_abstract_statement(self, node: AbstractStatement):
×
64
        pass
×
65

66
    @bind.register(AbstractExpression)
×
67
    def _bind_abstract_expr(self, node: AbstractExpression):
×
68
        for child in node.children:
×
69
            self.bind(child)
×
70

71
    @bind.register(ExplainStatement)
×
72
    def _bind_explain_statement(self, node: ExplainStatement):
×
73
        self.bind(node.explainable_stmt)
×
74

75
    @bind.register(CreateUDFStatement)
×
76
    def _bind_create_udf_statement(self, node: CreateUDFStatement):
×
77
        if node.query is not None:
×
78
            self.bind(node.query)
×
79
            # Drop the automatically generated _row_id column
80
            node.query.target_list = drop_row_id_from_target_list(
×
81
                node.query.target_list
82
            )
83
            all_column_list = get_column_definition_from_select_target_list(
×
84
                node.query.target_list
85
            )
86
            arg_map = {key: value for key, value in node.metadata}
×
87
            assert (
×
88
                "predict" in arg_map
89
            ), f"Creating {node.udf_type} UDFs expects 'predict' metadata."
90
            # We only support a single predict column for now
91
            predict_columns = set([arg_map["predict"]])
×
92
            inputs, outputs = [], []
×
93
            for column in all_column_list:
×
94
                if column.name in predict_columns:
×
95
                    column.name = column.name + "_predictions"
×
96
                    outputs.append(column)
×
97
                else:
98
                    inputs.append(column)
×
99
            assert (
×
100
                len(node.inputs) == 0 and len(node.outputs) == 0
101
            ), f"{node.udf_type} UDFs' input and output are auto assigned"
102
            node.inputs, node.outputs = inputs, outputs
×
103

104
    @bind.register(CreateIndexStatement)
×
105
    def _bind_create_index_statement(self, node: CreateIndexStatement):
×
106
        self.bind(node.table_ref)
×
107
        if node.udf_func:
×
108
            self.bind(node.udf_func)
×
109

110
        # TODO: create index currently only supports single numpy column.
111
        assert len(node.col_list) == 1, "Index cannot be created on more than 1 column"
×
112

113
        # TODO: create index currently only works on TableInfo, but will extend later.
114
        assert node.table_ref.is_table_atom(), "Index can only be created on Tableinfo"
×
115
        if not node.udf_func:
×
116
            # Feature table type needs to be float32 numpy array.
117
            assert (
×
118
                len(node.col_list) == 1
119
            ), f"Index can be only created on one column, but instead {len(node.col_list)} are provided"
120
            col_def = node.col_list[0]
×
121

122
            table_ref_obj = node.table_ref.table.table_obj
×
123
            col_list = [
×
124
                col for col in table_ref_obj.columns if col.name == col_def.name
125
            ]
126
            assert (
×
127
                len(col_list) == 1
128
            ), f"Index is created on non-existent column {col_def.name}"
129

130
            col = col_list[0]
×
131
            assert (
×
132
                col.array_type == NdArrayType.FLOAT32
133
            ), "Index input needs to be float32."
134
            assert len(col.array_dimensions) == 2
×
135
        else:
136
            # Output of the UDF should be 2 dimension and float32 type.
137
            udf_obj = self._catalog().get_udf_catalog_entry_by_name(node.udf_func.name)
×
138
            for output in udf_obj.outputs:
×
139
                assert (
×
140
                    output.array_type == NdArrayType.FLOAT32
141
                ), "Index input needs to be float32."
142
                assert (
×
143
                    len(output.array_dimensions) == 2
144
                ), "Index input needs to be 2 dimensional."
145

146
    @bind.register(SelectStatement)
×
147
    def _bind_select_statement(self, node: SelectStatement):
×
148
        self.bind(node.from_table)
×
149
        if node.where_clause:
×
150
            self.bind(node.where_clause)
×
151
            if node.where_clause.etype == ExpressionType.COMPARE_LIKE:
×
152
                check_column_name_is_string(node.where_clause.children[0])
×
153

154
        if node.target_list:
×
155
            # SELECT * support
156
            if (
×
157
                len(node.target_list) == 1
158
                and isinstance(node.target_list[0], TupleValueExpression)
159
                and node.target_list[0].name == "*"
160
            ):
161
                node.target_list = extend_star(self._binder_context)
×
162
            for expr in node.target_list:
×
163
                self.bind(expr)
×
164
        if node.groupby_clause:
×
165
            self.bind(node.groupby_clause)
×
166
            check_table_object_is_groupable(node.from_table)
×
167
            check_groupby_pattern(node.from_table, node.groupby_clause.value)
×
168
        if node.orderby_list:
×
169
            for expr in node.orderby_list:
×
170
                self.bind(expr[0])
×
171
        if node.union_link:
×
172
            current_context = self._binder_context
×
173
            self._binder_context = StatementBinderContext(self._catalog)
×
174
            self.bind(node.union_link)
×
175
            self._binder_context = current_context
×
176

177
        # chunk_params only supported for DOCUMENT TYPE
178
        if node.from_table.chunk_params:
×
179
            assert is_document_table(
×
180
                node.from_table.table.table_obj
181
            ), "CHUNK related parameters only supported for DOCUMENT tables."
182

183
        assert not (
×
184
            self._binder_context.is_retrieve_audio()
185
            and self._binder_context.is_retrieve_video()
186
        ), "Cannot query over both audio and video streams"
187

188
        if self._binder_context.is_retrieve_audio():
×
189
            node.from_table.get_audio = True
×
190
        if self._binder_context.is_retrieve_video():
×
191
            node.from_table.get_video = True
×
192

193
    @bind.register(DeleteTableStatement)
×
194
    def _bind_delete_statement(self, node: DeleteTableStatement):
×
195
        self.bind(node.table_ref)
×
196
        if node.where_clause:
×
197
            self.bind(node.where_clause)
×
198

199
    @bind.register(CreateTableStatement)
×
200
    def _bind_create_statement(self, node: CreateTableStatement):
×
201
        if node.query is not None:
×
202
            self.bind(node.query)
×
203

204
            node.column_list = get_column_definition_from_select_target_list(
×
205
                node.query.target_list
206
            )
207

208
    @bind.register(RenameTableStatement)
×
209
    def _bind_rename_table_statement(self, node: RenameTableStatement):
×
210
        self.bind(node.old_table_ref)
×
211
        assert (
×
212
            node.old_table_ref.table.table_obj.table_type != TableType.STRUCTURED_DATA
213
        ), "Rename not yet supported on structured data"
214

215
    @bind.register(TableRef)
×
216
    def _bind_tableref(self, node: TableRef):
×
217
        if node.is_table_atom():
×
218
            # Table
219
            self._binder_context.add_table_alias(
×
220
                node.alias.alias_name, node.table.database_name, node.table.table_name
221
            )
222
            bind_table_info(self._catalog(), node.table)
×
223
        elif node.is_select():
×
224
            current_context = self._binder_context
×
225
            self._binder_context = StatementBinderContext(self._catalog)
×
226
            self.bind(node.select_statement)
×
227
            self._binder_context = current_context
×
228
            self._binder_context.add_derived_table_alias(
×
229
                node.alias.alias_name, node.select_statement.target_list
230
            )
231
        elif node.is_join():
×
232
            self.bind(node.join_node.left)
×
233
            self.bind(node.join_node.right)
×
234
            if node.join_node.predicate:
×
235
                self.bind(node.join_node.predicate)
×
236
        elif node.is_table_valued_expr():
×
237
            func_expr = node.table_valued_expr.func_expr
×
238
            func_expr.alias = node.alias
×
239
            self.bind(func_expr)
×
240
            output_cols = []
×
241
            for obj, alias in zip(func_expr.output_objs, func_expr.alias.col_names):
×
242
                col_alias = "{}.{}".format(func_expr.alias.alias_name, alias)
×
243
                alias_obj = TupleValueExpression(
×
244
                    name=alias,
245
                    table_alias=func_expr.alias.alias_name,
246
                    col_object=obj,
247
                    col_alias=col_alias,
248
                )
249
                output_cols.append(alias_obj)
×
250
            self._binder_context.add_derived_table_alias(
×
251
                func_expr.alias.alias_name, output_cols
252
            )
253
        else:
254
            raise BinderError(f"Unsupported node {type(node)}")
255

256
    @bind.register(TupleValueExpression)
×
257
    def _bind_tuple_expr(self, node: TupleValueExpression):
×
258
        table_alias, col_obj = self._binder_context.get_binded_column(
×
259
            node.name, node.table_alias
260
        )
261
        node.table_alias = table_alias
×
262
        if node.name == VideoColumnName.audio:
×
263
            self._binder_context.enable_audio_retrieval()
×
264
        if node.name == VideoColumnName.data:
×
265
            self._binder_context.enable_video_retrieval()
×
266
        node.col_alias = "{}.{}".format(table_alias, node.name.lower())
×
267
        node.col_object = col_obj
×
268

269
    @bind.register(FunctionExpression)
×
270
    def _bind_func_expr(self, node: FunctionExpression):
×
271
        # handle the special case of "extract_object"
272
        if node.name.upper() == str(UDFType.EXTRACT_OBJECT):
×
273
            handle_bind_extract_object_function(node, self)
×
274
            return
×
275

276
        # Handle Func(*)
277
        if (
×
278
            len(node.children) == 1
279
            and isinstance(node.children[0], TupleValueExpression)
280
            and node.children[0].name == "*"
281
        ):
282
            node.children = extend_star(self._binder_context)
×
283
        # bind all the children
284
        for child in node.children:
×
285
            self.bind(child)
×
286

287
        udf_obj = self._catalog().get_udf_catalog_entry_by_name(node.name)
×
288
        if udf_obj is None:
×
289
            err_msg = (
×
290
                f"Function '{node.name}' does not exist in the catalog. "
291
                "Please create the function using CREATE UDF command."
292
            )
293
            logger.error(err_msg)
×
294
            raise BinderError(err_msg)
295

296
        if udf_obj.type == "HuggingFace":
×
297
            node.function = assign_hf_udf(udf_obj)
×
298

299
        elif udf_obj.type == "Ludwig":
×
300
            udf_class = load_udf_class_from_file(
×
301
                udf_obj.impl_file_path,
302
                "GenericLudwigModel",
303
            )
304
            udf_metadata = get_metadata_properties(udf_obj)
×
305
            assert "model_path" in udf_metadata, "Ludwig models expect 'model_path'."
×
306
            node.function = lambda: udf_class(model_path=udf_metadata["model_path"])
×
307

308
        else:
309
            if udf_obj.type == "ultralytics":
×
310
                # manually set the impl_path for yolo udfs we only handle object
311
                # detection for now, hopefully this can be generalized
312
                udf_dir = Path(EvaDB_INSTALLATION_DIR) / "udfs"
×
313
                udf_obj.impl_file_path = (
×
314
                    Path(f"{udf_dir}/yolo_object_detector.py").absolute().as_posix()
315
                )
316

317
            # Verify the consistency of the UDF. If the checksum of the UDF does not
318
            # match the one stored in the catalog, an error will be thrown and the user
319
            # will be asked to register the UDF again.
320
            # assert (
321
            #     get_file_checksum(udf_obj.impl_file_path) == udf_obj.checksum
322
            # ), f"""UDF file {udf_obj.impl_file_path} has been modified from the
323
            #     registration. Please use DROP UDF to drop it and re-create it # using CREATE UDF."""
324

325
            try:
×
326
                udf_class = load_udf_class_from_file(
×
327
                    udf_obj.impl_file_path,
328
                    udf_obj.name,
329
                )
330
                # certain udfs take additional inputs like yolo needs the model_name
331
                # these arguments are passed by the user as part of metadata
332
                node.function = lambda: udf_class(**get_metadata_properties(udf_obj))
×
333
            except Exception as e:
334
                err_msg = (
335
                    f"{str(e)}. Please verify that the UDF class name in the "
336
                    "implementation file matches the UDF name."
337
                )
338
                logger.error(err_msg)
339
                raise BinderError(err_msg)
340

341
        node.udf_obj = udf_obj
×
342
        output_objs = self._catalog().get_udf_io_catalog_output_entries(udf_obj)
×
343
        if node.output:
×
344
            for obj in output_objs:
×
345
                if obj.name.lower() == node.output:
×
346
                    node.output_objs = [obj]
×
347
            if not node.output_objs:
×
348
                err_msg = f"Output {node.output} does not exist for {udf_obj.name}."
×
349
                logger.error(err_msg)
×
350
                raise BinderError(err_msg)
351
            node.projection_columns = [node.output]
×
352
        else:
353
            node.output_objs = output_objs
×
354
            node.projection_columns = [obj.name.lower() for obj in output_objs]
×
355

356
        resolve_alias_table_value_expression(node)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc