• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 27097745110

07 Jun 2026 04:07PM UTC coverage: 97.794% (-0.05%) from 97.839%
27097745110

Pull #480

github

web-flow
Merge 32c2eed01 into 42b3b39a3
Pull Request #480: Preserve dialect-specific ARRAY types instead of adapting to generic ARRAY

7 of 7 new or added lines in 2 files covered. (100.0%)

1 existing line in 1 file now uncovered.

1862 of 1904 relevant lines covered (97.79%)

3.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.91
/src/sqlacodegen/generators.py
1
from __future__ import annotations
4✔
2

3
import inspect
4✔
4
import re
4✔
5
import sys
4✔
6
from abc import ABCMeta, abstractmethod
4✔
7
from collections import defaultdict
4✔
8
from collections.abc import Collection, Iterable, Mapping, Sequence
4✔
9
from dataclasses import dataclass
4✔
10
from decimal import Decimal
4✔
11
from importlib import import_module
4✔
12
from inspect import Parameter
4✔
13
from itertools import count
4✔
14
from keyword import iskeyword
4✔
15
from pprint import pformat
4✔
16
from textwrap import indent
4✔
17
from typing import Any, ClassVar, Literal, cast
4✔
18

19
import inflect
4✔
20
import sqlalchemy
4✔
21
from sqlalchemy import (
4✔
22
    ARRAY,
23
    Boolean,
24
    CheckConstraint,
25
    Column,
26
    Computed,
27
    Constraint,
28
    DefaultClause,
29
    Enum,
30
    ForeignKey,
31
    ForeignKeyConstraint,
32
    Identity,
33
    Index,
34
    MetaData,
35
    PrimaryKeyConstraint,
36
    String,
37
    Table,
38
    Text,
39
    TypeDecorator,
40
    UniqueConstraint,
41
)
42
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
4✔
43
from sqlalchemy.engine import Connection, Engine
4✔
44
from sqlalchemy.exc import CompileError
4✔
45
from sqlalchemy.sql.elements import TextClause
4✔
46
from sqlalchemy.sql.type_api import UserDefinedType
4✔
47
from sqlalchemy.types import TypeEngine
4✔
48

49
from .models import (
4✔
50
    ColumnAttribute,
51
    JoinType,
52
    Model,
53
    ModelClass,
54
    RelationshipAttribute,
55
    RelationshipType,
56
)
57
from .utils import (
4✔
58
    decode_postgresql_sequence,
59
    get_column_names,
60
    get_common_fk_constraints,
61
    get_compiled_expression,
62
    get_constraint_sort_key,
63
    get_stdlib_module_names,
64
    qualified_table_name,
65
    render_callable,
66
    uses_default_name,
67
)
68

69
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
4✔
70
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
4✔
71
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
4✔
72
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
4✔
73
_re_invalid_identifier = re.compile(r"(?u)\W")
4✔
74

75

76
@dataclass
4✔
77
class LiteralImport:
4✔
78
    pkgname: str
4✔
79
    name: str
4✔
80

81

82
@dataclass
4✔
83
class Base:
4✔
84
    """Representation of MetaData for Tables, respectively Base for classes"""
85

86
    literal_imports: list[LiteralImport]
4✔
87
    declarations: list[str]
4✔
88
    metadata_ref: str
4✔
89
    decorator: str | None = None
4✔
90
    table_metadata_declaration: str | None = None
4✔
91

92

93
class CodeGenerator(metaclass=ABCMeta):
4✔
94
    valid_options: ClassVar[set[str]] = set()
4✔
95

96
    def __init__(
4✔
97
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
98
    ):
99
        self.metadata: MetaData = metadata
4✔
100
        self.bind: Connection | Engine = bind
4✔
101
        self.options: set[str] = set(options)
4✔
102

103
        # Validate options
104
        invalid_options = {opt for opt in options if opt not in self.valid_options}
4✔
105
        if invalid_options:
4✔
106
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
107

108
    @property
4✔
109
    @abstractmethod
4✔
110
    def views_supported(self) -> bool:
4✔
111
        pass
×
112

113
    @abstractmethod
4✔
114
    def generate(self) -> str:
4✔
115
        """
116
        Generate the code for the given metadata.
117
        .. note:: May modify the metadata.
118
        """
119

120

121
@dataclass(eq=False)
4✔
122
class TablesGenerator(CodeGenerator):
4✔
123
    valid_options: ClassVar[set[str]] = {
4✔
124
        "noindexes",
125
        "noconstraints",
126
        "nocomments",
127
        "nonativeenums",
128
        "nosyntheticenums",
129
        "include_dialect_options",
130
        "keep_dialect_types",
131
    }
132
    stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
4✔
133

134
    def __init__(
4✔
135
        self,
136
        metadata: MetaData,
137
        bind: Connection | Engine,
138
        options: Sequence[str],
139
        *,
140
        indentation: str = "    ",
141
    ):
142
        super().__init__(metadata, bind, options)
4✔
143
        self.indentation: str = indentation
4✔
144
        self.imports: dict[str, set[str]] = defaultdict(set)
4✔
145
        self.module_imports: set[str] = set()
4✔
146

147
        # Render SchemaItem.info and dialect kwargs (Table/Column) into output
148
        self.include_dialect_options_and_info: bool = (
4✔
149
            "include_dialect_options" in self.options
150
        )
151
        # Keep dialect-specific types instead of adapting to generic SQLAlchemy types
152
        self.keep_dialect_types: bool = "keep_dialect_types" in self.options
4✔
153

154
        # Track Python enum classes: maps (table_name, column_name) -> enum_class_name
155
        self.enum_classes: dict[tuple[str, str], str] = {}
4✔
156
        # Track enum values: maps enum_class_name -> list of values
157
        self.enum_values: dict[str, list[str]] = {}
4✔
158

159
    @property
4✔
160
    def views_supported(self) -> bool:
4✔
161
        return True
×
162

163
    def generate_base(self) -> None:
4✔
164
        self.base = Base(
4✔
165
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
166
            declarations=["metadata = MetaData()"],
167
            metadata_ref="metadata",
168
        )
169

170
    def generate(self) -> str:
4✔
171
        self.generate_base()
4✔
172

173
        sections: list[str] = []
4✔
174

175
        # Remove unwanted elements from the metadata
176
        for table in list(self.metadata.tables.values()):
4✔
177
            if self.should_ignore_table(table):
4✔
178
                self.metadata.remove(table)
×
179
                continue
×
180

181
            if "noindexes" in self.options:
4✔
182
                table.indexes.clear()
4✔
183

184
            if "noconstraints" in self.options:
4✔
185
                table.constraints.clear()
4✔
186

187
            if "nocomments" in self.options:
4✔
188
                table.comment = None
4✔
189

190
            for column in table.columns:
4✔
191
                if "nocomments" in self.options:
4✔
192
                    column.comment = None
4✔
193

194
        # Use information from column constraints to figure out the intended column
195
        # types
196
        for table in self.metadata.tables.values():
4✔
197
            self.fix_column_types(table)
4✔
198

199
        # Generate the models
200
        models: list[Model] = self.generate_models()
4✔
201

202
        # Render module level variables
203
        if variables := self.render_module_variables(models):
4✔
204
            sections.append(variables + "\n")
4✔
205

206
        # Render enum classes
207
        if enum_classes := self.render_enum_classes():
4✔
208
            sections.append(enum_classes + "\n")
4✔
209

210
        # Render models
211
        if rendered_models := self.render_models(models):
4✔
212
            sections.append(rendered_models)
4✔
213

214
        # Render collected imports
215
        groups = self.group_imports()
4✔
216
        if imports := "\n\n".join(
4✔
217
            "\n".join(line for line in group) for group in groups
218
        ):
219
            sections.insert(0, imports)
4✔
220

221
        return "\n\n".join(sections) + "\n"
4✔
222

223
    def collect_imports(self, models: Iterable[Model]) -> None:
4✔
224
        for literal_import in self.base.literal_imports:
4✔
225
            self.add_literal_import(literal_import.pkgname, literal_import.name)
4✔
226

227
        for model in models:
4✔
228
            self.collect_imports_for_model(model)
4✔
229

230
    def collect_imports_for_model(self, model: Model) -> None:
4✔
231
        if model.__class__ is Model:
4✔
232
            self.add_import(Table)
4✔
233

234
        for column in model.table.c:
4✔
235
            self.collect_imports_for_column(column)
4✔
236

237
        for constraint in model.table.constraints:
4✔
238
            self.collect_imports_for_constraint(constraint)
4✔
239

240
        for index in model.table.indexes:
4✔
241
            self.collect_imports_for_constraint(index)
4✔
242

243
    def collect_imports_for_column(self, column: Column[Any]) -> None:
4✔
244
        self.add_import(column.type)
4✔
245

246
        if isinstance(column.type, ARRAY):
4✔
247
            self.add_import(column.type.item_type.__class__)
4✔
248
        elif isinstance(column.type, (JSONB, JSON)):
4✔
249
            if (
4✔
250
                not isinstance(column.type.astext_type, Text)
251
                or column.type.astext_type.length is not None
252
            ):
253
                self.add_import(column.type.astext_type)
4✔
254
        elif isinstance(column.type, DOMAIN):
4✔
255
            self.add_import(column.type.data_type.__class__)
4✔
256

257
        if column.default:
4✔
258
            self.add_import(column.default)
4✔
259

260
        if column.server_default:
4✔
261
            if isinstance(column.server_default, (Computed, Identity)):
4✔
262
                self.add_import(column.server_default)
4✔
263
            elif isinstance(column.server_default, DefaultClause):
4✔
264
                self.add_literal_import("sqlalchemy", "text")
4✔
265

266
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
4✔
267
        if isinstance(constraint, Index):
4✔
268
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
4✔
269
                self.add_literal_import("sqlalchemy", "Index")
4✔
270
        elif isinstance(constraint, PrimaryKeyConstraint):
4✔
271
            if not uses_default_name(constraint):
4✔
272
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
4✔
273
        elif isinstance(constraint, UniqueConstraint):
4✔
274
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
4✔
275
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
4✔
276
        elif isinstance(constraint, ForeignKeyConstraint):
4✔
277
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
4✔
278
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
4✔
279
            else:
280
                self.add_import(ForeignKey)
4✔
281
        else:
282
            self.add_import(constraint)
4✔
283

284
    def add_import(self, obj: Any) -> None:
4✔
285
        # Don't store builtin imports
286
        if getattr(obj, "__module__", "builtins") == "builtins":
4✔
287
            return
×
288

289
        type_ = type(obj) if not isinstance(obj, type) else obj
4✔
290
        pkgname = type_.__module__
4✔
291

292
        # The column types have already been adapted towards generic types if possible,
293
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
294
        # use that rather than the generic sqlalchemy type as it might have different
295
        # constructor parameters.
296
        if pkgname.startswith("sqlalchemy.dialects."):
4✔
297
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
4✔
298
            dialect_pkg = import_module(dialect_pkgname)
4✔
299

300
            if type_.__name__ in dialect_pkg.__all__:
4✔
301
                pkgname = dialect_pkgname
4✔
302
        elif type_ is getattr(sqlalchemy, type_.__name__, None):
4✔
303
            pkgname = "sqlalchemy"
4✔
304
        else:
305
            pkgname = type_.__module__
4✔
306

307
        self.add_literal_import(pkgname, type_.__name__)
4✔
308

309
    def add_literal_import(self, pkgname: str, name: str) -> None:
4✔
310
        names = self.imports.setdefault(pkgname, set())
4✔
311
        names.add(name)
4✔
312

313
    def remove_literal_import(self, pkgname: str, name: str) -> None:
4✔
314
        names = self.imports.setdefault(pkgname, set())
4✔
315
        if name in names:
4✔
316
            names.remove(name)
4✔
317

318
    def add_module_import(self, pgkname: str) -> None:
4✔
319
        self.module_imports.add(pgkname)
4✔
320

321
    def group_imports(self) -> list[list[str]]:
4✔
322
        future_imports: list[str] = []
4✔
323
        stdlib_imports: list[str] = []
4✔
324
        thirdparty_imports: list[str] = []
4✔
325

326
        def get_collection(package: str) -> list[str]:
4✔
327
            collection = thirdparty_imports
4✔
328
            if package == "__future__":
4✔
329
                collection = future_imports
×
330
            elif package in self.stdlib_module_names:
4✔
331
                collection = stdlib_imports
4✔
332
            elif package in sys.modules:
4✔
333
                if "site-packages" not in (sys.modules[package].__file__ or ""):
4✔
334
                    collection = stdlib_imports
4✔
335
            return collection
4✔
336

337
        for package in sorted(self.imports):
4✔
338
            imports = ", ".join(sorted(self.imports[package]))
4✔
339

340
            collection = get_collection(package)
4✔
341
            collection.append(f"from {package} import {imports}")
4✔
342

343
        for module in sorted(self.module_imports):
4✔
344
            collection = get_collection(module)
4✔
345
            collection.append(f"import {module}")
4✔
346

347
        return [
4✔
348
            group
349
            for group in (future_imports, stdlib_imports, thirdparty_imports)
350
            if group
351
        ]
352

353
    def generate_models(self) -> list[Model]:
4✔
354
        models = [Model(table) for table in self.metadata.sorted_tables]
4✔
355

356
        # Collect the imports
357
        self.collect_imports(models)
4✔
358

359
        # Generate names for models
360
        global_names = {
4✔
361
            name for namespace in self.imports.values() for name in namespace
362
        }
363
        for model in models:
4✔
364
            self.generate_model_name(model, global_names)
4✔
365
            global_names.add(model.name)
4✔
366

367
        return models
4✔
368

369
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
4✔
370
        preferred_name = f"t_{model.table.name}"
4✔
371
        model.name = self.find_free_name(preferred_name, global_names)
4✔
372

373
    def render_module_variables(self, models: list[Model]) -> str:
4✔
374
        declarations = self.base.declarations
4✔
375

376
        if any(not isinstance(model, ModelClass) for model in models):
4✔
377
            if self.base.table_metadata_declaration is not None:
4✔
378
                declarations.append(self.base.table_metadata_declaration)
×
379

380
        return "\n".join(declarations)
4✔
381

382
    def render_models(self, models: list[Model]) -> str:
4✔
383
        rendered: list[str] = []
4✔
384
        for model in models:
4✔
385
            rendered_table = self.render_table(model.table)
4✔
386
            rendered.append(f"{model.name} = {rendered_table}")
4✔
387

388
        return "\n\n".join(rendered)
4✔
389

390
    def render_table(self, table: Table) -> str:
4✔
391
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
4✔
392
        kwargs: dict[str, object] = {}
4✔
393
        for column in table.columns:
4✔
394
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
395
            # Table.columns
396
            args.append(self.render_column(column, True, is_table=True))
4✔
397

398
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
4✔
399
            if uses_default_name(constraint):
4✔
400
                if isinstance(constraint, PrimaryKeyConstraint):
4✔
401
                    continue
4✔
402
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
4✔
403
                    if len(constraint.columns) == 1:
4✔
404
                        continue
4✔
405

406
            args.append(self.render_constraint(constraint))
4✔
407

408
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
4✔
409
            # One-column indexes should be rendered as index=True on columns
410
            if len(index.columns) > 1 or not uses_default_name(index):
4✔
411
                args.append(self.render_index(index))
4✔
412

413
        if table.schema:
4✔
414
            kwargs["schema"] = repr(table.schema)
4✔
415

416
        table_comment = getattr(table, "comment", None)
4✔
417
        if table_comment:
4✔
418
            kwargs["comment"] = repr(table.comment)
4✔
419

420
        # add info + dialect kwargs for callable context (opt-in)
421
        if self.include_dialect_options_and_info:
4✔
422
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=False)
4✔
423

424
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
4✔
425

426
    def render_index(self, index: Index) -> str:
4✔
427
        extra_args = [repr(col.name) for col in index.columns]
4✔
428
        kwargs = {
4✔
429
            key: repr(value) if isinstance(value, str) else value
430
            for key, value in sorted(index.kwargs.items(), key=lambda item: item[0])
431
            if value not in ([], {})
432
        }
433
        if index.unique:
4✔
434
            kwargs["unique"] = True
4✔
435

436
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
4✔
437

438
    # TODO find better solution for is_table
439
    def render_column(
4✔
440
        self, column: Column[Any], show_name: bool, is_table: bool = False
441
    ) -> str:
442
        args = []
4✔
443
        kwargs: dict[str, Any] = {}
4✔
444
        kwarg = []
4✔
445
        is_part_of_composite_pk = (
4✔
446
            column.primary_key and len(column.table.primary_key) > 1
447
        )
448
        dedicated_fks = [
4✔
449
            c
450
            for c in column.foreign_keys
451
            if c.constraint
452
            and len(c.constraint.columns) == 1
453
            and uses_default_name(c.constraint)
454
        ]
455
        is_unique = any(
4✔
456
            isinstance(c, UniqueConstraint)
457
            and set(c.columns) == {column}
458
            and uses_default_name(c)
459
            for c in column.table.constraints
460
        )
461
        is_unique = is_unique or any(
4✔
462
            i.unique and set(i.columns) == {column} and uses_default_name(i)
463
            for i in column.table.indexes
464
        )
465
        is_primary = (
4✔
466
            any(
467
                isinstance(c, PrimaryKeyConstraint)
468
                and column.name in c.columns
469
                and uses_default_name(c)
470
                for c in column.table.constraints
471
            )
472
            or column.primary_key
473
        )
474
        is_autoincrement = (
4✔
475
            isinstance(column.autoincrement, bool) and column.autoincrement
476
        )
477
        has_index = any(
4✔
478
            set(i.columns) == {column} and uses_default_name(i)
479
            for i in column.table.indexes
480
        )
481

482
        if show_name:
4✔
483
            args.append(repr(column.name))
4✔
484

485
        # Render the column type if there are no foreign keys on it or any of them
486
        # points back to itself
487
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
4✔
488
            args.append(self.render_column_type(column))
4✔
489

490
        for fk in dedicated_fks:
4✔
491
            args.append(self.render_constraint(fk))
4✔
492

493
        if column.default:
4✔
494
            args.append(repr(column.default))
4✔
495

496
        if column.key != column.name:
4✔
497
            kwargs["key"] = column.key
×
498
        if is_primary:
4✔
499
            kwargs["primary_key"] = True
4✔
500
        if is_autoincrement and is_primary:
4✔
501
            kwargs["autoincrement"] = True
4✔
502
        if not column.nullable and not column.primary_key:
4✔
503
            kwargs["nullable"] = False
4✔
504
        if column.nullable and is_part_of_composite_pk:
4✔
505
            kwargs["nullable"] = True
4✔
506

507
        if is_unique:
4✔
508
            column.unique = True
4✔
509
            kwargs["unique"] = True
4✔
510
        if has_index:
4✔
511
            column.index = True
4✔
512
            kwarg.append("index")
4✔
513
            kwargs["index"] = True
4✔
514

515
        if isinstance(column.server_default, DefaultClause):
4✔
516
            kwargs["server_default"] = render_callable(
4✔
517
                "text", repr(cast(TextClause, column.server_default.arg).text)
518
            )
519
        elif isinstance(column.server_default, Computed):
4✔
520
            expression = str(column.server_default.sqltext)
4✔
521

522
            computed_kwargs = {}
4✔
523
            if column.server_default.persisted is not None:
4✔
524
                computed_kwargs["persisted"] = column.server_default.persisted
4✔
525

526
            args.append(
4✔
527
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
528
            )
529
        elif isinstance(column.server_default, Identity):
4✔
530
            identity = column.server_default
4✔
531
            identity_kwargs: dict[str, Any] = {}
4✔
532

533
            for name, param in inspect.signature(Identity).parameters.items():
4✔
534
                if name == "self" or param.kind in (
4✔
535
                    Parameter.VAR_POSITIONAL,
536
                    Parameter.VAR_KEYWORD,
537
                ):
538
                    continue
×
539

540
                value = getattr(identity, name, None)
4✔
541
                if value is None:
4✔
542
                    continue
4✔
543

544
                if isinstance(value, Decimal):
4✔
545
                    value = int(value)
4✔
546

547
                if param.default is not Parameter.empty and value == param.default:
4✔
548
                    continue
4✔
549

550
                identity_kwargs[name] = value
4✔
551

552
            args.append(render_callable("Identity", kwargs=identity_kwargs))
4✔
553
        elif column.server_default:
4✔
554
            kwargs["server_default"] = repr(column.server_default)
×
555

556
        comment = getattr(column, "comment", None)
4✔
557
        if comment:
4✔
558
            kwargs["comment"] = repr(comment)
4✔
559

560
        # add column info + dialect kwargs for callable context (opt-in)
561
        if self.include_dialect_options_and_info:
4✔
562
            self._add_dialect_kwargs_and_info(column, kwargs, values_for_dict=False)
4✔
563

564
        return self.render_column_callable(is_table, *args, **kwargs)
4✔
565

566
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
4✔
567
        if is_table:
4✔
568
            self.add_import(Column)
4✔
569
            return render_callable("Column", *args, kwargs=kwargs)
4✔
570
        else:
571
            return render_callable("mapped_column", *args, kwargs=kwargs)
4✔
572

573
    def _render_column_type_value(self, value: Any) -> str:
4✔
574
        if isinstance(value, (JSONB, JSON)):
4✔
575
            # Remove astext_type if it's the default
576
            if isinstance(value.astext_type, Text) and value.astext_type.length is None:
4✔
577
                value.astext_type = None  # type: ignore[assignment]
4✔
578
            else:
579
                self.add_import(Text)
4✔
580

581
        if isinstance(value, TextClause):
4✔
582
            self.add_literal_import("sqlalchemy", "text")
4✔
583
            return render_callable("text", repr(value.text))
4✔
584

585
        return repr(value)
4✔
586

587
    def _collect_inherited_init_kwargs(
4✔
588
        self,
589
        column_type: Any,
590
        init_sig: inspect.Signature,
591
        seen_param_names: set[str],
592
        missing: object,
593
    ) -> dict[str, str]:
594
        has_var_keyword = any(
4✔
595
            param.kind is Parameter.VAR_KEYWORD
596
            for param in init_sig.parameters.values()
597
        )
598
        has_var_positional = any(
4✔
599
            param.kind is Parameter.VAR_POSITIONAL
600
            for param in init_sig.parameters.values()
601
        )
602
        if not has_var_keyword or has_var_positional:
4✔
603
            return {}
4✔
604

605
        inherited_kwargs: dict[str, str] = {}
4✔
606
        for supercls in column_type.__class__.__mro__[1:]:
4✔
607
            if supercls is object:
4✔
608
                break
4✔
609

610
            try:
4✔
611
                super_sig = inspect.signature(supercls.__init__)
4✔
612
            except (TypeError, ValueError):
×
613
                continue
×
614

615
            for super_param in list(super_sig.parameters.values())[1:]:
4✔
616
                if super_param.name.startswith("_"):
4✔
617
                    continue
4✔
618

619
                if super_param.kind in (
4✔
620
                    Parameter.POSITIONAL_ONLY,
621
                    Parameter.VAR_POSITIONAL,
622
                    Parameter.VAR_KEYWORD,
623
                ):
624
                    continue
4✔
625

626
                if super_param.name in seen_param_names:
4✔
627
                    continue
4✔
628

629
                seen_param_names.add(super_param.name)
4✔
630
                value = getattr(column_type, super_param.name, missing)
4✔
631
                if value is missing:
4✔
632
                    continue
4✔
633

634
                default = super_param.default
4✔
635
                if default is not Parameter.empty and value == default:
4✔
636
                    continue
4✔
637

638
                inherited_kwargs[super_param.name] = self._render_column_type_value(
4✔
639
                    value
640
                )
641

642
        return inherited_kwargs
4✔
643

644
    def render_column_type(self, column: Column[Any]) -> str:
4✔
645
        column_type = column.type
4✔
646
        # Check if this is an enum column with a Python enum class
647
        if isinstance(column_type, Enum) and column is not None:
4✔
648
            if enum_class_name := self.enum_classes.get(
4✔
649
                (column.table.name, column.name)
650
            ):
651
                # Import SQLAlchemy Enum (will be handled in collect_imports)
652
                self.add_import(Enum)
4✔
653
                extra_kwargs = ""
4✔
654
                if column_type.name is not None:
4✔
655
                    extra_kwargs += f", name={column_type.name!r}"
4✔
656

657
                if column_type.schema is not None:
4✔
658
                    extra_kwargs += f", schema={column_type.schema!r}"
4✔
659

660
                return f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
4✔
661

662
        args = []
4✔
663
        kwargs: dict[str, Any] = {}
4✔
664

665
        # Check if this is an ARRAY column with an Enum item type mapped to a Python enum class
666
        if isinstance(column_type, ARRAY) and isinstance(column_type.item_type, Enum):
4✔
667
            if enum_class_name := self.enum_classes.get(
4✔
668
                (column.table.name, column.name)
669
            ):
670
                self.add_import(ARRAY)
4✔
671
                self.add_import(Enum)
4✔
672
                extra_kwargs = ""
4✔
673
                if column_type.item_type.name is not None:
4✔
674
                    extra_kwargs += f", name={column_type.item_type.name!r}"
4✔
675

676
                if column_type.item_type.schema is not None:
4✔
677
                    extra_kwargs += f", schema={column_type.item_type.schema!r}"
4✔
678

679
                rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
4✔
680
                if column_type.dimensions is not None:
4✔
681
                    kwargs["dimensions"] = repr(column_type.dimensions)
4✔
682

683
                return render_callable("ARRAY", rendered_enum, kwargs=kwargs)
4✔
684

685
        sig = inspect.signature(column_type.__class__.__init__)
4✔
686
        defaults = {param.name: param.default for param in sig.parameters.values()}
4✔
687
        missing = object()
4✔
688
        use_kwargs = False
4✔
689
        seen_param_names: set[str] = set()
4✔
690

691
        for param in list(sig.parameters.values())[1:]:
4✔
692
            # Remove annoyances like _warn_on_bytestring
693
            if param.name.startswith("_"):
4✔
694
                continue
4✔
695
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
4✔
696
                use_kwargs = True
4✔
697
                continue
4✔
698

699
            seen_param_names.add(param.name)
4✔
700
            value = getattr(column_type, param.name, missing)
4✔
701
            default = defaults.get(param.name, missing)
4✔
702
            if value is missing or value == default:
4✔
703
                use_kwargs = True
4✔
704
                continue
4✔
705

706
            rendered_value = self._render_column_type_value(value)
4✔
707
            if use_kwargs:
4✔
708
                kwargs[param.name] = rendered_value
4✔
709
            else:
710
                args.append(rendered_value)
4✔
711

712
        kwargs.update(
4✔
713
            self._collect_inherited_init_kwargs(
714
                column_type, sig, seen_param_names, missing
715
            )
716
        )
717

718
        vararg = next(
4✔
719
            (
720
                param.name
721
                for param in sig.parameters.values()
722
                if param.kind is Parameter.VAR_POSITIONAL
723
            ),
724
            None,
725
        )
726
        if vararg and hasattr(column_type, vararg):
4✔
727
            varargs_repr = [repr(arg) for arg in getattr(column_type, vararg)]
4✔
728
            args.extend(varargs_repr)
4✔
729

730
        # These arguments cannot be autodetected from the Enum initializer
731
        if isinstance(column_type, Enum):
4✔
732
            for colname in "name", "schema":
4✔
733
                if (value := getattr(column_type, colname)) is not None:
4✔
734
                    kwargs[colname] = repr(value)
4✔
735

736
        if isinstance(column_type, (JSONB, JSON)):
4✔
737
            # Remove astext_type if it's the default
738
            if (
4✔
739
                isinstance(column_type.astext_type, Text)
740
                and column_type.astext_type.length is None
741
            ):
742
                del kwargs["astext_type"]
4✔
743

744
        if args or kwargs:
4✔
745
            return render_callable(column_type.__class__.__name__, *args, kwargs=kwargs)
4✔
746
        else:
747
            return column_type.__class__.__name__
4✔
748

749
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
4✔
750
        def add_fk_options(*opts: Any) -> None:
4✔
751
            args.extend(repr(opt) for opt in opts)
4✔
752
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
4✔
753
                value = getattr(constraint, attr, None)
4✔
754
                if value:
4✔
755
                    kwargs[attr] = repr(value)
4✔
756

757
        args: list[str] = []
4✔
758
        kwargs: dict[str, Any] = {}
4✔
759
        if isinstance(constraint, ForeignKey):
4✔
760
            remote_column = (
4✔
761
                f"{constraint.column.table.fullname}.{constraint.column.name}"
762
            )
763
            add_fk_options(remote_column)
4✔
764
        elif isinstance(constraint, ForeignKeyConstraint):
4✔
765
            local_columns = get_column_names(constraint)
4✔
766
            remote_columns = [
4✔
767
                f"{fk.column.table.fullname}.{fk.column.name}"
768
                for fk in constraint.elements
769
            ]
770
            add_fk_options(local_columns, remote_columns)
4✔
771
        elif isinstance(constraint, CheckConstraint):
4✔
772
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
4✔
773
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
4✔
774
            args.extend(repr(col.name) for col in constraint.columns)
4✔
775
        else:
776
            raise TypeError(
×
777
                f"Cannot render constraint of type {constraint.__class__.__name__}"
778
            )
779

780
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
4✔
781
            kwargs["name"] = repr(constraint.name)
4✔
782

783
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
4✔
784

785
    def _add_dialect_kwargs_and_info(
4✔
786
        self, obj: Any, target_kwargs: dict[str, object], *, values_for_dict: bool
787
    ) -> None:
788
        """
789
        Merge SchemaItem-like object's .info and .dialect_kwargs into target_kwargs.
790
        - values_for_dict=True: keep raw values so pretty-printer emits repr() (for __table_args__ dict)
791
        - values_for_dict=False: set values to repr() strings (for callable kwargs)
792
        """
793
        info_dict = getattr(obj, "info", None)
4✔
794
        if info_dict:
4✔
795
            target_kwargs["info"] = info_dict if values_for_dict else repr(info_dict)
4✔
796

797
        dialect_keys: list[str]
798
        try:
4✔
799
            dialect_keys = sorted(getattr(obj, "dialect_kwargs"))
4✔
800
        except Exception:
×
801
            return
×
802

803
        dialect_kwargs = getattr(obj, "dialect_kwargs", {})
4✔
804
        for key in dialect_keys:
4✔
805
            try:
4✔
806
                value = dialect_kwargs[key]
4✔
807
            except Exception:
×
808
                continue
×
809

810
            if isinstance(value, list | dict) and not value:
4✔
811
                continue
4✔
812

813
            # Render values:
814
            # - callable context (values_for_dict=False): produce a string expression.
815
            #   primitives use repr(value); custom objects stringify then repr().
816
            # - dict context (values_for_dict=True): pass raw primitives / str;
817
            #   custom objects become str(value) so pformat quotes them.
818
            if values_for_dict:
4✔
819
                if isinstance(value, type(None) | bool | int | float):
4✔
820
                    target_kwargs[key] = value
×
821
                elif isinstance(value, str | dict | list):
4✔
822
                    target_kwargs[key] = value
4✔
823
                else:
824
                    target_kwargs[key] = str(value)
4✔
825
            else:
826
                if isinstance(
4✔
827
                    value, type(None) | bool | int | float | str | dict | list
828
                ):
829
                    target_kwargs[key] = repr(value)
4✔
830
                else:
831
                    target_kwargs[key] = repr(str(value))
4✔
832

833
    def should_ignore_table(self, table: Table) -> bool:
4✔
834
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
835
        # tables
836
        return table.name in ("alembic_version", "migrate_version")
4✔
837

838
    def find_free_name(
4✔
839
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
840
    ) -> str:
841
        """
842
        Generate an attribute name that does not clash with other local or global names.
843
        """
844
        name = name.strip()
4✔
845
        assert name, "Identifier cannot be empty"
4✔
846
        name = _re_invalid_identifier.sub("_", name)
4✔
847
        if name[0].isdigit():
4✔
848
            name = "_" + name
4✔
849
        elif iskeyword(name) or name == "metadata":
4✔
850
            name += "_"
4✔
851

852
        original = name
4✔
853
        for i in count():
4✔
854
            if name not in global_names and name not in local_names:
4✔
855
                break
4✔
856

857
            name = original + (str(i) if i else "_")
4✔
858

859
        return name
4✔
860

861
    def _enum_name_to_class_name(self, enum_name: str) -> str:
4✔
862
        """Convert a database enum name to a Python class name (PascalCase)."""
863
        return "".join(part.capitalize() for part in enum_name.split("_") if part)
4✔
864

865
    def _create_enum_class(
4✔
866
        self, table_name: str, column_name: str, values: list[str]
867
    ) -> str:
868
        """
869
        Create a Python enum class name and register it.
870

871
        Returns the enum class name to use in generated code.
872
        """
873
        # Generate enum class name from table and column names
874
        # Convert to PascalCase: user_status -> UserStatus
875
        base_name = "".join(
4✔
876
            part.capitalize()
877
            for part in table_name.split("_") + column_name.split("_")
878
            if part
879
        )
880

881
        # Ensure uniqueness
882
        enum_class_name = base_name
4✔
883
        for counter in count(1):
4✔
884
            if enum_class_name not in self.enum_values:
4✔
885
                break
4✔
886

887
            # Check if it's the same enum (same values)
888
            if self.enum_values[enum_class_name] == values:
4✔
889
                # Reuse existing enum class
890
                return enum_class_name
4✔
891

892
            enum_class_name = f"{base_name}{counter}"
4✔
893

894
        # Register the new enum class
895
        self.enum_values[enum_class_name] = values
4✔
896
        return enum_class_name
4✔
897

898
    def render_enum_classes(self) -> str:
4✔
899
        """Render Python enum class definitions."""
900
        if not self.enum_values:
4✔
901
            return ""
4✔
902

903
        self.add_module_import("enum")
4✔
904

905
        enum_defs = []
4✔
906
        for enum_class_name, values in sorted(self.enum_values.items()):
4✔
907
            # Create enum members with valid Python identifiers
908
            members = []
4✔
909
            for value in values:
4✔
910
                # Unescape SQL escape sequences (e.g., \' -> ')
911
                # The value from the CHECK constraint has SQL escaping
912
                unescaped_value = value.replace("\\'", "'").replace("\\\\", "\\")
4✔
913

914
                # Create a valid identifier from the enum value
915
                member_name = _re_invalid_identifier.sub("_", unescaped_value).upper()
4✔
916
                if not member_name:
4✔
917
                    member_name = "EMPTY"
×
918
                elif member_name[0].isdigit():
4✔
919
                    member_name = "_" + member_name
×
920
                elif iskeyword(member_name):
4✔
921
                    member_name += "_"
×
922
                #
923
                # # Re-escape for Python string literal
924
                # python_escaped = unescaped_value.replace("\\", "\\\\").replace(
925
                #     "'", "\\'"
926
                # )
927
                members.append(f"    {member_name} = {unescaped_value!r}")
4✔
928

929
            enum_def = f"class {enum_class_name}(str, enum.Enum):\n" + "\n".join(
4✔
930
                members
931
            )
932
            enum_defs.append(enum_def)
4✔
933

934
        return "\n\n\n".join(enum_defs)
4✔
935

936
    def fix_column_types(self, table: Table) -> None:
4✔
937
        """Adjust the reflected column types."""
938

939
        def fix_enum_column(col_name: str, enum_type: Enum) -> None:
4✔
940
            if (table.name, col_name) in self.enum_classes:
4✔
941
                return
4✔
942

943
            if enum_type.name:
4✔
944
                existing_class = None
4✔
945
                for (_, _), cls in self.enum_classes.items():
4✔
946
                    if cls == self._enum_name_to_class_name(enum_type.name):
4✔
947
                        existing_class = cls
4✔
948
                        break
4✔
949

950
                if existing_class:
4✔
951
                    enum_class_name = existing_class
4✔
952
                else:
953
                    enum_class_name = self._enum_name_to_class_name(enum_type.name)
4✔
954
                    if enum_class_name not in self.enum_values:
4✔
955
                        self.enum_values[enum_class_name] = list(enum_type.enums)
4✔
956
            else:
957
                enum_class_name = self._create_enum_class(
4✔
958
                    table.name, col_name, list(enum_type.enums)
959
                )
960

961
            self.enum_classes[(table.name, col_name)] = enum_class_name
4✔
962

963
        # Detect check constraints for boolean and enum columns
964
        for constraint in table.constraints.copy():
4✔
965
            if isinstance(constraint, CheckConstraint):
4✔
966
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
4✔
967

968
                # Turn any integer-like column with a CheckConstraint like
969
                # "column IN (0, 1)" into a Boolean
970
                if match := _re_boolean_check_constraint.match(sqltext):
4✔
971
                    if colname_match := _re_column_name.match(match.group(1)):
4✔
972
                        colname = colname_match.group(3)
4✔
973
                        table.constraints.remove(constraint)
4✔
974
                        table.c[colname].type = Boolean()
4✔
975
                        continue
4✔
976

977
                # Turn VARCHAR columns with CHECK constraints like "column IN ('a', 'b')"
978
                # into synthetic Enum types with Python enum classes
979
                if (
4✔
980
                    "nosyntheticenums" not in self.options
981
                    and (match := _re_enum_check_constraint.match(sqltext))
982
                    and (colname_match := _re_column_name.match(match.group(1)))
983
                ):
984
                    colname = colname_match.group(3)
4✔
985
                    items = match.group(2)
4✔
986
                    if isinstance(table.c[colname].type, String) and not isinstance(
4✔
987
                        table.c[colname].type, Enum
988
                    ):
989
                        options = _re_enum_item.findall(items)
4✔
990
                        # Create Python enum class
991
                        enum_class_name = self._create_enum_class(
4✔
992
                            table.name, colname, options
993
                        )
994
                        self.enum_classes[(table.name, colname)] = enum_class_name
4✔
995
                        # Convert to Enum type but KEEP the constraint
996
                        table.c[colname].type = Enum(*options, native_enum=False)
4✔
997
                        continue
4✔
998

999
        for column in table.c:
4✔
1000
            # Handle native database Enum types (e.g., PostgreSQL ENUM)
1001
            if (
4✔
1002
                "nonativeenums" not in self.options
1003
                and isinstance(column.type, Enum)
1004
                and column.type.enums
1005
            ):
1006
                fix_enum_column(column.name, column.type)
4✔
1007

1008
            # Handle ARRAY columns with Enum item types (e.g., PostgreSQL ARRAY(ENUM))
1009
            elif (
4✔
1010
                "nonativeenums" not in self.options
1011
                and isinstance(column.type, ARRAY)
1012
                and isinstance(column.type.item_type, Enum)
1013
                and column.type.item_type.enums
1014
            ):
1015
                fix_enum_column(column.name, column.type.item_type)
4✔
1016

1017
            if not self.keep_dialect_types:
4✔
1018
                try:
4✔
1019
                    column.type = self.get_adapted_type(column.type)
4✔
1020
                except CompileError:
4✔
1021
                    continue
4✔
1022

1023
            # PostgreSQL specific fix: detect sequences from server_default
1024
            if column.server_default and self.bind.dialect.name == "postgresql":
4✔
1025
                if isinstance(column.server_default, DefaultClause) and isinstance(
4✔
1026
                    column.server_default.arg, TextClause
1027
                ):
1028
                    schema, seqname = decode_postgresql_sequence(
4✔
1029
                        column.server_default.arg
1030
                    )
1031
                    if seqname:
4✔
1032
                        # Add an explicit sequence
1033
                        if seqname != f"{column.table.name}_{column.name}_seq":
4✔
1034
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
4✔
1035

1036
                        column.server_default = None
4✔
1037

1038
    def get_adapted_type(self, coltype: Any) -> Any:
4✔
1039
        # Keep dialect-specific ARRAY subclasses; the generic sqlalchemy.ARRAY
1040
        # is missing operators like .contains() (GH-441).
1041
        if isinstance(coltype, ARRAY) and type(coltype) is not ARRAY:
4✔
1042
            coltype.item_type = self.get_adapted_type(coltype.item_type)
4✔
1043
            return coltype
4✔
1044

1045
        compiled_type = coltype.compile(self.bind.engine.dialect)
4✔
1046
        for supercls in coltype.__class__.__mro__:
4✔
1047
            if not supercls.__name__.startswith("_") and hasattr(
4✔
1048
                supercls, "__visit_name__"
1049
            ):
1050
                # Don't try to adapt UserDefinedType as it's not a proper column type
1051
                if supercls is UserDefinedType or issubclass(supercls, TypeDecorator):
4✔
1052
                    return coltype
4✔
1053

1054
                # Hack to fix adaptation of the Enum class which is broken since
1055
                # SQLAlchemy 1.2
1056
                kw = {}
4✔
1057
                if supercls is Enum:
4✔
1058
                    kw["name"] = coltype.name
4✔
1059
                    if coltype.schema:
4✔
1060
                        kw["schema"] = coltype.schema
4✔
1061

1062
                # Hack to fix Postgres DOMAIN type adaptation, broken as of SQLAlchemy 2.0.42
1063
                # For additional information - https://github.com/agronholm/sqlacodegen/issues/416#issuecomment-3417480599
1064
                if supercls is DOMAIN:
4✔
1065
                    if coltype.default:
4✔
1066
                        kw["default"] = coltype.default
×
1067
                    if coltype.constraint_name is not None:
4✔
1068
                        kw["constraint_name"] = coltype.constraint_name
4✔
1069
                    if coltype.not_null:
4✔
1070
                        kw["not_null"] = coltype.not_null
×
1071
                    if coltype.check is not None:
4✔
1072
                        kw["check"] = coltype.check
4✔
1073
                    if coltype.create_type:
4✔
1074
                        kw["create_type"] = coltype.create_type
4✔
1075

1076
                try:
4✔
1077
                    new_coltype = coltype.adapt(supercls)
4✔
1078
                except TypeError:
4✔
1079
                    # If the adaptation fails, don't try again
1080
                    break
4✔
1081

1082
                for key, value in kw.items():
4✔
1083
                    setattr(new_coltype, key, value)
4✔
1084

1085
                if isinstance(coltype, ARRAY):
4✔
UNCOV
1086
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
×
1087

1088
                try:
4✔
1089
                    # If the adapted column type does not render the same as the
1090
                    # original, don't substitute it
1091
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
4✔
1092
                        break
4✔
1093
                except CompileError:
4✔
1094
                    # If the adapted column type can't be compiled, don't substitute it
1095
                    break
4✔
1096

1097
                # Stop on the first valid non-uppercase column type class
1098
                coltype = new_coltype
4✔
1099
                if supercls.__name__ != supercls.__name__.upper():
4✔
1100
                    break
4✔
1101

1102
        return coltype
4✔
1103

1104

1105
class DeclarativeGenerator(TablesGenerator):
4✔
1106
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
4✔
1107
        "use_inflect",
1108
        "nojoined",
1109
        "nobidi",
1110
        "noidsuffix",
1111
        "nofknames",
1112
    }
1113

1114
    def __init__(
4✔
1115
        self,
1116
        metadata: MetaData,
1117
        bind: Connection | Engine,
1118
        options: Sequence[str],
1119
        *,
1120
        indentation: str = "    ",
1121
        base_class_name: str = "Base",
1122
        explicit_foreign_keys: bool = False,
1123
    ):
1124
        super().__init__(metadata, bind, options, indentation=indentation)
4✔
1125
        self.base_class_name: str = base_class_name
4✔
1126
        self.inflect_engine = inflect.engine()
4✔
1127
        self.explicit_foreign_keys = explicit_foreign_keys
4✔
1128

1129
    def generate_base(self) -> None:
4✔
1130
        self.base = Base(
4✔
1131
            literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
1132
            declarations=[
1133
                f"class {self.base_class_name}(DeclarativeBase):",
1134
                f"{self.indentation}pass",
1135
            ],
1136
            metadata_ref=f"{self.base_class_name}.metadata",
1137
        )
1138

1139
    def collect_imports(self, models: Iterable[Model]) -> None:
4✔
1140
        super().collect_imports(models)
4✔
1141
        if any(isinstance(model, ModelClass) for model in models):
4✔
1142
            self.add_literal_import("sqlalchemy.orm", "Mapped")
4✔
1143
            self.add_literal_import("sqlalchemy.orm", "mapped_column")
4✔
1144

1145
    def collect_imports_for_model(self, model: Model) -> None:
4✔
1146
        super().collect_imports_for_model(model)
4✔
1147
        if isinstance(model, ModelClass):
4✔
1148
            if model.relationships:
4✔
1149
                self.add_literal_import("sqlalchemy.orm", "relationship")
4✔
1150

1151
    def generate_models(self) -> list[Model]:
4✔
1152
        models_by_table_name: dict[str, Model] = {}
4✔
1153

1154
        # Pick association tables from the metadata into their own set, don't process
1155
        # them normally
1156
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
4✔
1157
        for table in self.metadata.sorted_tables:
4✔
1158
            qualified_name = qualified_table_name(table)
4✔
1159

1160
            # Link tables have exactly two foreign key constraints and all columns are
1161
            # involved in them
1162
            fk_constraints = sorted(
4✔
1163
                table.foreign_key_constraints, key=get_constraint_sort_key
1164
            )
1165
            if len(fk_constraints) == 2 and all(
4✔
1166
                col.foreign_keys for col in table.columns
1167
            ):
1168
                model = models_by_table_name[qualified_name] = Model(table)
4✔
1169
                tablename = fk_constraints[0].elements[0].column.table.name
4✔
1170
                links[tablename].append(model)
4✔
1171
                continue
4✔
1172

1173
            # Only form model classes for tables that have a primary key and are not
1174
            # association tables
1175
            if not table.primary_key:
4✔
1176
                models_by_table_name[qualified_name] = Model(table)
4✔
1177
            else:
1178
                model = ModelClass(table)
4✔
1179
                models_by_table_name[qualified_name] = model
4✔
1180

1181
                # Fill in the columns
1182
                for column in table.c:
4✔
1183
                    column_attr = ColumnAttribute(model, column)
4✔
1184
                    model.columns.append(column_attr)
4✔
1185

1186
        # Add relationships
1187
        for model in models_by_table_name.values():
4✔
1188
            if isinstance(model, ModelClass):
4✔
1189
                self.generate_relationships(
4✔
1190
                    model, models_by_table_name, links[model.table.name]
1191
                )
1192

1193
        # Nest inherited classes in their superclasses to ensure proper ordering
1194
        if "nojoined" not in self.options:
4✔
1195
            for model in list(models_by_table_name.values()):
4✔
1196
                if not isinstance(model, ModelClass):
4✔
1197
                    continue
4✔
1198

1199
                pk_column_names = {col.name for col in model.table.primary_key.columns}
4✔
1200
                for constraint in model.table.foreign_key_constraints:
4✔
1201
                    if set(get_column_names(constraint)) == pk_column_names:
4✔
1202
                        target = models_by_table_name[
4✔
1203
                            qualified_table_name(constraint.elements[0].column.table)
1204
                        ]
1205
                        if isinstance(target, ModelClass):
4✔
1206
                            model.parent_class = target
4✔
1207
                            target.children.append(model)
4✔
1208

1209
        # Change base if we only have tables
1210
        if not any(
4✔
1211
            isinstance(model, ModelClass) for model in models_by_table_name.values()
1212
        ):
1213
            super().generate_base()
4✔
1214

1215
        # Collect the imports
1216
        self.collect_imports(models_by_table_name.values())
4✔
1217

1218
        # Rename models and their attributes that conflict with imports or other
1219
        # attributes
1220
        global_names = {
4✔
1221
            name for namespace in self.imports.values() for name in namespace
1222
        }
1223
        for model in models_by_table_name.values():
4✔
1224
            self.generate_model_name(model, global_names)
4✔
1225
            global_names.add(model.name)
4✔
1226

1227
        return list(models_by_table_name.values())
4✔
1228

1229
    def generate_relationships(
4✔
1230
        self,
1231
        source: ModelClass,
1232
        models_by_table_name: dict[str, Model],
1233
        association_tables: list[Model],
1234
    ) -> list[RelationshipAttribute]:
1235
        relationships: list[RelationshipAttribute] = []
4✔
1236
        reverse_relationship: RelationshipAttribute | None
1237

1238
        # Add many-to-one (and one-to-many) relationships
1239
        pk_column_names = {col.name for col in source.table.primary_key.columns}
4✔
1240
        for constraint in sorted(
4✔
1241
            source.table.foreign_key_constraints, key=get_constraint_sort_key
1242
        ):
1243
            target = models_by_table_name[
4✔
1244
                qualified_table_name(constraint.elements[0].column.table)
1245
            ]
1246
            if isinstance(target, ModelClass):
4✔
1247
                if "nojoined" not in self.options:
4✔
1248
                    if set(get_column_names(constraint)) == pk_column_names:
4✔
1249
                        parent = models_by_table_name[
4✔
1250
                            qualified_table_name(constraint.elements[0].column.table)
1251
                        ]
1252
                        if isinstance(parent, ModelClass):
4✔
1253
                            source.parent_class = parent
4✔
1254
                            parent.children.append(source)
4✔
1255
                            continue
4✔
1256

1257
                # Add uselist=False to One-to-One relationships
1258
                column_names = get_column_names(constraint)
4✔
1259
                if any(
4✔
1260
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
1261
                    and {col.name for col in c.columns} == set(column_names)
1262
                    for c in constraint.table.constraints
1263
                ):
1264
                    r_type = RelationshipType.ONE_TO_ONE
4✔
1265
                else:
1266
                    r_type = RelationshipType.MANY_TO_ONE
4✔
1267

1268
                relationship = RelationshipAttribute(r_type, source, target, constraint)
4✔
1269
                source.relationships.append(relationship)
4✔
1270

1271
                # For self referential relationships, remote_side needs to be set
1272
                if source is target:
4✔
1273
                    relationship.remote_side = [
4✔
1274
                        source.get_column_attribute(col.name)
1275
                        for col in constraint.referred_table.primary_key
1276
                    ]
1277

1278
                # If the two tables share more than one foreign key constraint,
1279
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
1280
                # it needs
1281
                common_fk_constraints = get_common_fk_constraints(
4✔
1282
                    source.table, target.table
1283
                )
1284
                if len(common_fk_constraints) > 1:
4✔
1285
                    relationship.foreign_keys = [
4✔
1286
                        source.get_column_attribute(key)
1287
                        for key in constraint.column_keys
1288
                    ]
1289

1290
                # Generate the opposite end of the relationship in the target class
1291
                if "nobidi" not in self.options:
4✔
1292
                    if r_type is RelationshipType.MANY_TO_ONE:
4✔
1293
                        r_type = RelationshipType.ONE_TO_MANY
4✔
1294

1295
                    reverse_relationship = RelationshipAttribute(
4✔
1296
                        r_type,
1297
                        target,
1298
                        source,
1299
                        constraint,
1300
                        foreign_keys=relationship.foreign_keys,
1301
                        backref=relationship,
1302
                    )
1303
                    relationship.backref = reverse_relationship
4✔
1304
                    target.relationships.append(reverse_relationship)
4✔
1305

1306
                    # For self referential relationships, remote_side needs to be set
1307
                    if source is target:
4✔
1308
                        reverse_relationship.remote_side = [
4✔
1309
                            source.get_column_attribute(colname)
1310
                            for colname in constraint.column_keys
1311
                        ]
1312

1313
        # Add many-to-many relationships
1314
        for association_table in association_tables:
4✔
1315
            fk_constraints = sorted(
4✔
1316
                association_table.table.foreign_key_constraints,
1317
                key=get_constraint_sort_key,
1318
            )
1319
            target = models_by_table_name[
4✔
1320
                qualified_table_name(fk_constraints[1].elements[0].column.table)
1321
            ]
1322
            if isinstance(target, ModelClass):
4✔
1323
                relationship = RelationshipAttribute(
4✔
1324
                    RelationshipType.MANY_TO_MANY,
1325
                    source,
1326
                    target,
1327
                    fk_constraints[1],
1328
                    association_table,
1329
                )
1330
                source.relationships.append(relationship)
4✔
1331

1332
                # Generate the opposite end of the relationship in the target class
1333
                reverse_relationship = None
4✔
1334
                if "nobidi" not in self.options:
4✔
1335
                    reverse_relationship = RelationshipAttribute(
4✔
1336
                        RelationshipType.MANY_TO_MANY,
1337
                        target,
1338
                        source,
1339
                        fk_constraints[0],
1340
                        association_table,
1341
                        relationship,
1342
                    )
1343
                    relationship.backref = reverse_relationship
4✔
1344
                    target.relationships.append(reverse_relationship)
4✔
1345

1346
                # Add a primary/secondary join for self-referential many-to-many
1347
                # relationships
1348
                if source is target:
4✔
1349
                    both_relationships = [relationship]
4✔
1350
                    reverse_flags = [False, True]
4✔
1351
                    if reverse_relationship:
4✔
1352
                        both_relationships.append(reverse_relationship)
4✔
1353

1354
                    for relationship, reverse in zip(both_relationships, reverse_flags):
4✔
1355
                        if (
4✔
1356
                            not relationship.association_table
1357
                            or not relationship.constraint
1358
                        ):
1359
                            continue
×
1360

1361
                        constraints = sorted(
4✔
1362
                            relationship.constraint.table.foreign_key_constraints,
1363
                            key=get_constraint_sort_key,
1364
                            reverse=reverse,
1365
                        )
1366
                        pri_pairs = zip(
4✔
1367
                            get_column_names(constraints[0]), constraints[0].elements
1368
                        )
1369
                        sec_pairs = zip(
4✔
1370
                            get_column_names(constraints[1]), constraints[1].elements
1371
                        )
1372
                        relationship.primaryjoin = [
4✔
1373
                            (
1374
                                relationship.source,
1375
                                elem.column.name,
1376
                                relationship.association_table,
1377
                                col,
1378
                            )
1379
                            for col, elem in pri_pairs
1380
                        ]
1381
                        relationship.secondaryjoin = [
4✔
1382
                            (
1383
                                relationship.target,
1384
                                elem.column.name,
1385
                                relationship.association_table,
1386
                                col,
1387
                            )
1388
                            for col, elem in sec_pairs
1389
                        ]
1390

1391
        return relationships
4✔
1392

1393
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
4✔
1394
        if isinstance(model, ModelClass):
4✔
1395
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
4✔
1396
            preferred_name = "".join(
4✔
1397
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1398
            )
1399
            if "use_inflect" in self.options:
4✔
1400
                singular_name = self.inflect_engine.singular_noun(preferred_name)
4✔
1401
                if singular_name:
4✔
1402
                    preferred_name = singular_name
4✔
1403

1404
            model.name = self.find_free_name(preferred_name, global_names)
4✔
1405

1406
            # Fill in the names for column attributes
1407
            local_names: set[str] = set()
4✔
1408
            for column_attr in model.columns:
4✔
1409
                self.generate_column_attr_name(column_attr, global_names, local_names)
4✔
1410
                local_names.add(column_attr.name)
4✔
1411

1412
            # Fill in the names for relationship attributes
1413
            for relationship in model.relationships:
4✔
1414
                self.generate_relationship_name(relationship, global_names, local_names)
4✔
1415
                local_names.add(relationship.name)
4✔
1416
        else:
1417
            super().generate_model_name(model, global_names)
4✔
1418

1419
    def generate_column_attr_name(
4✔
1420
        self,
1421
        column_attr: ColumnAttribute,
1422
        global_names: set[str],
1423
        local_names: set[str],
1424
    ) -> None:
1425
        column_attr.name = self.find_free_name(
4✔
1426
            column_attr.column.name, global_names, local_names
1427
        )
1428

1429
    def generate_relationship_name(
4✔
1430
        self,
1431
        relationship: RelationshipAttribute,
1432
        global_names: set[str],
1433
        local_names: set[str],
1434
    ) -> None:
1435
        def strip_id_suffix(name: str) -> str:
4✔
1436
            # Strip _id only if at the end or followed by underscore (e.g., "course_id" -> "course", "course_id_1" -> "course_1")
1437
            # But don't strip from "parent_id1" (where id is followed by a digit without underscore)
1438
            return re.sub(r"_id(?=_|$)", "", name)
4✔
1439

1440
        def get_m2m_qualified_name(default_name: str) -> str:
4✔
1441
            """Generate qualified name for many-to-many relationship when multiple junction tables exist."""
1442
            # Check if there are multiple M2M relationships to the same target
1443
            target_m2m_relationships = [
4✔
1444
                r
1445
                for r in relationship.source.relationships
1446
                if r.target is relationship.target
1447
                and r.type == RelationshipType.MANY_TO_MANY
1448
            ]
1449

1450
            # Only use junction-based naming when there are multiple M2M to same target
1451
            if len(target_m2m_relationships) > 1:
4✔
1452
                if relationship.source is relationship.target:
4✔
1453
                    # Self-referential: use FK column name from junction table
1454
                    # (e.g., "parent_id" -> "parent", "child_id" -> "child")
1455
                    if relationship.constraint:
4✔
1456
                        column_names = [c.name for c in relationship.constraint.columns]
4✔
1457
                        if len(column_names) == 1:
4✔
1458
                            fk_qualifier = strip_id_suffix(column_names[0])
4✔
1459
                        else:
1460
                            fk_qualifier = "_".join(
×
1461
                                strip_id_suffix(col_name) for col_name in column_names
1462
                            )
1463
                        return fk_qualifier
4✔
1464
                elif relationship.association_table:
4✔
1465
                    # Normal: use junction table name as qualifier
1466
                    junction_name = relationship.association_table.table.name
4✔
1467
                    fk_qualifier = strip_id_suffix(junction_name)
4✔
1468
                    return f"{relationship.target.table.name}_{fk_qualifier}"
4✔
1469
            else:
1470
                # Single M2M: use simple name from junction table FK column
1471
                # (e.g., "right_id" -> "right" instead of "right_table")
1472
                if relationship.constraint and "noidsuffix" not in self.options:
4✔
1473
                    column_names = [c.name for c in relationship.constraint.columns]
4✔
1474
                    if len(column_names) == 1:
4✔
1475
                        stripped_name = strip_id_suffix(column_names[0])
4✔
1476
                        if stripped_name != column_names[0]:
4✔
1477
                            return stripped_name
4✔
1478

1479
            return default_name
4✔
1480

1481
        def get_fk_qualified_name(constraint: ForeignKeyConstraint) -> str:
4✔
1482
            """Generate qualified name for one-to-many/one-to-one relationship using FK column names."""
1483
            column_names = [c.name for c in constraint.columns]
4✔
1484

1485
            if len(column_names) == 1:
4✔
1486
                # Single column FK: strip _id suffix if present
1487
                fk_qualifier = strip_id_suffix(column_names[0])
4✔
1488
            else:
1489
                # Multi-column FK: concatenate all column names (strip _id from each)
1490
                fk_qualifier = "_".join(
4✔
1491
                    strip_id_suffix(col_name) for col_name in column_names
1492
                )
1493

1494
            # For self-referential relationships, don't prepend the table name
1495
            if relationship.source is relationship.target:
4✔
1496
                return fk_qualifier
×
1497
            else:
1498
                return f"{relationship.target.table.name}_{fk_qualifier}"
4✔
1499

1500
        def resolve_preferred_name() -> str:
4✔
1501
            resolved_name = relationship.target.table.name
4✔
1502

1503
            # For reverse relationships with multiple FKs to the same table, use the FK
1504
            # column name to create a more descriptive relationship name
1505
            # For M2M relationships with multiple junction tables, use the junction table name
1506
            use_fk_based_naming = "nofknames" not in self.options and (
4✔
1507
                (
1508
                    relationship.constraint
1509
                    and relationship.type
1510
                    in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1511
                    and relationship.foreign_keys
1512
                )
1513
                or (
1514
                    relationship.type == RelationshipType.MANY_TO_MANY
1515
                    and relationship.association_table
1516
                )
1517
            )
1518

1519
            if use_fk_based_naming:
4✔
1520
                if relationship.type == RelationshipType.MANY_TO_MANY:
4✔
1521
                    resolved_name = get_m2m_qualified_name(resolved_name)
4✔
1522
                elif relationship.constraint:
4✔
1523
                    resolved_name = get_fk_qualified_name(relationship.constraint)
4✔
1524

1525
            # If there's a constraint with a single column that contains "_id", use the
1526
            # stripped version as the relationship name
1527
            elif relationship.constraint and "noidsuffix" not in self.options:
4✔
1528
                is_source = relationship.source.table is relationship.constraint.table
4✔
1529
                if is_source or relationship.type not in (
4✔
1530
                    RelationshipType.ONE_TO_ONE,
1531
                    RelationshipType.ONE_TO_MANY,
1532
                ):
1533
                    column_names = [c.name for c in relationship.constraint.columns]
4✔
1534
                    if len(column_names) == 1:
4✔
1535
                        stripped_name = strip_id_suffix(column_names[0])
4✔
1536
                        # Only use the stripped name if it actually changed (had _id in it)
1537
                        if stripped_name != column_names[0]:
4✔
1538
                            resolved_name = stripped_name
4✔
1539
                    else:
1540
                        # For composite FKs, check if there are multiple FKs to the same target
1541
                        target_relationships = [
4✔
1542
                            r
1543
                            for r in relationship.source.relationships
1544
                            if r.target is relationship.target
1545
                            and r.type == relationship.type
1546
                        ]
1547
                        if len(target_relationships) > 1:
4✔
1548
                            # Multiple FKs to same table - use concatenated column names
1549
                            resolved_name = "_".join(
4✔
1550
                                strip_id_suffix(col_name) for col_name in column_names
1551
                            )
1552

1553
            if "use_inflect" in self.options:
4✔
1554
                inflected_name: str | Literal[False]
1555
                if relationship.type in (
4✔
1556
                    RelationshipType.ONE_TO_MANY,
1557
                    RelationshipType.MANY_TO_MANY,
1558
                ):
1559
                    if not self.inflect_engine.singular_noun(resolved_name):
4✔
1560
                        resolved_name = self.inflect_engine.plural_noun(resolved_name)
4✔
1561
                else:
1562
                    inflected_name = self.inflect_engine.singular_noun(resolved_name)
4✔
1563
                    if inflected_name:
4✔
1564
                        resolved_name = inflected_name
4✔
1565

1566
            return resolved_name
4✔
1567

1568
        if (
4✔
1569
            relationship.type
1570
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1571
            and relationship.source is relationship.target
1572
            and relationship.backref
1573
            and relationship.backref.name
1574
        ):
1575
            preferred_name = relationship.backref.name + "_reverse"
4✔
1576
        else:
1577
            preferred_name = resolve_preferred_name()
4✔
1578

1579
        relationship.name = self.find_free_name(
4✔
1580
            preferred_name, global_names, local_names
1581
        )
1582

1583
    def render_models(self, models: list[Model]) -> str:
4✔
1584
        rendered: list[str] = []
4✔
1585
        for model in models:
4✔
1586
            if isinstance(model, ModelClass):
4✔
1587
                rendered.append(self.render_class(model))
4✔
1588
            else:
1589
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
4✔
1590

1591
        return "\n\n\n".join(rendered)
4✔
1592

1593
    def render_class(self, model: ModelClass) -> str:
4✔
1594
        sections: list[str] = []
4✔
1595

1596
        # Render class variables / special declarations
1597
        class_vars: str = self.render_class_variables(model)
4✔
1598
        if class_vars:
4✔
1599
            sections.append(class_vars)
4✔
1600

1601
        # Render column attributes
1602
        rendered_column_attributes: list[str] = []
4✔
1603
        for nullable in (False, True):
4✔
1604
            for column_attr in model.columns:
4✔
1605
                if column_attr.column.nullable is nullable:
4✔
1606
                    rendered_column_attributes.append(
4✔
1607
                        self.render_column_attribute(column_attr)
1608
                    )
1609

1610
        if rendered_column_attributes:
4✔
1611
            sections.append("\n".join(rendered_column_attributes))
4✔
1612

1613
        # Render relationship attributes
1614
        rendered_relationship_attributes: list[str] = [
4✔
1615
            self.render_relationship(relationship)
1616
            for relationship in model.relationships
1617
        ]
1618

1619
        if rendered_relationship_attributes:
4✔
1620
            sections.append("\n".join(rendered_relationship_attributes))
4✔
1621

1622
        declaration = self.render_class_declaration(model)
4✔
1623
        rendered_sections = "\n\n".join(
4✔
1624
            indent(section, self.indentation) for section in sections
1625
        )
1626
        return f"{declaration}\n{rendered_sections}"
4✔
1627

1628
    def render_class_declaration(self, model: ModelClass) -> str:
4✔
1629
        parent_class_name = (
4✔
1630
            model.parent_class.name if model.parent_class else self.base_class_name
1631
        )
1632
        return f"class {model.name}({parent_class_name}):"
4✔
1633

1634
    def render_class_variables(self, model: ModelClass) -> str:
4✔
1635
        variables = [f"__tablename__ = {model.table.name!r}"]
4✔
1636

1637
        # Render constraints and indexes as __table_args__
1638
        table_args = self.render_table_args(model.table)
4✔
1639
        if table_args:
4✔
1640
            variables.append(f"__table_args__ = {table_args}")
4✔
1641

1642
        return "\n".join(variables)
4✔
1643

1644
    def render_table_args(self, table: Table) -> str:
4✔
1645
        args: list[str] = []
4✔
1646
        kwargs: dict[str, object] = {}
4✔
1647

1648
        # Render constraints
1649
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
4✔
1650
            if uses_default_name(constraint):
4✔
1651
                if isinstance(constraint, PrimaryKeyConstraint):
4✔
1652
                    continue
4✔
1653
                if (
4✔
1654
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1655
                    and len(constraint.columns) == 1
1656
                ):
1657
                    continue
4✔
1658

1659
            args.append(self.render_constraint(constraint))
4✔
1660

1661
        # Render indexes
1662
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
4✔
1663
            if len(index.columns) > 1 or not uses_default_name(index):
4✔
1664
                args.append(self.render_index(index))
4✔
1665

1666
        if table.schema:
4✔
1667
            kwargs["schema"] = table.schema
4✔
1668

1669
        if table.comment:
4✔
1670
            kwargs["comment"] = table.comment
4✔
1671

1672
        # add info + dialect kwargs for dict context (__table_args__) (opt-in)
1673
        if self.include_dialect_options_and_info:
4✔
1674
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=True)
4✔
1675

1676
        if kwargs:
4✔
1677
            formatted_kwargs = pformat(kwargs)
4✔
1678
            if not args:
4✔
1679
                return formatted_kwargs
4✔
1680
            else:
1681
                args.append(formatted_kwargs)
4✔
1682

1683
        if args:
4✔
1684
            rendered_args = f",\n{self.indentation}".join(args)
4✔
1685
            if len(args) == 1:
4✔
1686
                rendered_args += ","
4✔
1687

1688
            return f"(\n{self.indentation}{rendered_args}\n)"
4✔
1689
        else:
1690
            return ""
4✔
1691

1692
    def render_column_python_type(self, column: Column[Any]) -> str:
4✔
1693
        def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
4✔
1694
            column_type = column.type
4✔
1695
            pre: list[str] = []
4✔
1696
            post_size = 0
4✔
1697
            if column.nullable:
4✔
1698
                self.add_literal_import("typing", "Optional")
4✔
1699
                pre.append("Optional[")
4✔
1700
                post_size += 1
4✔
1701

1702
            if isinstance(column_type, ARRAY):
4✔
1703
                dim = getattr(column_type, "dimensions", None) or 1
4✔
1704
                pre.extend("list[" for _ in range(dim))
4✔
1705
                post_size += dim
4✔
1706

1707
                column_type = column_type.item_type
4✔
1708

1709
            return "".join(pre), column_type, "]" * post_size
4✔
1710

1711
        def render_python_type(column_type: TypeEngine[Any]) -> str:
4✔
1712
            # Check if this is an enum column with a Python enum class
1713
            if isinstance(column_type, Enum):
4✔
1714
                table_name = column.table.name
4✔
1715
                column_name = column.name
4✔
1716
                if (table_name, column_name) in self.enum_classes:
4✔
1717
                    enum_class_name = self.enum_classes[(table_name, column_name)]
4✔
1718
                    return enum_class_name
4✔
1719

1720
            if isinstance(column_type, DOMAIN):
4✔
1721
                column_type = column_type.data_type
4✔
1722

1723
            try:
4✔
1724
                python_type = column_type.python_type
4✔
1725
                python_type_module = python_type.__module__
4✔
1726
                python_type_name = python_type.__name__
4✔
1727
            except NotImplementedError:
4✔
1728
                self.add_literal_import("typing", "Any")
4✔
1729
                return "Any"
4✔
1730

1731
            if python_type_module == "builtins":
4✔
1732
                return python_type_name
4✔
1733

1734
            self.add_module_import(python_type_module)
4✔
1735
            return f"{python_type_module}.{python_type_name}"
4✔
1736

1737
        pre, col_type, post = get_type_qualifiers()
4✔
1738
        column_python_type = f"{pre}{render_python_type(col_type)}{post}"
4✔
1739
        return column_python_type
4✔
1740

1741
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
4✔
1742
        column = column_attr.column
4✔
1743
        rendered_column = self.render_column(column, column_attr.name != column.name)
4✔
1744
        rendered_column_python_type = self.render_column_python_type(column)
4✔
1745

1746
        return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
4✔
1747

1748
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
4✔
1749
        kwargs = self.render_relationship_arguments(relationship)
4✔
1750
        annotation = self.render_relationship_annotation(relationship)
4✔
1751
        rendered_relationship = render_callable(
4✔
1752
            "relationship", repr(relationship.target.name), kwargs=kwargs
1753
        )
1754
        return f"{relationship.name}: Mapped[{annotation}] = {rendered_relationship}"
4✔
1755

1756
    def render_relationship_annotation(
4✔
1757
        self, relationship: RelationshipAttribute
1758
    ) -> str:
1759
        match relationship.type:
4✔
1760
            case RelationshipType.ONE_TO_MANY:
4✔
1761
                return f"list[{relationship.target.name!r}]"
4✔
1762
            case RelationshipType.ONE_TO_ONE | RelationshipType.MANY_TO_ONE:
4✔
1763
                if relationship.constraint and any(
4✔
1764
                    col.nullable for col in relationship.constraint.columns
1765
                ):
1766
                    self.add_literal_import("typing", "Optional")
4✔
1767
                    return f"Optional[{relationship.target.name!r}]"
4✔
1768
                else:
1769
                    return f"'{relationship.target.name}'"
4✔
1770
            case RelationshipType.MANY_TO_MANY:
4✔
1771
                return f"list[{relationship.target.name!r}]"
4✔
1772

1773
    def render_relationship_arguments(
4✔
1774
        self, relationship: RelationshipAttribute
1775
    ) -> Mapping[str, Any]:
1776
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
4✔
1777
            rendered = []
4✔
1778
            render_as_string = False
4✔
1779
            for attr in column_attrs:
4✔
1780
                if not self.explicit_foreign_keys and attr.model is relationship.source:
4✔
1781
                    rendered.append(attr.name)
4✔
1782
                else:
1783
                    rendered.append(f"{attr.model.name}.{attr.name}")
4✔
1784
                    render_as_string = True
4✔
1785

1786
            joined = "[" + ", ".join(rendered) + "]"
4✔
1787
            return repr(joined) if render_as_string else joined
4✔
1788

1789
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
4✔
1790
            rendered = []
4✔
1791
            render_as_string = False
4✔
1792
            # Assume that column_attrs are all in relationship.source or none
1793
            for attr in column_attrs:
4✔
1794
                if not self.explicit_foreign_keys and attr.model is relationship.source:
4✔
1795
                    rendered.append(attr.name)
4✔
1796
                else:
1797
                    rendered.append(f"{attr.model.name}.{attr.name}")
4✔
1798
                    render_as_string = True
4✔
1799

1800
            if render_as_string:
4✔
1801
                return "'[" + ", ".join(rendered) + "]'"
4✔
1802
            else:
1803
                return "[" + ", ".join(rendered) + "]"
4✔
1804

1805
        def render_join(terms: list[JoinType]) -> str:
4✔
1806
            rendered_joins = []
4✔
1807
            for source, source_col, target, target_col in terms:
4✔
1808
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
4✔
1809
                if target.__class__ is Model:
4✔
1810
                    rendered += "c."
4✔
1811

1812
                rendered += str(target_col)
4✔
1813
                rendered_joins.append(rendered)
4✔
1814

1815
            if len(rendered_joins) > 1:
4✔
1816
                rendered = ", ".join(rendered_joins)
×
1817
                return f"and_({rendered})"
×
1818
            else:
1819
                return rendered_joins[0]
4✔
1820

1821
        # Render keyword arguments
1822
        kwargs: dict[str, Any] = {}
4✔
1823
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
4✔
1824
            if relationship.constraint.referred_table is relationship.source.table:
4✔
1825
                kwargs["uselist"] = False
4✔
1826

1827
        # Add the "secondary" keyword for many-to-many relationships
1828
        if relationship.association_table:
4✔
1829
            table_ref = relationship.association_table.table.name
4✔
1830
            if relationship.association_table.schema:
4✔
1831
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
4✔
1832

1833
            kwargs["secondary"] = repr(table_ref)
4✔
1834

1835
        if relationship.remote_side:
4✔
1836
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
4✔
1837

1838
        if relationship.foreign_keys:
4✔
1839
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
4✔
1840

1841
        if relationship.primaryjoin:
4✔
1842
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
4✔
1843

1844
        if relationship.secondaryjoin:
4✔
1845
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
4✔
1846

1847
        if relationship.backref:
4✔
1848
            kwargs["back_populates"] = repr(relationship.backref.name)
4✔
1849

1850
        return kwargs
4✔
1851

1852

1853
class DataclassGenerator(DeclarativeGenerator):
4✔
1854
    def __init__(
4✔
1855
        self,
1856
        metadata: MetaData,
1857
        bind: Connection | Engine,
1858
        options: Sequence[str],
1859
        *,
1860
        indentation: str = "    ",
1861
        base_class_name: str = "Base",
1862
        quote_annotations: bool = False,
1863
        metadata_key: str = "sa",
1864
    ):
1865
        super().__init__(
4✔
1866
            metadata,
1867
            bind,
1868
            options,
1869
            indentation=indentation,
1870
            base_class_name=base_class_name,
1871
        )
1872
        self.metadata_key: str = metadata_key
4✔
1873
        self.quote_annotations: bool = quote_annotations
4✔
1874

1875
    def generate_base(self) -> None:
4✔
1876
        self.base = Base(
4✔
1877
            literal_imports=[
1878
                LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1879
                LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1880
            ],
1881
            declarations=[
1882
                (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"),
1883
                f"{self.indentation}pass",
1884
            ],
1885
            metadata_ref=f"{self.base_class_name}.metadata",
1886
        )
1887

1888

1889
class SQLModelGenerator(DeclarativeGenerator):
4✔
1890
    def __init__(
4✔
1891
        self,
1892
        metadata: MetaData,
1893
        bind: Connection | Engine,
1894
        options: Sequence[str],
1895
        *,
1896
        indentation: str = "    ",
1897
        base_class_name: str = "SQLModel",
1898
    ):
1899
        super().__init__(
4✔
1900
            metadata,
1901
            bind,
1902
            options,
1903
            indentation=indentation,
1904
            base_class_name=base_class_name,
1905
            explicit_foreign_keys=True,
1906
        )
1907

1908
    @property
4✔
1909
    def views_supported(self) -> bool:
4✔
1910
        return False
×
1911

1912
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
4✔
1913
        self.add_import(Column)
4✔
1914
        return render_callable("Column", *args, kwargs=kwargs)
4✔
1915

1916
    def render_table(self, table: Table) -> str:
4✔
1917
        # Hack to fix #465 without breaking backwards compatibility
1918
        self.base.metadata_ref = "SQLModel.metadata"
4✔
1919

1920
        return super().render_table(table)
4✔
1921

1922
    def generate_base(self) -> None:
4✔
1923
        self.base = Base(
4✔
1924
            literal_imports=[],
1925
            declarations=[],
1926
            metadata_ref="SQLModel.metadata",
1927
        )
1928

1929
    def collect_imports(self, models: Iterable[Model]) -> None:
4✔
1930
        super(DeclarativeGenerator, self).collect_imports(models)
4✔
1931
        if any(isinstance(model, ModelClass) for model in models):
4✔
1932
            self.add_literal_import("sqlmodel", "Field")
4✔
1933

1934
        if models:
4✔
1935
            self.remove_literal_import("sqlalchemy", "MetaData")
4✔
1936
            self.add_literal_import("sqlmodel", "SQLModel")
4✔
1937

1938
    def collect_imports_for_model(self, model: Model) -> None:
4✔
1939
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
4✔
1940
        if isinstance(model, ModelClass):
4✔
1941
            for column_attr in model.columns:
4✔
1942
                if column_attr.column.nullable:
4✔
1943
                    self.add_literal_import("typing", "Optional")
4✔
1944
                    break
4✔
1945

1946
            if model.relationships:
4✔
1947
                self.add_literal_import("sqlmodel", "Relationship")
4✔
1948

1949
    def render_module_variables(self, models: list[Model]) -> str:
4✔
1950
        declarations: list[str] = []
4✔
1951
        if any(not isinstance(model, ModelClass) for model in models):
4✔
1952
            if self.base.table_metadata_declaration is not None:
4✔
1953
                declarations.append(self.base.table_metadata_declaration)
×
1954

1955
        return "\n".join(declarations)
4✔
1956

1957
    def render_class_declaration(self, model: ModelClass) -> str:
4✔
1958
        if model.parent_class:
4✔
1959
            parent = model.parent_class.name
×
1960
        else:
1961
            parent = self.base_class_name
4✔
1962

1963
        superclass_part = f"({parent}, table=True)"
4✔
1964
        return f"class {model.name}{superclass_part}:"
4✔
1965

1966
    def render_class_variables(self, model: ModelClass) -> str:
4✔
1967
        variables = []
4✔
1968

1969
        if model.table.name != model.name.lower():
4✔
1970
            variables.append(f"__tablename__ = {model.table.name!r}")
4✔
1971

1972
        # Render constraints and indexes as __table_args__
1973
        table_args = self.render_table_args(model.table)
4✔
1974
        if table_args:
4✔
1975
            variables.append(f"__table_args__ = {table_args}")
4✔
1976

1977
        return "\n".join(variables)
4✔
1978

1979
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
4✔
1980
        column = column_attr.column
4✔
1981
        rendered_column = self.render_column(column, True)
4✔
1982
        rendered_column_python_type = self.render_column_python_type(column)
4✔
1983

1984
        kwargs: dict[str, Any] = {}
4✔
1985
        if column.nullable:
4✔
1986
            kwargs["default"] = None
4✔
1987
        kwargs["sa_column"] = f"{rendered_column}"
4✔
1988

1989
        rendered_field = render_callable("Field", kwargs=kwargs)
4✔
1990

1991
        return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
4✔
1992

1993
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
4✔
1994
        kwargs = self.render_relationship_arguments(relationship)
4✔
1995
        annotation = self.render_relationship_annotation(relationship)
4✔
1996

1997
        native_kwargs: dict[str, Any] = {}
4✔
1998
        non_native_kwargs: dict[str, Any] = {}
4✔
1999
        for key, value in kwargs.items():
4✔
2000
            # The following keyword arguments are natively supported in Relationship
2001
            if key in ("back_populates", "cascade_delete", "passive_deletes"):
4✔
2002
                native_kwargs[key] = value
4✔
2003
            else:
2004
                non_native_kwargs[key] = value
4✔
2005

2006
        if non_native_kwargs:
4✔
2007
            native_kwargs["sa_relationship_kwargs"] = (
4✔
2008
                "{"
2009
                + ", ".join(
2010
                    f"{key!r}: {value}" for key, value in non_native_kwargs.items()
2011
                )
2012
                + "}"
2013
            )
2014

2015
        rendered_field = render_callable("Relationship", kwargs=native_kwargs)
4✔
2016
        return f"{relationship.name}: {annotation} = {rendered_field}"
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc