• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / #754

04 Sep 2023 09:54PM UTC coverage: 74.807% (-5.5%) from 80.336%
#754

push

circle-ci

jiashenC
update case

8727 of 11666 relevant lines covered (74.81%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

61.87
/evadb/optimizer/rules/rules.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
from __future__ import annotations
1✔
16

17
from typing import TYPE_CHECKING
1✔
18

19
from evadb.catalog.catalog_type import TableType
1✔
20
from evadb.catalog.catalog_utils import is_video_table
1✔
21
from evadb.constants import CACHEABLE_FUNCTIONS
1✔
22
from evadb.executor.execution_context import Context
1✔
23
from evadb.expression.expression_utils import (
1✔
24
    conjunction_list_to_expression_tree,
25
    to_conjunction_list,
26
)
27
from evadb.expression.function_expression import FunctionExpression
1✔
28
from evadb.expression.tuple_value_expression import TupleValueExpression
1✔
29
from evadb.optimizer.optimizer_utils import (
1✔
30
    check_expr_validity_for_cache,
31
    enable_cache,
32
    enable_cache_on_expression_tree,
33
    extract_equi_join_keys,
34
    extract_pushdown_predicate,
35
    extract_pushdown_predicate_for_alias,
36
    get_expression_execution_cost,
37
)
38
from evadb.optimizer.rules.pattern import Pattern
1✔
39
from evadb.optimizer.rules.rules_base import Promise, Rule, RuleType
1✔
40
from evadb.parser.types import JoinType, ParserOrderBySortType
1✔
41
from evadb.plan_nodes.apply_and_merge_plan import ApplyAndMergePlan
1✔
42
from evadb.plan_nodes.create_from_select_plan import CreateFromSelectPlan
1✔
43
from evadb.plan_nodes.exchange_plan import ExchangePlan
1✔
44
from evadb.plan_nodes.explain_plan import ExplainPlan
1✔
45
from evadb.plan_nodes.hash_join_build_plan import HashJoinBuildPlan
1✔
46
from evadb.plan_nodes.nested_loop_join_plan import NestedLoopJoinPlan
1✔
47
from evadb.plan_nodes.predicate_plan import PredicatePlan
1✔
48
from evadb.plan_nodes.project_plan import ProjectPlan
1✔
49
from evadb.plan_nodes.show_info_plan import ShowInfoPlan
1✔
50

51
if TYPE_CHECKING:
52
    from evadb.optimizer.optimizer_context import OptimizerContext
53

54
from evadb.optimizer.operators import (
1✔
55
    Dummy,
56
    LogicalApplyAndMerge,
57
    LogicalCreate,
58
    LogicalCreateFunction,
59
    LogicalCreateIndex,
60
    LogicalDelete,
61
    LogicalDropObject,
62
    LogicalExchange,
63
    LogicalExplain,
64
    LogicalExtractObject,
65
    LogicalFilter,
66
    LogicalFunctionScan,
67
    LogicalGet,
68
    LogicalGroupBy,
69
    LogicalInsert,
70
    LogicalJoin,
71
    LogicalLimit,
72
    LogicalLoadData,
73
    LogicalOrderBy,
74
    LogicalProject,
75
    LogicalQueryDerivedGet,
76
    LogicalRename,
77
    LogicalSample,
78
    LogicalShow,
79
    LogicalUnion,
80
    LogicalVectorIndexScan,
81
    Operator,
82
    OperatorType,
83
)
84
from evadb.plan_nodes.create_function_plan import CreateFunctionPlan
1✔
85
from evadb.plan_nodes.create_index_plan import CreateIndexPlan
1✔
86
from evadb.plan_nodes.create_plan import CreatePlan
1✔
87
from evadb.plan_nodes.delete_plan import DeletePlan
1✔
88
from evadb.plan_nodes.drop_object_plan import DropObjectPlan
1✔
89
from evadb.plan_nodes.function_scan_plan import FunctionScanPlan
1✔
90
from evadb.plan_nodes.groupby_plan import GroupByPlan
1✔
91
from evadb.plan_nodes.hash_join_probe_plan import HashJoinProbePlan
1✔
92
from evadb.plan_nodes.insert_plan import InsertPlan
1✔
93
from evadb.plan_nodes.lateral_join_plan import LateralJoinPlan
1✔
94
from evadb.plan_nodes.limit_plan import LimitPlan
1✔
95
from evadb.plan_nodes.load_data_plan import LoadDataPlan
1✔
96
from evadb.plan_nodes.orderby_plan import OrderByPlan
1✔
97
from evadb.plan_nodes.rename_plan import RenamePlan
1✔
98
from evadb.plan_nodes.seq_scan_plan import SeqScanPlan
1✔
99
from evadb.plan_nodes.storage_plan import StoragePlan
1✔
100
from evadb.plan_nodes.union_plan import UnionPlan
1✔
101
from evadb.plan_nodes.vector_index_scan_plan import VectorIndexScanPlan
1✔
102

103
##############################################
104
# REWRITE RULES START
105

106

107
class EmbedFilterIntoGet(Rule):
1✔
108
    def __init__(self):
1✔
109
        pattern = Pattern(OperatorType.LOGICALFILTER)
1✔
110
        pattern.append_child(Pattern(OperatorType.LOGICALGET))
1✔
111
        super().__init__(RuleType.EMBED_FILTER_INTO_GET, pattern)
1✔
112

113
    def promise(self):
1✔
114
        return Promise.EMBED_FILTER_INTO_GET
1✔
115

116
    def check(self, before: LogicalFilter, context: OptimizerContext):
1✔
117
        # System supports predicate pushdown only while reading video data
118
        predicate = before.predicate
1✔
119
        lget: LogicalGet = before.children[0]
1✔
120
        if predicate and is_video_table(lget.table_obj):
1✔
121
            # System only supports pushing basic range predicates on id
122
            video_alias = lget.video.alias
1✔
123
            col_alias = f"{video_alias}.id"
1✔
124
            pushdown_pred, _ = extract_pushdown_predicate(predicate, col_alias)
1✔
125
            if pushdown_pred:
1✔
126
                return True
1✔
127
        return False
1✔
128

129
    def apply(self, before: LogicalFilter, context: OptimizerContext):
1✔
130
        predicate = before.predicate
1✔
131
        lget = before.children[0]
1✔
132
        # System only supports pushing basic range predicates on id
133
        video_alias = lget.video.alias
1✔
134
        col_alias = f"{video_alias}.id"
1✔
135
        pushdown_pred, unsupported_pred = extract_pushdown_predicate(
1✔
136
            predicate, col_alias
137
        )
138
        if pushdown_pred:
1✔
139
            new_get_opr = LogicalGet(
1✔
140
                lget.video,
141
                lget.table_obj,
142
                alias=lget.alias,
143
                predicate=pushdown_pred,
144
                target_list=lget.target_list,
145
                sampling_rate=lget.sampling_rate,
146
                sampling_type=lget.sampling_type,
147
                children=lget.children,
148
            )
149
            if unsupported_pred:
1✔
150
                unsupported_opr = LogicalFilter(unsupported_pred)
×
151
                unsupported_opr.append_child(new_get_opr)
×
152
                new_get_opr = unsupported_opr
×
153
            yield new_get_opr
1✔
154
        else:
155
            yield before
1✔
156

157

158
class EmbedSampleIntoGet(Rule):
1✔
159
    def __init__(self):
1✔
160
        pattern = Pattern(OperatorType.LOGICALSAMPLE)
1✔
161
        pattern.append_child(Pattern(OperatorType.LOGICALGET))
1✔
162
        super().__init__(RuleType.EMBED_SAMPLE_INTO_GET, pattern)
1✔
163

164
    def promise(self):
1✔
165
        return Promise.EMBED_SAMPLE_INTO_GET
×
166

167
    def check(self, before: LogicalSample, context: OptimizerContext):
1✔
168
        # System supports sample pushdown only while reading video data
169
        lget: LogicalGet = before.children[0]
1✔
170
        if lget.table_obj.table_type == TableType.VIDEO_DATA:
1✔
171
            return True
×
172
        return False
1✔
173

174
    def apply(self, before: LogicalSample, context: OptimizerContext):
1✔
175
        sample_freq = before.sample_freq.value
×
176
        sample_type = before.sample_type.value.value if before.sample_type else None
×
177
        lget: LogicalGet = before.children[0]
×
178
        new_get_opr = LogicalGet(
×
179
            lget.video,
180
            lget.table_obj,
181
            alias=lget.alias,
182
            predicate=lget.predicate,
183
            target_list=lget.target_list,
184
            sampling_rate=sample_freq,
185
            sampling_type=sample_type,
186
            children=lget.children,
187
        )
188
        yield new_get_opr
×
189

190

191
class CacheFunctionExpressionInProject(Rule):
1✔
192
    def __init__(self):
1✔
193
        pattern = Pattern(OperatorType.LOGICALPROJECT)
1✔
194
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
195
        super().__init__(RuleType.CACHE_FUNCTION_EXPRESISON_IN_PROJECT, pattern)
1✔
196

197
    def promise(self):
1✔
198
        return Promise.CACHE_FUNCTION_EXPRESISON_IN_PROJECT
1✔
199

200
    def check(self, before: LogicalProject, context: OptimizerContext):
1✔
201
        valid_exprs = []
1✔
202
        for expr in before.target_list:
1✔
203
            if isinstance(expr, FunctionExpression):
1✔
204
                func_exprs = list(expr.find_all(FunctionExpression))
1✔
205
                valid_exprs.extend(
1✔
206
                    filter(lambda expr: check_expr_validity_for_cache(expr), func_exprs)
207
                )
208

209
        if len(valid_exprs) > 0:
1✔
210
            return True
×
211
        return False
1✔
212

213
    def apply(self, before: LogicalProject, context: OptimizerContext):
1✔
214
        new_target_list = [expr.copy() for expr in before.target_list]
×
215
        for expr in new_target_list:
×
216
            enable_cache_on_expression_tree(context, expr)
×
217
        after = LogicalProject(target_list=new_target_list, children=before.children)
×
218
        yield after
×
219

220

221
class CacheFunctionExpressionInFilter(Rule):
1✔
222
    def __init__(self):
1✔
223
        pattern = Pattern(OperatorType.LOGICALFILTER)
1✔
224
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
225
        super().__init__(RuleType.CACHE_FUNCTION_EXPRESISON_IN_FILTER, pattern)
1✔
226

227
    def promise(self):
1✔
228
        return Promise.CACHE_FUNCTION_EXPRESISON_IN_FILTER
1✔
229

230
    def check(self, before: LogicalFilter, context: OptimizerContext):
1✔
231
        func_exprs = list(before.predicate.find_all(FunctionExpression))
1✔
232

233
        valid_exprs = list(
1✔
234
            filter(lambda expr: check_expr_validity_for_cache(expr), func_exprs)
235
        )
236

237
        if len(valid_exprs) > 0:
1✔
238
            return True
×
239
        return False
1✔
240

241
    def apply(self, before: LogicalFilter, context: OptimizerContext):
1✔
242
        # there could be 2^n different combinations with enable and disable option
243
        # cache for n function Expressions. Currently considering only the case where
244
        # cache is enabled for all eligible function expressions
245
        after_predicate = before.predicate.copy()
×
246
        enable_cache_on_expression_tree(context, after_predicate)
×
247
        after_operator = LogicalFilter(
×
248
            predicate=after_predicate, children=before.children
249
        )
250
        yield after_operator
×
251

252

253
class CacheFunctionExpressionInApply(Rule):
1✔
254
    def __init__(self):
1✔
255
        pattern = Pattern(OperatorType.LOGICAL_APPLY_AND_MERGE)
1✔
256
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
257
        super().__init__(RuleType.CACHE_FUNCTION_EXPRESISON_IN_APPLY, pattern)
1✔
258

259
    def promise(self):
1✔
260
        return Promise.CACHE_FUNCTION_EXPRESISON_IN_APPLY
×
261

262
    def check(self, before: LogicalApplyAndMerge, context: OptimizerContext):
1✔
263
        expr = before.func_expr
×
264
        # already cache enabled
265
        # replace the cacheable condition once we have the property supported as part of the function itself.
266
        if expr.has_cache() or expr.name not in CACHEABLE_FUNCTIONS:
×
267
            return False
×
268
        # we do not support caching function expression instances with multiple arguments or nested function expressions
269
        if len(expr.children) > 1 or not isinstance(
×
270
            expr.children[0], TupleValueExpression
271
        ):
272
            return False
×
273
        return True
×
274

275
    def apply(self, before: LogicalApplyAndMerge, context: OptimizerContext):
1✔
276
        # todo: this will create a catalog entry even in the case of explain command
277
        # We should run this code conditionally
278
        new_func_expr = enable_cache(context, before.func_expr)
×
279
        after = LogicalApplyAndMerge(
×
280
            func_expr=new_func_expr, alias=before.alias, do_unnest=before.do_unnest
281
        )
282
        after.append_child(before.children[0])
×
283
        yield after
×
284

285

286
# Join Queries
287
class PushDownFilterThroughJoin(Rule):
1✔
288
    def __init__(self):
1✔
289
        pattern = Pattern(OperatorType.LOGICALFILTER)
1✔
290
        pattern_join = Pattern(OperatorType.LOGICALJOIN)
1✔
291
        pattern_join.append_child(Pattern(OperatorType.DUMMY))
1✔
292
        pattern_join.append_child(Pattern(OperatorType.DUMMY))
1✔
293
        pattern.append_child(pattern_join)
1✔
294
        super().__init__(RuleType.PUSHDOWN_FILTER_THROUGH_JOIN, pattern)
1✔
295

296
    def promise(self):
1✔
297
        return Promise.PUSHDOWN_FILTER_THROUGH_JOIN
1✔
298

299
    def check(self, before: Operator, context: OptimizerContext):
1✔
300
        return True
×
301

302
    def apply(self, before: LogicalFilter, context: OptimizerContext):
1✔
303
        predicate = before.predicate
×
304
        join: LogicalJoin = before.children[0]
×
305
        left: Dummy = join.children[0]
×
306
        right: Dummy = join.children[1]
×
307

308
        new_join_node = LogicalJoin(
×
309
            join.join_type,
310
            join.join_predicate,
311
            join.left_keys,
312
            join.right_keys,
313
        )
314
        left_group_aliases = context.memo.get_group_by_id(left.group_id).aliases
×
315
        right_group_aliases = context.memo.get_group_by_id(right.group_id).aliases
×
316

317
        left_pushdown_pred, rem_pred = extract_pushdown_predicate_for_alias(
×
318
            predicate, left_group_aliases
319
        )
320
        right_pushdown_pred, rem_pred = extract_pushdown_predicate_for_alias(
×
321
            rem_pred, right_group_aliases
322
        )
323

324
        if left_pushdown_pred:
×
325
            left_filter = LogicalFilter(predicate=left_pushdown_pred)
×
326
            left_filter.append_child(left)
×
327
            new_join_node.append_child(left_filter)
×
328
        else:
329
            new_join_node.append_child(left)
×
330

331
        if right_pushdown_pred:
×
332
            right_filter = LogicalFilter(predicate=right_pushdown_pred)
×
333
            right_filter.append_child(right)
×
334
            new_join_node.append_child(right_filter)
×
335
        else:
336
            new_join_node.append_child(right)
×
337

338
        if rem_pred:
×
339
            new_join_node._join_predicate = conjunction_list_to_expression_tree(
×
340
                [rem_pred, new_join_node.join_predicate]
341
            )
342

343
        yield new_join_node
×
344

345

346
class XformLateralJoinToLinearFlow(Rule):
1✔
347
    """If the inner node of a lateral join is a function-valued expression, we
348
    eliminate the join node and make the inner node the parent of the outer node. This
349
    produces a linear data flow path. Because this scenario is common in our system,
350
    we chose to explicitly convert it to a linear flow, which simplifies the
351
    implementation of other optimizations such as function reuse and parallelized plans by
352
    removing the join."""
353

354
    def __init__(self):
1✔
355
        pattern = Pattern(OperatorType.LOGICALJOIN)
1✔
356
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
357
        pattern.append_child(Pattern(OperatorType.LOGICALFUNCTIONSCAN))
1✔
358
        super().__init__(RuleType.XFORM_LATERAL_JOIN_TO_LINEAR_FLOW, pattern)
1✔
359

360
    def promise(self):
1✔
361
        return Promise.XFORM_LATERAL_JOIN_TO_LINEAR_FLOW
×
362

363
    def check(self, before: LogicalJoin, context: OptimizerContext):
1✔
364
        if before.join_type == JoinType.LATERAL_JOIN:
1✔
365
            if before.join_predicate is None and not before.join_project:
×
366
                return True
×
367
        return False
1✔
368

369
    def apply(self, before: LogicalJoin, context: OptimizerContext):
1✔
370
        #     LogicalJoin(Lateral)              LogicalApplyAndMerge
371
        #     /           \                 ->       |
372
        #    A        LogicalFunctionScan            A
373

374
        A: Dummy = before.children[0]
×
375
        logical_func_scan: LogicalFunctionScan = before.children[1]
×
376
        logical_apply_merge = LogicalApplyAndMerge(
×
377
            logical_func_scan.func_expr,
378
            logical_func_scan.alias,
379
            logical_func_scan.do_unnest,
380
        )
381
        logical_apply_merge.append_child(A)
×
382
        yield logical_apply_merge
×
383

384

385
class PushDownFilterThroughApplyAndMerge(Rule):
1✔
386
    """If it is feasible to partially or fully push the predicate contained within the
387
    logical filter through the ApplyAndMerge operator, we should do so. This is often
388
    beneficial, for instance, in order to prevent decoding additional frames beyond
389
    those that satisfy the predicate.
390
    Eg:
391

392
    Filter(id < 10 and func.label = 'car')           Filter(func.label = 'car')
393
            |                                                   |
394
        ApplyAndMerge(func)                  ->          ApplyAndMerge(func)
395
            |                                                   |
396
            A                                            Filter(id < 10)
397
                                                                |
398
                                                                A
399

400
    """
401

402
    def __init__(self):
1✔
403
        appply_merge_pattern = Pattern(OperatorType.LOGICAL_APPLY_AND_MERGE)
1✔
404
        appply_merge_pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
405
        pattern = Pattern(OperatorType.LOGICALFILTER)
1✔
406
        pattern.append_child(appply_merge_pattern)
1✔
407
        super().__init__(RuleType.PUSHDOWN_FILTER_THROUGH_APPLY_AND_MERGE, pattern)
1✔
408

409
    def promise(self):
1✔
410
        return Promise.PUSHDOWN_FILTER_THROUGH_APPLY_AND_MERGE
1✔
411

412
    def check(self, before: LogicalFilter, context: OptimizerContext):
1✔
413
        return True
×
414

415
    def apply(self, before: LogicalFilter, context: OptimizerContext):
1✔
416
        A: Dummy = before.children[0].children[0]
×
417
        apply_and_merge: LogicalApplyAndMerge = before.children[0]
×
418
        aliases = context.memo.get_group_by_id(A.group_id).aliases
×
419
        predicate = before.predicate
×
420
        pushdown_pred, rem_pred = extract_pushdown_predicate_for_alias(
×
421
            predicate, aliases
422
        )
423

424
        # we do not return a new plan if nothing can be pushed
425
        # this ensures we do not keep applying this optimization
426
        if pushdown_pred is None:
×
427
            return
×
428

429
        # if we find a feasible pushdown predicate, add a new filter node between
430
        # ApplyAndMerge and Dummy
431
        if pushdown_pred:
×
432
            pushdown_filter = LogicalFilter(predicate=pushdown_pred)
×
433
            pushdown_filter.append_child(A)
×
434
            apply_and_merge.children = [pushdown_filter]
×
435

436
        # If we have partial predicate make it the root
437
        root_node = apply_and_merge
×
438
        if rem_pred:
×
439
            root_node = LogicalFilter(predicate=rem_pred)
×
440
            root_node.append_child(apply_and_merge)
×
441

442
        yield root_node
×
443

444

445
class XformExtractObjectToLinearFlow(Rule):
1✔
446
    """If the inner node of a lateral join is a Extract_Object function-valued
447
    expression, we eliminate the join node and make the inner node the parent of the
448
    outer node. This produces a linear data flow path.
449
    TODO: We need to add a sorting operation after detector to ensure we always provide tracker data in order.
450
    """
451

452
    #                                          LogicalApplyAndMerge(tracker)
453
    #     LogicalJoin(Lateral)                         |
454
    #     /           \                 ->    LogicalApplyAndMerge(detector)
455
    #    A        LogicalExtractObject                 |
456
    #                                                  A
457

458
    def __init__(self):
1✔
459
        pattern = Pattern(OperatorType.LOGICALJOIN)
1✔
460
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
461
        pattern.append_child(Pattern(OperatorType.LOGICAL_EXTRACT_OBJECT))
1✔
462
        super().__init__(RuleType.XFORM_EXTRACT_OBJECT_TO_LINEAR_FLOW, pattern)
1✔
463

464
    def promise(self):
1✔
465
        return Promise.XFORM_EXTRACT_OBJECT_TO_LINEAR_FLOW
×
466

467
    def check(self, before: LogicalJoin, context: OptimizerContext):
1✔
468
        if before.join_type == JoinType.LATERAL_JOIN:
×
469
            return True
×
470
        return False
×
471

472
    def apply(self, before: LogicalJoin, context: OptimizerContext):
1✔
473
        A: Dummy = before.children[0]
×
474
        logical_extract_obj: LogicalExtractObject = before.children[1]
×
475

476
        detector = LogicalApplyAndMerge(
×
477
            logical_extract_obj.detector, alias=logical_extract_obj.detector.alias
478
        )
479
        tracker = LogicalApplyAndMerge(
×
480
            logical_extract_obj.tracker,
481
            alias=logical_extract_obj.alias,
482
            do_unnest=logical_extract_obj.do_unnest,
483
        )
484
        detector.append_child(A)
×
485
        tracker.append_child(detector)
×
486
        yield tracker
×
487

488

489
class CombineSimilarityOrderByAndLimitToVectorIndexScan(Rule):
1✔
490
    """
491
    This rule currently rewrites Order By + Limit to a vector index scan.
492
    Because vector index only works for similarity search, the rule will
493
    only be applied when the Order By is on Similarity expression. For
494
    simplicity, we also only enable this rule when the Similarity expression
495
    applies to the full table. Predicated query will yield incorrect results
496
    if we use an index scan.
497

498
    Limit(10)
499
        |
500
    OrderBy(func)        ->        IndexScan(10)
501
        |                               |
502
        A                               A
503
    """
504

505
    def __init__(self):
1✔
506
        pattern = Pattern(OperatorType.LOGICALLIMIT)
1✔
507
        orderby_pattern = Pattern(OperatorType.LOGICALORDERBY)
1✔
508
        orderby_pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
509
        pattern.append_child(orderby_pattern)
1✔
510
        super().__init__(
1✔
511
            RuleType.COMBINE_SIMILARITY_ORDERBY_AND_LIMIT_TO_VECTOR_INDEX_SCAN, pattern
512
        )
513

514
        # Entries populate after rule eligibility validation.
515
        self._index_catalog_entry = None
1✔
516
        self._query_func_expr = None
1✔
517

518
    def promise(self):
1✔
519
        return Promise.COMBINE_SIMILARITY_ORDERBY_AND_LIMIT_TO_VECTOR_INDEX_SCAN
×
520

521
    def check(self, before: LogicalLimit, context: OptimizerContext):
1✔
522
        return True
×
523

524
    def apply(self, before: LogicalLimit, context: OptimizerContext):
1✔
525
        catalog_manager = context.db.catalog
×
526

527
        # Get corresponding nodes.
528
        limit_node = before
×
529
        orderby_node = before.children[0]
×
530
        sub_tree_root = orderby_node.children[0]
×
531

532
        # Check if predicate exists on table.
533
        def _exists_predicate(opr):
×
534
            if isinstance(opr, LogicalGet):
×
535
                return opr.predicate is not None
×
536
            # LogicalFilter
537
            return True
×
538

539
        if _exists_predicate(sub_tree_root.opr):
×
540
            return
×
541

542
        # Check if orderby runs on similarity expression.
543
        # Current optimization will only accept Similarity expression.
544
        func_orderby_expr = None
×
545
        for column, sort_type in orderby_node.orderby_list:
×
546
            if (
×
547
                isinstance(column, FunctionExpression)
548
                and sort_type == ParserOrderBySortType.ASC
549
            ):
550
                func_orderby_expr = column
×
551
        if not func_orderby_expr or func_orderby_expr.name != "Similarity":
×
552
            return
×
553

554
        # Check if there exists an index on table and column.
555
        query_func_expr, base_func_expr = func_orderby_expr.children
×
556

557
        # Get table and column of orderby.
558
        tv_expr = base_func_expr
×
559
        while not isinstance(tv_expr, TupleValueExpression):
×
560
            tv_expr = tv_expr.children[0]
×
561

562
        # Get column catalog entry and function_signature.
563
        column_catalog_entry = tv_expr.col_object
×
564
        function_signature = (
×
565
            None
566
            if isinstance(base_func_expr, TupleValueExpression)
567
            else base_func_expr.signature()
568
        )
569

570
        # Get index catalog. Check if an index exists for matching
571
        # function signature and table columns.
572
        index_catalog_entry = (
×
573
            catalog_manager().get_index_catalog_entry_by_column_and_function_signature(
574
                column_catalog_entry, function_signature
575
            )
576
        )
577
        if not index_catalog_entry:
×
578
            return
×
579

580
        # Construct the Vector index scan plan.
581
        vector_index_scan_node = LogicalVectorIndexScan(
×
582
            index_catalog_entry.name,
583
            index_catalog_entry.type,
584
            limit_node.limit_count,
585
            query_func_expr,
586
        )
587
        for child in orderby_node.children:
×
588
            vector_index_scan_node.append_child(child)
×
589
        yield vector_index_scan_node
×
590

591

592
# REWRITE RULES END
593
##############################################
594

595
##############################################
596
# LOGICAL RULES START
597

598

599
class LogicalInnerJoinCommutativity(Rule):
1✔
600
    def __init__(self):
1✔
601
        pattern = Pattern(OperatorType.LOGICALJOIN)
1✔
602
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
603
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
604
        super().__init__(RuleType.LOGICAL_INNER_JOIN_COMMUTATIVITY, pattern)
1✔
605

606
    def promise(self):
1✔
607
        return Promise.LOGICAL_INNER_JOIN_COMMUTATIVITY
×
608

609
    def check(self, before: LogicalJoin, context: OptimizerContext):
1✔
610
        # has to be an inner join
611
        return before.join_type == JoinType.INNER_JOIN
×
612

613
    def apply(self, before: LogicalJoin, context: OptimizerContext):
1✔
614
        #     LogicalJoin(Inner)            LogicalJoin(Inner)
615
        #     /           \        ->       /               \
616
        #    A             B               B                A
617

618
        new_join = LogicalJoin(before.join_type, before.join_predicate)
×
619
        new_join.append_child(before.rhs())
×
620
        new_join.append_child(before.lhs())
×
621
        yield new_join
×
622

623

624
class ReorderPredicates(Rule):
1✔
625
    """
626
    The current implementation orders conjuncts based on their individual cost.
627
    The optimization for OR clauses has `not` been implemented yet. Additionally, we do
628
    not optimize predicates that are not user-defined functions since we assume that
629
    they will likely be pushed to the underlying relational database, which will handle
630
    the optimization process.
631
    """
632

633
    def __init__(self):
1✔
634
        pattern = Pattern(OperatorType.LOGICALFILTER)
1✔
635
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
636
        super().__init__(RuleType.REORDER_PREDICATES, pattern)
1✔
637

638
    def promise(self):
1✔
639
        return Promise.REORDER_PREDICATES
1✔
640

641
    def check(self, before: LogicalFilter, context: OptimizerContext):
1✔
642
        # there exists at least one function Expression
643
        return len(list(before.predicate.find_all(FunctionExpression))) > 0
1✔
644

645
    def apply(self, before: LogicalFilter, context: OptimizerContext):
1✔
646
        # Decompose the expression tree into a list of conjuncts
647
        conjuncts = to_conjunction_list(before.predicate)
1✔
648

649
        # Segregate the conjuncts into simple and function expressions
650
        contains_func_exprs = []
1✔
651
        simple_exprs = []
1✔
652
        for conjunct in conjuncts:
1✔
653
            if list(conjunct.find_all(FunctionExpression)):
1✔
654
                contains_func_exprs.append(conjunct)
1✔
655
            else:
656
                simple_exprs.append(conjunct)
×
657

658
        # Compute the cost of every function expression and sort them in
659
        # ascending order of cost
660
        function_expr_cost_tuples = [
1✔
661
            (expr, get_expression_execution_cost(context, expr))
662
            for expr in contains_func_exprs
663
        ]
664
        function_expr_cost_tuples = sorted(
1✔
665
            function_expr_cost_tuples, key=lambda x: x[1]
666
        )
667

668
        # Build the final ordered list of conjuncts
669
        ordered_conjuncts = simple_exprs + [
1✔
670
            expr for (expr, _) in function_expr_cost_tuples
671
        ]
672

673
        # we do not return a new plan if nothing has changed
674
        # this ensures we do not keep applying this optimization
675
        if ordered_conjuncts != conjuncts:
1✔
676
            # Build expression tree based on the ordered conjuncts
677
            reordered_predicate = conjunction_list_to_expression_tree(ordered_conjuncts)
×
678
            reordered_filter_node = LogicalFilter(predicate=reordered_predicate)
×
679
            reordered_filter_node.append_child(before.children[0])
×
680
            yield reordered_filter_node
×
681

682

683
# LOGICAL RULES END
684
##############################################
685

686

687
##############################################
688
# IMPLEMENTATION RULES START
689

690

691
class LogicalCreateToPhysical(Rule):
1✔
692
    def __init__(self):
1✔
693
        pattern = Pattern(OperatorType.LOGICALCREATE)
1✔
694
        super().__init__(RuleType.LOGICAL_CREATE_TO_PHYSICAL, pattern)
1✔
695

696
    def promise(self):
1✔
697
        return Promise.LOGICAL_CREATE_TO_PHYSICAL
1✔
698

699
    def check(self, before: Operator, context: OptimizerContext):
1✔
700
        return True
1✔
701

702
    def apply(self, before: LogicalCreate, context: OptimizerContext):
1✔
703
        after = CreatePlan(before.video, before.column_list, before.if_not_exists)
1✔
704
        yield after
1✔
705

706

707
class LogicalCreateFromSelectToPhysical(Rule):
1✔
708
    def __init__(self):
1✔
709
        pattern = Pattern(OperatorType.LOGICALCREATE)
1✔
710
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
711
        super().__init__(RuleType.LOGICAL_CREATE_FROM_SELECT_TO_PHYSICAL, pattern)
1✔
712

713
    def promise(self):
1✔
714
        return Promise.LOGICAL_CREATE_FROM_SELECT_TO_PHYSICAL
1✔
715

716
    def check(self, before: Operator, context: OptimizerContext):
1✔
717
        return True
×
718

719
    def apply(self, before: LogicalCreate, context: OptimizerContext):
1✔
720
        after = CreateFromSelectPlan(
×
721
            before.video, before.column_list, before.if_not_exists
722
        )
723
        for child in before.children:
×
724
            after.append_child(child)
×
725
        yield after
×
726

727

728
class LogicalRenameToPhysical(Rule):
1✔
729
    def __init__(self):
1✔
730
        pattern = Pattern(OperatorType.LOGICALRENAME)
1✔
731
        super().__init__(RuleType.LOGICAL_RENAME_TO_PHYSICAL, pattern)
1✔
732

733
    def promise(self):
1✔
734
        return Promise.LOGICAL_RENAME_TO_PHYSICAL
×
735

736
    def check(self, before: Operator, context: OptimizerContext):
1✔
737
        return True
×
738

739
    def apply(self, before: LogicalRename, context: OptimizerContext):
1✔
740
        after = RenamePlan(before.old_table_ref, before.new_name)
×
741
        yield after
×
742

743

744
class LogicalCreateFunctionToPhysical(Rule):
1✔
745
    def __init__(self):
1✔
746
        pattern = Pattern(OperatorType.LOGICALCREATEFUNCTION)
1✔
747
        super().__init__(RuleType.LOGICAL_CREATE_FUNCTION_TO_PHYSICAL, pattern)
1✔
748

749
    def promise(self):
1✔
750
        return Promise.LOGICAL_CREATE_FUNCTION_TO_PHYSICAL
1✔
751

752
    def check(self, before: Operator, context: OptimizerContext):
1✔
753
        return True
1✔
754

755
    def apply(self, before: LogicalCreateFunction, context: OptimizerContext):
1✔
756
        after = CreateFunctionPlan(
1✔
757
            before.name,
758
            before.if_not_exists,
759
            before.inputs,
760
            before.outputs,
761
            before.impl_path,
762
            before.function_type,
763
            before.metadata,
764
        )
765
        yield after
1✔
766

767

768
class LogicalCreateFunctionFromSelectToPhysical(Rule):
1✔
769
    def __init__(self):
1✔
770
        pattern = Pattern(OperatorType.LOGICALCREATEFUNCTION)
1✔
771
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
772
        super().__init__(
1✔
773
            RuleType.LOGICAL_CREATE_FUNCTION_FROM_SELECT_TO_PHYSICAL, pattern
774
        )
775

776
    def promise(self):
1✔
777
        return Promise.LOGICAL_CREATE_FUNCTION_FROM_SELECT_TO_PHYSICAL
1✔
778

779
    def check(self, before: Operator, context: OptimizerContext):
1✔
780
        return True
×
781

782
    def apply(self, before: LogicalCreateFunction, context: OptimizerContext):
1✔
783
        after = CreateFunctionPlan(
×
784
            before.name,
785
            before.if_not_exists,
786
            before.inputs,
787
            before.outputs,
788
            before.impl_path,
789
            before.function_type,
790
            before.metadata,
791
        )
792
        for child in before.children:
×
793
            after.append_child(child)
×
794
        yield after
×
795

796

797
class LogicalCreateIndexToVectorIndex(Rule):
1✔
798
    def __init__(self):
1✔
799
        pattern = Pattern(OperatorType.LOGICALCREATEINDEX)
1✔
800
        super().__init__(RuleType.LOGICAL_CREATE_INDEX_TO_VECTOR_INDEX, pattern)
1✔
801

802
    def promise(self):
1✔
803
        return Promise.LOGICAL_CREATE_INDEX_TO_VECTOR_INDEX
×
804

805
    def check(self, before: Operator, context: OptimizerContext):
1✔
806
        return True
×
807

808
    def apply(self, before: LogicalCreateIndex, context: OptimizerContext):
1✔
809
        after = CreateIndexPlan(
×
810
            before.name,
811
            before.table_ref,
812
            before.col_list,
813
            before.vector_store_type,
814
            before.function,
815
        )
816
        yield after
×
817

818

819
class LogicalDropObjectToPhysical(Rule):
1✔
820
    def __init__(self):
1✔
821
        pattern = Pattern(OperatorType.LOGICAL_DROP_OBJECT)
1✔
822
        super().__init__(RuleType.LOGICAL_DROP_OBJECT_TO_PHYSICAL, pattern)
1✔
823

824
    def promise(self):
1✔
825
        return Promise.LOGICAL_DROP_OBJECT_TO_PHYSICAL
1✔
826

827
    def check(self, before: Operator, context: OptimizerContext):
1✔
828
        return True
1✔
829

830
    def apply(self, before: LogicalDropObject, context: OptimizerContext):
1✔
831
        after = DropObjectPlan(before.object_type, before.name, before.if_exists)
1✔
832
        yield after
1✔
833

834

835
class LogicalInsertToPhysical(Rule):
1✔
836
    def __init__(self):
1✔
837
        pattern = Pattern(OperatorType.LOGICALINSERT)
1✔
838
        super().__init__(RuleType.LOGICAL_INSERT_TO_PHYSICAL, pattern)
1✔
839

840
    def promise(self):
1✔
841
        return Promise.LOGICAL_INSERT_TO_PHYSICAL
×
842

843
    def check(self, before: Operator, context: OptimizerContext):
1✔
844
        return True
×
845

846
    def apply(self, before: LogicalInsert, context: OptimizerContext):
1✔
847
        after = InsertPlan(before.table, before.column_list, before.value_list)
×
848
        yield after
×
849

850

851
class LogicalDeleteToPhysical(Rule):
1✔
852
    def __init__(self):
1✔
853
        pattern = Pattern(OperatorType.LOGICALDELETE)
1✔
854
        super().__init__(RuleType.LOGICAL_DELETE_TO_PHYSICAL, pattern)
1✔
855

856
    def promise(self):
1✔
857
        return Promise.LOGICAL_DELETE_TO_PHYSICAL
×
858

859
    def check(self, before: Operator, context: OptimizerContext):
1✔
860
        return True
×
861

862
    def apply(self, before: LogicalDelete, context: OptimizerContext):
1✔
863
        after = DeletePlan(before.table_ref, before.where_clause)
×
864
        yield after
×
865

866

867
class LogicalLoadToPhysical(Rule):
1✔
868
    def __init__(self):
1✔
869
        pattern = Pattern(OperatorType.LOGICALLOADDATA)
1✔
870
        super().__init__(RuleType.LOGICAL_LOAD_TO_PHYSICAL, pattern)
1✔
871

872
    def promise(self):
1✔
873
        return Promise.LOGICAL_LOAD_TO_PHYSICAL
1✔
874

875
    def check(self, before: Operator, context: OptimizerContext):
1✔
876
        return True
1✔
877

878
    def apply(self, before: LogicalLoadData, context: OptimizerContext):
1✔
879
        after = LoadDataPlan(
1✔
880
            before.table_info,
881
            before.path,
882
            before.column_list,
883
            before.file_options,
884
        )
885
        yield after
1✔
886

887

888
class LogicalGetToSeqScan(Rule):
1✔
889
    def __init__(self):
1✔
890
        pattern = Pattern(OperatorType.LOGICALGET)
1✔
891
        super().__init__(RuleType.LOGICAL_GET_TO_SEQSCAN, pattern)
1✔
892

893
    def promise(self):
1✔
894
        return Promise.LOGICAL_GET_TO_SEQSCAN
1✔
895

896
    def check(self, before: Operator, context: OptimizerContext):
1✔
897
        return True
1✔
898

899
    def apply(self, before: LogicalGet, context: OptimizerContext):
1✔
900
        # Configure the batch_mem_size. It decides the number of rows
901
        # read in a batch from storage engine.
902
        # Todo: Experiment heuristics.
903
        after = SeqScanPlan(None, before.target_list, before.alias)
1✔
904
        batch_mem_size = context.db.config.get_value("executor", "batch_mem_size")
1✔
905
        after.append_child(
1✔
906
            StoragePlan(
907
                before.table_obj,
908
                before.video,
909
                predicate=before.predicate,
910
                sampling_rate=before.sampling_rate,
911
                sampling_type=before.sampling_type,
912
                chunk_params=before.chunk_params,
913
                batch_mem_size=batch_mem_size,
914
            )
915
        )
916
        yield after
1✔
917

918

919
class LogicalDerivedGetToPhysical(Rule):
1✔
920
    def __init__(self):
1✔
921
        pattern = Pattern(OperatorType.LOGICALQUERYDERIVEDGET)
1✔
922
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
923
        super().__init__(RuleType.LOGICAL_DERIVED_GET_TO_PHYSICAL, pattern)
1✔
924

925
    def promise(self):
1✔
926
        return Promise.LOGICAL_DERIVED_GET_TO_PHYSICAL
1✔
927

928
    def check(self, before: Operator, context: OptimizerContext):
1✔
929
        return True
1✔
930

931
    def apply(self, before: LogicalQueryDerivedGet, context: OptimizerContext):
1✔
932
        after = SeqScanPlan(before.predicate, before.target_list, before.alias)
1✔
933
        after.append_child(before.children[0])
1✔
934
        yield after
1✔
935

936

937
class LogicalUnionToPhysical(Rule):
1✔
938
    def __init__(self):
1✔
939
        pattern = Pattern(OperatorType.LOGICALUNION)
1✔
940
        # add 2 dummy children
941
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
942
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
943
        super().__init__(RuleType.LOGICAL_UNION_TO_PHYSICAL, pattern)
1✔
944

945
    def promise(self):
1✔
946
        return Promise.LOGICAL_UNION_TO_PHYSICAL
×
947

948
    def check(self, before: Operator, context: OptimizerContext):
1✔
949
        return True
×
950

951
    def apply(self, before: LogicalUnion, context: OptimizerContext):
1✔
952
        after = UnionPlan(before.all)
×
953
        for child in before.children:
×
954
            after.append_child(child)
×
955
        yield after
×
956

957

958
class LogicalGroupByToPhysical(Rule):
1✔
959
    def __init__(self):
1✔
960
        pattern = Pattern(OperatorType.LOGICALGROUPBY)
1✔
961
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
962
        super().__init__(RuleType.LOGICAL_GROUPBY_TO_PHYSICAL, pattern)
1✔
963

964
    def promise(self):
1✔
965
        return Promise.LOGICAL_GROUPBY_TO_PHYSICAL
×
966

967
    def check(self, before: Operator, context: OptimizerContext):
1✔
968
        return True
×
969

970
    def apply(self, before: LogicalGroupBy, context: OptimizerContext):
1✔
971
        after = GroupByPlan(before.groupby_clause)
×
972
        for child in before.children:
×
973
            after.append_child(child)
×
974
        yield after
×
975

976

977
class LogicalOrderByToPhysical(Rule):
1✔
978
    def __init__(self):
1✔
979
        pattern = Pattern(OperatorType.LOGICALORDERBY)
1✔
980
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
981
        super().__init__(RuleType.LOGICAL_ORDERBY_TO_PHYSICAL, pattern)
1✔
982

983
    def promise(self):
1✔
984
        return Promise.LOGICAL_ORDERBY_TO_PHYSICAL
×
985

986
    def check(self, before: Operator, context: OptimizerContext):
1✔
987
        return True
×
988

989
    def apply(self, before: LogicalOrderBy, context: OptimizerContext):
1✔
990
        after = OrderByPlan(before.orderby_list)
×
991
        for child in before.children:
×
992
            after.append_child(child)
×
993
        yield after
×
994

995

996
class LogicalLimitToPhysical(Rule):
1✔
997
    def __init__(self):
1✔
998
        pattern = Pattern(OperatorType.LOGICALLIMIT)
1✔
999
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1000
        super().__init__(RuleType.LOGICAL_LIMIT_TO_PHYSICAL, pattern)
1✔
1001

1002
    def promise(self):
1✔
1003
        return Promise.LOGICAL_LIMIT_TO_PHYSICAL
×
1004

1005
    def check(self, before: Operator, context: OptimizerContext):
1✔
1006
        return True
×
1007

1008
    def apply(self, before: LogicalLimit, context: OptimizerContext):
1✔
1009
        after = LimitPlan(before.limit_count)
×
1010
        for child in before.children:
×
1011
            after.append_child(child)
×
1012
        yield after
×
1013

1014

1015
class LogicalFunctionScanToPhysical(Rule):
1✔
1016
    def __init__(self):
1✔
1017
        pattern = Pattern(OperatorType.LOGICALFUNCTIONSCAN)
1✔
1018
        super().__init__(RuleType.LOGICAL_FUNCTION_SCAN_TO_PHYSICAL, pattern)
1✔
1019

1020
    def promise(self):
1✔
1021
        return Promise.LOGICAL_FUNCTION_SCAN_TO_PHYSICAL
×
1022

1023
    def check(self, before: Operator, context: OptimizerContext):
1✔
1024
        return True
×
1025

1026
    def apply(self, before: LogicalFunctionScan, context: OptimizerContext):
1✔
1027
        after = FunctionScanPlan(before.func_expr, before.do_unnest)
×
1028
        yield after
×
1029

1030

1031
class LogicalLateralJoinToPhysical(Rule):
1✔
1032
    def __init__(self):
1✔
1033
        pattern = Pattern(OperatorType.LOGICALJOIN)
1✔
1034
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1035
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1036
        super().__init__(RuleType.LOGICAL_LATERAL_JOIN_TO_PHYSICAL, pattern)
1✔
1037

1038
    def promise(self):
1✔
1039
        return Promise.LOGICAL_LATERAL_JOIN_TO_PHYSICAL
×
1040

1041
    def check(self, before: Operator, context: OptimizerContext):
1✔
1042
        return before.join_type == JoinType.LATERAL_JOIN
×
1043

1044
    def apply(self, join_node: LogicalJoin, context: OptimizerContext):
1✔
1045
        lateral_join_plan = LateralJoinPlan(join_node.join_predicate)
×
1046
        lateral_join_plan.join_project = join_node.join_project
×
1047
        lateral_join_plan.append_child(join_node.lhs())
×
1048
        lateral_join_plan.append_child(join_node.rhs())
×
1049
        yield lateral_join_plan
×
1050

1051

1052
class LogicalJoinToPhysicalHashJoin(Rule):
1✔
1053
    def __init__(self):
1✔
1054
        pattern = Pattern(OperatorType.LOGICALJOIN)
1✔
1055
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1056
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1057
        super().__init__(RuleType.LOGICAL_JOIN_TO_PHYSICAL_HASH_JOIN, pattern)
1✔
1058

1059
    def promise(self):
1✔
1060
        return Promise.LOGICAL_JOIN_TO_PHYSICAL_HASH_JOIN
×
1061

1062
    def check(self, before: Operator, context: OptimizerContext):
1✔
1063
        """
1064
        We don't want to apply this rule to the join when FuzzDistance
1065
        is being used, which implies that the join is a FuzzyJoin
1066
        """
1067
        if before.join_predicate is None:
×
1068
            return False
×
1069
        j_child: FunctionExpression = before.join_predicate.children[0]
×
1070

1071
        if isinstance(j_child, FunctionExpression):
×
1072
            if j_child.name.startswith("FuzzDistance"):
×
1073
                return before.join_type == JoinType.INNER_JOIN and (
×
1074
                    not (j_child) or not (j_child.name.startswith("FuzzDistance"))
1075
                )
1076
        else:
1077
            return before.join_type == JoinType.INNER_JOIN
×
1078

1079
    def apply(self, join_node: LogicalJoin, context: OptimizerContext):
1✔
1080
        #          HashJoinPlan                       HashJoinProbePlan
1081
        #          /           \     ->                  /               \
1082
        #         A             B        HashJoinBuildPlan               B
1083
        #                                              /
1084
        #                                            A
1085

1086
        a: Dummy = join_node.lhs()
×
1087
        b: Dummy = join_node.rhs()
×
1088
        a_table_aliases = context.memo.get_group_by_id(a.group_id).aliases
×
1089
        b_table_aliases = context.memo.get_group_by_id(b.group_id).aliases
×
1090
        join_predicates = join_node.join_predicate
×
1091
        a_join_keys, b_join_keys = extract_equi_join_keys(
×
1092
            join_predicates, a_table_aliases, b_table_aliases
1093
        )
1094

1095
        build_plan = HashJoinBuildPlan(join_node.join_type, a_join_keys)
×
1096
        build_plan.append_child(a)
×
1097
        probe_side = HashJoinProbePlan(
×
1098
            join_node.join_type,
1099
            b_join_keys,
1100
            join_predicates,
1101
            join_node.join_project,
1102
        )
1103
        probe_side.append_child(build_plan)
×
1104
        probe_side.append_child(b)
×
1105
        yield probe_side
×
1106

1107

1108
class LogicalJoinToPhysicalNestedLoopJoin(Rule):
1✔
1109
    def __init__(self):
1✔
1110
        pattern = Pattern(OperatorType.LOGICALJOIN)
1✔
1111
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1112
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1113
        super().__init__(RuleType.LOGICAL_JOIN_TO_PHYSICAL_NESTED_LOOP_JOIN, pattern)
1✔
1114

1115
    def promise(self):
1✔
1116
        return Promise.LOGICAL_JOIN_TO_PHYSICAL_NESTED_LOOP_JOIN
×
1117

1118
    def check(self, before: LogicalJoin, context: OptimizerContext):
1✔
1119
        """
1120
        We want to apply this rule to the join when FuzzDistance
1121
        is being used, which implies that the join is a FuzzyJoin
1122
        """
1123
        if before.join_predicate is None:
×
1124
            return False
×
1125
        j_child: FunctionExpression = before.join_predicate.children[0]
×
1126
        if not isinstance(j_child, FunctionExpression):
×
1127
            return False
×
1128
        return before.join_type == JoinType.INNER_JOIN and j_child.name.startswith(
×
1129
            "FuzzDistance"
1130
        )
1131

1132
    def apply(self, join_node: LogicalJoin, context: OptimizerContext):
1✔
1133
        nested_loop_join_plan = NestedLoopJoinPlan(
×
1134
            join_node.join_type, join_node.join_predicate
1135
        )
1136
        nested_loop_join_plan.append_child(join_node.lhs())
×
1137
        nested_loop_join_plan.append_child(join_node.rhs())
×
1138
        yield nested_loop_join_plan
×
1139

1140

1141
class LogicalFilterToPhysical(Rule):
1✔
1142
    def __init__(self):
1✔
1143
        pattern = Pattern(OperatorType.LOGICALFILTER)
1✔
1144
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1145
        super().__init__(RuleType.LOGICAL_FILTER_TO_PHYSICAL, pattern)
1✔
1146

1147
    def promise(self):
1✔
1148
        return Promise.LOGICAL_FILTER_TO_PHYSICAL
1✔
1149

1150
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1151
        return True
1✔
1152

1153
    def apply(self, before: LogicalFilter, context: OptimizerContext):
1✔
1154
        after = PredicatePlan(before.predicate)
1✔
1155
        for child in before.children:
1✔
1156
            after.append_child(child)
1✔
1157
        yield after
1✔
1158

1159

1160
class LogicalProjectToPhysical(Rule):
1✔
1161
    def __init__(self):
1✔
1162
        pattern = Pattern(OperatorType.LOGICALPROJECT)
1✔
1163
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1164
        super().__init__(RuleType.LOGICAL_PROJECT_TO_PHYSICAL, pattern)
1✔
1165

1166
    def promise(self):
1✔
1167
        return Promise.LOGICAL_PROJECT_TO_PHYSICAL
1✔
1168

1169
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1170
        return True
1✔
1171

1172
    def apply(self, before: LogicalProject, context: OptimizerContext):
1✔
1173
        after = ProjectPlan(before.target_list)
1✔
1174
        for child in before.children:
1✔
1175
            after.append_child(child)
1✔
1176
        yield after
1✔
1177

1178

1179
class LogicalShowToPhysical(Rule):
1✔
1180
    def __init__(self):
1✔
1181
        pattern = Pattern(OperatorType.LOGICAL_SHOW)
1✔
1182
        super().__init__(RuleType.LOGICAL_SHOW_TO_PHYSICAL, pattern)
1✔
1183

1184
    def promise(self):
1✔
1185
        return Promise.LOGICAL_SHOW_TO_PHYSICAL
1✔
1186

1187
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1188
        return True
1✔
1189

1190
    def apply(self, before: LogicalShow, context: OptimizerContext):
1✔
1191
        after = ShowInfoPlan(before.show_type)
1✔
1192
        yield after
1✔
1193

1194

1195
class LogicalExplainToPhysical(Rule):
1✔
1196
    def __init__(self):
1✔
1197
        pattern = Pattern(OperatorType.LOGICALEXPLAIN)
1✔
1198
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1199
        super().__init__(RuleType.LOGICAL_EXPLAIN_TO_PHYSICAL, pattern)
1✔
1200

1201
    def promise(self):
1✔
1202
        return Promise.LOGICAL_EXPLAIN_TO_PHYSICAL
×
1203

1204
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1205
        return True
×
1206

1207
    def apply(self, before: LogicalExplain, context: OptimizerContext):
1✔
1208
        after = ExplainPlan(before.explainable_opr)
×
1209
        for child in before.children:
×
1210
            after.append_child(child)
×
1211
        yield after
×
1212

1213

1214
class LogicalApplyAndMergeToPhysical(Rule):
1✔
1215
    def __init__(self):
1✔
1216
        pattern = Pattern(OperatorType.LOGICAL_APPLY_AND_MERGE)
1✔
1217
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1218
        super().__init__(RuleType.LOGICAL_APPLY_AND_MERGE_TO_PHYSICAL, pattern)
1✔
1219

1220
    def promise(self):
1✔
1221
        return Promise.LOGICAL_APPLY_AND_MERGE_TO_PHYSICAL
×
1222

1223
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1224
        return True
×
1225

1226
    def apply(self, before: LogicalApplyAndMerge, context: OptimizerContext):
1✔
1227
        after = ApplyAndMergePlan(before.func_expr, before.alias, before.do_unnest)
×
1228
        for child in before.children:
×
1229
            after.append_child(child)
×
1230
        yield after
×
1231

1232

1233
class LogicalVectorIndexScanToPhysical(Rule):
1✔
1234
    def __init__(self):
1✔
1235
        pattern = Pattern(OperatorType.LOGICAL_VECTOR_INDEX_SCAN)
1✔
1236
        pattern.append_child(Pattern(OperatorType.DUMMY))
1✔
1237
        super().__init__(RuleType.LOGICAL_VECTOR_INDEX_SCAN_TO_PHYSICAL, pattern)
1✔
1238

1239
    def promise(self):
1✔
1240
        return Promise.LOGICAL_VECTOR_INDEX_SCAN_TO_PHYSICAL
×
1241

1242
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1243
        return True
×
1244

1245
    def apply(self, before: LogicalVectorIndexScan, context: OptimizerContext):
1✔
1246
        after = VectorIndexScanPlan(
×
1247
            before.index_name,
1248
            before.vector_store_type,
1249
            before.limit_count,
1250
            before.search_query_expr,
1251
        )
1252
        for child in before.children:
×
1253
            after.append_child(child)
×
1254
        yield after
×
1255

1256

1257
"""
1✔
1258
Rules to optimize Ray.
1259
"""
1260

1261

1262
def get_ray_env_dict():
1✔
1263
    # Get the highest GPU id and expose all GPUs that have id lower than
1264
    # the max id.
1265
    if len(Context().gpus) > 0:
×
1266
        max_gpu_id = max(Context().gpus) + 1
×
1267
        return {"CUDA_VISIBLE_DEVICES": ",".join([str(n) for n in range(max_gpu_id)])}
×
1268
    else:
1269
        return {}
×
1270

1271

1272
class LogicalExchangeToPhysical(Rule):
1✔
1273
    def __init__(self):
1✔
1274
        pattern = Pattern(OperatorType.LOGICALEXCHANGE)
×
1275
        pattern.append_child(Pattern(OperatorType.DUMMY))
×
1276
        super().__init__(RuleType.LOGICAL_EXCHANGE_TO_PHYSICAL, pattern)
×
1277

1278
    def promise(self):
1✔
1279
        return Promise.LOGICAL_EXCHANGE_TO_PHYSICAL
×
1280

1281
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1282
        return True
×
1283

1284
    def apply(self, before: LogicalExchange, context: OptimizerContext):
1✔
1285
        after = ExchangePlan(before.view)
×
1286
        for child in before.children:
×
1287
            after.append_child(child)
×
1288
        yield after
×
1289

1290

1291
class LogicalApplyAndMergeToRayPhysical(Rule):
1✔
1292
    def __init__(self):
1✔
1293
        pattern = Pattern(OperatorType.LOGICAL_APPLY_AND_MERGE)
×
1294
        pattern.append_child(Pattern(OperatorType.DUMMY))
×
1295
        super().__init__(RuleType.LOGICAL_APPLY_AND_MERGE_TO_PHYSICAL, pattern)
×
1296

1297
    def promise(self):
1✔
1298
        return Promise.LOGICAL_APPLY_AND_MERGE_TO_PHYSICAL
×
1299

1300
    def check(self, grp_id: int, context: OptimizerContext):
1✔
1301
        return True
×
1302

1303
    def apply(self, before: LogicalApplyAndMerge, context: OptimizerContext):
1✔
1304
        apply_plan = ApplyAndMergePlan(before.func_expr, before.alias, before.do_unnest)
×
1305

1306
        parallelism = 2
×
1307

1308
        ray_process_env_dict = get_ray_env_dict()
×
1309
        ray_parallel_env_conf_dict = [ray_process_env_dict for _ in range(parallelism)]
×
1310

1311
        exchange_plan = ExchangePlan(
×
1312
            inner_plan=apply_plan,
1313
            parallelism=parallelism,
1314
            ray_pull_env_conf_dict=ray_process_env_dict,
1315
            ray_parallel_env_conf_dict=ray_parallel_env_conf_dict,
1316
        )
1317
        for child in before.children:
×
1318
            exchange_plan.append_child(child)
×
1319

1320
        yield exchange_plan
×
1321

1322

1323
class LogicalProjectToRayPhysical(Rule):
1✔
1324
    def __init__(self):
1✔
1325
        pattern = Pattern(OperatorType.LOGICALPROJECT)
×
1326
        pattern.append_child(Pattern(OperatorType.DUMMY))
×
1327
        super().__init__(RuleType.LOGICAL_PROJECT_TO_PHYSICAL, pattern)
×
1328

1329
    def promise(self):
1✔
1330
        return Promise.LOGICAL_PROJECT_TO_PHYSICAL
×
1331

1332
    def check(self, before: LogicalProject, context: OptimizerContext):
1✔
1333
        return True
×
1334

1335
    def apply(self, before: LogicalProject, context: OptimizerContext):
1✔
1336
        project_plan = ProjectPlan(before.target_list)
×
1337
        # Check whether the projection contains a Function
1338
        if before.target_list is None or not any(
×
1339
            [isinstance(expr, FunctionExpression) for expr in before.target_list]
1340
        ):
1341
            for child in before.children:
×
1342
                project_plan.append_child(child)
×
1343
            yield project_plan
×
1344
        else:
1345
            parallelism = 2
×
1346

1347
            ray_process_env_dict = get_ray_env_dict()
×
1348
            ray_parallel_env_conf_dict = [
×
1349
                ray_process_env_dict for _ in range(parallelism)
1350
            ]
1351

1352
            exchange_plan = ExchangePlan(
×
1353
                inner_plan=project_plan,
1354
                parallelism=parallelism,
1355
                ray_pull_env_conf_dict=ray_process_env_dict,
1356
                ray_parallel_env_conf_dict=ray_parallel_env_conf_dict,
1357
            )
1358
            for child in before.children:
×
1359
                exchange_plan.append_child(child)
×
1360
            yield exchange_plan
×
1361

1362

1363
# IMPLEMENTATION RULES END
1364
##############################################
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc