• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 20555306382

28 Dec 2025 02:44PM UTC coverage: 96.088% (-1.3%) from 97.36%
20555306382

Pull #446

github

web-flow
Merge b7952664d into 90831a745
Pull Request #446: Support native python enum generation

89 of 112 new or added lines in 4 files covered. (79.46%)

32 existing lines in 3 files now uncovered.

1572 of 1636 relevant lines covered (96.09%)

4.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.44
/src/sqlacodegen/utils.py
1
from __future__ import annotations
5✔
2

3
import re
5✔
4
import sys
5✔
5
from collections.abc import Mapping
5✔
6
from typing import Any, Literal, cast
5✔
7

8
from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
5✔
9
from sqlalchemy.engine import Connection, Engine
5✔
10
from sqlalchemy.sql import ClauseElement
5✔
11
from sqlalchemy.sql.elements import TextClause
5✔
12
from sqlalchemy.sql.schema import (
5✔
13
    CheckConstraint,
14
    ColumnCollectionConstraint,
15
    Constraint,
16
    ForeignKeyConstraint,
17
    Index,
18
    Table,
19
)
20

21
_re_postgresql_nextval_sequence = re.compile(r"nextval\('(.+)'::regclass\)")
5✔
22
_re_postgresql_sequence_delimiter = re.compile(r'(.*?)([."]|$)')
5✔
23

24

25
def get_column_names(constraint: ColumnCollectionConstraint) -> list[str]:
5✔
26
    return list(constraint.columns.keys())
5✔
27

28

29
def get_constraint_sort_key(constraint: Constraint) -> str:
5✔
30
    if isinstance(constraint, CheckConstraint):
5✔
31
        return f"C{constraint.sqltext}"
5✔
32
    elif isinstance(constraint, ColumnCollectionConstraint):
5✔
33
        return constraint.__class__.__name__[0] + repr(get_column_names(constraint))
5✔
34
    else:
35
        return str(constraint)
×
36

37

38
def get_compiled_expression(statement: ClauseElement, bind: Engine | Connection) -> str:
5✔
39
    """Return the statement in a form where any placeholders have been filled in."""
40
    return str(statement.compile(bind, compile_kwargs={"literal_binds": True}))
5✔
41

42

43
def get_common_fk_constraints(
5✔
44
    table1: Table, table2: Table
45
) -> set[ForeignKeyConstraint]:
46
    """
47
    Return a set of foreign key constraints the two tables have against each other.
48

49
    """
50
    c1 = {
5✔
51
        c
52
        for c in table1.constraints
53
        if isinstance(c, ForeignKeyConstraint) and c.elements[0].column.table == table2
54
    }
55
    c2 = {
5✔
56
        c
57
        for c in table2.constraints
58
        if isinstance(c, ForeignKeyConstraint) and c.elements[0].column.table == table1
59
    }
60
    return c1.union(c2)
5✔
61

62

63
def uses_default_name(constraint: Constraint | Index) -> bool:
5✔
64
    if not constraint.name or constraint.table is None:
5✔
65
        return True
5✔
66

67
    table = constraint.table
5✔
68
    values: dict[str, Any] = {
5✔
69
        "table_name": table.name,
70
        "constraint_name": constraint.name,
71
    }
72
    if isinstance(constraint, (Index, ColumnCollectionConstraint)):
5✔
73
        values.update(
5✔
74
            {
75
                "column_0N_name": "".join(col.name for col in constraint.columns),
76
                "column_0_N_name": "_".join(col.name for col in constraint.columns),
77
                "column_0N_label": "".join(
78
                    col.label(col.name).name for col in constraint.columns
79
                ),
80
                "column_0_N_label": "_".join(
81
                    col.label(col.name).name for col in constraint.columns
82
                ),
83
                "column_0N_key": "".join(
84
                    col.key for col in constraint.columns if col.key
85
                ),
86
                "column_0_N_key": "_".join(
87
                    col.key for col in constraint.columns if col.key
88
                ),
89
            }
90
        )
91
        if constraint.columns:
5✔
92
            columns = constraint.columns.values()
5✔
93
            values.update(
5✔
94
                {
95
                    "column_0_name": columns[0].name,
96
                    "column_0_label": columns[0].label(columns[0].name).name,
97
                    "column_0_key": columns[0].key,
98
                }
99
            )
100

101
    key: Literal["fk", "pk", "ix", "ck", "uq"]
102
    if isinstance(constraint, Index):
5✔
103
        key = "ix"
5✔
104
    elif isinstance(constraint, CheckConstraint):
5✔
105
        key = "ck"
5✔
106
    elif isinstance(constraint, UniqueConstraint):
5✔
107
        key = "uq"
5✔
108
    elif isinstance(constraint, PrimaryKeyConstraint):
5✔
109
        key = "pk"
5✔
110
    elif isinstance(constraint, ForeignKeyConstraint):
5✔
111
        key = "fk"
5✔
112
        values.update(
5✔
113
            {
114
                "referred_table_name": constraint.referred_table,
115
                "referred_column_0_name": constraint.elements[0].column.name,
116
                "referred_column_0N_name": "".join(
117
                    fk.column.name for fk in constraint.elements
118
                ),
119
                "referred_column_0_N_name": "_".join(
120
                    fk.column.name for fk in constraint.elements
121
                ),
122
                "referred_column_0_label": constraint.elements[0]
123
                .column.label(constraint.elements[0].column.name)
124
                .name,
125
                "referred_fk.column_0N_label": "".join(
126
                    fk.column.label(fk.column.name).name for fk in constraint.elements
127
                ),
128
                "referred_fk.column_0_N_label": "_".join(
129
                    fk.column.label(fk.column.name).name for fk in constraint.elements
130
                ),
131
                "referred_fk.column_0_key": constraint.elements[0].column.key,
132
                "referred_fk.column_0N_key": "".join(
133
                    fk.column.key for fk in constraint.elements if fk.column.key
134
                ),
135
                "referred_fk.column_0_N_key": "_".join(
136
                    fk.column.key for fk in constraint.elements if fk.column.key
137
                ),
138
            }
139
        )
140
    else:
141
        raise TypeError(f"Unknown constraint type: {constraint.__class__.__qualname__}")
×
142

143
    try:
5✔
144
        convention = cast(
5✔
145
            Mapping[str, str],
146
            table.metadata.naming_convention,
147
        )[key]
148
        return constraint.name == (convention % values)
5✔
149
    except KeyError:
5✔
150
        return False
5✔
151

152

153
def render_callable(
5✔
154
    name: str,
155
    *args: object,
156
    kwargs: Mapping[str, object] | None = None,
157
    indentation: str = "",
158
) -> str:
159
    """
160
    Render a function call.
161

162
    :param name: name of the callable
163
    :param args: positional arguments
164
    :param kwargs: keyword arguments
165
    :param indentation: if given, each argument will be rendered on its own line with
166
        this value used as the indentation
167

168
    """
169
    if kwargs:
5✔
170
        args += tuple(f"{key}={value}" for key, value in kwargs.items())
5✔
171

172
    if indentation:
5✔
173
        prefix = f"\n{indentation}"
5✔
174
        suffix = "\n"
5✔
175
        delimiter = f",\n{indentation}"
5✔
176
    else:
177
        prefix = suffix = ""
5✔
178
        delimiter = ", "
5✔
179

180
    rendered_args = delimiter.join(str(arg) for arg in args)
5✔
181
    return f"{name}({prefix}{rendered_args}{suffix})"
5✔
182

183

184
def qualified_table_name(table: Table) -> str:
5✔
185
    if table.schema:
5✔
186
        return f"{table.schema}.{table.name}"
5✔
187
    else:
188
        return str(table.name)
5✔
189

190

191
def decode_postgresql_sequence(clause: TextClause) -> tuple[str | None, str | None]:
5✔
192
    match = _re_postgresql_nextval_sequence.match(clause.text)
5✔
193
    if not match:
5✔
194
        return None, None
×
195

196
    schema: str | None = None
5✔
197
    sequence: str = ""
5✔
198
    in_quotes = False
5✔
199
    for match in _re_postgresql_sequence_delimiter.finditer(match.group(1)):
5✔
200
        sequence += match.group(1)
5✔
201
        if match.group(2) == '"':
5✔
202
            in_quotes = not in_quotes
5✔
203
        elif match.group(2) == ".":
5✔
204
            if in_quotes:
5✔
205
                sequence += "."
5✔
206
            else:
207
                schema, sequence = sequence, ""
5✔
208

209
    return schema, sequence
5✔
210

211

212
def get_stdlib_module_names() -> set[str]:
5✔
213
    major, minor = sys.version_info.major, sys.version_info.minor
5✔
214
    if (major, minor) > (3, 9):
5✔
215
        return set(sys.builtin_module_names) | set(sys.stdlib_module_names)
5✔
216
    else:
UNCOV
217
        from stdlib_list import stdlib_list
×
218

UNCOV
219
        return set(sys.builtin_module_names) | set(stdlib_list(f"{major}.{minor}"))
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc