• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 12727109829

11 Jan 2025 07:46PM UTC coverage: 97.014%. First build
12727109829

Pull #358

github

web-flow
Merge a22b4a63d into 9925f10fa
Pull Request #358: SQLModel Code generation fixes

12 of 13 new or added lines in 1 file covered. (92.31%)

1332 of 1373 relevant lines covered (97.01%)

4.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.94
/src/sqlacodegen/generators.py
1
from __future__ import annotations
5✔
2

3
import inspect
5✔
4
import re
5✔
5
import sys
5✔
6
from abc import ABCMeta, abstractmethod
5✔
7
from collections import defaultdict
5✔
8
from collections.abc import Collection, Iterable, Sequence
5✔
9
from dataclasses import dataclass
5✔
10
from importlib import import_module
5✔
11
from inspect import Parameter
5✔
12
from itertools import count
5✔
13
from keyword import iskeyword
5✔
14
from pprint import pformat
5✔
15
from textwrap import indent
5✔
16
from typing import Any, ClassVar
5✔
17

18
import inflect
5✔
19
import sqlalchemy
5✔
20
from sqlalchemy import (
5✔
21
    ARRAY,
22
    Boolean,
23
    CheckConstraint,
24
    Column,
25
    Computed,
26
    Constraint,
27
    DefaultClause,
28
    Enum,
29
    Float,
30
    ForeignKey,
31
    ForeignKeyConstraint,
32
    Identity,
33
    Index,
34
    MetaData,
35
    PrimaryKeyConstraint,
36
    String,
37
    Table,
38
    Text,
39
    UniqueConstraint,
40
)
41
from sqlalchemy.dialects.postgresql import JSONB
5✔
42
from sqlalchemy.engine import Connection, Engine
5✔
43
from sqlalchemy.exc import CompileError
5✔
44
from sqlalchemy.sql.elements import TextClause
5✔
45

46
from .models import (
5✔
47
    ColumnAttribute,
48
    JoinType,
49
    Model,
50
    ModelClass,
51
    RelationshipAttribute,
52
    RelationshipType,
53
)
54
from .utils import (
5✔
55
    decode_postgresql_sequence,
56
    get_column_names,
57
    get_common_fk_constraints,
58
    get_compiled_expression,
59
    get_constraint_sort_key,
60
    qualified_table_name,
61
    render_callable,
62
    uses_default_name,
63
)
64

65
if sys.version_info < (3, 10):
5✔
66
    pass
2✔
67
else:
68
    pass
3✔
69

70
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
5✔
71
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
5✔
72
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
5✔
73
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
5✔
74
_re_invalid_identifier = re.compile(r"(?u)\W")
5✔
75

76

77
@dataclass
5✔
78
class LiteralImport:
5✔
79
    pkgname: str
5✔
80
    name: str
5✔
81

82

83
@dataclass
5✔
84
class Base:
5✔
85
    """Representation of MetaData for Tables, respectively Base for classes"""
86

87
    literal_imports: list[LiteralImport]
5✔
88
    declarations: list[str]
5✔
89
    metadata_ref: str
5✔
90
    decorator: str | None = None
5✔
91
    table_metadata_declaration: str | None = None
5✔
92

93

94
class CodeGenerator(metaclass=ABCMeta):
5✔
95
    valid_options: ClassVar[set[str]] = set()
5✔
96

97
    def __init__(
5✔
98
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
99
    ):
100
        self.metadata: MetaData = metadata
5✔
101
        self.bind: Connection | Engine = bind
5✔
102
        self.options: set[str] = set(options)
5✔
103

104
        # Validate options
105
        invalid_options = {opt for opt in options if opt not in self.valid_options}
5✔
106
        if invalid_options:
5✔
107
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
108

109
    @abstractmethod
5✔
110
    def generate(self) -> str:
5✔
111
        """
112
        Generate the code for the given metadata.
113
        .. note:: May modify the metadata.
114
        """
115

116

117
@dataclass(eq=False)
5✔
118
class TablesGenerator(CodeGenerator):
5✔
119
    valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
5✔
120
    builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | {
5✔
121
        "dataclasses"
122
    }
123

124
    def __init__(
5✔
125
        self,
126
        metadata: MetaData,
127
        bind: Connection | Engine,
128
        options: Sequence[str],
129
        *,
130
        indentation: str = "    ",
131
    ):
132
        super().__init__(metadata, bind, options)
5✔
133
        self.indentation: str = indentation
5✔
134
        self.imports: dict[str, set[str]] = defaultdict(set)
5✔
135
        self.module_imports: set[str] = set()
5✔
136

137
    @property
5✔
138
    def mapped_columns_supported(self) -> bool:
5✔
NEW
139
        return False
×
140

141
    def generate_base(self) -> None:
5✔
142
        self.base = Base(
5✔
143
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
144
            declarations=["metadata = MetaData()"],
145
            metadata_ref="metadata",
146
        )
147

148
    def generate(self) -> str:
5✔
149
        self.generate_base()
5✔
150

151
        sections: list[str] = []
5✔
152

153
        # Remove unwanted elements from the metadata
154
        for table in list(self.metadata.tables.values()):
5✔
155
            if self.should_ignore_table(table):
5✔
156
                self.metadata.remove(table)
×
157
                continue
×
158

159
            if "noindexes" in self.options:
5✔
160
                table.indexes.clear()
5✔
161

162
            if "noconstraints" in self.options:
5✔
163
                table.constraints.clear()
5✔
164

165
            if "nocomments" in self.options:
5✔
166
                table.comment = None
5✔
167

168
            for column in table.columns:
5✔
169
                if "nocomments" in self.options:
5✔
170
                    column.comment = None
5✔
171

172
        # Use information from column constraints to figure out the intended column
173
        # types
174
        for table in self.metadata.tables.values():
5✔
175
            self.fix_column_types(table)
5✔
176

177
        # Generate the models
178
        models: list[Model] = self.generate_models()
5✔
179

180
        # Render module level variables
181
        variables = self.render_module_variables(models)
5✔
182
        if variables:
5✔
183
            sections.append(variables + "\n")
5✔
184

185
        # Render models
186
        rendered_models = self.render_models(models)
5✔
187
        if rendered_models:
5✔
188
            sections.append(rendered_models)
5✔
189

190
        # Render collected imports
191
        groups = self.group_imports()
5✔
192
        imports = "\n\n".join("\n".join(line for line in group) for group in groups)
5✔
193
        if imports:
5✔
194
            sections.insert(0, imports)
5✔
195

196
        return "\n\n".join(sections) + "\n"
5✔
197

198
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
199
        for literal_import in self.base.literal_imports:
5✔
200
            self.add_literal_import(literal_import.pkgname, literal_import.name)
5✔
201

202
        for model in models:
5✔
203
            self.collect_imports_for_model(model)
5✔
204

205
    def collect_imports_for_model(self, model: Model) -> None:
5✔
206
        if model.__class__ is Model:
5✔
207
            self.add_import(Table)
5✔
208

209
        for column in model.table.c:
5✔
210
            self.collect_imports_for_column(column)
5✔
211

212
        for constraint in model.table.constraints:
5✔
213
            self.collect_imports_for_constraint(constraint)
5✔
214

215
        for index in model.table.indexes:
5✔
216
            self.collect_imports_for_constraint(index)
5✔
217

218
    def collect_imports_for_column(self, column: Column[Any]) -> None:
5✔
219
        self.add_import(column.type)
5✔
220

221
        if isinstance(column.type, ARRAY):
5✔
222
            self.add_import(column.type.item_type.__class__)
5✔
223
        elif isinstance(column.type, JSONB):
5✔
224
            if (
5✔
225
                not isinstance(column.type.astext_type, Text)
226
                or column.type.astext_type.length is not None
227
            ):
228
                self.add_import(column.type.astext_type)
5✔
229

230
        if column.default:
5✔
231
            self.add_import(column.default)
5✔
232

233
        if column.server_default:
5✔
234
            if isinstance(column.server_default, (Computed, Identity)):
5✔
235
                self.add_import(column.server_default)
5✔
236
            elif isinstance(column.server_default, DefaultClause):
5✔
237
                self.add_literal_import("sqlalchemy", "text")
5✔
238

239
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
5✔
240
        if isinstance(constraint, Index):
5✔
241
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
242
                self.add_literal_import("sqlalchemy", "Index")
5✔
243
        elif isinstance(constraint, PrimaryKeyConstraint):
5✔
244
            if not uses_default_name(constraint):
5✔
245
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
5✔
246
        elif isinstance(constraint, UniqueConstraint):
5✔
247
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
248
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
5✔
249
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
250
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
251
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
5✔
252
            else:
253
                self.add_import(ForeignKey)
5✔
254
        else:
255
            self.add_import(constraint)
5✔
256

257
    def add_import(self, obj: Any) -> None:
5✔
258
        # Don't store builtin imports
259
        if getattr(obj, "__module__", "builtins") == "builtins":
5✔
260
            return
5✔
261

262
        type_ = type(obj) if not isinstance(obj, type) else obj
5✔
263
        pkgname = type_.__module__
5✔
264

265
        # The column types have already been adapted towards generic types if possible,
266
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
267
        # use that rather than the generic sqlalchemy type as it might have different
268
        # constructor parameters.
269
        if pkgname.startswith("sqlalchemy.dialects."):
5✔
270
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
5✔
271
            dialect_pkg = import_module(dialect_pkgname)
5✔
272

273
            if type_.__name__ in dialect_pkg.__all__:
5✔
274
                pkgname = dialect_pkgname
5✔
275
        elif type_.__name__ in dir(sqlalchemy):
5✔
276
            pkgname = "sqlalchemy"
5✔
277
        else:
278
            pkgname = type_.__module__
5✔
279

280
        self.add_literal_import(pkgname, type_.__name__)
5✔
281

282
    def add_literal_import(self, pkgname: str, name: str) -> None:
5✔
283
        names = self.imports.setdefault(pkgname, set())
5✔
284
        names.add(name)
5✔
285

286
    def remove_literal_import(self, pkgname: str, name: str) -> None:
5✔
287
        names = self.imports.setdefault(pkgname, set())
×
288
        if name in names:
×
289
            names.remove(name)
×
290

291
    def add_module_import(self, pgkname: str) -> None:
5✔
292
        self.module_imports.add(pgkname)
5✔
293

294
    def group_imports(self) -> list[list[str]]:
5✔
295
        future_imports: list[str] = []
5✔
296
        stdlib_imports: list[str] = []
5✔
297
        thirdparty_imports: list[str] = []
5✔
298

299
        for package in sorted(self.imports):
5✔
300
            imports = ", ".join(sorted(self.imports[package]))
5✔
301
            collection = thirdparty_imports
5✔
302
            if package == "__future__":
5✔
303
                collection = future_imports
×
304
            elif package in self.builtin_module_names:
5✔
305
                collection = stdlib_imports
×
306
            elif package in sys.modules:
5✔
307
                if "site-packages" not in (sys.modules[package].__file__ or ""):
5✔
308
                    collection = stdlib_imports
5✔
309

310
            collection.append(f"from {package} import {imports}")
5✔
311

312
        for module in sorted(self.module_imports):
5✔
313
            thirdparty_imports.append(f"import {module}")
5✔
314

315
        return [
5✔
316
            group
317
            for group in (future_imports, stdlib_imports, thirdparty_imports)
318
            if group
319
        ]
320

321
    def generate_models(self) -> list[Model]:
5✔
322
        models = [Model(table) for table in self.metadata.sorted_tables]
5✔
323

324
        # Collect the imports
325
        self.collect_imports(models)
5✔
326

327
        # Generate names for models
328
        global_names = {
5✔
329
            name for namespace in self.imports.values() for name in namespace
330
        }
331
        for model in models:
5✔
332
            self.generate_model_name(model, global_names)
5✔
333
            global_names.add(model.name)
5✔
334

335
        return models
5✔
336

337
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
338
        preferred_name = f"t_{model.table.name}"
5✔
339
        model.name = self.find_free_name(preferred_name, global_names)
5✔
340

341
    def render_module_variables(self, models: list[Model]) -> str:
5✔
342
        declarations = self.base.declarations
5✔
343

344
        if any(not isinstance(model, ModelClass) for model in models):
5✔
345
            if self.base.table_metadata_declaration is not None:
5✔
346
                declarations.append(self.base.table_metadata_declaration)
×
347

348
        return "\n".join(declarations)
5✔
349

350
    def render_models(self, models: list[Model]) -> str:
5✔
351
        rendered: list[str] = []
5✔
352
        for model in models:
5✔
353
            rendered_table = self.render_table(model.table)
5✔
354
            rendered.append(f"{model.name} = {rendered_table}")
5✔
355

356
        return "\n\n".join(rendered)
5✔
357

358
    def render_table(self, table: Table) -> str:
5✔
359
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
5✔
360
        kwargs: dict[str, object] = {}
5✔
361
        for column in table.columns:
5✔
362
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
363
            # Table.columns
364
            args.append(self.render_column(column, True, is_table=True))
5✔
365

366
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
367
            if uses_default_name(constraint):
5✔
368
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
369
                    continue
5✔
370
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
5✔
371
                    if len(constraint.columns) == 1:
5✔
372
                        continue
5✔
373

374
            args.append(self.render_constraint(constraint))
5✔
375

376
        for index in sorted(table.indexes, key=lambda i: i.name):
5✔
377
            # One-column indexes should be rendered as index=True on columns
378
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
379
                args.append(self.render_index(index))
5✔
380

381
        if table.schema:
5✔
382
            kwargs["schema"] = repr(table.schema)
5✔
383

384
        table_comment = getattr(table, "comment", None)
5✔
385
        if table_comment:
5✔
386
            kwargs["comment"] = repr(table.comment)
5✔
387

388
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
5✔
389

390
    def render_index(self, index: Index) -> str:
5✔
391
        extra_args = [repr(col.name) for col in index.columns]
5✔
392
        kwargs = {}
5✔
393
        if index.unique:
5✔
394
            kwargs["unique"] = True
5✔
395

396
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
5✔
397

398
    # TODO find better solution for is_table
399
    def render_column(
5✔
400
        self, column: Column[Any], show_name: bool, is_table: bool = False
401
    ) -> str:
402
        args = []
5✔
403
        kwargs: dict[str, Any] = {}
5✔
404
        kwarg = []
5✔
405
        is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
5✔
406
        dedicated_fks = [
5✔
407
            c
408
            for c in column.foreign_keys
409
            if c.constraint
410
            and len(c.constraint.columns) == 1
411
            and uses_default_name(c.constraint)
412
        ]
413
        is_unique = any(
5✔
414
            isinstance(c, UniqueConstraint)
415
            and set(c.columns) == {column}
416
            and uses_default_name(c)
417
            for c in column.table.constraints
418
        )
419
        is_unique = is_unique or any(
5✔
420
            i.unique and set(i.columns) == {column} and uses_default_name(i)
421
            for i in column.table.indexes
422
        )
423
        is_primary = (
5✔
424
            any(
425
                isinstance(c, PrimaryKeyConstraint)
426
                and column.name in c.columns
427
                and uses_default_name(c)
428
                for c in column.table.constraints
429
            )
430
            or column.primary_key
431
        )
432
        has_index = any(
5✔
433
            set(i.columns) == {column} and uses_default_name(i)
434
            for i in column.table.indexes
435
        )
436

437
        if show_name:
5✔
438
            args.append(repr(column.name))
5✔
439

440
        # Render the column type if there are no foreign keys on it or any of them
441
        # points back to itself
442
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
5✔
443
            args.append(self.render_column_type(column.type))
5✔
444

445
        for fk in dedicated_fks:
5✔
446
            args.append(self.render_constraint(fk))
5✔
447

448
        if column.default:
5✔
449
            args.append(repr(column.default))
5✔
450

451
        if column.key != column.name:
5✔
452
            kwargs["key"] = column.key
×
453
        if is_primary:
5✔
454
            kwargs["primary_key"] = True
5✔
455
        if not column.nullable and not is_sole_pk and is_table:
5✔
456
            kwargs["nullable"] = False
×
457

458
        if is_unique:
5✔
459
            column.unique = True
5✔
460
            kwargs["unique"] = True
5✔
461
        if has_index:
5✔
462
            column.index = True
5✔
463
            kwarg.append("index")
5✔
464
            kwargs["index"] = True
5✔
465

466
        if isinstance(column.server_default, DefaultClause):
5✔
467
            kwargs["server_default"] = render_callable(
5✔
468
                "text", repr(column.server_default.arg.text)
469
            )
470
        elif isinstance(column.server_default, Computed):
5✔
471
            expression = str(column.server_default.sqltext)
5✔
472

473
            computed_kwargs = {}
5✔
474
            if column.server_default.persisted is not None:
5✔
475
                computed_kwargs["persisted"] = column.server_default.persisted
5✔
476

477
            args.append(
5✔
478
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
479
            )
480
        elif isinstance(column.server_default, Identity):
5✔
481
            args.append(repr(column.server_default))
5✔
482
        elif column.server_default:
5✔
483
            kwargs["server_default"] = repr(column.server_default)
×
484

485
        comment = getattr(column, "comment", None)
5✔
486
        if comment:
5✔
487
            kwargs["comment"] = repr(comment)
5✔
488

489
        if is_table or not self.mapped_columns_supported:
5✔
490
            self.add_import(Column)
5✔
491
            return render_callable("Column", *args, kwargs=kwargs)
5✔
492
        else:
493
            return render_callable("mapped_column", *args, kwargs=kwargs)
5✔
494

495
    def render_column_type(self, coltype: object) -> str:
5✔
496
        args = []
5✔
497
        kwargs: dict[str, Any] = {}
5✔
498
        sig = inspect.signature(coltype.__class__.__init__)
5✔
499
        defaults = {param.name: param.default for param in sig.parameters.values()}
5✔
500
        missing = object()
5✔
501
        use_kwargs = False
5✔
502
        for param in list(sig.parameters.values())[1:]:
5✔
503
            # Remove annoyances like _warn_on_bytestring
504
            if param.name.startswith("_"):
5✔
505
                continue
5✔
506
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
5✔
507
                continue
5✔
508

509
            value = getattr(coltype, param.name, missing)
5✔
510
            default = defaults.get(param.name, missing)
5✔
511
            if value is missing or value == default:
5✔
512
                use_kwargs = True
5✔
513
            elif use_kwargs:
5✔
514
                kwargs[param.name] = repr(value)
5✔
515
            else:
516
                args.append(repr(value))
5✔
517

518
        vararg = next(
5✔
519
            (
520
                param.name
521
                for param in sig.parameters.values()
522
                if param.kind is Parameter.VAR_POSITIONAL
523
            ),
524
            None,
525
        )
526
        if vararg and hasattr(coltype, vararg):
5✔
527
            varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
5✔
528
            args.extend(varargs_repr)
5✔
529

530
        if isinstance(coltype, Enum) and coltype.name is not None:
5✔
531
            kwargs["name"] = repr(coltype.name)
5✔
532

533
        if isinstance(coltype, JSONB):
5✔
534
            # Remove astext_type if it's the default
535
            if (
5✔
536
                isinstance(coltype.astext_type, Text)
537
                and coltype.astext_type.length is None
538
            ):
539
                del kwargs["astext_type"]
5✔
540

541
        if args or kwargs:
5✔
542
            return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
5✔
543
        else:
544
            return coltype.__class__.__name__
5✔
545

546
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
5✔
547
        def add_fk_options(*opts: Any) -> None:
5✔
548
            args.extend(repr(opt) for opt in opts)
5✔
549
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
5✔
550
                value = getattr(constraint, attr, None)
5✔
551
                if value:
5✔
552
                    kwargs[attr] = repr(value)
5✔
553

554
        args: list[str] = []
5✔
555
        kwargs: dict[str, Any] = {}
5✔
556
        if isinstance(constraint, ForeignKey):
5✔
557
            remote_column = (
5✔
558
                f"{constraint.column.table.fullname}.{constraint.column.name}"
559
            )
560
            add_fk_options(remote_column)
5✔
561
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
562
            local_columns = get_column_names(constraint)
5✔
563
            remote_columns = [
5✔
564
                f"{fk.column.table.fullname}.{fk.column.name}"
565
                for fk in constraint.elements
566
            ]
567
            add_fk_options(local_columns, remote_columns)
5✔
568
        elif isinstance(constraint, CheckConstraint):
5✔
569
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
5✔
570
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
5✔
571
            args.extend(repr(col.name) for col in constraint.columns)
5✔
572
        else:
573
            raise TypeError(
×
574
                f"Cannot render constraint of type {constraint.__class__.__name__}"
575
            )
576

577
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
5✔
578
            kwargs["name"] = repr(constraint.name)
5✔
579

580
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
5✔
581

582
    def should_ignore_table(self, table: Table) -> bool:
5✔
583
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
584
        # tables
585
        return table.name in ("alembic_version", "migrate_version")
5✔
586

587
    def find_free_name(
5✔
588
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
589
    ) -> str:
590
        """
591
        Generate an attribute name that does not clash with other local or global names.
592
        """
593
        name = name.strip()
5✔
594
        assert name, "Identifier cannot be empty"
5✔
595
        name = _re_invalid_identifier.sub("_", name)
5✔
596
        if name[0].isdigit():
5✔
597
            name = "_" + name
5✔
598
        elif iskeyword(name) or name == "metadata":
5✔
599
            name += "_"
5✔
600

601
        original = name
5✔
602
        for i in count():
5✔
603
            if name not in global_names and name not in local_names:
5✔
604
                break
5✔
605

606
            name = original + (str(i) if i else "_")
5✔
607

608
        return name
5✔
609

610
    def fix_column_types(self, table: Table) -> None:
5✔
611
        """Adjust the reflected column types."""
612
        # Detect check constraints for boolean and enum columns
613
        for constraint in table.constraints.copy():
5✔
614
            if isinstance(constraint, CheckConstraint):
5✔
615
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
5✔
616

617
                # Turn any integer-like column with a CheckConstraint like
618
                # "column IN (0, 1)" into a Boolean
619
                match = _re_boolean_check_constraint.match(sqltext)
5✔
620
                if match:
5✔
621
                    colname_match = _re_column_name.match(match.group(1))
5✔
622
                    if colname_match:
5✔
623
                        colname = colname_match.group(3)
5✔
624
                        table.constraints.remove(constraint)
5✔
625
                        table.c[colname].type = Boolean()
5✔
626
                        continue
5✔
627

628
                # Turn any string-type column with a CheckConstraint like
629
                # "column IN (...)" into an Enum
630
                match = _re_enum_check_constraint.match(sqltext)
5✔
631
                if match:
5✔
632
                    colname_match = _re_column_name.match(match.group(1))
5✔
633
                    if colname_match:
5✔
634
                        colname = colname_match.group(3)
5✔
635
                        items = match.group(2)
5✔
636
                        if isinstance(table.c[colname].type, String):
5✔
637
                            table.constraints.remove(constraint)
5✔
638
                            if not isinstance(table.c[colname].type, Enum):
5✔
639
                                options = _re_enum_item.findall(items)
5✔
640
                                table.c[colname].type = Enum(
5✔
641
                                    *options, native_enum=False
642
                                )
643

644
                            continue
5✔
645

646
        for column in table.c:
5✔
647
            try:
5✔
648
                column.type = self.get_adapted_type(column.type)
5✔
649
            except CompileError:
5✔
650
                pass
5✔
651

652
            # PostgreSQL specific fix: detect sequences from server_default
653
            if column.server_default and self.bind.dialect.name == "postgresql":
5✔
654
                if isinstance(column.server_default, DefaultClause) and isinstance(
5✔
655
                    column.server_default.arg, TextClause
656
                ):
657
                    schema, seqname = decode_postgresql_sequence(
5✔
658
                        column.server_default.arg
659
                    )
660
                    if seqname:
5✔
661
                        # Add an explicit sequence
662
                        if seqname != f"{column.table.name}_{column.name}_seq":
5✔
663
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
5✔
664

665
                        column.server_default = None
5✔
666

667
    def get_adapted_type(self, coltype: Any) -> Any:
5✔
668
        compiled_type = coltype.compile(self.bind.engine.dialect)
5✔
669
        for supercls in coltype.__class__.__mro__:
5✔
670
            if not supercls.__name__.startswith("_") and hasattr(
5✔
671
                supercls, "__visit_name__"
672
            ):
673
                # Hack to fix adaptation of the Enum class which is broken since
674
                # SQLAlchemy 1.2
675
                kw = {}
5✔
676
                if supercls is Enum:
5✔
677
                    kw["name"] = coltype.name
5✔
678

679
                try:
5✔
680
                    new_coltype = coltype.adapt(supercls)
5✔
681
                except TypeError:
5✔
682
                    # If the adaptation fails, don't try again
683
                    break
5✔
684

685
                for key, value in kw.items():
5✔
686
                    setattr(new_coltype, key, value)
5✔
687

688
                if isinstance(coltype, ARRAY):
5✔
689
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
5✔
690

691
                try:
5✔
692
                    # If the adapted column type does not render the same as the
693
                    # original, don't substitute it
694
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
5✔
695
                        # Make an exception to the rule for Float and arrays of Float,
696
                        # since at least on PostgreSQL, Float can accurately represent
697
                        # both REAL and DOUBLE_PRECISION
698
                        if not isinstance(new_coltype, Float) and not (
5✔
699
                            isinstance(new_coltype, ARRAY)
700
                            and isinstance(new_coltype.item_type, Float)
701
                        ):
702
                            break
5✔
703
                except CompileError:
5✔
704
                    # If the adapted column type can't be compiled, don't substitute it
705
                    break
5✔
706

707
                # Stop on the first valid non-uppercase column type class
708
                coltype = new_coltype
5✔
709
                if supercls.__name__ != supercls.__name__.upper():
5✔
710
                    break
5✔
711

712
        return coltype
5✔
713

714

715
class DeclarativeGenerator(TablesGenerator):
5✔
716
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
5✔
717
        "use_inflect",
718
        "nojoined",
719
        "nobidi",
720
    }
721

722
    def __init__(
5✔
723
        self,
724
        metadata: MetaData,
725
        bind: Connection | Engine,
726
        options: Sequence[str],
727
        *,
728
        indentation: str = "    ",
729
        base_class_name: str = "Base",
730
    ):
731
        super().__init__(metadata, bind, options, indentation=indentation)
5✔
732
        self.base_class_name: str = base_class_name
5✔
733
        self.inflect_engine = inflect.engine()
5✔
734

735
    @property
5✔
736
    def mapped_columns_supported(self) -> bool:
5✔
737
        return True
5✔
738

739
    def generate_base(self) -> None:
5✔
740
        self.base = Base(
5✔
741
            literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
742
            declarations=[
743
                f"class {self.base_class_name}(DeclarativeBase):",
744
                f"{self.indentation}pass",
745
            ],
746
            metadata_ref=f"{self.base_class_name}.metadata",
747
        )
748

749
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
750
        super().collect_imports(models)
5✔
751
        if (
5✔
752
            any(isinstance(model, ModelClass) for model in models)
753
            and self.mapped_columns_supported
754
        ):
755
            self.add_literal_import("sqlalchemy.orm", "Mapped")
5✔
756
            self.add_literal_import("sqlalchemy.orm", "mapped_column")
5✔
757

758
    def collect_imports_for_model(self, model: Model) -> None:
5✔
759
        super().collect_imports_for_model(model)
5✔
760
        if isinstance(model, ModelClass):
5✔
761
            if model.relationships:
5✔
762
                self.add_literal_import("sqlalchemy.orm", "relationship")
5✔
763

764
    def generate_models(self) -> list[Model]:
5✔
765
        models_by_table_name: dict[str, Model] = {}
5✔
766

767
        # Pick association tables from the metadata into their own set, don't process
768
        # them normally
769
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
5✔
770
        for table in self.metadata.sorted_tables:
5✔
771
            qualified_name = qualified_table_name(table)
5✔
772

773
            # Link tables have exactly two foreign key constraints and all columns are
774
            # involved in them
775
            fk_constraints = sorted(
5✔
776
                table.foreign_key_constraints, key=get_constraint_sort_key
777
            )
778
            if len(fk_constraints) == 2 and all(
5✔
779
                col.foreign_keys for col in table.columns
780
            ):
781
                model = models_by_table_name[qualified_name] = Model(table)
5✔
782
                tablename = fk_constraints[0].elements[0].column.table.name
5✔
783
                links[tablename].append(model)
5✔
784
                continue
5✔
785

786
            # Only form model classes for tables that have a primary key and are not
787
            # association tables
788
            if not table.primary_key:
5✔
789
                models_by_table_name[qualified_name] = Model(table)
5✔
790
            else:
791
                model = ModelClass(table)
5✔
792
                models_by_table_name[qualified_name] = model
5✔
793

794
                # Fill in the columns
795
                for column in table.c:
5✔
796
                    column_attr = ColumnAttribute(model, column)
5✔
797
                    model.columns.append(column_attr)
5✔
798

799
        # Add relationships
800
        for model in models_by_table_name.values():
5✔
801
            if isinstance(model, ModelClass):
5✔
802
                self.generate_relationships(
5✔
803
                    model, models_by_table_name, links[model.table.name]
804
                )
805

806
        # Nest inherited classes in their superclasses to ensure proper ordering
807
        if "nojoined" not in self.options:
5✔
808
            for model in list(models_by_table_name.values()):
5✔
809
                if not isinstance(model, ModelClass):
5✔
810
                    continue
5✔
811

812
                pk_column_names = {col.name for col in model.table.primary_key.columns}
5✔
813
                for constraint in model.table.foreign_key_constraints:
5✔
814
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
815
                        target = models_by_table_name[
5✔
816
                            qualified_table_name(constraint.elements[0].column.table)
817
                        ]
818
                        if isinstance(target, ModelClass):
5✔
819
                            model.parent_class = target
5✔
820
                            target.children.append(model)
5✔
821

822
        # Change base if we only have tables
823
        if not any(
5✔
824
            isinstance(model, ModelClass) for model in models_by_table_name.values()
825
        ):
826
            super().generate_base()
5✔
827

828
        # Collect the imports
829
        self.collect_imports(models_by_table_name.values())
5✔
830

831
        # Rename models and their attributes that conflict with imports or other
832
        # attributes
833
        global_names = {
5✔
834
            name for namespace in self.imports.values() for name in namespace
835
        }
836
        for model in models_by_table_name.values():
5✔
837
            self.generate_model_name(model, global_names)
5✔
838
            global_names.add(model.name)
5✔
839

840
        return list(models_by_table_name.values())
5✔
841

842
    def generate_relationships(
5✔
843
        self,
844
        source: ModelClass,
845
        models_by_table_name: dict[str, Model],
846
        association_tables: list[Model],
847
    ) -> list[RelationshipAttribute]:
848
        relationships: list[RelationshipAttribute] = []
5✔
849
        reverse_relationship: RelationshipAttribute | None
850

851
        # Add many-to-one (and one-to-many) relationships
852
        pk_column_names = {col.name for col in source.table.primary_key.columns}
5✔
853
        for constraint in sorted(
5✔
854
            source.table.foreign_key_constraints, key=get_constraint_sort_key
855
        ):
856
            target = models_by_table_name[
5✔
857
                qualified_table_name(constraint.elements[0].column.table)
858
            ]
859
            if isinstance(target, ModelClass):
5✔
860
                if "nojoined" not in self.options:
5✔
861
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
862
                        parent = models_by_table_name[
5✔
863
                            qualified_table_name(constraint.elements[0].column.table)
864
                        ]
865
                        if isinstance(parent, ModelClass):
5✔
866
                            source.parent_class = parent
5✔
867
                            parent.children.append(source)
5✔
868
                            continue
5✔
869

870
                # Add uselist=False to One-to-One relationships
871
                column_names = get_column_names(constraint)
5✔
872
                if any(
5✔
873
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
874
                    and {col.name for col in c.columns} == set(column_names)
875
                    for c in constraint.table.constraints
876
                ):
877
                    r_type = RelationshipType.ONE_TO_ONE
5✔
878
                else:
879
                    r_type = RelationshipType.MANY_TO_ONE
5✔
880

881
                relationship = RelationshipAttribute(r_type, source, target, constraint)
5✔
882
                source.relationships.append(relationship)
5✔
883

884
                # For self referential relationships, remote_side needs to be set
885
                if source is target:
5✔
886
                    relationship.remote_side = [
5✔
887
                        source.get_column_attribute(col.name)
888
                        for col in constraint.referred_table.primary_key
889
                    ]
890

891
                # If the two tables share more than one foreign key constraint,
892
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
893
                # it needs
894
                common_fk_constraints = get_common_fk_constraints(
5✔
895
                    source.table, target.table
896
                )
897
                if len(common_fk_constraints) > 1:
5✔
898
                    relationship.foreign_keys = [
5✔
899
                        source.get_column_attribute(key)
900
                        for key in constraint.column_keys
901
                    ]
902

903
                # Generate the opposite end of the relationship in the target class
904
                if "nobidi" not in self.options:
5✔
905
                    if r_type is RelationshipType.MANY_TO_ONE:
5✔
906
                        r_type = RelationshipType.ONE_TO_MANY
5✔
907

908
                    reverse_relationship = RelationshipAttribute(
5✔
909
                        r_type,
910
                        target,
911
                        source,
912
                        constraint,
913
                        foreign_keys=relationship.foreign_keys,
914
                        backref=relationship,
915
                    )
916
                    relationship.backref = reverse_relationship
5✔
917
                    target.relationships.append(reverse_relationship)
5✔
918

919
                    # For self referential relationships, remote_side needs to be set
920
                    if source is target:
5✔
921
                        reverse_relationship.remote_side = [
5✔
922
                            source.get_column_attribute(colname)
923
                            for colname in constraint.column_keys
924
                        ]
925

926
        # Add many-to-many relationships
927
        for association_table in association_tables:
5✔
928
            fk_constraints = sorted(
5✔
929
                association_table.table.foreign_key_constraints,
930
                key=get_constraint_sort_key,
931
            )
932
            target = models_by_table_name[
5✔
933
                qualified_table_name(fk_constraints[1].elements[0].column.table)
934
            ]
935
            if isinstance(target, ModelClass):
5✔
936
                relationship = RelationshipAttribute(
5✔
937
                    RelationshipType.MANY_TO_MANY,
938
                    source,
939
                    target,
940
                    fk_constraints[1],
941
                    association_table,
942
                )
943
                source.relationships.append(relationship)
5✔
944

945
                # Generate the opposite end of the relationship in the target class
946
                reverse_relationship = None
5✔
947
                if "nobidi" not in self.options:
5✔
948
                    reverse_relationship = RelationshipAttribute(
5✔
949
                        RelationshipType.MANY_TO_MANY,
950
                        target,
951
                        source,
952
                        fk_constraints[0],
953
                        association_table,
954
                        relationship,
955
                    )
956
                    relationship.backref = reverse_relationship
5✔
957
                    target.relationships.append(reverse_relationship)
5✔
958

959
                # Add a primary/secondary join for self-referential many-to-many
960
                # relationships
961
                if source is target:
5✔
962
                    both_relationships = [relationship]
5✔
963
                    reverse_flags = [False, True]
5✔
964
                    if reverse_relationship:
5✔
965
                        both_relationships.append(reverse_relationship)
5✔
966

967
                    for relationship, reverse in zip(both_relationships, reverse_flags):
5✔
968
                        if (
5✔
969
                            not relationship.association_table
970
                            or not relationship.constraint
971
                        ):
972
                            continue
×
973

974
                        constraints = sorted(
5✔
975
                            relationship.constraint.table.foreign_key_constraints,
976
                            key=get_constraint_sort_key,
977
                            reverse=reverse,
978
                        )
979
                        pri_pairs = zip(
5✔
980
                            get_column_names(constraints[0]), constraints[0].elements
981
                        )
982
                        sec_pairs = zip(
5✔
983
                            get_column_names(constraints[1]), constraints[1].elements
984
                        )
985
                        relationship.primaryjoin = [
5✔
986
                            (
987
                                relationship.source,
988
                                elem.column.name,
989
                                relationship.association_table,
990
                                col,
991
                            )
992
                            for col, elem in pri_pairs
993
                        ]
994
                        relationship.secondaryjoin = [
5✔
995
                            (
996
                                relationship.target,
997
                                elem.column.name,
998
                                relationship.association_table,
999
                                col,
1000
                            )
1001
                            for col, elem in sec_pairs
1002
                        ]
1003

1004
        return relationships
5✔
1005

1006
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
1007
        if isinstance(model, ModelClass):
5✔
1008
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
5✔
1009
            preferred_name = "".join(
5✔
1010
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1011
            )
1012
            if "use_inflect" in self.options:
5✔
1013
                singular_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1014
                if singular_name:
5✔
1015
                    preferred_name = singular_name
5✔
1016

1017
            model.name = self.find_free_name(preferred_name, global_names)
5✔
1018

1019
            # Fill in the names for column attributes
1020
            local_names: set[str] = set()
5✔
1021
            for column_attr in model.columns:
5✔
1022
                self.generate_column_attr_name(column_attr, global_names, local_names)
5✔
1023
                local_names.add(column_attr.name)
5✔
1024

1025
            # Fill in the names for relationship attributes
1026
            for relationship in model.relationships:
5✔
1027
                self.generate_relationship_name(relationship, global_names, local_names)
5✔
1028
                local_names.add(relationship.name)
5✔
1029
        else:
1030
            super().generate_model_name(model, global_names)
5✔
1031

1032
    def generate_column_attr_name(
5✔
1033
        self,
1034
        column_attr: ColumnAttribute,
1035
        global_names: set[str],
1036
        local_names: set[str],
1037
    ) -> None:
1038
        column_attr.name = self.find_free_name(
5✔
1039
            column_attr.column.name, global_names, local_names
1040
        )
1041

1042
    def generate_relationship_name(
5✔
1043
        self,
1044
        relationship: RelationshipAttribute,
1045
        global_names: set[str],
1046
        local_names: set[str],
1047
    ) -> None:
1048
        # Self referential reverse relationships
1049
        preferred_name: str
1050
        if (
5✔
1051
            relationship.type
1052
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1053
            and relationship.source is relationship.target
1054
            and relationship.backref
1055
            and relationship.backref.name
1056
        ):
1057
            preferred_name = relationship.backref.name + "_reverse"
5✔
1058
        else:
1059
            preferred_name = relationship.target.table.name
5✔
1060

1061
            # If there's a constraint with a single column that ends with "_id", use the
1062
            # preceding part as the relationship name
1063
            if relationship.constraint:
5✔
1064
                is_source = relationship.source.table is relationship.constraint.table
5✔
1065
                if is_source or relationship.type not in (
5✔
1066
                    RelationshipType.ONE_TO_ONE,
1067
                    RelationshipType.ONE_TO_MANY,
1068
                ):
1069
                    column_names = [c.name for c in relationship.constraint.columns]
5✔
1070
                    if len(column_names) == 1 and column_names[0].endswith("_id"):
5✔
1071
                        preferred_name = column_names[0][:-3]
5✔
1072

1073
            if "use_inflect" in self.options:
5✔
1074
                if relationship.type in (
5✔
1075
                    RelationshipType.ONE_TO_MANY,
1076
                    RelationshipType.MANY_TO_MANY,
1077
                ):
1078
                    inflected_name = self.inflect_engine.plural_noun(preferred_name)
×
1079
                    if inflected_name:
×
1080
                        preferred_name = inflected_name
×
1081
                else:
1082
                    inflected_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1083
                    if inflected_name:
5✔
1084
                        preferred_name = inflected_name
5✔
1085

1086
        relationship.name = self.find_free_name(
5✔
1087
            preferred_name, global_names, local_names
1088
        )
1089

1090
    def render_models(self, models: list[Model]) -> str:
5✔
1091
        rendered: list[str] = []
5✔
1092
        for model in models:
5✔
1093
            if isinstance(model, ModelClass):
5✔
1094
                rendered.append(self.render_class(model))
5✔
1095
            else:
1096
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
5✔
1097

1098
        return "\n\n\n".join(rendered)
5✔
1099

1100
    def render_class(self, model: ModelClass) -> str:
5✔
1101
        sections: list[str] = []
5✔
1102

1103
        # Render class variables / special declarations
1104
        class_vars: str = self.render_class_variables(model)
5✔
1105
        if class_vars:
5✔
1106
            sections.append(class_vars)
5✔
1107

1108
        # Render column attributes
1109
        rendered_column_attributes: list[str] = []
5✔
1110
        for nullable in (False, True):
5✔
1111
            for column_attr in model.columns:
5✔
1112
                if column_attr.column.nullable is nullable:
5✔
1113
                    rendered_column_attributes.append(
5✔
1114
                        self.render_column_attribute(column_attr)
1115
                    )
1116

1117
        if rendered_column_attributes:
5✔
1118
            sections.append("\n".join(rendered_column_attributes))
5✔
1119

1120
        # Render relationship attributes
1121
        rendered_relationship_attributes: list[str] = [
5✔
1122
            self.render_relationship(relationship)
1123
            for relationship in model.relationships
1124
        ]
1125

1126
        if rendered_relationship_attributes:
5✔
1127
            sections.append("\n".join(rendered_relationship_attributes))
5✔
1128

1129
        declaration = self.render_class_declaration(model)
5✔
1130
        rendered_sections = "\n\n".join(
5✔
1131
            indent(section, self.indentation) for section in sections
1132
        )
1133
        return f"{declaration}\n{rendered_sections}"
5✔
1134

1135
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1136
        parent_class_name = (
5✔
1137
            model.parent_class.name if model.parent_class else self.base_class_name
1138
        )
1139
        return f"class {model.name}({parent_class_name}):"
5✔
1140

1141
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1142
        variables = [f"__tablename__ = {model.table.name!r}"]
5✔
1143

1144
        # Render constraints and indexes as __table_args__
1145
        table_args = self.render_table_args(model.table)
5✔
1146
        if table_args:
5✔
1147
            variables.append(f"__table_args__ = {table_args}")
5✔
1148

1149
        return "\n".join(variables)
5✔
1150

1151
    def render_table_args(self, table: Table) -> str:
5✔
1152
        args: list[str] = []
5✔
1153
        kwargs: dict[str, str] = {}
5✔
1154

1155
        # Render constraints
1156
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
1157
            if uses_default_name(constraint):
5✔
1158
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
1159
                    continue
5✔
1160
                if (
5✔
1161
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1162
                    and len(constraint.columns) == 1
1163
                ):
1164
                    continue
5✔
1165

1166
            args.append(self.render_constraint(constraint))
5✔
1167

1168
        # Render indexes
1169
        for index in sorted(table.indexes, key=lambda i: i.name):
5✔
1170
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
1171
                args.append(self.render_index(index))
5✔
1172

1173
        if table.schema:
5✔
1174
            kwargs["schema"] = table.schema
5✔
1175

1176
        if table.comment:
5✔
1177
            kwargs["comment"] = table.comment
5✔
1178

1179
        if kwargs:
5✔
1180
            formatted_kwargs = pformat(kwargs)
5✔
1181
            if not args:
5✔
1182
                return formatted_kwargs
5✔
1183
            else:
1184
                args.append(formatted_kwargs)
5✔
1185

1186
        if args:
5✔
1187
            rendered_args = f",\n{self.indentation}".join(args)
5✔
1188
            if len(args) == 1:
5✔
1189
                rendered_args += ","
5✔
1190

1191
            return f"(\n{self.indentation}{rendered_args}\n)"
5✔
1192
        else:
1193
            return ""
5✔
1194

1195
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1196
        column = column_attr.column
5✔
1197
        rendered_column = self.render_column(column, column_attr.name != column.name)
5✔
1198

1199
        try:
5✔
1200
            python_type = column.type.python_type
5✔
1201
            python_type_name = python_type.__name__
5✔
1202
            if python_type.__module__ == "builtins":
5✔
1203
                column_python_type = python_type_name
5✔
1204
            else:
1205
                python_type_module = python_type.__module__
5✔
1206
                column_python_type = f"{python_type_module}.{python_type_name}"
5✔
1207
                self.add_module_import(python_type_module)
5✔
1208
        except NotImplementedError:
×
1209
            self.add_literal_import("typing", "Any")
×
1210
            column_python_type = "Any"
×
1211

1212
        if column.nullable:
5✔
1213
            self.add_literal_import("typing", "Optional")
5✔
1214
            column_python_type = f"Optional[{column_python_type}]"
5✔
1215
        return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
5✔
1216

1217
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1218
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
5✔
1219
            rendered = []
5✔
1220
            for attr in column_attrs:
5✔
1221
                if attr.model is relationship.source:
5✔
1222
                    rendered.append(attr.name)
5✔
1223
                else:
1224
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
×
1225

1226
            return "[" + ", ".join(rendered) + "]"
5✔
1227

1228
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
5✔
1229
            rendered = []
5✔
1230
            render_as_string = False
5✔
1231
            # Assume that column_attrs are all in relationship.source or none
1232
            for attr in column_attrs:
5✔
1233
                if attr.model is relationship.source:
5✔
1234
                    rendered.append(attr.name)
5✔
1235
                else:
1236
                    rendered.append(f"{attr.model.name}.{attr.name}")
5✔
1237
                    render_as_string = True
5✔
1238

1239
            if render_as_string:
5✔
1240
                return "'[" + ", ".join(rendered) + "]'"
5✔
1241
            else:
1242
                return "[" + ", ".join(rendered) + "]"
5✔
1243

1244
        def render_join(terms: list[JoinType]) -> str:
5✔
1245
            rendered_joins = []
5✔
1246
            for source, source_col, target, target_col in terms:
5✔
1247
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
5✔
1248
                if target.__class__ is Model:
5✔
1249
                    rendered += "c."
5✔
1250

1251
                rendered += str(target_col)
5✔
1252
                rendered_joins.append(rendered)
5✔
1253

1254
            if len(rendered_joins) > 1:
5✔
1255
                rendered = ", ".join(rendered_joins)
×
1256
                return f"and_({rendered})"
×
1257
            else:
1258
                return rendered_joins[0]
5✔
1259

1260
        # Render keyword arguments
1261
        kwargs: dict[str, Any] = {}
5✔
1262
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
5✔
1263
            if relationship.constraint.referred_table is relationship.source.table:
5✔
1264
                kwargs["uselist"] = False
5✔
1265

1266
        # Add the "secondary" keyword for many-to-many relationships
1267
        if relationship.association_table:
5✔
1268
            table_ref = relationship.association_table.table.name
5✔
1269
            if relationship.association_table.schema:
5✔
1270
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
5✔
1271

1272
            kwargs["secondary"] = repr(table_ref)
5✔
1273

1274
        if relationship.remote_side:
5✔
1275
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
5✔
1276

1277
        if relationship.foreign_keys:
5✔
1278
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
5✔
1279

1280
        if relationship.primaryjoin:
5✔
1281
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
5✔
1282

1283
        if relationship.secondaryjoin:
5✔
1284
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
5✔
1285

1286
        if relationship.backref:
5✔
1287
            kwargs["back_populates"] = repr(relationship.backref.name)
5✔
1288

1289
        rendered_relationship = render_callable(
5✔
1290
            "relationship", repr(relationship.target.name), kwargs=kwargs
1291
        )
1292

1293
        relationship_type: str
1294
        if relationship.type == RelationshipType.ONE_TO_MANY:
5✔
1295
            self.add_literal_import("typing", "List")
5✔
1296
            relationship_type = f"List['{relationship.target.name}']"
5✔
1297
        elif relationship.type in (
5✔
1298
            RelationshipType.ONE_TO_ONE,
1299
            RelationshipType.MANY_TO_ONE,
1300
        ):
1301
            relationship_type = f"'{relationship.target.name}'"
5✔
1302
        elif relationship.type == RelationshipType.MANY_TO_MANY:
5✔
1303
            self.add_literal_import("typing", "List")
5✔
1304
            relationship_type = f"List['{relationship.target.name}']"
5✔
1305
        else:
1306
            self.add_literal_import("typing", "Any")
×
1307
            relationship_type = "Any"
×
1308

1309
        return (
5✔
1310
            f"{relationship.name}: Mapped[{relationship_type}] "
1311
            f"= {rendered_relationship}"
1312
        )
1313

1314

1315
class DataclassGenerator(DeclarativeGenerator):
5✔
1316
    def __init__(
5✔
1317
        self,
1318
        metadata: MetaData,
1319
        bind: Connection | Engine,
1320
        options: Sequence[str],
1321
        *,
1322
        indentation: str = "    ",
1323
        base_class_name: str = "Base",
1324
        quote_annotations: bool = False,
1325
        metadata_key: str = "sa",
1326
    ):
1327
        super().__init__(
5✔
1328
            metadata,
1329
            bind,
1330
            options,
1331
            indentation=indentation,
1332
            base_class_name=base_class_name,
1333
        )
1334
        self.metadata_key: str = metadata_key
5✔
1335
        self.quote_annotations: bool = quote_annotations
5✔
1336

1337
    def generate_base(self) -> None:
5✔
1338
        self.base = Base(
5✔
1339
            literal_imports=[
1340
                LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1341
                LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1342
            ],
1343
            declarations=[
1344
                (
1345
                    f"class {self.base_class_name}(MappedAsDataclass, "
1346
                    "DeclarativeBase):"
1347
                ),
1348
                f"{self.indentation}pass",
1349
            ],
1350
            metadata_ref=f"{self.base_class_name}.metadata",
1351
        )
1352

1353

1354
class SQLModelGenerator(DeclarativeGenerator):
5✔
1355
    def __init__(
5✔
1356
        self,
1357
        metadata: MetaData,
1358
        bind: Connection | Engine,
1359
        options: Sequence[str],
1360
        *,
1361
        indentation: str = "    ",
1362
        base_class_name: str = "SQLModel",
1363
    ):
1364
        super().__init__(
5✔
1365
            metadata,
1366
            bind,
1367
            options,
1368
            indentation=indentation,
1369
            base_class_name=base_class_name,
1370
        )
1371

1372
    @property
5✔
1373
    def mapped_columns_supported(self) -> bool:
5✔
1374
        return False
5✔
1375

1376
    def generate_base(self) -> None:
5✔
1377
        self.base = Base(
5✔
1378
            literal_imports=[],
1379
            declarations=[],
1380
            metadata_ref="metadata",
1381
            table_metadata_declaration="metadata = MetaData()",
1382
        )
1383

1384
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
1385
        super(DeclarativeGenerator, self).collect_imports(models)
5✔
1386
        if any(isinstance(model, Model) for model in models):
5✔
1387
            self.add_literal_import("sqlalchemy", "MetaData")
5✔
1388

1389
        if any(isinstance(model, ModelClass) for model in models):
5✔
1390
            self.add_literal_import("sqlmodel", "SQLModel")
5✔
1391
            self.add_literal_import("sqlmodel", "Field")
5✔
1392

1393
    def collect_imports_for_model(self, model: Model) -> None:
5✔
1394
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
5✔
1395
        if isinstance(model, ModelClass):
5✔
1396
            for column_attr in model.columns:
5✔
1397
                if column_attr.column.nullable:
5✔
1398
                    self.add_literal_import("typing", "Optional")
5✔
1399
                    break
5✔
1400

1401
            if model.relationships:
5✔
1402
                self.add_literal_import("sqlmodel", "Relationship")
5✔
1403

1404
            for relationship_attr in model.relationships:
5✔
1405
                if relationship_attr.type in (
5✔
1406
                    RelationshipType.ONE_TO_MANY,
1407
                    RelationshipType.MANY_TO_MANY,
1408
                ):
1409
                    self.add_literal_import("typing", "List")
5✔
1410

1411
    def collect_imports_for_column(self, column: Column[Any]) -> None:
5✔
1412
        super().collect_imports_for_column(column)
5✔
1413
        try:
5✔
1414
            python_type = column.type.python_type
5✔
1415
        except NotImplementedError:
×
1416
            self.add_literal_import("typing", "Any")
×
1417
        else:
1418
            self.add_import(python_type)
5✔
1419

1420
    def render_module_variables(self, models: list[Model]) -> str:
5✔
1421
        declarations: list[str] = []
5✔
1422
        if any(not isinstance(model, ModelClass) for model in models):
5✔
1423
            if self.base.table_metadata_declaration is not None:
×
1424
                declarations.append(self.base.table_metadata_declaration)
×
1425

1426
        return "\n".join(declarations)
5✔
1427

1428
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1429
        if model.parent_class:
5✔
1430
            parent = model.parent_class.name
×
1431
        else:
1432
            parent = self.base_class_name
5✔
1433

1434
        superclass_part = f"({parent}, table=True)"
5✔
1435
        return f"class {model.name}{superclass_part}:"
5✔
1436

1437
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1438
        variables = []
5✔
1439

1440
        if model.table.name != model.name.lower():
5✔
1441
            variables.append(f"__tablename__ = {model.table.name!r}")
5✔
1442

1443
        # Render constraints and indexes as __table_args__
1444
        table_args = self.render_table_args(model.table)
5✔
1445
        if table_args:
5✔
1446
            variables.append(f"__table_args__ = {table_args}")
5✔
1447

1448
        return "\n".join(variables)
5✔
1449

1450
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1451
        column = column_attr.column
5✔
1452
        try:
5✔
1453
            python_type = column.type.python_type
5✔
1454
        except NotImplementedError:
×
1455
            python_type_name = "Any"
×
1456
        else:
1457
            python_type_name = python_type.__name__
5✔
1458

1459
        kwargs: dict[str, Any] = {}
5✔
1460
        if (
5✔
1461
            column.autoincrement and column.name in column.table.primary_key
1462
        ) or column.nullable:
1463
            self.add_literal_import("typing", "Optional")
5✔
1464
            kwargs["default"] = None
5✔
1465
            python_type_name = f"Optional[{python_type_name}]"
5✔
1466

1467
        rendered_column = self.render_column(column, True)
5✔
1468
        kwargs["sa_column"] = f"{rendered_column}"
5✔
1469
        rendered_field = render_callable("Field", kwargs=kwargs)
5✔
1470
        return f"{column_attr.name}: {python_type_name} = {rendered_field}"
5✔
1471

1472
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1473
        rendered = super().render_relationship(relationship).partition(" = ")[2]
5✔
1474
        args = self.render_relationship_args(rendered)
5✔
1475
        kwargs: dict[str, Any] = {}
5✔
1476
        annotation = repr(relationship.target.name)
5✔
1477

1478
        if relationship.type in (
5✔
1479
            RelationshipType.ONE_TO_MANY,
1480
            RelationshipType.MANY_TO_MANY,
1481
        ):
1482
            self.add_literal_import("typing", "List")
5✔
1483
            annotation = f"List[{annotation}]"
5✔
1484
        else:
1485
            self.add_literal_import("typing", "Optional")
5✔
1486
            annotation = f"Optional[{annotation}]"
5✔
1487

1488
        rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
5✔
1489
        return f"{relationship.name}: {annotation} = {rendered_field}"
5✔
1490

1491
    def render_relationship_args(self, arguments: str) -> list[str]:
5✔
1492
        argument_list = arguments.split(",")
5✔
1493
        # delete ')' and ' ' from args
1494
        argument_list[-1] = argument_list[-1][:-1]
5✔
1495
        argument_list = [argument[1:] for argument in argument_list]
5✔
1496

1497
        rendered_args: list[str] = []
5✔
1498
        for arg in argument_list:
5✔
1499
            if "back_populates" in arg:
5✔
1500
                rendered_args.append(arg)
5✔
1501
            if "uselist=False" in arg:
5✔
1502
                rendered_args.append("sa_relationship_kwargs={'uselist': False}")
5✔
1503

1504
        return rendered_args
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc