• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

amoffat / HeimdaLLM / 5566041373

pending completion
5566041373

push

github

web-flow
Merge pull request #10 from amoffat/feature/fix-unqualified-columns

Feature/fix unqualified columns

335 of 386 branches covered (86.79%)

Branch coverage included in aggregate %.

225 of 225 new or added lines in 25 files covered. (100.0%)

1634 of 1719 relevant lines covered (95.06%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.6
/heimdallm/bifrosts/sql/reconstruct.py
1
from typing import Generator, Iterable, cast
1✔
2

3
from lark import Discard, Token
1✔
4
from lark import Transformer as _Transformer
1✔
5
from lark import Tree
1✔
6

7
from . import exc
1✔
8
from .common import FqColumn
1✔
9
from .utils.identifier import get_identifier, is_count_function
1✔
10
from .validator import ConstraintValidator
1✔
11
from .visitors.aliases import AliasCollector
1✔
12

13

14
def _build_limit_tree(limit, offset=None):
1✔
15
    children = [
1✔
16
        Token("LIMIT", "LIMIT"),
17
        Tree(
18
            "limit",
19
            [Token("NUMBER", limit)],
20
        ),
21
    ]
22
    if offset:
1✔
23
        children.extend(
1✔
24
            [
25
                Token("OFFSET", "OFFSET"),
26
                Tree(
27
                    "offset",
28
                    [Token("NUMBER", offset)],
29
                ),
30
            ]
31
        )
32
    return Tree("limit_clause", children)
1✔
33

34

35
def add_limit(limit_placeholder, max_limit: int):
1✔
36
    """ensures that a limit exists on the limit placeholder, and that it is not
37
    greater than the max limit"""
38

39
    # existing limit? test and maybe replace it
40
    if limit_placeholder.children:
1✔
41
        current_limit = int(
1✔
42
            cast(Token, next(limit_placeholder.find_data("limit")).children[0]).value
43
        )
44

45
        try:
1✔
46
            offset_node = next(limit_placeholder.find_data("offset"))
1✔
47
        except StopIteration:
1✔
48
            current_offset = 0
1✔
49
        else:
50
            current_offset = int(cast(Token, offset_node.children[0]).value)
1✔
51

52
        if current_limit > max_limit:
1✔
53
            limit_tree = _build_limit_tree(max_limit, current_offset)
1✔
54
            limit_placeholder.children[0] = limit_tree
1✔
55

56
    # adding a limit? just append it
57
    else:
58
        limit_tree = _build_limit_tree(max_limit)
1✔
59
        limit_placeholder.children.append(limit_tree)
1✔
60

61

62
def qualify_column(fq_column: FqColumn) -> Tree:
1✔
63
    """Replaces an alias node with a fully qualified column node."""
64
    tree = Tree(
1✔
65
        "fq_column",
66
        [
67
            Tree(
68
                Token("RULE", "table_name"),
69
                [
70
                    Tree(
71
                        Token("RULE", "unquoted_identifier"),
72
                        [Token("IDENTIFIER", fq_column.table)],
73
                    )
74
                ],
75
            ),
76
            Tree(
77
                Token("RULE", "column_name"),
78
                [
79
                    Tree(
80
                        Token("RULE", "unquoted_identifier"),
81
                        [Token("IDENTIFIER", fq_column.column)],
82
                    )
83
                ],
84
            ),
85
        ],
86
    )
87
    return tree
1✔
88

89

90
class ReconstructTransformer(_Transformer):
1✔
91
    """makes some alterations to a query if it does not meet some basic validation
92
    constraints, but could with those alterations. currently, these are just the
93
    following:
94

95
        - adding or lowering a limit on the number of rows
96
        - removing illegal selected columns
97
    """
98

99
    def __init__(self, validator: ConstraintValidator, reserved_keywords: set[str]):
1✔
100
        self._validator = validator
1✔
101
        self._collector = AliasCollector(reserved_keywords=reserved_keywords)
1✔
102
        self._last_discarded_column: FqColumn | None = None
1✔
103
        self._reserved_keywords = reserved_keywords
1✔
104
        super().__init__()
1✔
105

106
    def transform(self, tree):
1✔
107
        self._collector.visit(tree)
1✔
108
        return super().transform(tree)
1✔
109

110
    def select_statement(self, children):
1✔
111
        """checks if a limit needs to be added or adjusted"""
112
        max_limit = self._validator.max_limit()
1✔
113

114
        if max_limit is not None:
1✔
115
            for child in children:
1!
116
                if not isinstance(child, Tree):
1✔
117
                    continue
1✔
118

119
                if child.data == "limit_placeholder":
1✔
120
                    add_limit(child, max_limit)
1✔
121
                    break
1✔
122

123
        return Tree("select_statement", children)
1✔
124

125
    def selected_columns(self, children):
1✔
126
        # if there's no children, it means we discarded every column selected, meaning
127
        # that they were all illegal columns. since we can't proceed without a column,
128
        # go ahead and raise an exception about illegal column.
129
        if not children:
1✔
130
            raise exc.IllegalSelectedColumn(column=self._last_discarded_column.name)
1✔
131
        return Tree("selected_columns", children)
1✔
132

133
    def column_alias(self, children: list[Tree | Token]):
1✔
134
        alias_name = get_identifier(children[0], self._reserved_keywords)
1✔
135

136
        # if we can't find the actual table where this column alias comes from, assume
137
        # the selected table.
138
        fq_columns = self._collector._aliased_columns[alias_name]
1✔
139

140
        # None means the alias is not based on any column (it's an expression of some
141
        # kind), so we leave this node alone
142
        if fq_columns is None:
1✔
143
            tree = Tree("column_alias", children)
1✔
144

145
        # if we haven't found any columns associated with this alias, it means that the
146
        # query is implicitly using the selected table, so we can fully qualify it based
147
        # on that information.
148
        elif len(fq_columns) == 0:
1✔
149
            tree = qualify_column(
1✔
150
                FqColumn(
151
                    table=cast(str, self._collector._selected_table),
152
                    column=alias_name,
153
                )
154
            )
155

156
        # if there's only one fq column associated with this alias, then we know it's
157
        # not a composite alias, so we can fully qualify it.
158
        elif len(fq_columns) == 1:
1✔
159
            tree = qualify_column(next(iter(fq_columns)))
1✔
160

161
        # if it's a composite alias, we can't fully qualify it, so we leave it alone.
162
        elif len(fq_columns) > 1:
1!
163
            tree = Tree("column_alias", children)
1✔
164

165
        else:
166
            assert False, "Unreachable"
×
167

168
        return tree
1✔
169

170
    def selected_column(self, children: list[Tree | Token]):
1✔
171
        """ensures that every selected column is allowed"""
172
        selected = children[0]
1✔
173
        if is_count_function(selected):
1✔
174
            pass
1✔
175

176
        elif isinstance(selected, Tree):
1✔
177
            for fq_column_node in selected.find_data("fq_column"):
1✔
178
                table_node, column_node = fq_column_node.children
1✔
179

180
                maybe_table_alias = get_identifier(table_node, self._reserved_keywords)
1✔
181
                column_name = get_identifier(column_node, self._reserved_keywords)
1✔
182

183
                table_name = self._collector._aliased_tables.get(
1✔
184
                    maybe_table_alias, maybe_table_alias
185
                )
186
                column = FqColumn(table=table_name, column=column_name)
1✔
187
                if not self._validator.select_column_allowed(column):
1✔
188
                    self._last_discarded_column = column
1✔
189
                    return Discard
1✔
190

191
        return Tree("selected_column", children)
1✔
192

193

194
PostProcToken = Token | str
1✔
195

196

197
def postproc(items: Iterable[PostProcToken]) -> Generator[PostProcToken, None, None]:
1✔
198
    for token in items:
1✔
199
        if token == "_WS":
1✔
200
            yield " "
1✔
201
            continue
1✔
202
        yield token
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc