• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 22219280309

20 Feb 2026 09:47AM UTC coverage: 97.661% (-0.06%) from 97.718%
22219280309

Pull #464

github

web-flow
Merge fc2ce00ad into de5adccaa
Pull Request #464: Fix inherited kwargs rendering for MySQL CHAR collation

50 of 52 new or added lines in 3 files covered. (96.15%)

23 existing lines in 1 file now uncovered.

1795 of 1838 relevant lines covered (97.66%)

4.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.71
/src/sqlacodegen/generators.py
1
from __future__ import annotations
5✔
2

3
import inspect
5✔
4
import re
5✔
5
import sys
5✔
6
from abc import ABCMeta, abstractmethod
5✔
7
from collections import defaultdict
5✔
8
from collections.abc import Collection, Iterable, Mapping, Sequence
5✔
9
from dataclasses import dataclass
5✔
10
from importlib import import_module
5✔
11
from inspect import Parameter
5✔
12
from itertools import count
5✔
13
from keyword import iskeyword
5✔
14
from pprint import pformat
5✔
15
from textwrap import indent
5✔
16
from typing import Any, ClassVar, Literal, cast
5✔
17

18
import inflect
5✔
19
import sqlalchemy
5✔
20
from sqlalchemy import (
5✔
21
    ARRAY,
22
    Boolean,
23
    CheckConstraint,
24
    Column,
25
    Computed,
26
    Constraint,
27
    DefaultClause,
28
    Enum,
29
    ForeignKey,
30
    ForeignKeyConstraint,
31
    Identity,
32
    Index,
33
    MetaData,
34
    PrimaryKeyConstraint,
35
    String,
36
    Table,
37
    Text,
38
    TypeDecorator,
39
    UniqueConstraint,
40
)
41
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
5✔
42
from sqlalchemy.engine import Connection, Engine
5✔
43
from sqlalchemy.exc import CompileError
5✔
44
from sqlalchemy.sql.elements import TextClause
5✔
45
from sqlalchemy.sql.type_api import UserDefinedType
5✔
46
from sqlalchemy.types import TypeEngine
5✔
47

48
from .models import (
5✔
49
    ColumnAttribute,
50
    JoinType,
51
    Model,
52
    ModelClass,
53
    RelationshipAttribute,
54
    RelationshipType,
55
)
56
from .utils import (
5✔
57
    decode_postgresql_sequence,
58
    get_column_names,
59
    get_common_fk_constraints,
60
    get_compiled_expression,
61
    get_constraint_sort_key,
62
    get_stdlib_module_names,
63
    qualified_table_name,
64
    render_callable,
65
    uses_default_name,
66
)
67

68
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
5✔
69
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
5✔
70
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
5✔
71
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
5✔
72
_re_invalid_identifier = re.compile(r"(?u)\W")
5✔
73

74

75
@dataclass
5✔
76
class LiteralImport:
5✔
77
    pkgname: str
5✔
78
    name: str
5✔
79

80

81
@dataclass
5✔
82
class Base:
5✔
83
    """Representation of MetaData for Tables, respectively Base for classes"""
84

85
    literal_imports: list[LiteralImport]
5✔
86
    declarations: list[str]
5✔
87
    metadata_ref: str
5✔
88
    decorator: str | None = None
5✔
89
    table_metadata_declaration: str | None = None
5✔
90

91

92
class CodeGenerator(metaclass=ABCMeta):
5✔
93
    valid_options: ClassVar[set[str]] = set()
5✔
94

95
    def __init__(
5✔
96
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
97
    ):
98
        self.metadata: MetaData = metadata
5✔
99
        self.bind: Connection | Engine = bind
5✔
100
        self.options: set[str] = set(options)
5✔
101

102
        # Validate options
103
        invalid_options = {opt for opt in options if opt not in self.valid_options}
5✔
104
        if invalid_options:
5✔
105
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
106

107
    @property
5✔
108
    @abstractmethod
5✔
109
    def views_supported(self) -> bool:
5✔
110
        pass
×
111

112
    @abstractmethod
5✔
113
    def generate(self) -> str:
5✔
114
        """
115
        Generate the code for the given metadata.
116
        .. note:: May modify the metadata.
117
        """
118

119

120
@dataclass(eq=False)
5✔
121
class TablesGenerator(CodeGenerator):
5✔
122
    valid_options: ClassVar[set[str]] = {
5✔
123
        "noindexes",
124
        "noconstraints",
125
        "nocomments",
126
        "nonativeenums",
127
        "nosyntheticenums",
128
        "include_dialect_options",
129
        "keep_dialect_types",
130
    }
131
    stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
5✔
132

133
    def __init__(
5✔
134
        self,
135
        metadata: MetaData,
136
        bind: Connection | Engine,
137
        options: Sequence[str],
138
        *,
139
        indentation: str = "    ",
140
    ):
141
        super().__init__(metadata, bind, options)
5✔
142
        self.indentation: str = indentation
5✔
143
        self.imports: dict[str, set[str]] = defaultdict(set)
5✔
144
        self.module_imports: set[str] = set()
5✔
145

146
        # Render SchemaItem.info and dialect kwargs (Table/Column) into output
147
        self.include_dialect_options_and_info: bool = (
5✔
148
            "include_dialect_options" in self.options
149
        )
150
        # Keep dialect-specific types instead of adapting to generic SQLAlchemy types
151
        self.keep_dialect_types: bool = "keep_dialect_types" in self.options
5✔
152

153
        # Track Python enum classes: maps (table_name, column_name) -> enum_class_name
154
        self.enum_classes: dict[tuple[str, str], str] = {}
5✔
155
        # Track enum values: maps enum_class_name -> list of values
156
        self.enum_values: dict[str, list[str]] = {}
5✔
157

158
    @property
5✔
159
    def views_supported(self) -> bool:
5✔
160
        return True
×
161

162
    def generate_base(self) -> None:
5✔
163
        self.base = Base(
5✔
164
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
165
            declarations=["metadata = MetaData()"],
166
            metadata_ref="metadata",
167
        )
168

169
    def generate(self) -> str:
5✔
170
        self.generate_base()
5✔
171

172
        sections: list[str] = []
5✔
173

174
        # Remove unwanted elements from the metadata
175
        for table in list(self.metadata.tables.values()):
5✔
176
            if self.should_ignore_table(table):
5✔
177
                self.metadata.remove(table)
×
178
                continue
×
179

180
            if "noindexes" in self.options:
5✔
181
                table.indexes.clear()
5✔
182

183
            if "noconstraints" in self.options:
5✔
184
                table.constraints.clear()
5✔
185

186
            if "nocomments" in self.options:
5✔
187
                table.comment = None
5✔
188

189
            for column in table.columns:
5✔
190
                if "nocomments" in self.options:
5✔
191
                    column.comment = None
5✔
192

193
        # Use information from column constraints to figure out the intended column
194
        # types
195
        for table in self.metadata.tables.values():
5✔
196
            self.fix_column_types(table)
5✔
197

198
        # Generate the models
199
        models: list[Model] = self.generate_models()
5✔
200

201
        # Render module level variables
202
        if variables := self.render_module_variables(models):
5✔
203
            sections.append(variables + "\n")
5✔
204

205
        # Render enum classes
206
        if enum_classes := self.render_enum_classes():
5✔
207
            sections.append(enum_classes + "\n")
5✔
208

209
        # Render models
210
        if rendered_models := self.render_models(models):
5✔
211
            sections.append(rendered_models)
5✔
212

213
        # Render collected imports
214
        groups = self.group_imports()
5✔
215
        if imports := "\n\n".join(
5✔
216
            "\n".join(line for line in group) for group in groups
217
        ):
218
            sections.insert(0, imports)
5✔
219

220
        return "\n\n".join(sections) + "\n"
5✔
221

222
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
223
        for literal_import in self.base.literal_imports:
5✔
224
            self.add_literal_import(literal_import.pkgname, literal_import.name)
5✔
225

226
        for model in models:
5✔
227
            self.collect_imports_for_model(model)
5✔
228

229
    def collect_imports_for_model(self, model: Model) -> None:
5✔
230
        if model.__class__ is Model:
5✔
231
            self.add_import(Table)
5✔
232

233
        for column in model.table.c:
5✔
234
            self.collect_imports_for_column(column)
5✔
235

236
        for constraint in model.table.constraints:
5✔
237
            self.collect_imports_for_constraint(constraint)
5✔
238

239
        for index in model.table.indexes:
5✔
240
            self.collect_imports_for_constraint(index)
5✔
241

242
    def collect_imports_for_column(self, column: Column[Any]) -> None:
5✔
243
        self.add_import(column.type)
5✔
244

245
        if isinstance(column.type, ARRAY):
5✔
246
            self.add_import(column.type.item_type.__class__)
5✔
247
        elif isinstance(column.type, (JSONB, JSON)):
5✔
248
            if (
5✔
249
                not isinstance(column.type.astext_type, Text)
250
                or column.type.astext_type.length is not None
251
            ):
252
                self.add_import(column.type.astext_type)
5✔
253
        elif isinstance(column.type, DOMAIN):
5✔
254
            self.add_import(column.type.data_type.__class__)
5✔
255

256
        if column.default:
5✔
257
            self.add_import(column.default)
5✔
258

259
        if column.server_default:
5✔
260
            if isinstance(column.server_default, (Computed, Identity)):
5✔
261
                self.add_import(column.server_default)
5✔
262
            elif isinstance(column.server_default, DefaultClause):
5✔
263
                self.add_literal_import("sqlalchemy", "text")
5✔
264

265
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
5✔
266
        if isinstance(constraint, Index):
5✔
267
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
268
                self.add_literal_import("sqlalchemy", "Index")
5✔
269
        elif isinstance(constraint, PrimaryKeyConstraint):
5✔
270
            if not uses_default_name(constraint):
5✔
271
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
5✔
272
        elif isinstance(constraint, UniqueConstraint):
5✔
273
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
274
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
5✔
275
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
276
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
277
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
5✔
278
            else:
279
                self.add_import(ForeignKey)
5✔
280
        else:
281
            self.add_import(constraint)
5✔
282

283
    def add_import(self, obj: Any) -> None:
5✔
284
        # Don't store builtin imports
285
        if getattr(obj, "__module__", "builtins") == "builtins":
5✔
286
            return
×
287

288
        type_ = type(obj) if not isinstance(obj, type) else obj
5✔
289
        pkgname = type_.__module__
5✔
290

291
        # The column types have already been adapted towards generic types if possible,
292
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
293
        # use that rather than the generic sqlalchemy type as it might have different
294
        # constructor parameters.
295
        if pkgname.startswith("sqlalchemy.dialects."):
5✔
296
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
5✔
297
            dialect_pkg = import_module(dialect_pkgname)
5✔
298

299
            if type_.__name__ in dialect_pkg.__all__:
5✔
300
                pkgname = dialect_pkgname
5✔
301
        elif type_ is getattr(sqlalchemy, type_.__name__, None):
5✔
302
            pkgname = "sqlalchemy"
5✔
303
        else:
304
            pkgname = type_.__module__
5✔
305

306
        self.add_literal_import(pkgname, type_.__name__)
5✔
307

308
    def add_literal_import(self, pkgname: str, name: str) -> None:
5✔
309
        names = self.imports.setdefault(pkgname, set())
5✔
310
        names.add(name)
5✔
311

312
    def remove_literal_import(self, pkgname: str, name: str) -> None:
5✔
313
        names = self.imports.setdefault(pkgname, set())
5✔
314
        if name in names:
5✔
315
            names.remove(name)
×
316

317
    def add_module_import(self, pgkname: str) -> None:
5✔
318
        self.module_imports.add(pgkname)
5✔
319

320
    def group_imports(self) -> list[list[str]]:
5✔
321
        future_imports: list[str] = []
5✔
322
        stdlib_imports: list[str] = []
5✔
323
        thirdparty_imports: list[str] = []
5✔
324

325
        def get_collection(package: str) -> list[str]:
5✔
326
            collection = thirdparty_imports
5✔
327
            if package == "__future__":
5✔
328
                collection = future_imports
×
329
            elif package in self.stdlib_module_names:
5✔
330
                collection = stdlib_imports
5✔
331
            elif package in sys.modules:
5✔
332
                if "site-packages" not in (sys.modules[package].__file__ or ""):
5✔
333
                    collection = stdlib_imports
5✔
334
            return collection
5✔
335

336
        for package in sorted(self.imports):
5✔
337
            imports = ", ".join(sorted(self.imports[package]))
5✔
338

339
            collection = get_collection(package)
5✔
340
            collection.append(f"from {package} import {imports}")
5✔
341

342
        for module in sorted(self.module_imports):
5✔
343
            collection = get_collection(module)
5✔
344
            collection.append(f"import {module}")
5✔
345

346
        return [
5✔
347
            group
348
            for group in (future_imports, stdlib_imports, thirdparty_imports)
349
            if group
350
        ]
351

352
    def generate_models(self) -> list[Model]:
5✔
353
        models = [Model(table) for table in self.metadata.sorted_tables]
5✔
354

355
        # Collect the imports
356
        self.collect_imports(models)
5✔
357

358
        # Generate names for models
359
        global_names = {
5✔
360
            name for namespace in self.imports.values() for name in namespace
361
        }
362
        for model in models:
5✔
363
            self.generate_model_name(model, global_names)
5✔
364
            global_names.add(model.name)
5✔
365

366
        return models
5✔
367

368
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
369
        preferred_name = f"t_{model.table.name}"
5✔
370
        model.name = self.find_free_name(preferred_name, global_names)
5✔
371

372
    def render_module_variables(self, models: list[Model]) -> str:
5✔
373
        declarations = self.base.declarations
5✔
374

375
        if any(not isinstance(model, ModelClass) for model in models):
5✔
376
            if self.base.table_metadata_declaration is not None:
5✔
377
                declarations.append(self.base.table_metadata_declaration)
×
378

379
        return "\n".join(declarations)
5✔
380

381
    def render_models(self, models: list[Model]) -> str:
5✔
382
        rendered: list[str] = []
5✔
383
        for model in models:
5✔
384
            rendered_table = self.render_table(model.table)
5✔
385
            rendered.append(f"{model.name} = {rendered_table}")
5✔
386

387
        return "\n\n".join(rendered)
5✔
388

389
    def render_table(self, table: Table) -> str:
5✔
390
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
5✔
391
        kwargs: dict[str, object] = {}
5✔
392
        for column in table.columns:
5✔
393
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
394
            # Table.columns
395
            args.append(self.render_column(column, True, is_table=True))
5✔
396

397
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
398
            if uses_default_name(constraint):
5✔
399
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
400
                    continue
5✔
401
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
5✔
402
                    if len(constraint.columns) == 1:
5✔
403
                        continue
5✔
404

405
            args.append(self.render_constraint(constraint))
5✔
406

407
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
5✔
408
            # One-column indexes should be rendered as index=True on columns
409
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
410
                args.append(self.render_index(index))
5✔
411

412
        if table.schema:
5✔
413
            kwargs["schema"] = repr(table.schema)
5✔
414

415
        table_comment = getattr(table, "comment", None)
5✔
416
        if table_comment:
5✔
417
            kwargs["comment"] = repr(table.comment)
5✔
418

419
        # add info + dialect kwargs for callable context (opt-in)
420
        if self.include_dialect_options_and_info:
5✔
421
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=False)
5✔
422

423
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
5✔
424

425
    def render_index(self, index: Index) -> str:
5✔
426
        extra_args = [repr(col.name) for col in index.columns]
5✔
427
        kwargs = {
5✔
428
            key: repr(value) if isinstance(value, str) else value
429
            for key, value in sorted(index.kwargs.items(), key=lambda item: item[0])
430
        }
431
        if index.unique:
5✔
432
            kwargs["unique"] = True
5✔
433

434
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
5✔
435

436
    # TODO find better solution for is_table
437
    def render_column(
5✔
438
        self, column: Column[Any], show_name: bool, is_table: bool = False
439
    ) -> str:
440
        args = []
5✔
441
        kwargs: dict[str, Any] = {}
5✔
442
        kwarg = []
5✔
443
        is_part_of_composite_pk = (
5✔
444
            column.primary_key and len(column.table.primary_key) > 1
445
        )
446
        dedicated_fks = [
5✔
447
            c
448
            for c in column.foreign_keys
449
            if c.constraint
450
            and len(c.constraint.columns) == 1
451
            and uses_default_name(c.constraint)
452
        ]
453
        is_unique = any(
5✔
454
            isinstance(c, UniqueConstraint)
455
            and set(c.columns) == {column}
456
            and uses_default_name(c)
457
            for c in column.table.constraints
458
        )
459
        is_unique = is_unique or any(
5✔
460
            i.unique and set(i.columns) == {column} and uses_default_name(i)
461
            for i in column.table.indexes
462
        )
463
        is_primary = (
5✔
464
            any(
465
                isinstance(c, PrimaryKeyConstraint)
466
                and column.name in c.columns
467
                and uses_default_name(c)
468
                for c in column.table.constraints
469
            )
470
            or column.primary_key
471
        )
472
        has_index = any(
5✔
473
            set(i.columns) == {column} and uses_default_name(i)
474
            for i in column.table.indexes
475
        )
476

477
        if show_name:
5✔
478
            args.append(repr(column.name))
5✔
479

480
        # Render the column type if there are no foreign keys on it or any of them
481
        # points back to itself
482
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
5✔
483
            args.append(self.render_column_type(column))
5✔
484

485
        for fk in dedicated_fks:
5✔
486
            args.append(self.render_constraint(fk))
5✔
487

488
        if column.default:
5✔
489
            args.append(repr(column.default))
5✔
490

491
        if column.key != column.name:
5✔
492
            kwargs["key"] = column.key
×
493
        if is_primary:
5✔
494
            kwargs["primary_key"] = True
5✔
495
        if not column.nullable and not column.primary_key:
5✔
496
            kwargs["nullable"] = False
5✔
497
        if column.nullable and is_part_of_composite_pk:
5✔
498
            kwargs["nullable"] = True
5✔
499

500
        if is_unique:
5✔
501
            column.unique = True
5✔
502
            kwargs["unique"] = True
5✔
503
        if has_index:
5✔
504
            column.index = True
5✔
505
            kwarg.append("index")
5✔
506
            kwargs["index"] = True
5✔
507

508
        if isinstance(column.server_default, DefaultClause):
5✔
509
            kwargs["server_default"] = render_callable(
5✔
510
                "text", repr(cast(TextClause, column.server_default.arg).text)
511
            )
512
        elif isinstance(column.server_default, Computed):
5✔
513
            expression = str(column.server_default.sqltext)
5✔
514

515
            computed_kwargs = {}
5✔
516
            if column.server_default.persisted is not None:
5✔
517
                computed_kwargs["persisted"] = column.server_default.persisted
5✔
518

519
            args.append(
5✔
520
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
521
            )
522
        elif isinstance(column.server_default, Identity):
5✔
523
            args.append(repr(column.server_default))
5✔
524
        elif column.server_default:
5✔
525
            kwargs["server_default"] = repr(column.server_default)
×
526

527
        comment = getattr(column, "comment", None)
5✔
528
        if comment:
5✔
529
            kwargs["comment"] = repr(comment)
5✔
530

531
        # add column info + dialect kwargs for callable context (opt-in)
532
        if self.include_dialect_options_and_info:
5✔
533
            self._add_dialect_kwargs_and_info(column, kwargs, values_for_dict=False)
5✔
534

535
        return self.render_column_callable(is_table, *args, **kwargs)
5✔
536

537
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
5✔
538
        if is_table:
5✔
539
            self.add_import(Column)
5✔
540
            return render_callable("Column", *args, kwargs=kwargs)
5✔
541
        else:
542
            return render_callable("mapped_column", *args, kwargs=kwargs)
5✔
543

544
    def render_column_type(self, column: Column[Any]) -> str:
5✔
545
        column_type = column.type
5✔
546
        # Check if this is an enum column with a Python enum class
547
        if isinstance(column_type, Enum) and column is not None:
5✔
548
            if enum_class_name := self.enum_classes.get(
5✔
549
                (column.table.name, column.name)
550
            ):
551
                # Import SQLAlchemy Enum (will be handled in collect_imports)
552
                self.add_import(Enum)
5✔
553
                extra_kwargs = ""
5✔
554
                if column_type.name is not None:
5✔
555
                    extra_kwargs += f", name={column_type.name!r}"
5✔
556

557
                if column_type.schema is not None:
5✔
558
                    extra_kwargs += f", schema={column_type.schema!r}"
5✔
559

560
                return f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
5✔
561

562
        args = []
5✔
563
        kwargs: dict[str, Any] = {}
5✔
564

565
        # Check if this is an ARRAY column with an Enum item type mapped to a Python enum class
566
        if isinstance(column_type, ARRAY) and isinstance(column_type.item_type, Enum):
5✔
567
            if enum_class_name := self.enum_classes.get(
5✔
568
                (column.table.name, column.name)
569
            ):
570
                self.add_import(ARRAY)
5✔
571
                self.add_import(Enum)
5✔
572
                extra_kwargs = ""
5✔
573
                if column_type.item_type.name is not None:
5✔
574
                    extra_kwargs += f", name={column_type.item_type.name!r}"
5✔
575

576
                if column_type.item_type.schema is not None:
5✔
577
                    extra_kwargs += f", schema={column_type.item_type.schema!r}"
5✔
578

579
                rendered_enum = f"Enum({enum_class_name}, values_callable=lambda cls: [member.value for member in cls]{extra_kwargs})"
5✔
580
                if column_type.dimensions is not None:
5✔
581
                    kwargs["dimensions"] = repr(column_type.dimensions)
5✔
582

583
                return render_callable("ARRAY", rendered_enum, kwargs=kwargs)
5✔
584

585
        sig = inspect.signature(column_type.__class__.__init__)
5✔
586
        defaults = {param.name: param.default for param in sig.parameters.values()}
5✔
587
        missing = object()
5✔
588
        use_kwargs = False
5✔
589
        seen_param_names = {
5✔
590
            param.name
591
            for param in list(sig.parameters.values())[1:]
592
            if not param.name.startswith("_")
593
            and param.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
594
        }
595

596
        def render_param_value(value: Any) -> str:
5✔
597
            if isinstance(value, (JSONB, JSON)):
5✔
598
                # Remove astext_type if it's the default
599
                if (
5✔
600
                    isinstance(value.astext_type, Text)
601
                    and value.astext_type.length is None
602
                ):
603
                    value.astext_type = None  # type: ignore[assignment]
5✔
604
                else:
605
                    self.add_import(Text)
5✔
606

607
            if isinstance(value, TextClause):
5✔
608
                self.add_literal_import("sqlalchemy", "text")
5✔
609
                return render_callable("text", repr(value.text))
5✔
610

611
            return repr(value)
5✔
612

613
        for param in list(sig.parameters.values())[1:]:
5✔
614
            # Remove annoyances like _warn_on_bytestring
615
            if param.name.startswith("_"):
5✔
616
                continue
5✔
617
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
5✔
618
                use_kwargs = True
5✔
619
                continue
5✔
620

621
            value = getattr(column_type, param.name, missing)
5✔
622
            default = defaults.get(param.name, missing)
5✔
623
            if value is missing or value == default:
5✔
624
                use_kwargs = True
5✔
625
                continue
5✔
626

627
            rendered_value = render_param_value(value)
5✔
628
            if use_kwargs:
5✔
629
                kwargs[param.name] = rendered_value
5✔
630
            else:
631
                args.append(rendered_value)
5✔
632

633
        has_var_keyword = any(
5✔
634
            param.kind is Parameter.VAR_KEYWORD for param in sig.parameters.values()
635
        )
636
        has_var_positional = any(
5✔
637
            param.kind is Parameter.VAR_POSITIONAL for param in sig.parameters.values()
638
        )
639
        if has_var_keyword and not has_var_positional:
5✔
640
            for supercls in column_type.__class__.__mro__[1:]:
5✔
641
                if supercls is object:
5✔
642
                    break
5✔
643

644
                try:
5✔
645
                    super_sig = inspect.signature(supercls.__init__)
5✔
NEW
UNCOV
646
                except (TypeError, ValueError):
×
NEW
UNCOV
647
                    continue
×
648

649
                for super_param in list(super_sig.parameters.values())[1:]:
5✔
650
                    if super_param.name.startswith("_"):
5✔
651
                        continue
5✔
652

653
                    if super_param.kind in (
5✔
654
                        Parameter.POSITIONAL_ONLY,
655
                        Parameter.VAR_POSITIONAL,
656
                        Parameter.VAR_KEYWORD,
657
                    ):
658
                        continue
5✔
659

660
                    if super_param.name in seen_param_names:
5✔
661
                        continue
5✔
662

663
                    seen_param_names.add(super_param.name)
5✔
664
                    value = getattr(column_type, super_param.name, missing)
5✔
665
                    if value is missing:
5✔
666
                        continue
5✔
667

668
                    default = super_param.default
5✔
669
                    if default is not Parameter.empty and value == default:
5✔
670
                        continue
5✔
671

672
                    kwargs[super_param.name] = render_param_value(value)
5✔
673

674
        vararg = next(
5✔
675
            (
676
                param.name
677
                for param in sig.parameters.values()
678
                if param.kind is Parameter.VAR_POSITIONAL
679
            ),
680
            None,
681
        )
682
        if vararg and hasattr(column_type, vararg):
5✔
683
            varargs_repr = [repr(arg) for arg in getattr(column_type, vararg)]
5✔
684
            args.extend(varargs_repr)
5✔
685

686
        # These arguments cannot be autodetected from the Enum initializer
687
        if isinstance(column_type, Enum):
5✔
688
            for colname in "name", "schema":
5✔
689
                if (value := getattr(column_type, colname)) is not None:
5✔
690
                    kwargs[colname] = repr(value)
5✔
691

692
        if isinstance(column_type, (JSONB, JSON)):
5✔
693
            # Remove astext_type if it's the default
694
            if (
5✔
695
                isinstance(column_type.astext_type, Text)
696
                and column_type.astext_type.length is None
697
            ):
698
                del kwargs["astext_type"]
5✔
699

700
        if args or kwargs:
5✔
701
            return render_callable(column_type.__class__.__name__, *args, kwargs=kwargs)
5✔
702
        else:
703
            return column_type.__class__.__name__
5✔
704

705
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
5✔
706
        def add_fk_options(*opts: Any) -> None:
5✔
707
            args.extend(repr(opt) for opt in opts)
5✔
708
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
5✔
709
                value = getattr(constraint, attr, None)
5✔
710
                if value:
5✔
711
                    kwargs[attr] = repr(value)
5✔
712

713
        args: list[str] = []
5✔
714
        kwargs: dict[str, Any] = {}
5✔
715
        if isinstance(constraint, ForeignKey):
5✔
716
            remote_column = (
5✔
717
                f"{constraint.column.table.fullname}.{constraint.column.name}"
718
            )
719
            add_fk_options(remote_column)
5✔
720
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
721
            local_columns = get_column_names(constraint)
5✔
722
            remote_columns = [
5✔
723
                f"{fk.column.table.fullname}.{fk.column.name}"
724
                for fk in constraint.elements
725
            ]
726
            add_fk_options(local_columns, remote_columns)
5✔
727
        elif isinstance(constraint, CheckConstraint):
5✔
728
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
5✔
729
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
5✔
730
            args.extend(repr(col.name) for col in constraint.columns)
5✔
731
        else:
UNCOV
732
            raise TypeError(
×
733
                f"Cannot render constraint of type {constraint.__class__.__name__}"
734
            )
735

736
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
5✔
737
            kwargs["name"] = repr(constraint.name)
5✔
738

739
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
5✔
740

741
    def _add_dialect_kwargs_and_info(
5✔
742
        self, obj: Any, target_kwargs: dict[str, object], *, values_for_dict: bool
743
    ) -> None:
744
        """
745
        Merge SchemaItem-like object's .info and .dialect_kwargs into target_kwargs.
746
        - values_for_dict=True: keep raw values so pretty-printer emits repr() (for __table_args__ dict)
747
        - values_for_dict=False: set values to repr() strings (for callable kwargs)
748
        """
749
        info_dict = getattr(obj, "info", None)
5✔
750
        if info_dict:
5✔
751
            target_kwargs["info"] = info_dict if values_for_dict else repr(info_dict)
5✔
752

753
        dialect_keys: list[str]
754
        try:
5✔
755
            dialect_keys = sorted(getattr(obj, "dialect_kwargs"))
5✔
UNCOV
756
        except Exception:
×
UNCOV
757
            return
×
758

759
        dialect_kwargs = getattr(obj, "dialect_kwargs", {})
5✔
760
        for key in dialect_keys:
5✔
761
            try:
5✔
762
                value = dialect_kwargs[key]
5✔
UNCOV
763
            except Exception:
×
UNCOV
764
                continue
×
765

766
            # Render values:
767
            # - callable context (values_for_dict=False): produce a string expression.
768
            #   primitives use repr(value); custom objects stringify then repr().
769
            # - dict context (values_for_dict=True): pass raw primitives / str;
770
            #   custom objects become str(value) so pformat quotes them.
771
            if values_for_dict:
5✔
772
                if isinstance(value, type(None) | bool | int | float):
5✔
UNCOV
773
                    target_kwargs[key] = value
×
774
                elif isinstance(value, str | dict | list):
5✔
775
                    target_kwargs[key] = value
5✔
776
                else:
777
                    target_kwargs[key] = str(value)
5✔
778
            else:
779
                if isinstance(
5✔
780
                    value, type(None) | bool | int | float | str | dict | list
781
                ):
782
                    target_kwargs[key] = repr(value)
5✔
783
                else:
784
                    target_kwargs[key] = repr(str(value))
5✔
785

786
    def should_ignore_table(self, table: Table) -> bool:
5✔
787
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
788
        # tables
789
        return table.name in ("alembic_version", "migrate_version")
5✔
790

791
    def find_free_name(
5✔
792
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
793
    ) -> str:
794
        """
795
        Generate an attribute name that does not clash with other local or global names.
796
        """
797
        name = name.strip()
5✔
798
        assert name, "Identifier cannot be empty"
5✔
799
        name = _re_invalid_identifier.sub("_", name)
5✔
800
        if name[0].isdigit():
5✔
801
            name = "_" + name
5✔
802
        elif iskeyword(name) or name == "metadata":
5✔
803
            name += "_"
5✔
804

805
        original = name
5✔
806
        for i in count():
5✔
807
            if name not in global_names and name not in local_names:
5✔
808
                break
5✔
809

810
            name = original + (str(i) if i else "_")
5✔
811

812
        return name
5✔
813

814
    def _enum_name_to_class_name(self, enum_name: str) -> str:
5✔
815
        """Convert a database enum name to a Python class name (PascalCase)."""
816
        return "".join(part.capitalize() for part in enum_name.split("_") if part)
5✔
817

818
    def _create_enum_class(
5✔
819
        self, table_name: str, column_name: str, values: list[str]
820
    ) -> str:
821
        """
822
        Create a Python enum class name and register it.
823

824
        Returns the enum class name to use in generated code.
825
        """
826
        # Generate enum class name from table and column names
827
        # Convert to PascalCase: user_status -> UserStatus
828
        base_name = "".join(
5✔
829
            part.capitalize()
830
            for part in table_name.split("_") + column_name.split("_")
831
            if part
832
        )
833

834
        # Ensure uniqueness
835
        enum_class_name = base_name
5✔
836
        for counter in count(1):
5✔
837
            if enum_class_name not in self.enum_values:
5✔
838
                break
5✔
839

840
            # Check if it's the same enum (same values)
841
            if self.enum_values[enum_class_name] == values:
5✔
842
                # Reuse existing enum class
843
                return enum_class_name
5✔
844

845
            enum_class_name = f"{base_name}{counter}"
5✔
846

847
        # Register the new enum class
848
        self.enum_values[enum_class_name] = values
5✔
849
        return enum_class_name
5✔
850

851
    def render_enum_classes(self) -> str:
5✔
852
        """Render Python enum class definitions."""
853
        if not self.enum_values:
5✔
854
            return ""
5✔
855

856
        self.add_module_import("enum")
5✔
857

858
        enum_defs = []
5✔
859
        for enum_class_name, values in sorted(self.enum_values.items()):
5✔
860
            # Create enum members with valid Python identifiers
861
            members = []
5✔
862
            for value in values:
5✔
863
                # Unescape SQL escape sequences (e.g., \' -> ')
864
                # The value from the CHECK constraint has SQL escaping
865
                unescaped_value = value.replace("\\'", "'").replace("\\\\", "\\")
5✔
866

867
                # Create a valid identifier from the enum value
868
                member_name = _re_invalid_identifier.sub("_", unescaped_value).upper()
5✔
869
                if not member_name:
5✔
UNCOV
870
                    member_name = "EMPTY"
×
871
                elif member_name[0].isdigit():
5✔
UNCOV
872
                    member_name = "_" + member_name
×
873
                elif iskeyword(member_name):
5✔
UNCOV
874
                    member_name += "_"
×
875
                #
876
                # # Re-escape for Python string literal
877
                # python_escaped = unescaped_value.replace("\\", "\\\\").replace(
878
                #     "'", "\\'"
879
                # )
880
                members.append(f"    {member_name} = {unescaped_value!r}")
5✔
881

882
            enum_def = f"class {enum_class_name}(str, enum.Enum):\n" + "\n".join(
5✔
883
                members
884
            )
885
            enum_defs.append(enum_def)
5✔
886

887
        return "\n\n\n".join(enum_defs)
5✔
888

889
    def fix_column_types(self, table: Table) -> None:
5✔
890
        """Adjust the reflected column types."""
891

892
        def fix_enum_column(col_name: str, enum_type: Enum) -> None:
5✔
893
            if (table.name, col_name) in self.enum_classes:
5✔
894
                return
5✔
895

896
            if enum_type.name:
5✔
897
                existing_class = None
5✔
898
                for (_, _), cls in self.enum_classes.items():
5✔
899
                    if cls == self._enum_name_to_class_name(enum_type.name):
5✔
900
                        existing_class = cls
5✔
901
                        break
5✔
902

903
                if existing_class:
5✔
904
                    enum_class_name = existing_class
5✔
905
                else:
906
                    enum_class_name = self._enum_name_to_class_name(enum_type.name)
5✔
907
                    if enum_class_name not in self.enum_values:
5✔
908
                        self.enum_values[enum_class_name] = list(enum_type.enums)
5✔
909
            else:
910
                enum_class_name = self._create_enum_class(
5✔
911
                    table.name, col_name, list(enum_type.enums)
912
                )
913

914
            self.enum_classes[(table.name, col_name)] = enum_class_name
5✔
915

916
        # Detect check constraints for boolean and enum columns
917
        for constraint in table.constraints.copy():
5✔
918
            if isinstance(constraint, CheckConstraint):
5✔
919
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
5✔
920

921
                # Turn any integer-like column with a CheckConstraint like
922
                # "column IN (0, 1)" into a Boolean
923
                if match := _re_boolean_check_constraint.match(sqltext):
5✔
924
                    if colname_match := _re_column_name.match(match.group(1)):
5✔
925
                        colname = colname_match.group(3)
5✔
926
                        table.constraints.remove(constraint)
5✔
927
                        table.c[colname].type = Boolean()
5✔
928
                        continue
5✔
929

930
                # Turn VARCHAR columns with CHECK constraints like "column IN ('a', 'b')"
931
                # into synthetic Enum types with Python enum classes
932
                if (
5✔
933
                    "nosyntheticenums" not in self.options
934
                    and (match := _re_enum_check_constraint.match(sqltext))
935
                    and (colname_match := _re_column_name.match(match.group(1)))
936
                ):
937
                    colname = colname_match.group(3)
5✔
938
                    items = match.group(2)
5✔
939
                    if isinstance(table.c[colname].type, String) and not isinstance(
5✔
940
                        table.c[colname].type, Enum
941
                    ):
942
                        options = _re_enum_item.findall(items)
5✔
943
                        # Create Python enum class
944
                        enum_class_name = self._create_enum_class(
5✔
945
                            table.name, colname, options
946
                        )
947
                        self.enum_classes[(table.name, colname)] = enum_class_name
5✔
948
                        # Convert to Enum type but KEEP the constraint
949
                        table.c[colname].type = Enum(*options, native_enum=False)
5✔
950
                        continue
5✔
951

952
        for column in table.c:
5✔
953
            # Handle native database Enum types (e.g., PostgreSQL ENUM)
954
            if (
5✔
955
                "nonativeenums" not in self.options
956
                and isinstance(column.type, Enum)
957
                and column.type.enums
958
            ):
959
                fix_enum_column(column.name, column.type)
5✔
960

961
            # Handle ARRAY columns with Enum item types (e.g., PostgreSQL ARRAY(ENUM))
962
            elif (
5✔
963
                "nonativeenums" not in self.options
964
                and isinstance(column.type, ARRAY)
965
                and isinstance(column.type.item_type, Enum)
966
                and column.type.item_type.enums
967
            ):
968
                fix_enum_column(column.name, column.type.item_type)
5✔
969

970
            if not self.keep_dialect_types:
5✔
971
                try:
5✔
972
                    column.type = self.get_adapted_type(column.type)
5✔
973
                except CompileError:
5✔
974
                    continue
5✔
975

976
            # PostgreSQL specific fix: detect sequences from server_default
977
            if column.server_default and self.bind.dialect.name == "postgresql":
5✔
978
                if isinstance(column.server_default, DefaultClause) and isinstance(
5✔
979
                    column.server_default.arg, TextClause
980
                ):
981
                    schema, seqname = decode_postgresql_sequence(
5✔
982
                        column.server_default.arg
983
                    )
984
                    if seqname:
5✔
985
                        # Add an explicit sequence
986
                        if seqname != f"{column.table.name}_{column.name}_seq":
5✔
987
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
5✔
988

989
                        column.server_default = None
5✔
990

991
    def get_adapted_type(self, coltype: Any) -> Any:
5✔
992
        compiled_type = coltype.compile(self.bind.engine.dialect)
5✔
993
        for supercls in coltype.__class__.__mro__:
5✔
994
            if not supercls.__name__.startswith("_") and hasattr(
5✔
995
                supercls, "__visit_name__"
996
            ):
997
                # Don't try to adapt UserDefinedType as it's not a proper column type
998
                if supercls is UserDefinedType or issubclass(supercls, TypeDecorator):
5✔
999
                    return coltype
5✔
1000

1001
                # Hack to fix adaptation of the Enum class which is broken since
1002
                # SQLAlchemy 1.2
1003
                kw = {}
5✔
1004
                if supercls is Enum:
5✔
1005
                    kw["name"] = coltype.name
5✔
1006
                    if coltype.schema:
5✔
1007
                        kw["schema"] = coltype.schema
5✔
1008

1009
                # Hack to fix Postgres DOMAIN type adaptation, broken as of SQLAlchemy 2.0.42
1010
                # For additional information - https://github.com/agronholm/sqlacodegen/issues/416#issuecomment-3417480599
1011
                if supercls is DOMAIN:
5✔
1012
                    if coltype.default:
5✔
UNCOV
1013
                        kw["default"] = coltype.default
×
1014
                    if coltype.constraint_name is not None:
5✔
1015
                        kw["constraint_name"] = coltype.constraint_name
5✔
1016
                    if coltype.not_null:
5✔
UNCOV
1017
                        kw["not_null"] = coltype.not_null
×
1018
                    if coltype.check is not None:
5✔
1019
                        kw["check"] = coltype.check
5✔
1020
                    if coltype.create_type:
5✔
1021
                        kw["create_type"] = coltype.create_type
5✔
1022

1023
                try:
5✔
1024
                    new_coltype = coltype.adapt(supercls)
5✔
1025
                except TypeError:
5✔
1026
                    # If the adaptation fails, don't try again
1027
                    break
5✔
1028

1029
                for key, value in kw.items():
5✔
1030
                    setattr(new_coltype, key, value)
5✔
1031

1032
                if isinstance(coltype, ARRAY):
5✔
1033
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
5✔
1034

1035
                try:
5✔
1036
                    # If the adapted column type does not render the same as the
1037
                    # original, don't substitute it
1038
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
5✔
1039
                        break
5✔
1040
                except CompileError:
5✔
1041
                    # If the adapted column type can't be compiled, don't substitute it
1042
                    break
5✔
1043

1044
                # Stop on the first valid non-uppercase column type class
1045
                coltype = new_coltype
5✔
1046
                if supercls.__name__ != supercls.__name__.upper():
5✔
1047
                    break
5✔
1048

1049
        return coltype
5✔
1050

1051

1052
class DeclarativeGenerator(TablesGenerator):
5✔
1053
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
5✔
1054
        "use_inflect",
1055
        "nojoined",
1056
        "nobidi",
1057
        "noidsuffix",
1058
        "nofknames",
1059
    }
1060

1061
    def __init__(
5✔
1062
        self,
1063
        metadata: MetaData,
1064
        bind: Connection | Engine,
1065
        options: Sequence[str],
1066
        *,
1067
        indentation: str = "    ",
1068
        base_class_name: str = "Base",
1069
        explicit_foreign_keys: bool = False,
1070
    ):
1071
        super().__init__(metadata, bind, options, indentation=indentation)
5✔
1072
        self.base_class_name: str = base_class_name
5✔
1073
        self.inflect_engine = inflect.engine()
5✔
1074
        self.explicit_foreign_keys = explicit_foreign_keys
5✔
1075

1076
    def generate_base(self) -> None:
5✔
1077
        self.base = Base(
5✔
1078
            literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
1079
            declarations=[
1080
                f"class {self.base_class_name}(DeclarativeBase):",
1081
                f"{self.indentation}pass",
1082
            ],
1083
            metadata_ref=f"{self.base_class_name}.metadata",
1084
        )
1085

1086
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
1087
        super().collect_imports(models)
5✔
1088
        if any(isinstance(model, ModelClass) for model in models):
5✔
1089
            self.add_literal_import("sqlalchemy.orm", "Mapped")
5✔
1090
            self.add_literal_import("sqlalchemy.orm", "mapped_column")
5✔
1091

1092
    def collect_imports_for_model(self, model: Model) -> None:
5✔
1093
        super().collect_imports_for_model(model)
5✔
1094
        if isinstance(model, ModelClass):
5✔
1095
            if model.relationships:
5✔
1096
                self.add_literal_import("sqlalchemy.orm", "relationship")
5✔
1097

1098
    def generate_models(self) -> list[Model]:
5✔
1099
        models_by_table_name: dict[str, Model] = {}
5✔
1100

1101
        # Pick association tables from the metadata into their own set, don't process
1102
        # them normally
1103
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
5✔
1104
        for table in self.metadata.sorted_tables:
5✔
1105
            qualified_name = qualified_table_name(table)
5✔
1106

1107
            # Link tables have exactly two foreign key constraints and all columns are
1108
            # involved in them
1109
            fk_constraints = sorted(
5✔
1110
                table.foreign_key_constraints, key=get_constraint_sort_key
1111
            )
1112
            if len(fk_constraints) == 2 and all(
5✔
1113
                col.foreign_keys for col in table.columns
1114
            ):
1115
                model = models_by_table_name[qualified_name] = Model(table)
5✔
1116
                tablename = fk_constraints[0].elements[0].column.table.name
5✔
1117
                links[tablename].append(model)
5✔
1118
                continue
5✔
1119

1120
            # Only form model classes for tables that have a primary key and are not
1121
            # association tables
1122
            if not table.primary_key:
5✔
1123
                models_by_table_name[qualified_name] = Model(table)
5✔
1124
            else:
1125
                model = ModelClass(table)
5✔
1126
                models_by_table_name[qualified_name] = model
5✔
1127

1128
                # Fill in the columns
1129
                for column in table.c:
5✔
1130
                    column_attr = ColumnAttribute(model, column)
5✔
1131
                    model.columns.append(column_attr)
5✔
1132

1133
        # Add relationships
1134
        for model in models_by_table_name.values():
5✔
1135
            if isinstance(model, ModelClass):
5✔
1136
                self.generate_relationships(
5✔
1137
                    model, models_by_table_name, links[model.table.name]
1138
                )
1139

1140
        # Nest inherited classes in their superclasses to ensure proper ordering
1141
        if "nojoined" not in self.options:
5✔
1142
            for model in list(models_by_table_name.values()):
5✔
1143
                if not isinstance(model, ModelClass):
5✔
1144
                    continue
5✔
1145

1146
                pk_column_names = {col.name for col in model.table.primary_key.columns}
5✔
1147
                for constraint in model.table.foreign_key_constraints:
5✔
1148
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
1149
                        target = models_by_table_name[
5✔
1150
                            qualified_table_name(constraint.elements[0].column.table)
1151
                        ]
1152
                        if isinstance(target, ModelClass):
5✔
1153
                            model.parent_class = target
5✔
1154
                            target.children.append(model)
5✔
1155

1156
        # Change base if we only have tables
1157
        if not any(
5✔
1158
            isinstance(model, ModelClass) for model in models_by_table_name.values()
1159
        ):
1160
            super().generate_base()
5✔
1161

1162
        # Collect the imports
1163
        self.collect_imports(models_by_table_name.values())
5✔
1164

1165
        # Rename models and their attributes that conflict with imports or other
1166
        # attributes
1167
        global_names = {
5✔
1168
            name for namespace in self.imports.values() for name in namespace
1169
        }
1170
        for model in models_by_table_name.values():
5✔
1171
            self.generate_model_name(model, global_names)
5✔
1172
            global_names.add(model.name)
5✔
1173

1174
        return list(models_by_table_name.values())
5✔
1175

1176
    def generate_relationships(
5✔
1177
        self,
1178
        source: ModelClass,
1179
        models_by_table_name: dict[str, Model],
1180
        association_tables: list[Model],
1181
    ) -> list[RelationshipAttribute]:
1182
        relationships: list[RelationshipAttribute] = []
5✔
1183
        reverse_relationship: RelationshipAttribute | None
1184

1185
        # Add many-to-one (and one-to-many) relationships
1186
        pk_column_names = {col.name for col in source.table.primary_key.columns}
5✔
1187
        for constraint in sorted(
5✔
1188
            source.table.foreign_key_constraints, key=get_constraint_sort_key
1189
        ):
1190
            target = models_by_table_name[
5✔
1191
                qualified_table_name(constraint.elements[0].column.table)
1192
            ]
1193
            if isinstance(target, ModelClass):
5✔
1194
                if "nojoined" not in self.options:
5✔
1195
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
1196
                        parent = models_by_table_name[
5✔
1197
                            qualified_table_name(constraint.elements[0].column.table)
1198
                        ]
1199
                        if isinstance(parent, ModelClass):
5✔
1200
                            source.parent_class = parent
5✔
1201
                            parent.children.append(source)
5✔
1202
                            continue
5✔
1203

1204
                # Add uselist=False to One-to-One relationships
1205
                column_names = get_column_names(constraint)
5✔
1206
                if any(
5✔
1207
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
1208
                    and {col.name for col in c.columns} == set(column_names)
1209
                    for c in constraint.table.constraints
1210
                ):
1211
                    r_type = RelationshipType.ONE_TO_ONE
5✔
1212
                else:
1213
                    r_type = RelationshipType.MANY_TO_ONE
5✔
1214

1215
                relationship = RelationshipAttribute(r_type, source, target, constraint)
5✔
1216
                source.relationships.append(relationship)
5✔
1217

1218
                # For self referential relationships, remote_side needs to be set
1219
                if source is target:
5✔
1220
                    relationship.remote_side = [
5✔
1221
                        source.get_column_attribute(col.name)
1222
                        for col in constraint.referred_table.primary_key
1223
                    ]
1224

1225
                # If the two tables share more than one foreign key constraint,
1226
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
1227
                # it needs
1228
                common_fk_constraints = get_common_fk_constraints(
5✔
1229
                    source.table, target.table
1230
                )
1231
                if len(common_fk_constraints) > 1:
5✔
1232
                    relationship.foreign_keys = [
5✔
1233
                        source.get_column_attribute(key)
1234
                        for key in constraint.column_keys
1235
                    ]
1236

1237
                # Generate the opposite end of the relationship in the target class
1238
                if "nobidi" not in self.options:
5✔
1239
                    if r_type is RelationshipType.MANY_TO_ONE:
5✔
1240
                        r_type = RelationshipType.ONE_TO_MANY
5✔
1241

1242
                    reverse_relationship = RelationshipAttribute(
5✔
1243
                        r_type,
1244
                        target,
1245
                        source,
1246
                        constraint,
1247
                        foreign_keys=relationship.foreign_keys,
1248
                        backref=relationship,
1249
                    )
1250
                    relationship.backref = reverse_relationship
5✔
1251
                    target.relationships.append(reverse_relationship)
5✔
1252

1253
                    # For self referential relationships, remote_side needs to be set
1254
                    if source is target:
5✔
1255
                        reverse_relationship.remote_side = [
5✔
1256
                            source.get_column_attribute(colname)
1257
                            for colname in constraint.column_keys
1258
                        ]
1259

1260
        # Add many-to-many relationships
1261
        for association_table in association_tables:
5✔
1262
            fk_constraints = sorted(
5✔
1263
                association_table.table.foreign_key_constraints,
1264
                key=get_constraint_sort_key,
1265
            )
1266
            target = models_by_table_name[
5✔
1267
                qualified_table_name(fk_constraints[1].elements[0].column.table)
1268
            ]
1269
            if isinstance(target, ModelClass):
5✔
1270
                relationship = RelationshipAttribute(
5✔
1271
                    RelationshipType.MANY_TO_MANY,
1272
                    source,
1273
                    target,
1274
                    fk_constraints[1],
1275
                    association_table,
1276
                )
1277
                source.relationships.append(relationship)
5✔
1278

1279
                # Generate the opposite end of the relationship in the target class
1280
                reverse_relationship = None
5✔
1281
                if "nobidi" not in self.options:
5✔
1282
                    reverse_relationship = RelationshipAttribute(
5✔
1283
                        RelationshipType.MANY_TO_MANY,
1284
                        target,
1285
                        source,
1286
                        fk_constraints[0],
1287
                        association_table,
1288
                        relationship,
1289
                    )
1290
                    relationship.backref = reverse_relationship
5✔
1291
                    target.relationships.append(reverse_relationship)
5✔
1292

1293
                # Add a primary/secondary join for self-referential many-to-many
1294
                # relationships
1295
                if source is target:
5✔
1296
                    both_relationships = [relationship]
5✔
1297
                    reverse_flags = [False, True]
5✔
1298
                    if reverse_relationship:
5✔
1299
                        both_relationships.append(reverse_relationship)
5✔
1300

1301
                    for relationship, reverse in zip(both_relationships, reverse_flags):
5✔
1302
                        if (
5✔
1303
                            not relationship.association_table
1304
                            or not relationship.constraint
1305
                        ):
UNCOV
1306
                            continue
×
1307

1308
                        constraints = sorted(
5✔
1309
                            relationship.constraint.table.foreign_key_constraints,
1310
                            key=get_constraint_sort_key,
1311
                            reverse=reverse,
1312
                        )
1313
                        pri_pairs = zip(
5✔
1314
                            get_column_names(constraints[0]), constraints[0].elements
1315
                        )
1316
                        sec_pairs = zip(
5✔
1317
                            get_column_names(constraints[1]), constraints[1].elements
1318
                        )
1319
                        relationship.primaryjoin = [
5✔
1320
                            (
1321
                                relationship.source,
1322
                                elem.column.name,
1323
                                relationship.association_table,
1324
                                col,
1325
                            )
1326
                            for col, elem in pri_pairs
1327
                        ]
1328
                        relationship.secondaryjoin = [
5✔
1329
                            (
1330
                                relationship.target,
1331
                                elem.column.name,
1332
                                relationship.association_table,
1333
                                col,
1334
                            )
1335
                            for col, elem in sec_pairs
1336
                        ]
1337

1338
        return relationships
5✔
1339

1340
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
1341
        if isinstance(model, ModelClass):
5✔
1342
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
5✔
1343
            preferred_name = "".join(
5✔
1344
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1345
            )
1346
            if "use_inflect" in self.options:
5✔
1347
                singular_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1348
                if singular_name:
5✔
1349
                    preferred_name = singular_name
5✔
1350

1351
            model.name = self.find_free_name(preferred_name, global_names)
5✔
1352

1353
            # Fill in the names for column attributes
1354
            local_names: set[str] = set()
5✔
1355
            for column_attr in model.columns:
5✔
1356
                self.generate_column_attr_name(column_attr, global_names, local_names)
5✔
1357
                local_names.add(column_attr.name)
5✔
1358

1359
            # Fill in the names for relationship attributes
1360
            for relationship in model.relationships:
5✔
1361
                self.generate_relationship_name(relationship, global_names, local_names)
5✔
1362
                local_names.add(relationship.name)
5✔
1363
        else:
1364
            super().generate_model_name(model, global_names)
5✔
1365

1366
    def generate_column_attr_name(
5✔
1367
        self,
1368
        column_attr: ColumnAttribute,
1369
        global_names: set[str],
1370
        local_names: set[str],
1371
    ) -> None:
1372
        column_attr.name = self.find_free_name(
5✔
1373
            column_attr.column.name, global_names, local_names
1374
        )
1375

1376
    def generate_relationship_name(
5✔
1377
        self,
1378
        relationship: RelationshipAttribute,
1379
        global_names: set[str],
1380
        local_names: set[str],
1381
    ) -> None:
1382
        def strip_id_suffix(name: str) -> str:
5✔
1383
            # Strip _id only if at the end or followed by underscore (e.g., "course_id" -> "course", "course_id_1" -> "course_1")
1384
            # But don't strip from "parent_id1" (where id is followed by a digit without underscore)
1385
            return re.sub(r"_id(?=_|$)", "", name)
5✔
1386

1387
        def get_m2m_qualified_name(default_name: str) -> str:
5✔
1388
            """Generate qualified name for many-to-many relationship when multiple junction tables exist."""
1389
            # Check if there are multiple M2M relationships to the same target
1390
            target_m2m_relationships = [
5✔
1391
                r
1392
                for r in relationship.source.relationships
1393
                if r.target is relationship.target
1394
                and r.type == RelationshipType.MANY_TO_MANY
1395
            ]
1396

1397
            # Only use junction-based naming when there are multiple M2M to same target
1398
            if len(target_m2m_relationships) > 1:
5✔
1399
                if relationship.source is relationship.target:
5✔
1400
                    # Self-referential: use FK column name from junction table
1401
                    # (e.g., "parent_id" -> "parent", "child_id" -> "child")
1402
                    if relationship.constraint:
5✔
1403
                        column_names = [c.name for c in relationship.constraint.columns]
5✔
1404
                        if len(column_names) == 1:
5✔
1405
                            fk_qualifier = strip_id_suffix(column_names[0])
5✔
1406
                        else:
UNCOV
1407
                            fk_qualifier = "_".join(
×
1408
                                strip_id_suffix(col_name) for col_name in column_names
1409
                            )
1410
                        return fk_qualifier
5✔
1411
                elif relationship.association_table:
5✔
1412
                    # Normal: use junction table name as qualifier
1413
                    junction_name = relationship.association_table.table.name
5✔
1414
                    fk_qualifier = strip_id_suffix(junction_name)
5✔
1415
                    return f"{relationship.target.table.name}_{fk_qualifier}"
5✔
1416
            else:
1417
                # Single M2M: use simple name from junction table FK column
1418
                # (e.g., "right_id" -> "right" instead of "right_table")
1419
                if relationship.constraint and "noidsuffix" not in self.options:
5✔
1420
                    column_names = [c.name for c in relationship.constraint.columns]
5✔
1421
                    if len(column_names) == 1:
5✔
1422
                        stripped_name = strip_id_suffix(column_names[0])
5✔
1423
                        if stripped_name != column_names[0]:
5✔
1424
                            return stripped_name
5✔
1425

1426
            return default_name
5✔
1427

1428
        def get_fk_qualified_name(constraint: ForeignKeyConstraint) -> str:
5✔
1429
            """Generate qualified name for one-to-many/one-to-one relationship using FK column names."""
1430
            column_names = [c.name for c in constraint.columns]
5✔
1431

1432
            if len(column_names) == 1:
5✔
1433
                # Single column FK: strip _id suffix if present
1434
                fk_qualifier = strip_id_suffix(column_names[0])
5✔
1435
            else:
1436
                # Multi-column FK: concatenate all column names (strip _id from each)
1437
                fk_qualifier = "_".join(
5✔
1438
                    strip_id_suffix(col_name) for col_name in column_names
1439
                )
1440

1441
            # For self-referential relationships, don't prepend the table name
1442
            if relationship.source is relationship.target:
5✔
UNCOV
1443
                return fk_qualifier
×
1444
            else:
1445
                return f"{relationship.target.table.name}_{fk_qualifier}"
5✔
1446

1447
        def resolve_preferred_name() -> str:
5✔
1448
            resolved_name = relationship.target.table.name
5✔
1449

1450
            # For reverse relationships with multiple FKs to the same table, use the FK
1451
            # column name to create a more descriptive relationship name
1452
            # For M2M relationships with multiple junction tables, use the junction table name
1453
            use_fk_based_naming = "nofknames" not in self.options and (
5✔
1454
                (
1455
                    relationship.constraint
1456
                    and relationship.type
1457
                    in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1458
                    and relationship.foreign_keys
1459
                )
1460
                or (
1461
                    relationship.type == RelationshipType.MANY_TO_MANY
1462
                    and relationship.association_table
1463
                )
1464
            )
1465

1466
            if use_fk_based_naming:
5✔
1467
                if relationship.type == RelationshipType.MANY_TO_MANY:
5✔
1468
                    resolved_name = get_m2m_qualified_name(resolved_name)
5✔
1469
                elif relationship.constraint:
5✔
1470
                    resolved_name = get_fk_qualified_name(relationship.constraint)
5✔
1471

1472
            # If there's a constraint with a single column that contains "_id", use the
1473
            # stripped version as the relationship name
1474
            elif relationship.constraint and "noidsuffix" not in self.options:
5✔
1475
                is_source = relationship.source.table is relationship.constraint.table
5✔
1476
                if is_source or relationship.type not in (
5✔
1477
                    RelationshipType.ONE_TO_ONE,
1478
                    RelationshipType.ONE_TO_MANY,
1479
                ):
1480
                    column_names = [c.name for c in relationship.constraint.columns]
5✔
1481
                    if len(column_names) == 1:
5✔
1482
                        stripped_name = strip_id_suffix(column_names[0])
5✔
1483
                        # Only use the stripped name if it actually changed (had _id in it)
1484
                        if stripped_name != column_names[0]:
5✔
1485
                            resolved_name = stripped_name
5✔
1486
                    else:
1487
                        # For composite FKs, check if there are multiple FKs to the same target
1488
                        target_relationships = [
5✔
1489
                            r
1490
                            for r in relationship.source.relationships
1491
                            if r.target is relationship.target
1492
                            and r.type == relationship.type
1493
                        ]
1494
                        if len(target_relationships) > 1:
5✔
1495
                            # Multiple FKs to same table - use concatenated column names
1496
                            resolved_name = "_".join(
5✔
1497
                                strip_id_suffix(col_name) for col_name in column_names
1498
                            )
1499

1500
            if "use_inflect" in self.options:
5✔
1501
                inflected_name: str | Literal[False]
1502
                if relationship.type in (
5✔
1503
                    RelationshipType.ONE_TO_MANY,
1504
                    RelationshipType.MANY_TO_MANY,
1505
                ):
1506
                    if not self.inflect_engine.singular_noun(resolved_name):
5✔
1507
                        resolved_name = self.inflect_engine.plural_noun(resolved_name)
5✔
1508
                else:
1509
                    inflected_name = self.inflect_engine.singular_noun(resolved_name)
5✔
1510
                    if inflected_name:
5✔
1511
                        resolved_name = inflected_name
5✔
1512

1513
            return resolved_name
5✔
1514

1515
        if (
5✔
1516
            relationship.type
1517
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1518
            and relationship.source is relationship.target
1519
            and relationship.backref
1520
            and relationship.backref.name
1521
        ):
1522
            preferred_name = relationship.backref.name + "_reverse"
5✔
1523
        else:
1524
            preferred_name = resolve_preferred_name()
5✔
1525

1526
        relationship.name = self.find_free_name(
5✔
1527
            preferred_name, global_names, local_names
1528
        )
1529

1530
    def render_models(self, models: list[Model]) -> str:
5✔
1531
        rendered: list[str] = []
5✔
1532
        for model in models:
5✔
1533
            if isinstance(model, ModelClass):
5✔
1534
                rendered.append(self.render_class(model))
5✔
1535
            else:
1536
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
5✔
1537

1538
        return "\n\n\n".join(rendered)
5✔
1539

1540
    def render_class(self, model: ModelClass) -> str:
5✔
1541
        sections: list[str] = []
5✔
1542

1543
        # Render class variables / special declarations
1544
        class_vars: str = self.render_class_variables(model)
5✔
1545
        if class_vars:
5✔
1546
            sections.append(class_vars)
5✔
1547

1548
        # Render column attributes
1549
        rendered_column_attributes: list[str] = []
5✔
1550
        for nullable in (False, True):
5✔
1551
            for column_attr in model.columns:
5✔
1552
                if column_attr.column.nullable is nullable:
5✔
1553
                    rendered_column_attributes.append(
5✔
1554
                        self.render_column_attribute(column_attr)
1555
                    )
1556

1557
        if rendered_column_attributes:
5✔
1558
            sections.append("\n".join(rendered_column_attributes))
5✔
1559

1560
        # Render relationship attributes
1561
        rendered_relationship_attributes: list[str] = [
5✔
1562
            self.render_relationship(relationship)
1563
            for relationship in model.relationships
1564
        ]
1565

1566
        if rendered_relationship_attributes:
5✔
1567
            sections.append("\n".join(rendered_relationship_attributes))
5✔
1568

1569
        declaration = self.render_class_declaration(model)
5✔
1570
        rendered_sections = "\n\n".join(
5✔
1571
            indent(section, self.indentation) for section in sections
1572
        )
1573
        return f"{declaration}\n{rendered_sections}"
5✔
1574

1575
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1576
        parent_class_name = (
5✔
1577
            model.parent_class.name if model.parent_class else self.base_class_name
1578
        )
1579
        return f"class {model.name}({parent_class_name}):"
5✔
1580

1581
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1582
        variables = [f"__tablename__ = {model.table.name!r}"]
5✔
1583

1584
        # Render constraints and indexes as __table_args__
1585
        table_args = self.render_table_args(model.table)
5✔
1586
        if table_args:
5✔
1587
            variables.append(f"__table_args__ = {table_args}")
5✔
1588

1589
        return "\n".join(variables)
5✔
1590

1591
    def render_table_args(self, table: Table) -> str:
5✔
1592
        args: list[str] = []
5✔
1593
        kwargs: dict[str, object] = {}
5✔
1594

1595
        # Render constraints
1596
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
1597
            if uses_default_name(constraint):
5✔
1598
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
1599
                    continue
5✔
1600
                if (
5✔
1601
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1602
                    and len(constraint.columns) == 1
1603
                ):
1604
                    continue
5✔
1605

1606
            args.append(self.render_constraint(constraint))
5✔
1607

1608
        # Render indexes
1609
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
5✔
1610
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
1611
                args.append(self.render_index(index))
5✔
1612

1613
        if table.schema:
5✔
1614
            kwargs["schema"] = table.schema
5✔
1615

1616
        if table.comment:
5✔
1617
            kwargs["comment"] = table.comment
5✔
1618

1619
        # add info + dialect kwargs for dict context (__table_args__) (opt-in)
1620
        if self.include_dialect_options_and_info:
5✔
1621
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=True)
5✔
1622

1623
        if kwargs:
5✔
1624
            formatted_kwargs = pformat(kwargs)
5✔
1625
            if not args:
5✔
1626
                return formatted_kwargs
5✔
1627
            else:
1628
                args.append(formatted_kwargs)
5✔
1629

1630
        if args:
5✔
1631
            rendered_args = f",\n{self.indentation}".join(args)
5✔
1632
            if len(args) == 1:
5✔
1633
                rendered_args += ","
5✔
1634

1635
            return f"(\n{self.indentation}{rendered_args}\n)"
5✔
1636
        else:
1637
            return ""
5✔
1638

1639
    def render_column_python_type(self, column: Column[Any]) -> str:
5✔
1640
        def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
5✔
1641
            column_type = column.type
5✔
1642
            pre: list[str] = []
5✔
1643
            post_size = 0
5✔
1644
            if column.nullable:
5✔
1645
                self.add_literal_import("typing", "Optional")
5✔
1646
                pre.append("Optional[")
5✔
1647
                post_size += 1
5✔
1648

1649
            if isinstance(column_type, ARRAY):
5✔
1650
                dim = getattr(column_type, "dimensions", None) or 1
5✔
1651
                pre.extend("list[" for _ in range(dim))
5✔
1652
                post_size += dim
5✔
1653

1654
                column_type = column_type.item_type
5✔
1655

1656
            return "".join(pre), column_type, "]" * post_size
5✔
1657

1658
        def render_python_type(column_type: TypeEngine[Any]) -> str:
5✔
1659
            # Check if this is an enum column with a Python enum class
1660
            if isinstance(column_type, Enum):
5✔
1661
                table_name = column.table.name
5✔
1662
                column_name = column.name
5✔
1663
                if (table_name, column_name) in self.enum_classes:
5✔
1664
                    enum_class_name = self.enum_classes[(table_name, column_name)]
5✔
1665
                    return enum_class_name
5✔
1666

1667
            if isinstance(column_type, DOMAIN):
5✔
1668
                column_type = column_type.data_type
5✔
1669

1670
            try:
5✔
1671
                python_type = column_type.python_type
5✔
1672
                python_type_module = python_type.__module__
5✔
1673
                python_type_name = python_type.__name__
5✔
1674
            except NotImplementedError:
5✔
1675
                self.add_literal_import("typing", "Any")
5✔
1676
                return "Any"
5✔
1677

1678
            if python_type_module == "builtins":
5✔
1679
                return python_type_name
5✔
1680

1681
            self.add_module_import(python_type_module)
5✔
1682
            return f"{python_type_module}.{python_type_name}"
5✔
1683

1684
        pre, col_type, post = get_type_qualifiers()
5✔
1685
        column_python_type = f"{pre}{render_python_type(col_type)}{post}"
5✔
1686
        return column_python_type
5✔
1687

1688
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1689
        column = column_attr.column
5✔
1690
        rendered_column = self.render_column(column, column_attr.name != column.name)
5✔
1691
        rendered_column_python_type = self.render_column_python_type(column)
5✔
1692

1693
        return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
5✔
1694

1695
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1696
        kwargs = self.render_relationship_arguments(relationship)
5✔
1697
        annotation = self.render_relationship_annotation(relationship)
5✔
1698
        rendered_relationship = render_callable(
5✔
1699
            "relationship", repr(relationship.target.name), kwargs=kwargs
1700
        )
1701
        return f"{relationship.name}: Mapped[{annotation}] = {rendered_relationship}"
5✔
1702

1703
    def render_relationship_annotation(
5✔
1704
        self, relationship: RelationshipAttribute
1705
    ) -> str:
1706
        match relationship.type:
5✔
1707
            case RelationshipType.ONE_TO_MANY:
5✔
1708
                return f"list[{relationship.target.name!r}]"
5✔
1709
            case RelationshipType.ONE_TO_ONE | RelationshipType.MANY_TO_ONE:
5✔
1710
                if relationship.constraint and any(
5✔
1711
                    col.nullable for col in relationship.constraint.columns
1712
                ):
1713
                    self.add_literal_import("typing", "Optional")
5✔
1714
                    return f"Optional[{relationship.target.name!r}]"
5✔
1715
                else:
1716
                    return f"'{relationship.target.name}'"
5✔
1717
            case RelationshipType.MANY_TO_MANY:
5✔
1718
                return f"list[{relationship.target.name!r}]"
5✔
1719

1720
    def render_relationship_arguments(
5✔
1721
        self, relationship: RelationshipAttribute
1722
    ) -> Mapping[str, Any]:
1723
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
5✔
1724
            rendered = []
5✔
1725
            for attr in column_attrs:
5✔
1726
                if attr.model is relationship.source:
5✔
1727
                    rendered.append(attr.name)
5✔
1728
                else:
UNCOV
1729
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
×
1730

1731
            return "[" + ", ".join(rendered) + "]"
5✔
1732

1733
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
5✔
1734
            rendered = []
5✔
1735
            render_as_string = False
5✔
1736
            # Assume that column_attrs are all in relationship.source or none
1737
            for attr in column_attrs:
5✔
1738
                if not self.explicit_foreign_keys and attr.model is relationship.source:
5✔
1739
                    rendered.append(attr.name)
5✔
1740
                else:
1741
                    rendered.append(f"{attr.model.name}.{attr.name}")
5✔
1742
                    render_as_string = True
5✔
1743

1744
            if render_as_string:
5✔
1745
                return "'[" + ", ".join(rendered) + "]'"
5✔
1746
            else:
1747
                return "[" + ", ".join(rendered) + "]"
5✔
1748

1749
        def render_join(terms: list[JoinType]) -> str:
5✔
1750
            rendered_joins = []
5✔
1751
            for source, source_col, target, target_col in terms:
5✔
1752
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
5✔
1753
                if target.__class__ is Model:
5✔
1754
                    rendered += "c."
5✔
1755

1756
                rendered += str(target_col)
5✔
1757
                rendered_joins.append(rendered)
5✔
1758

1759
            if len(rendered_joins) > 1:
5✔
UNCOV
1760
                rendered = ", ".join(rendered_joins)
×
UNCOV
1761
                return f"and_({rendered})"
×
1762
            else:
1763
                return rendered_joins[0]
5✔
1764

1765
        # Render keyword arguments
1766
        kwargs: dict[str, Any] = {}
5✔
1767
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
5✔
1768
            if relationship.constraint.referred_table is relationship.source.table:
5✔
1769
                kwargs["uselist"] = False
5✔
1770

1771
        # Add the "secondary" keyword for many-to-many relationships
1772
        if relationship.association_table:
5✔
1773
            table_ref = relationship.association_table.table.name
5✔
1774
            if relationship.association_table.schema:
5✔
1775
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
5✔
1776

1777
            kwargs["secondary"] = repr(table_ref)
5✔
1778

1779
        if relationship.remote_side:
5✔
1780
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
5✔
1781

1782
        if relationship.foreign_keys:
5✔
1783
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
5✔
1784

1785
        if relationship.primaryjoin:
5✔
1786
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
5✔
1787

1788
        if relationship.secondaryjoin:
5✔
1789
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
5✔
1790

1791
        if relationship.backref:
5✔
1792
            kwargs["back_populates"] = repr(relationship.backref.name)
5✔
1793

1794
        return kwargs
5✔
1795

1796

1797
class DataclassGenerator(DeclarativeGenerator):
5✔
1798
    def __init__(
5✔
1799
        self,
1800
        metadata: MetaData,
1801
        bind: Connection | Engine,
1802
        options: Sequence[str],
1803
        *,
1804
        indentation: str = "    ",
1805
        base_class_name: str = "Base",
1806
        quote_annotations: bool = False,
1807
        metadata_key: str = "sa",
1808
    ):
1809
        super().__init__(
5✔
1810
            metadata,
1811
            bind,
1812
            options,
1813
            indentation=indentation,
1814
            base_class_name=base_class_name,
1815
        )
1816
        self.metadata_key: str = metadata_key
5✔
1817
        self.quote_annotations: bool = quote_annotations
5✔
1818

1819
    def generate_base(self) -> None:
5✔
1820
        self.base = Base(
5✔
1821
            literal_imports=[
1822
                LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1823
                LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1824
            ],
1825
            declarations=[
1826
                (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"),
1827
                f"{self.indentation}pass",
1828
            ],
1829
            metadata_ref=f"{self.base_class_name}.metadata",
1830
        )
1831

1832

1833
class SQLModelGenerator(DeclarativeGenerator):
5✔
1834
    def __init__(
5✔
1835
        self,
1836
        metadata: MetaData,
1837
        bind: Connection | Engine,
1838
        options: Sequence[str],
1839
        *,
1840
        indentation: str = "    ",
1841
        base_class_name: str = "SQLModel",
1842
    ):
1843
        super().__init__(
5✔
1844
            metadata,
1845
            bind,
1846
            options,
1847
            indentation=indentation,
1848
            base_class_name=base_class_name,
1849
            explicit_foreign_keys=True,
1850
        )
1851

1852
    @property
5✔
1853
    def views_supported(self) -> bool:
5✔
UNCOV
1854
        return False
×
1855

1856
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
5✔
1857
        self.add_import(Column)
5✔
1858
        return render_callable("Column", *args, kwargs=kwargs)
5✔
1859

1860
    def generate_base(self) -> None:
5✔
1861
        self.base = Base(
5✔
1862
            literal_imports=[],
1863
            declarations=[],
1864
            metadata_ref="",
1865
        )
1866

1867
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
1868
        super(DeclarativeGenerator, self).collect_imports(models)
5✔
1869
        if any(isinstance(model, ModelClass) for model in models):
5✔
1870
            self.remove_literal_import("sqlalchemy", "MetaData")
5✔
1871
            self.add_literal_import("sqlmodel", "SQLModel")
5✔
1872
            self.add_literal_import("sqlmodel", "Field")
5✔
1873

1874
    def collect_imports_for_model(self, model: Model) -> None:
5✔
1875
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
5✔
1876
        if isinstance(model, ModelClass):
5✔
1877
            for column_attr in model.columns:
5✔
1878
                if column_attr.column.nullable:
5✔
1879
                    self.add_literal_import("typing", "Optional")
5✔
1880
                    break
5✔
1881

1882
            if model.relationships:
5✔
1883
                self.add_literal_import("sqlmodel", "Relationship")
5✔
1884

1885
    def render_module_variables(self, models: list[Model]) -> str:
5✔
1886
        declarations: list[str] = []
5✔
1887
        if any(not isinstance(model, ModelClass) for model in models):
5✔
UNCOV
1888
            if self.base.table_metadata_declaration is not None:
×
UNCOV
1889
                declarations.append(self.base.table_metadata_declaration)
×
1890

1891
        return "\n".join(declarations)
5✔
1892

1893
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1894
        if model.parent_class:
5✔
UNCOV
1895
            parent = model.parent_class.name
×
1896
        else:
1897
            parent = self.base_class_name
5✔
1898

1899
        superclass_part = f"({parent}, table=True)"
5✔
1900
        return f"class {model.name}{superclass_part}:"
5✔
1901

1902
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1903
        variables = []
5✔
1904

1905
        if model.table.name != model.name.lower():
5✔
1906
            variables.append(f"__tablename__ = {model.table.name!r}")
5✔
1907

1908
        # Render constraints and indexes as __table_args__
1909
        table_args = self.render_table_args(model.table)
5✔
1910
        if table_args:
5✔
1911
            variables.append(f"__table_args__ = {table_args}")
5✔
1912

1913
        return "\n".join(variables)
5✔
1914

1915
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1916
        column = column_attr.column
5✔
1917
        rendered_column = self.render_column(column, True)
5✔
1918
        rendered_column_python_type = self.render_column_python_type(column)
5✔
1919

1920
        kwargs: dict[str, Any] = {}
5✔
1921
        if column.nullable:
5✔
1922
            kwargs["default"] = None
5✔
1923
        kwargs["sa_column"] = f"{rendered_column}"
5✔
1924

1925
        rendered_field = render_callable("Field", kwargs=kwargs)
5✔
1926

1927
        return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
5✔
1928

1929
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1930
        kwargs = self.render_relationship_arguments(relationship)
5✔
1931
        annotation = self.render_relationship_annotation(relationship)
5✔
1932

1933
        native_kwargs: dict[str, Any] = {}
5✔
1934
        non_native_kwargs: dict[str, Any] = {}
5✔
1935
        for key, value in kwargs.items():
5✔
1936
            # The following keyword arguments are natively supported in Relationship
1937
            if key in ("back_populates", "cascade_delete", "passive_deletes"):
5✔
1938
                native_kwargs[key] = value
5✔
1939
            else:
1940
                non_native_kwargs[key] = value
5✔
1941

1942
        if non_native_kwargs:
5✔
1943
            native_kwargs["sa_relationship_kwargs"] = (
5✔
1944
                "{"
1945
                + ", ".join(
1946
                    f"{key!r}: {value}" for key, value in non_native_kwargs.items()
1947
                )
1948
                + "}"
1949
            )
1950

1951
        rendered_field = render_callable("Relationship", kwargs=native_kwargs)
5✔
1952
        return f"{relationship.name}: {annotation} = {rendered_field}"
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc