• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 19626218723

24 Nov 2025 07:10AM UTC coverage: 97.365% (-0.3%) from 97.64%
19626218723

Pull #438

github

web-flow
Merge ae9118dd2 into d7a6024df
Pull Request #438: Add support for rendering dialect kwargs and info, and introduce keep-dialect-types option

74 of 80 new or added lines in 3 files covered. (92.5%)

23 existing lines in 1 file now uncovered.

1515 of 1556 relevant lines covered (97.37%)

4.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.61
/src/sqlacodegen/generators.py
1
from __future__ import annotations
5✔
2

3
import inspect
5✔
4
import re
5✔
5
import sys
5✔
6
from abc import ABCMeta, abstractmethod
5✔
7
from collections import defaultdict
5✔
8
from collections.abc import Collection, Iterable, Sequence
5✔
9
from dataclasses import dataclass
5✔
10
from importlib import import_module
5✔
11
from inspect import Parameter
5✔
12
from itertools import count
5✔
13
from keyword import iskeyword
5✔
14
from pprint import pformat
5✔
15
from textwrap import indent
5✔
16
from typing import Any, ClassVar, Literal, cast
5✔
17

18
import inflect
5✔
19
import sqlalchemy
5✔
20
from sqlalchemy import (
5✔
21
    ARRAY,
22
    Boolean,
23
    CheckConstraint,
24
    Column,
25
    Computed,
26
    Constraint,
27
    DefaultClause,
28
    Enum,
29
    ForeignKey,
30
    ForeignKeyConstraint,
31
    Identity,
32
    Index,
33
    MetaData,
34
    PrimaryKeyConstraint,
35
    String,
36
    Table,
37
    Text,
38
    TypeDecorator,
39
    UniqueConstraint,
40
)
41
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
5✔
42
from sqlalchemy.engine import Connection, Engine
5✔
43
from sqlalchemy.exc import CompileError
5✔
44
from sqlalchemy.sql.elements import TextClause
5✔
45
from sqlalchemy.sql.type_api import UserDefinedType
5✔
46
from sqlalchemy.types import TypeEngine
5✔
47

48
from .models import (
5✔
49
    ColumnAttribute,
50
    JoinType,
51
    Model,
52
    ModelClass,
53
    RelationshipAttribute,
54
    RelationshipType,
55
)
56
from .utils import (
5✔
57
    decode_postgresql_sequence,
58
    get_column_names,
59
    get_common_fk_constraints,
60
    get_compiled_expression,
61
    get_constraint_sort_key,
62
    get_stdlib_module_names,
63
    qualified_table_name,
64
    render_callable,
65
    uses_default_name,
66
)
67

68
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
5✔
69
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
5✔
70
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
5✔
71
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
5✔
72
_re_invalid_identifier = re.compile(r"(?u)\W")
5✔
73

74

75
@dataclass
5✔
76
class LiteralImport:
5✔
77
    pkgname: str
5✔
78
    name: str
5✔
79

80

81
@dataclass
5✔
82
class Base:
5✔
83
    """Representation of MetaData for Tables, respectively Base for classes"""
84

85
    literal_imports: list[LiteralImport]
5✔
86
    declarations: list[str]
5✔
87
    metadata_ref: str
5✔
88
    decorator: str | None = None
5✔
89
    table_metadata_declaration: str | None = None
5✔
90

91

92
class CodeGenerator(metaclass=ABCMeta):
5✔
93
    valid_options: ClassVar[set[str]] = set()
5✔
94

95
    def __init__(
5✔
96
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
97
    ):
98
        self.metadata: MetaData = metadata
5✔
99
        self.bind: Connection | Engine = bind
5✔
100
        self.options: set[str] = set(options)
5✔
101

102
        # Validate options
103
        invalid_options = {opt for opt in options if opt not in self.valid_options}
5✔
104
        if invalid_options:
5✔
105
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
106

107
    @property
5✔
108
    @abstractmethod
5✔
109
    def views_supported(self) -> bool:
5✔
110
        pass
×
111

112
    @abstractmethod
5✔
113
    def generate(self) -> str:
5✔
114
        """
115
        Generate the code for the given metadata.
116
        .. note:: May modify the metadata.
117
        """
118

119

120
@dataclass(eq=False)
5✔
121
class TablesGenerator(CodeGenerator):
5✔
122
    valid_options: ClassVar[set[str]] = {
5✔
123
        "noindexes",
124
        "noconstraints",
125
        "nocomments",
126
        "include-dialect-options",
127
        "keep-dialect-types",
128
    }
129
    stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
5✔
130

131
    def __init__(
5✔
132
        self,
133
        metadata: MetaData,
134
        bind: Connection | Engine,
135
        options: Sequence[str],
136
        *,
137
        indentation: str = "    ",
138
    ):
139
        super().__init__(metadata, bind, options)
5✔
140
        self.indentation: str = indentation
5✔
141
        self.imports: dict[str, set[str]] = defaultdict(set)
5✔
142
        self.module_imports: set[str] = set()
5✔
143

144
        # Render SchemaItem.info and dialect kwargs (Table/Column) into output
145
        self.include_dialect_options_and_info: bool = "include-dialect-options" in self.options
5✔
146
        # Keep dialect-specific types instead of adapting to generic SQLAlchemy types
147
        self.keep_dialect_types: bool = "keep-dialect-types" in self.options
5✔
148

149
    @property
5✔
150
    def views_supported(self) -> bool:
5✔
UNCOV
151
        return True
×
152

153
    def generate_base(self) -> None:
5✔
154
        self.base = Base(
5✔
155
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
156
            declarations=["metadata = MetaData()"],
157
            metadata_ref="metadata",
158
        )
159

160
    def generate(self) -> str:
5✔
161
        self.generate_base()
5✔
162

163
        sections: list[str] = []
5✔
164

165
        # Remove unwanted elements from the metadata
166
        for table in list(self.metadata.tables.values()):
5✔
167
            if self.should_ignore_table(table):
5✔
UNCOV
168
                self.metadata.remove(table)
×
UNCOV
169
                continue
×
170

171
            if "noindexes" in self.options:
5✔
172
                table.indexes.clear()
5✔
173

174
            if "noconstraints" in self.options:
5✔
175
                table.constraints.clear()
5✔
176

177
            if "nocomments" in self.options:
5✔
178
                table.comment = None
5✔
179

180
            for column in table.columns:
5✔
181
                if "nocomments" in self.options:
5✔
182
                    column.comment = None
5✔
183

184
        # Use information from column constraints to figure out the intended column
185
        # types
186
        for table in self.metadata.tables.values():
5✔
187
            self.fix_column_types(table)
5✔
188

189
        # Generate the models
190
        models: list[Model] = self.generate_models()
5✔
191

192
        # Render module level variables
193
        variables = self.render_module_variables(models)
5✔
194
        if variables:
5✔
195
            sections.append(variables + "\n")
5✔
196

197
        # Render models
198
        rendered_models = self.render_models(models)
5✔
199
        if rendered_models:
5✔
200
            sections.append(rendered_models)
5✔
201

202
        # Render collected imports
203
        groups = self.group_imports()
5✔
204
        imports = "\n\n".join("\n".join(line for line in group) for group in groups)
5✔
205
        if imports:
5✔
206
            sections.insert(0, imports)
5✔
207

208
        return "\n\n".join(sections) + "\n"
5✔
209

210
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
211
        for literal_import in self.base.literal_imports:
5✔
212
            self.add_literal_import(literal_import.pkgname, literal_import.name)
5✔
213

214
        for model in models:
5✔
215
            self.collect_imports_for_model(model)
5✔
216

217
    def collect_imports_for_model(self, model: Model) -> None:
5✔
218
        if model.__class__ is Model:
5✔
219
            self.add_import(Table)
5✔
220

221
        for column in model.table.c:
5✔
222
            self.collect_imports_for_column(column)
5✔
223

224
        for constraint in model.table.constraints:
5✔
225
            self.collect_imports_for_constraint(constraint)
5✔
226

227
        for index in model.table.indexes:
5✔
228
            self.collect_imports_for_constraint(index)
5✔
229

230
    def collect_imports_for_column(self, column: Column[Any]) -> None:
5✔
231
        self.add_import(column.type)
5✔
232

233
        if isinstance(column.type, ARRAY):
5✔
234
            self.add_import(column.type.item_type.__class__)
5✔
235
        elif isinstance(column.type, (JSONB, JSON)):
5✔
236
            if (
5✔
237
                not isinstance(column.type.astext_type, Text)
238
                or column.type.astext_type.length is not None
239
            ):
240
                self.add_import(column.type.astext_type)
5✔
241
        elif isinstance(column.type, DOMAIN):
5✔
242
            self.add_import(column.type.data_type.__class__)
5✔
243

244
        if column.default:
5✔
245
            self.add_import(column.default)
5✔
246

247
        if column.server_default:
5✔
248
            if isinstance(column.server_default, (Computed, Identity)):
5✔
249
                self.add_import(column.server_default)
5✔
250
            elif isinstance(column.server_default, DefaultClause):
5✔
251
                self.add_literal_import("sqlalchemy", "text")
5✔
252

253
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
5✔
254
        if isinstance(constraint, Index):
5✔
255
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
256
                self.add_literal_import("sqlalchemy", "Index")
5✔
257
        elif isinstance(constraint, PrimaryKeyConstraint):
5✔
258
            if not uses_default_name(constraint):
5✔
259
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
5✔
260
        elif isinstance(constraint, UniqueConstraint):
5✔
261
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
262
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
5✔
263
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
264
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
265
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
5✔
266
            else:
267
                self.add_import(ForeignKey)
5✔
268
        else:
269
            self.add_import(constraint)
5✔
270

271
    def add_import(self, obj: Any) -> None:
5✔
272
        # Don't store builtin imports
273
        if getattr(obj, "__module__", "builtins") == "builtins":
5✔
UNCOV
274
            return
×
275

276
        type_ = type(obj) if not isinstance(obj, type) else obj
5✔
277
        pkgname = type_.__module__
5✔
278

279
        # The column types have already been adapted towards generic types if possible,
280
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
281
        # use that rather than the generic sqlalchemy type as it might have different
282
        # constructor parameters.
283
        if pkgname.startswith("sqlalchemy.dialects."):
5✔
284
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
5✔
285
            dialect_pkg = import_module(dialect_pkgname)
5✔
286

287
            if type_.__name__ in dialect_pkg.__all__:
5✔
288
                pkgname = dialect_pkgname
5✔
289
        elif type_ is getattr(sqlalchemy, type_.__name__, None):
5✔
290
            pkgname = "sqlalchemy"
5✔
291
        else:
292
            pkgname = type_.__module__
5✔
293

294
        self.add_literal_import(pkgname, type_.__name__)
5✔
295

296
    def add_literal_import(self, pkgname: str, name: str) -> None:
5✔
297
        names = self.imports.setdefault(pkgname, set())
5✔
298
        names.add(name)
5✔
299

300
    def remove_literal_import(self, pkgname: str, name: str) -> None:
5✔
301
        names = self.imports.setdefault(pkgname, set())
5✔
302
        if name in names:
5✔
UNCOV
303
            names.remove(name)
×
304

305
    def add_module_import(self, pgkname: str) -> None:
5✔
306
        self.module_imports.add(pgkname)
5✔
307

308
    def group_imports(self) -> list[list[str]]:
5✔
309
        future_imports: list[str] = []
5✔
310
        stdlib_imports: list[str] = []
5✔
311
        thirdparty_imports: list[str] = []
5✔
312

313
        def get_collection(package: str) -> list[str]:
5✔
314
            collection = thirdparty_imports
5✔
315
            if package == "__future__":
5✔
UNCOV
316
                collection = future_imports
×
317
            elif package in self.stdlib_module_names:
5✔
318
                collection = stdlib_imports
5✔
319
            elif package in sys.modules:
5✔
320
                if "site-packages" not in (sys.modules[package].__file__ or ""):
5✔
321
                    collection = stdlib_imports
5✔
322
            return collection
5✔
323

324
        for package in sorted(self.imports):
5✔
325
            imports = ", ".join(sorted(self.imports[package]))
5✔
326

327
            collection = get_collection(package)
5✔
328
            collection.append(f"from {package} import {imports}")
5✔
329

330
        for module in sorted(self.module_imports):
5✔
331
            collection = get_collection(module)
5✔
332
            collection.append(f"import {module}")
5✔
333

334
        return [
5✔
335
            group
336
            for group in (future_imports, stdlib_imports, thirdparty_imports)
337
            if group
338
        ]
339

340
    def generate_models(self) -> list[Model]:
5✔
341
        models = [Model(table) for table in self.metadata.sorted_tables]
5✔
342

343
        # Collect the imports
344
        self.collect_imports(models)
5✔
345

346
        # Generate names for models
347
        global_names = {
5✔
348
            name for namespace in self.imports.values() for name in namespace
349
        }
350
        for model in models:
5✔
351
            self.generate_model_name(model, global_names)
5✔
352
            global_names.add(model.name)
5✔
353

354
        return models
5✔
355

356
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
357
        preferred_name = f"t_{model.table.name}"
5✔
358
        model.name = self.find_free_name(preferred_name, global_names)
5✔
359

360
    def render_module_variables(self, models: list[Model]) -> str:
5✔
361
        declarations = self.base.declarations
5✔
362

363
        if any(not isinstance(model, ModelClass) for model in models):
5✔
364
            if self.base.table_metadata_declaration is not None:
5✔
UNCOV
365
                declarations.append(self.base.table_metadata_declaration)
×
366

367
        return "\n".join(declarations)
5✔
368

369
    def render_models(self, models: list[Model]) -> str:
5✔
370
        rendered: list[str] = []
5✔
371
        for model in models:
5✔
372
            rendered_table = self.render_table(model.table)
5✔
373
            rendered.append(f"{model.name} = {rendered_table}")
5✔
374

375
        return "\n\n".join(rendered)
5✔
376

377
    def render_table(self, table: Table) -> str:
5✔
378
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
5✔
379
        kwargs: dict[str, object] = {}
5✔
380
        for column in table.columns:
5✔
381
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
382
            # Table.columns
383
            args.append(self.render_column(column, True, is_table=True))
5✔
384

385
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
386
            if uses_default_name(constraint):
5✔
387
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
388
                    continue
5✔
389
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
5✔
390
                    if len(constraint.columns) == 1:
5✔
391
                        continue
5✔
392

393
            args.append(self.render_constraint(constraint))
5✔
394

395
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
5✔
396
            # One-column indexes should be rendered as index=True on columns
397
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
398
                args.append(self.render_index(index))
5✔
399

400
        if table.schema:
5✔
401
            kwargs["schema"] = repr(table.schema)
5✔
402

403
        table_comment = getattr(table, "comment", None)
5✔
404
        if table_comment:
5✔
405
            kwargs["comment"] = repr(table.comment)
5✔
406

407
        # add info + dialect kwargs for callable context (opt-in)
408
        if self.include_dialect_options_and_info:
5✔
409
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=False)
5✔
410

411
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
5✔
412

413
    def render_index(self, index: Index) -> str:
5✔
414
        extra_args = [repr(col.name) for col in index.columns]
5✔
415
        kwargs = {}
5✔
416
        if index.unique:
5✔
417
            kwargs["unique"] = True
5✔
418

419
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
5✔
420

421
    # TODO find better solution for is_table
422
    def render_column(
5✔
423
        self, column: Column[Any], show_name: bool, is_table: bool = False
424
    ) -> str:
425
        args = []
5✔
426
        kwargs: dict[str, Any] = {}
5✔
427
        kwarg = []
5✔
428
        is_part_of_composite_pk = (
5✔
429
            column.primary_key and len(column.table.primary_key) > 1
430
        )
431
        dedicated_fks = [
5✔
432
            c
433
            for c in column.foreign_keys
434
            if c.constraint
435
            and len(c.constraint.columns) == 1
436
            and uses_default_name(c.constraint)
437
        ]
438
        is_unique = any(
5✔
439
            isinstance(c, UniqueConstraint)
440
            and set(c.columns) == {column}
441
            and uses_default_name(c)
442
            for c in column.table.constraints
443
        )
444
        is_unique = is_unique or any(
5✔
445
            i.unique and set(i.columns) == {column} and uses_default_name(i)
446
            for i in column.table.indexes
447
        )
448
        is_primary = (
5✔
449
            any(
450
                isinstance(c, PrimaryKeyConstraint)
451
                and column.name in c.columns
452
                and uses_default_name(c)
453
                for c in column.table.constraints
454
            )
455
            or column.primary_key
456
        )
457
        has_index = any(
5✔
458
            set(i.columns) == {column} and uses_default_name(i)
459
            for i in column.table.indexes
460
        )
461

462
        if show_name:
5✔
463
            args.append(repr(column.name))
5✔
464

465
        # Render the column type if there are no foreign keys on it or any of them
466
        # points back to itself
467
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
5✔
468
            args.append(self.render_column_type(column.type))
5✔
469

470
        for fk in dedicated_fks:
5✔
471
            args.append(self.render_constraint(fk))
5✔
472

473
        if column.default:
5✔
474
            args.append(repr(column.default))
5✔
475

476
        if column.key != column.name:
5✔
UNCOV
477
            kwargs["key"] = column.key
×
478
        if is_primary:
5✔
479
            kwargs["primary_key"] = True
5✔
480
        if not column.nullable and not column.primary_key:
5✔
481
            kwargs["nullable"] = False
5✔
482
        if column.nullable and is_part_of_composite_pk:
5✔
483
            kwargs["nullable"] = True
5✔
484

485
        if is_unique:
5✔
486
            column.unique = True
5✔
487
            kwargs["unique"] = True
5✔
488
        if has_index:
5✔
489
            column.index = True
5✔
490
            kwarg.append("index")
5✔
491
            kwargs["index"] = True
5✔
492

493
        if isinstance(column.server_default, DefaultClause):
5✔
494
            kwargs["server_default"] = render_callable(
5✔
495
                "text", repr(cast(TextClause, column.server_default.arg).text)
496
            )
497
        elif isinstance(column.server_default, Computed):
5✔
498
            expression = str(column.server_default.sqltext)
5✔
499

500
            computed_kwargs = {}
5✔
501
            if column.server_default.persisted is not None:
5✔
502
                computed_kwargs["persisted"] = column.server_default.persisted
5✔
503

504
            args.append(
5✔
505
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
506
            )
507
        elif isinstance(column.server_default, Identity):
5✔
508
            args.append(repr(column.server_default))
5✔
509
        elif column.server_default:
5✔
UNCOV
510
            kwargs["server_default"] = repr(column.server_default)
×
511

512
        comment = getattr(column, "comment", None)
5✔
513
        if comment:
5✔
514
            kwargs["comment"] = repr(comment)
5✔
515

516
        # add column info + dialect kwargs for callable context (opt-in)
517
        if self.include_dialect_options_and_info:
5✔
518
            self._add_dialect_kwargs_and_info(column, kwargs, values_for_dict=False)
5✔
519

520
        return self.render_column_callable(is_table, *args, **kwargs)
5✔
521

522
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
5✔
523
        if is_table:
5✔
524
            self.add_import(Column)
5✔
525
            return render_callable("Column", *args, kwargs=kwargs)
5✔
526
        else:
527
            return render_callable("mapped_column", *args, kwargs=kwargs)
5✔
528

529
    def render_column_type(self, coltype: TypeEngine[Any]) -> str:
5✔
530
        args = []
5✔
531
        kwargs: dict[str, Any] = {}
5✔
532
        sig = inspect.signature(coltype.__class__.__init__)
5✔
533
        defaults = {param.name: param.default for param in sig.parameters.values()}
5✔
534
        missing = object()
5✔
535
        use_kwargs = False
5✔
536
        for param in list(sig.parameters.values())[1:]:
5✔
537
            # Remove annoyances like _warn_on_bytestring
538
            if param.name.startswith("_"):
5✔
539
                continue
5✔
540
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
5✔
541
                use_kwargs = True
5✔
542
                continue
5✔
543

544
            value = getattr(coltype, param.name, missing)
5✔
545

546
            if isinstance(value, (JSONB, JSON)):
5✔
547
                # Remove astext_type if it's the default
548
                if (
5✔
549
                    isinstance(value.astext_type, Text)
550
                    and value.astext_type.length is None
551
                ):
552
                    value.astext_type = None  # type: ignore[assignment]
5✔
553
                else:
554
                    self.add_import(Text)
5✔
555

556
            default = defaults.get(param.name, missing)
5✔
557
            if isinstance(value, TextClause):
5✔
558
                self.add_literal_import("sqlalchemy", "text")
5✔
559
                rendered_value = render_callable("text", repr(value.text))
5✔
560
            else:
561
                rendered_value = repr(value)
5✔
562

563
            if value is missing or value == default:
5✔
564
                use_kwargs = True
5✔
565
            elif use_kwargs:
5✔
566
                kwargs[param.name] = rendered_value
5✔
567
            else:
568
                args.append(rendered_value)
5✔
569

570
        vararg = next(
5✔
571
            (
572
                param.name
573
                for param in sig.parameters.values()
574
                if param.kind is Parameter.VAR_POSITIONAL
575
            ),
576
            None,
577
        )
578
        if vararg and hasattr(coltype, vararg):
5✔
579
            varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
5✔
580
            args.extend(varargs_repr)
5✔
581

582
        # These arguments cannot be autodetected from the Enum initializer
583
        if isinstance(coltype, Enum):
5✔
584
            for colname in "name", "schema":
5✔
585
                if (value := getattr(coltype, colname)) is not None:
5✔
586
                    kwargs[colname] = repr(value)
5✔
587

588
        if isinstance(coltype, (JSONB, JSON)):
5✔
589
            # Remove astext_type if it's the default
590
            if (
5✔
591
                isinstance(coltype.astext_type, Text)
592
                and coltype.astext_type.length is None
593
            ):
594
                del kwargs["astext_type"]
5✔
595

596
        if args or kwargs:
5✔
597
            return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
5✔
598
        else:
599
            return coltype.__class__.__name__
5✔
600

601
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
5✔
602
        def add_fk_options(*opts: Any) -> None:
5✔
603
            args.extend(repr(opt) for opt in opts)
5✔
604
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
5✔
605
                value = getattr(constraint, attr, None)
5✔
606
                if value:
5✔
607
                    kwargs[attr] = repr(value)
5✔
608

609
        args: list[str] = []
5✔
610
        kwargs: dict[str, Any] = {}
5✔
611
        if isinstance(constraint, ForeignKey):
5✔
612
            remote_column = (
5✔
613
                f"{constraint.column.table.fullname}.{constraint.column.name}"
614
            )
615
            add_fk_options(remote_column)
5✔
616
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
617
            local_columns = get_column_names(constraint)
5✔
618
            remote_columns = [
5✔
619
                f"{fk.column.table.fullname}.{fk.column.name}"
620
                for fk in constraint.elements
621
            ]
622
            add_fk_options(local_columns, remote_columns)
5✔
623
        elif isinstance(constraint, CheckConstraint):
5✔
624
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
5✔
625
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
5✔
626
            args.extend(repr(col.name) for col in constraint.columns)
5✔
627
        else:
UNCOV
628
            raise TypeError(
×
629
                f"Cannot render constraint of type {constraint.__class__.__name__}"
630
            )
631

632
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
5✔
633
            kwargs["name"] = repr(constraint.name)
5✔
634

635
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
5✔
636

637
    def _add_dialect_kwargs_and_info(
5✔
638
        self, obj: Any, target_kwargs: dict[str, object], *, values_for_dict: bool
639
    ) -> None:
640
        """
641
        Merge SchemaItem-like object's .info and .dialect_kwargs into target_kwargs.
642
        - values_for_dict=True: keep raw values so pretty-printer emits repr() (for __table_args__ dict)
643
        - values_for_dict=False: set values to repr() strings (for callable kwargs)
644
        """
645
        info_dict = getattr(obj, "info", None)
5✔
646
        if info_dict:
5✔
647
            target_kwargs["info"] = info_dict if values_for_dict else repr(info_dict)
5✔
648

649
        dialect_keys: list[str]
650
        try:
5✔
651
            dialect_keys = sorted(getattr(obj, "dialect_kwargs"))
5✔
NEW
652
        except Exception:
×
NEW
653
            return
×
654

655
        dialect_kwargs = getattr(obj, "dialect_kwargs", {})
5✔
656
        for key in dialect_keys:
5✔
657
            try:
5✔
658
                value = dialect_kwargs[key]
5✔
NEW
659
            except Exception:
×
NEW
660
                continue
×
661

662
            # Render values:
663
            # - callable context (values_for_dict=False): produce a string expression.
664
            #   primitives use repr(value); custom objects stringify then repr().
665
            # - dict context (values_for_dict=True): pass raw primitives / str;
666
            #   custom objects become str(value) so pformat quotes them.
667
            if values_for_dict:
5✔
668
                if isinstance(value, type(None) | bool | int | float):
5✔
NEW
669
                    target_kwargs[key] = value
×
670
                elif isinstance(value, str | dict | list):
5✔
671
                    target_kwargs[key] = value
5✔
672
                else:
673
                    target_kwargs[key] = str(value)
5✔
674
            else:
675
                if isinstance(
5✔
676
                    value, type(None) | bool | int | float | str | dict | list
677
                ):
678
                    target_kwargs[key] = repr(value)
5✔
679
                else:
680
                    target_kwargs[key] = repr(str(value))
5✔
681

682
    def should_ignore_table(self, table: Table) -> bool:
5✔
683
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
684
        # tables
685
        return table.name in ("alembic_version", "migrate_version")
5✔
686

687
    def find_free_name(
5✔
688
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
689
    ) -> str:
690
        """
691
        Generate an attribute name that does not clash with other local or global names.
692
        """
693
        name = name.strip()
5✔
694
        assert name, "Identifier cannot be empty"
5✔
695
        name = _re_invalid_identifier.sub("_", name)
5✔
696
        if name[0].isdigit():
5✔
697
            name = "_" + name
5✔
698
        elif iskeyword(name) or name == "metadata":
5✔
699
            name += "_"
5✔
700

701
        original = name
5✔
702
        for i in count():
5✔
703
            if name not in global_names and name not in local_names:
5✔
704
                break
5✔
705

706
            name = original + (str(i) if i else "_")
5✔
707

708
        return name
5✔
709

710
    def fix_column_types(self, table: Table) -> None:
5✔
711
        """Adjust the reflected column types."""
712
        # Detect check constraints for boolean and enum columns
713
        for constraint in table.constraints.copy():
5✔
714
            if isinstance(constraint, CheckConstraint):
5✔
715
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
5✔
716

717
                # Turn any integer-like column with a CheckConstraint like
718
                # "column IN (0, 1)" into a Boolean
719
                match = _re_boolean_check_constraint.match(sqltext)
5✔
720
                if match:
5✔
721
                    colname_match = _re_column_name.match(match.group(1))
5✔
722
                    if colname_match:
5✔
723
                        colname = colname_match.group(3)
5✔
724
                        table.constraints.remove(constraint)
5✔
725
                        table.c[colname].type = Boolean()
5✔
726
                        continue
5✔
727

728
                # Turn any string-type column with a CheckConstraint like
729
                # "column IN (...)" into an Enum
730
                match = _re_enum_check_constraint.match(sqltext)
5✔
731
                if match:
5✔
732
                    colname_match = _re_column_name.match(match.group(1))
5✔
733
                    if colname_match:
5✔
734
                        colname = colname_match.group(3)
5✔
735
                        items = match.group(2)
5✔
736
                        if isinstance(table.c[colname].type, String):
5✔
737
                            table.constraints.remove(constraint)
5✔
738
                            if not isinstance(table.c[colname].type, Enum):
5✔
739
                                options = _re_enum_item.findall(items)
5✔
740
                                table.c[colname].type = Enum(
5✔
741
                                    *options, native_enum=False
742
                                )
743

744
                            continue
5✔
745

746
        for column in table.c:
5✔
747
            if not self.keep_dialect_types:
5✔
748
                try:
5✔
749
                    column.type = self.get_adapted_type(column.type)
5✔
750
                except CompileError:
5✔
751
                    continue
5✔
752

753
            # PostgreSQL specific fix: detect sequences from server_default
754
            if column.server_default and self.bind.dialect.name == "postgresql":
5✔
755
                if isinstance(column.server_default, DefaultClause) and isinstance(
5✔
756
                    column.server_default.arg, TextClause
757
                ):
758
                    schema, seqname = decode_postgresql_sequence(
5✔
759
                        column.server_default.arg
760
                    )
761
                    if seqname:
5✔
762
                        # Add an explicit sequence
763
                        if seqname != f"{column.table.name}_{column.name}_seq":
5✔
764
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
5✔
765

766
                        column.server_default = None
5✔
767

768
    def get_adapted_type(self, coltype: Any) -> Any:
5✔
769
        compiled_type = coltype.compile(self.bind.engine.dialect)
5✔
770
        for supercls in coltype.__class__.__mro__:
5✔
771
            if not supercls.__name__.startswith("_") and hasattr(
5✔
772
                supercls, "__visit_name__"
773
            ):
774
                # Don't try to adapt UserDefinedType as it's not a proper column type
775
                if supercls is UserDefinedType or issubclass(supercls, TypeDecorator):
5✔
776
                    return coltype
5✔
777

778
                # Hack to fix adaptation of the Enum class which is broken since
779
                # SQLAlchemy 1.2
780
                kw = {}
5✔
781
                if supercls is Enum:
5✔
782
                    kw["name"] = coltype.name
5✔
783
                    if coltype.schema:
5✔
784
                        kw["schema"] = coltype.schema
5✔
785

786
                # Hack to fix Postgres DOMAIN type adaptation, broken as of SQLAlchemy 2.0.42
787
                # For additional information - https://github.com/agronholm/sqlacodegen/issues/416#issuecomment-3417480599
788
                if supercls is DOMAIN:
5✔
789
                    if coltype.default:
5✔
UNCOV
790
                        kw["default"] = coltype.default
×
791
                    if coltype.constraint_name is not None:
5✔
792
                        kw["constraint_name"] = coltype.constraint_name
5✔
793
                    if coltype.not_null:
5✔
UNCOV
794
                        kw["not_null"] = coltype.not_null
×
795
                    if coltype.check is not None:
5✔
796
                        kw["check"] = coltype.check
5✔
797
                    if coltype.create_type:
5✔
798
                        kw["create_type"] = coltype.create_type
5✔
799

800
                try:
5✔
801
                    new_coltype = coltype.adapt(supercls)
5✔
802
                except TypeError:
5✔
803
                    # If the adaptation fails, don't try again
804
                    break
5✔
805

806
                for key, value in kw.items():
5✔
807
                    setattr(new_coltype, key, value)
5✔
808

809
                if isinstance(coltype, ARRAY):
5✔
810
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
5✔
811

812
                try:
5✔
813
                    # If the adapted column type does not render the same as the
814
                    # original, don't substitute it
815
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
5✔
816
                        break
5✔
817
                except CompileError:
5✔
818
                    # If the adapted column type can't be compiled, don't substitute it
819
                    break
5✔
820

821
                # Stop on the first valid non-uppercase column type class
822
                coltype = new_coltype
5✔
823
                if supercls.__name__ != supercls.__name__.upper():
5✔
824
                    break
5✔
825

826
        return coltype
5✔
827

828

829
class DeclarativeGenerator(TablesGenerator):
5✔
830
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
5✔
831
        "use_inflect",
832
        "nojoined",
833
        "nobidi",
834
        "noidsuffix",
835
    }
836

837
    def __init__(
5✔
838
        self,
839
        metadata: MetaData,
840
        bind: Connection | Engine,
841
        options: Sequence[str],
842
        *,
843
        indentation: str = "    ",
844
        base_class_name: str = "Base",
845
    ):
846
        super().__init__(metadata, bind, options, indentation=indentation)
5✔
847
        self.base_class_name: str = base_class_name
5✔
848
        self.inflect_engine = inflect.engine()
5✔
849

850
    def generate_base(self) -> None:
5✔
851
        self.base = Base(
5✔
852
            literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
853
            declarations=[
854
                f"class {self.base_class_name}(DeclarativeBase):",
855
                f"{self.indentation}pass",
856
            ],
857
            metadata_ref=f"{self.base_class_name}.metadata",
858
        )
859

860
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
861
        super().collect_imports(models)
5✔
862
        if any(isinstance(model, ModelClass) for model in models):
5✔
863
            self.add_literal_import("sqlalchemy.orm", "Mapped")
5✔
864
            self.add_literal_import("sqlalchemy.orm", "mapped_column")
5✔
865

866
    def collect_imports_for_model(self, model: Model) -> None:
5✔
867
        super().collect_imports_for_model(model)
5✔
868
        if isinstance(model, ModelClass):
5✔
869
            if model.relationships:
5✔
870
                self.add_literal_import("sqlalchemy.orm", "relationship")
5✔
871

872
    def generate_models(self) -> list[Model]:
5✔
873
        models_by_table_name: dict[str, Model] = {}
5✔
874

875
        # Pick association tables from the metadata into their own set, don't process
876
        # them normally
877
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
5✔
878
        for table in self.metadata.sorted_tables:
5✔
879
            qualified_name = qualified_table_name(table)
5✔
880

881
            # Link tables have exactly two foreign key constraints and all columns are
882
            # involved in them
883
            fk_constraints = sorted(
5✔
884
                table.foreign_key_constraints, key=get_constraint_sort_key
885
            )
886
            if len(fk_constraints) == 2 and all(
5✔
887
                col.foreign_keys for col in table.columns
888
            ):
889
                model = models_by_table_name[qualified_name] = Model(table)
5✔
890
                tablename = fk_constraints[0].elements[0].column.table.name
5✔
891
                links[tablename].append(model)
5✔
892
                continue
5✔
893

894
            # Only form model classes for tables that have a primary key and are not
895
            # association tables
896
            if not table.primary_key:
5✔
897
                models_by_table_name[qualified_name] = Model(table)
5✔
898
            else:
899
                model = ModelClass(table)
5✔
900
                models_by_table_name[qualified_name] = model
5✔
901

902
                # Fill in the columns
903
                for column in table.c:
5✔
904
                    column_attr = ColumnAttribute(model, column)
5✔
905
                    model.columns.append(column_attr)
5✔
906

907
        # Add relationships
908
        for model in models_by_table_name.values():
5✔
909
            if isinstance(model, ModelClass):
5✔
910
                self.generate_relationships(
5✔
911
                    model, models_by_table_name, links[model.table.name]
912
                )
913

914
        # Nest inherited classes in their superclasses to ensure proper ordering
915
        if "nojoined" not in self.options:
5✔
916
            for model in list(models_by_table_name.values()):
5✔
917
                if not isinstance(model, ModelClass):
5✔
918
                    continue
5✔
919

920
                pk_column_names = {col.name for col in model.table.primary_key.columns}
5✔
921
                for constraint in model.table.foreign_key_constraints:
5✔
922
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
923
                        target = models_by_table_name[
5✔
924
                            qualified_table_name(constraint.elements[0].column.table)
925
                        ]
926
                        if isinstance(target, ModelClass):
5✔
927
                            model.parent_class = target
5✔
928
                            target.children.append(model)
5✔
929

930
        # Change base if we only have tables
931
        if not any(
5✔
932
            isinstance(model, ModelClass) for model in models_by_table_name.values()
933
        ):
934
            super().generate_base()
5✔
935

936
        # Collect the imports
937
        self.collect_imports(models_by_table_name.values())
5✔
938

939
        # Rename models and their attributes that conflict with imports or other
940
        # attributes
941
        global_names = {
5✔
942
            name for namespace in self.imports.values() for name in namespace
943
        }
944
        for model in models_by_table_name.values():
5✔
945
            self.generate_model_name(model, global_names)
5✔
946
            global_names.add(model.name)
5✔
947

948
        return list(models_by_table_name.values())
5✔
949

950
    def generate_relationships(
5✔
951
        self,
952
        source: ModelClass,
953
        models_by_table_name: dict[str, Model],
954
        association_tables: list[Model],
955
    ) -> list[RelationshipAttribute]:
956
        relationships: list[RelationshipAttribute] = []
5✔
957
        reverse_relationship: RelationshipAttribute | None
958

959
        # Add many-to-one (and one-to-many) relationships
960
        pk_column_names = {col.name for col in source.table.primary_key.columns}
5✔
961
        for constraint in sorted(
5✔
962
            source.table.foreign_key_constraints, key=get_constraint_sort_key
963
        ):
964
            target = models_by_table_name[
5✔
965
                qualified_table_name(constraint.elements[0].column.table)
966
            ]
967
            if isinstance(target, ModelClass):
5✔
968
                if "nojoined" not in self.options:
5✔
969
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
970
                        parent = models_by_table_name[
5✔
971
                            qualified_table_name(constraint.elements[0].column.table)
972
                        ]
973
                        if isinstance(parent, ModelClass):
5✔
974
                            source.parent_class = parent
5✔
975
                            parent.children.append(source)
5✔
976
                            continue
5✔
977

978
                # Add uselist=False to One-to-One relationships
979
                column_names = get_column_names(constraint)
5✔
980
                if any(
5✔
981
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
982
                    and {col.name for col in c.columns} == set(column_names)
983
                    for c in constraint.table.constraints
984
                ):
985
                    r_type = RelationshipType.ONE_TO_ONE
5✔
986
                else:
987
                    r_type = RelationshipType.MANY_TO_ONE
5✔
988

989
                relationship = RelationshipAttribute(r_type, source, target, constraint)
5✔
990
                source.relationships.append(relationship)
5✔
991

992
                # For self referential relationships, remote_side needs to be set
993
                if source is target:
5✔
994
                    relationship.remote_side = [
5✔
995
                        source.get_column_attribute(col.name)
996
                        for col in constraint.referred_table.primary_key
997
                    ]
998

999
                # If the two tables share more than one foreign key constraint,
1000
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
1001
                # it needs
1002
                common_fk_constraints = get_common_fk_constraints(
5✔
1003
                    source.table, target.table
1004
                )
1005
                if len(common_fk_constraints) > 1:
5✔
1006
                    relationship.foreign_keys = [
5✔
1007
                        source.get_column_attribute(key)
1008
                        for key in constraint.column_keys
1009
                    ]
1010

1011
                # Generate the opposite end of the relationship in the target class
1012
                if "nobidi" not in self.options:
5✔
1013
                    if r_type is RelationshipType.MANY_TO_ONE:
5✔
1014
                        r_type = RelationshipType.ONE_TO_MANY
5✔
1015

1016
                    reverse_relationship = RelationshipAttribute(
5✔
1017
                        r_type,
1018
                        target,
1019
                        source,
1020
                        constraint,
1021
                        foreign_keys=relationship.foreign_keys,
1022
                        backref=relationship,
1023
                    )
1024
                    relationship.backref = reverse_relationship
5✔
1025
                    target.relationships.append(reverse_relationship)
5✔
1026

1027
                    # For self referential relationships, remote_side needs to be set
1028
                    if source is target:
5✔
1029
                        reverse_relationship.remote_side = [
5✔
1030
                            source.get_column_attribute(colname)
1031
                            for colname in constraint.column_keys
1032
                        ]
1033

1034
        # Add many-to-many relationships
1035
        for association_table in association_tables:
5✔
1036
            fk_constraints = sorted(
5✔
1037
                association_table.table.foreign_key_constraints,
1038
                key=get_constraint_sort_key,
1039
            )
1040
            target = models_by_table_name[
5✔
1041
                qualified_table_name(fk_constraints[1].elements[0].column.table)
1042
            ]
1043
            if isinstance(target, ModelClass):
5✔
1044
                relationship = RelationshipAttribute(
5✔
1045
                    RelationshipType.MANY_TO_MANY,
1046
                    source,
1047
                    target,
1048
                    fk_constraints[1],
1049
                    association_table,
1050
                )
1051
                source.relationships.append(relationship)
5✔
1052

1053
                # Generate the opposite end of the relationship in the target class
1054
                reverse_relationship = None
5✔
1055
                if "nobidi" not in self.options:
5✔
1056
                    reverse_relationship = RelationshipAttribute(
5✔
1057
                        RelationshipType.MANY_TO_MANY,
1058
                        target,
1059
                        source,
1060
                        fk_constraints[0],
1061
                        association_table,
1062
                        relationship,
1063
                    )
1064
                    relationship.backref = reverse_relationship
5✔
1065
                    target.relationships.append(reverse_relationship)
5✔
1066

1067
                # Add a primary/secondary join for self-referential many-to-many
1068
                # relationships
1069
                if source is target:
5✔
1070
                    both_relationships = [relationship]
5✔
1071
                    reverse_flags = [False, True]
5✔
1072
                    if reverse_relationship:
5✔
1073
                        both_relationships.append(reverse_relationship)
5✔
1074

1075
                    for relationship, reverse in zip(both_relationships, reverse_flags):
5✔
1076
                        if (
5✔
1077
                            not relationship.association_table
1078
                            or not relationship.constraint
1079
                        ):
UNCOV
1080
                            continue
×
1081

1082
                        constraints = sorted(
5✔
1083
                            relationship.constraint.table.foreign_key_constraints,
1084
                            key=get_constraint_sort_key,
1085
                            reverse=reverse,
1086
                        )
1087
                        pri_pairs = zip(
5✔
1088
                            get_column_names(constraints[0]), constraints[0].elements
1089
                        )
1090
                        sec_pairs = zip(
5✔
1091
                            get_column_names(constraints[1]), constraints[1].elements
1092
                        )
1093
                        relationship.primaryjoin = [
5✔
1094
                            (
1095
                                relationship.source,
1096
                                elem.column.name,
1097
                                relationship.association_table,
1098
                                col,
1099
                            )
1100
                            for col, elem in pri_pairs
1101
                        ]
1102
                        relationship.secondaryjoin = [
5✔
1103
                            (
1104
                                relationship.target,
1105
                                elem.column.name,
1106
                                relationship.association_table,
1107
                                col,
1108
                            )
1109
                            for col, elem in sec_pairs
1110
                        ]
1111

1112
        return relationships
5✔
1113

1114
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
1115
        if isinstance(model, ModelClass):
5✔
1116
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
5✔
1117
            preferred_name = "".join(
5✔
1118
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1119
            )
1120
            if "use_inflect" in self.options:
5✔
1121
                singular_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1122
                if singular_name:
5✔
1123
                    preferred_name = singular_name
5✔
1124

1125
            model.name = self.find_free_name(preferred_name, global_names)
5✔
1126

1127
            # Fill in the names for column attributes
1128
            local_names: set[str] = set()
5✔
1129
            for column_attr in model.columns:
5✔
1130
                self.generate_column_attr_name(column_attr, global_names, local_names)
5✔
1131
                local_names.add(column_attr.name)
5✔
1132

1133
            # Fill in the names for relationship attributes
1134
            for relationship in model.relationships:
5✔
1135
                self.generate_relationship_name(relationship, global_names, local_names)
5✔
1136
                local_names.add(relationship.name)
5✔
1137
        else:
1138
            super().generate_model_name(model, global_names)
5✔
1139

1140
    def generate_column_attr_name(
5✔
1141
        self,
1142
        column_attr: ColumnAttribute,
1143
        global_names: set[str],
1144
        local_names: set[str],
1145
    ) -> None:
1146
        column_attr.name = self.find_free_name(
5✔
1147
            column_attr.column.name, global_names, local_names
1148
        )
1149

1150
    def generate_relationship_name(
5✔
1151
        self,
1152
        relationship: RelationshipAttribute,
1153
        global_names: set[str],
1154
        local_names: set[str],
1155
    ) -> None:
1156
        # Self referential reverse relationships
1157
        preferred_name: str
1158
        if (
5✔
1159
            relationship.type
1160
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1161
            and relationship.source is relationship.target
1162
            and relationship.backref
1163
            and relationship.backref.name
1164
        ):
1165
            preferred_name = relationship.backref.name + "_reverse"
5✔
1166
        else:
1167
            preferred_name = relationship.target.table.name
5✔
1168

1169
            # If there's a constraint with a single column that ends with "_id", use the
1170
            # preceding part as the relationship name
1171
            if relationship.constraint and "noidsuffix" not in self.options:
5✔
1172
                is_source = relationship.source.table is relationship.constraint.table
5✔
1173
                if is_source or relationship.type not in (
5✔
1174
                    RelationshipType.ONE_TO_ONE,
1175
                    RelationshipType.ONE_TO_MANY,
1176
                ):
1177
                    column_names = [c.name for c in relationship.constraint.columns]
5✔
1178
                    if len(column_names) == 1 and column_names[0].endswith("_id"):
5✔
1179
                        preferred_name = column_names[0][:-3]
5✔
1180

1181
            if "use_inflect" in self.options:
5✔
1182
                inflected_name: str | Literal[False]
1183
                if relationship.type in (
5✔
1184
                    RelationshipType.ONE_TO_MANY,
1185
                    RelationshipType.MANY_TO_MANY,
1186
                ):
1187
                    if not self.inflect_engine.singular_noun(preferred_name):
5✔
UNCOV
1188
                        preferred_name = self.inflect_engine.plural_noun(preferred_name)
×
1189
                else:
1190
                    inflected_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1191
                    if inflected_name:
5✔
1192
                        preferred_name = inflected_name
5✔
1193

1194
        relationship.name = self.find_free_name(
5✔
1195
            preferred_name, global_names, local_names
1196
        )
1197

1198
    def render_models(self, models: list[Model]) -> str:
5✔
1199
        rendered: list[str] = []
5✔
1200
        for model in models:
5✔
1201
            if isinstance(model, ModelClass):
5✔
1202
                rendered.append(self.render_class(model))
5✔
1203
            else:
1204
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
5✔
1205

1206
        return "\n\n\n".join(rendered)
5✔
1207

1208
    def render_class(self, model: ModelClass) -> str:
5✔
1209
        sections: list[str] = []
5✔
1210

1211
        # Render class variables / special declarations
1212
        class_vars: str = self.render_class_variables(model)
5✔
1213
        if class_vars:
5✔
1214
            sections.append(class_vars)
5✔
1215

1216
        # Render column attributes
1217
        rendered_column_attributes: list[str] = []
5✔
1218
        for nullable in (False, True):
5✔
1219
            for column_attr in model.columns:
5✔
1220
                if column_attr.column.nullable is nullable:
5✔
1221
                    rendered_column_attributes.append(
5✔
1222
                        self.render_column_attribute(column_attr)
1223
                    )
1224

1225
        if rendered_column_attributes:
5✔
1226
            sections.append("\n".join(rendered_column_attributes))
5✔
1227

1228
        # Render relationship attributes
1229
        rendered_relationship_attributes: list[str] = [
5✔
1230
            self.render_relationship(relationship)
1231
            for relationship in model.relationships
1232
        ]
1233

1234
        if rendered_relationship_attributes:
5✔
1235
            sections.append("\n".join(rendered_relationship_attributes))
5✔
1236

1237
        declaration = self.render_class_declaration(model)
5✔
1238
        rendered_sections = "\n\n".join(
5✔
1239
            indent(section, self.indentation) for section in sections
1240
        )
1241
        return f"{declaration}\n{rendered_sections}"
5✔
1242

1243
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1244
        parent_class_name = (
5✔
1245
            model.parent_class.name if model.parent_class else self.base_class_name
1246
        )
1247
        return f"class {model.name}({parent_class_name}):"
5✔
1248

1249
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1250
        variables = [f"__tablename__ = {model.table.name!r}"]
5✔
1251

1252
        # Render constraints and indexes as __table_args__
1253
        table_args = self.render_table_args(model.table)
5✔
1254
        if table_args:
5✔
1255
            variables.append(f"__table_args__ = {table_args}")
5✔
1256

1257
        return "\n".join(variables)
5✔
1258

1259
    def render_table_args(self, table: Table) -> str:
5✔
1260
        args: list[str] = []
5✔
1261
        kwargs: dict[str, object] = {}
5✔
1262

1263
        # Render constraints
1264
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
1265
            if uses_default_name(constraint):
5✔
1266
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
1267
                    continue
5✔
1268
                if (
5✔
1269
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1270
                    and len(constraint.columns) == 1
1271
                ):
1272
                    continue
5✔
1273

1274
            args.append(self.render_constraint(constraint))
5✔
1275

1276
        # Render indexes
1277
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
5✔
1278
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
1279
                args.append(self.render_index(index))
5✔
1280

1281
        if table.schema:
5✔
1282
            kwargs["schema"] = table.schema
5✔
1283

1284
        if table.comment:
5✔
1285
            kwargs["comment"] = table.comment
5✔
1286

1287
        # add info + dialect kwargs for dict context (__table_args__) (opt-in)
1288
        if self.include_dialect_options_and_info:
5✔
1289
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=True)
5✔
1290

1291
        if kwargs:
5✔
1292
            formatted_kwargs = pformat(kwargs)
5✔
1293
            if not args:
5✔
1294
                return formatted_kwargs
5✔
1295
            else:
1296
                args.append(formatted_kwargs)
5✔
1297

1298
        if args:
5✔
1299
            rendered_args = f",\n{self.indentation}".join(args)
5✔
1300
            if len(args) == 1:
5✔
1301
                rendered_args += ","
5✔
1302

1303
            return f"(\n{self.indentation}{rendered_args}\n)"
5✔
1304
        else:
1305
            return ""
5✔
1306

1307
    def render_column_python_type(self, column: Column[Any]) -> str:
5✔
1308
        def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
5✔
1309
            column_type = column.type
5✔
1310
            pre: list[str] = []
5✔
1311
            post_size = 0
5✔
1312
            if column.nullable:
5✔
1313
                self.add_literal_import("typing", "Optional")
5✔
1314
                pre.append("Optional[")
5✔
1315
                post_size += 1
5✔
1316

1317
            if isinstance(column_type, ARRAY):
5✔
1318
                dim = getattr(column_type, "dimensions", None) or 1
5✔
1319
                pre.extend("list[" for _ in range(dim))
5✔
1320
                post_size += dim
5✔
1321

1322
                column_type = column_type.item_type
5✔
1323

1324
            return "".join(pre), column_type, "]" * post_size
5✔
1325

1326
        def render_python_type(column_type: TypeEngine[Any]) -> str:
5✔
1327
            if isinstance(column_type, DOMAIN):
5✔
1328
                column_type = column_type.data_type
5✔
1329

1330
            try:
5✔
1331
                python_type = column_type.python_type
5✔
1332
                python_type_module = python_type.__module__
5✔
1333
                python_type_name = python_type.__name__
5✔
1334
            except NotImplementedError:
5✔
1335
                self.add_literal_import("typing", "Any")
5✔
1336
                return "Any"
5✔
1337

1338
            if python_type_module == "builtins":
5✔
1339
                return python_type_name
5✔
1340

1341
            self.add_module_import(python_type_module)
5✔
1342
            return f"{python_type_module}.{python_type_name}"
5✔
1343

1344
        pre, col_type, post = get_type_qualifiers()
5✔
1345
        column_python_type = f"{pre}{render_python_type(col_type)}{post}"
5✔
1346
        return column_python_type
5✔
1347

1348
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1349
        column = column_attr.column
5✔
1350
        rendered_column = self.render_column(column, column_attr.name != column.name)
5✔
1351
        rendered_column_python_type = self.render_column_python_type(column)
5✔
1352

1353
        return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
5✔
1354

1355
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1356
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
5✔
1357
            rendered = []
5✔
1358
            for attr in column_attrs:
5✔
1359
                if attr.model is relationship.source:
5✔
1360
                    rendered.append(attr.name)
5✔
1361
                else:
UNCOV
1362
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
×
1363

1364
            return "[" + ", ".join(rendered) + "]"
5✔
1365

1366
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
5✔
1367
            rendered = []
5✔
1368
            render_as_string = False
5✔
1369
            # Assume that column_attrs are all in relationship.source or none
1370
            for attr in column_attrs:
5✔
1371
                if attr.model is relationship.source:
5✔
1372
                    rendered.append(attr.name)
5✔
1373
                else:
1374
                    rendered.append(f"{attr.model.name}.{attr.name}")
5✔
1375
                    render_as_string = True
5✔
1376

1377
            if render_as_string:
5✔
1378
                return "'[" + ", ".join(rendered) + "]'"
5✔
1379
            else:
1380
                return "[" + ", ".join(rendered) + "]"
5✔
1381

1382
        def render_join(terms: list[JoinType]) -> str:
5✔
1383
            rendered_joins = []
5✔
1384
            for source, source_col, target, target_col in terms:
5✔
1385
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
5✔
1386
                if target.__class__ is Model:
5✔
1387
                    rendered += "c."
5✔
1388

1389
                rendered += str(target_col)
5✔
1390
                rendered_joins.append(rendered)
5✔
1391

1392
            if len(rendered_joins) > 1:
5✔
UNCOV
1393
                rendered = ", ".join(rendered_joins)
×
UNCOV
1394
                return f"and_({rendered})"
×
1395
            else:
1396
                return rendered_joins[0]
5✔
1397

1398
        # Render keyword arguments
1399
        kwargs: dict[str, Any] = {}
5✔
1400
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
5✔
1401
            if relationship.constraint.referred_table is relationship.source.table:
5✔
1402
                kwargs["uselist"] = False
5✔
1403

1404
        # Add the "secondary" keyword for many-to-many relationships
1405
        if relationship.association_table:
5✔
1406
            table_ref = relationship.association_table.table.name
5✔
1407
            if relationship.association_table.schema:
5✔
1408
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
5✔
1409

1410
            kwargs["secondary"] = repr(table_ref)
5✔
1411

1412
        if relationship.remote_side:
5✔
1413
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
5✔
1414

1415
        if relationship.foreign_keys:
5✔
1416
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
5✔
1417

1418
        if relationship.primaryjoin:
5✔
1419
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
5✔
1420

1421
        if relationship.secondaryjoin:
5✔
1422
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
5✔
1423

1424
        if relationship.backref:
5✔
1425
            kwargs["back_populates"] = repr(relationship.backref.name)
5✔
1426

1427
        rendered_relationship = render_callable(
5✔
1428
            "relationship", repr(relationship.target.name), kwargs=kwargs
1429
        )
1430

1431
        relationship_type: str
1432
        if relationship.type == RelationshipType.ONE_TO_MANY:
5✔
1433
            relationship_type = f"list['{relationship.target.name}']"
5✔
1434
        elif relationship.type in (
5✔
1435
            RelationshipType.ONE_TO_ONE,
1436
            RelationshipType.MANY_TO_ONE,
1437
        ):
1438
            relationship_type = f"'{relationship.target.name}'"
5✔
1439
            if relationship.constraint and any(
5✔
1440
                col.nullable for col in relationship.constraint.columns
1441
            ):
1442
                self.add_literal_import("typing", "Optional")
5✔
1443
                relationship_type = f"Optional[{relationship_type}]"
5✔
1444
        elif relationship.type == RelationshipType.MANY_TO_MANY:
5✔
1445
            relationship_type = f"list['{relationship.target.name}']"
5✔
1446
        else:
UNCOV
1447
            self.add_literal_import("typing", "Any")
×
UNCOV
1448
            relationship_type = "Any"
×
1449

1450
        return (
5✔
1451
            f"{relationship.name}: Mapped[{relationship_type}] "
1452
            f"= {rendered_relationship}"
1453
        )
1454

1455

1456
class DataclassGenerator(DeclarativeGenerator):
5✔
1457
    def __init__(
5✔
1458
        self,
1459
        metadata: MetaData,
1460
        bind: Connection | Engine,
1461
        options: Sequence[str],
1462
        *,
1463
        indentation: str = "    ",
1464
        base_class_name: str = "Base",
1465
        quote_annotations: bool = False,
1466
        metadata_key: str = "sa",
1467
    ):
1468
        super().__init__(
5✔
1469
            metadata,
1470
            bind,
1471
            options,
1472
            indentation=indentation,
1473
            base_class_name=base_class_name,
1474
        )
1475
        self.metadata_key: str = metadata_key
5✔
1476
        self.quote_annotations: bool = quote_annotations
5✔
1477

1478
    def generate_base(self) -> None:
5✔
1479
        self.base = Base(
5✔
1480
            literal_imports=[
1481
                LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1482
                LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1483
            ],
1484
            declarations=[
1485
                (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"),
1486
                f"{self.indentation}pass",
1487
            ],
1488
            metadata_ref=f"{self.base_class_name}.metadata",
1489
        )
1490

1491

1492
class SQLModelGenerator(DeclarativeGenerator):
5✔
1493
    def __init__(
5✔
1494
        self,
1495
        metadata: MetaData,
1496
        bind: Connection | Engine,
1497
        options: Sequence[str],
1498
        *,
1499
        indentation: str = "    ",
1500
        base_class_name: str = "SQLModel",
1501
    ):
1502
        super().__init__(
5✔
1503
            metadata,
1504
            bind,
1505
            options,
1506
            indentation=indentation,
1507
            base_class_name=base_class_name,
1508
        )
1509

1510
    @property
5✔
1511
    def views_supported(self) -> bool:
5✔
UNCOV
1512
        return False
×
1513

1514
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
5✔
1515
        self.add_import(Column)
5✔
1516
        return render_callable("Column", *args, kwargs=kwargs)
5✔
1517

1518
    def generate_base(self) -> None:
5✔
1519
        self.base = Base(
5✔
1520
            literal_imports=[],
1521
            declarations=[],
1522
            metadata_ref="",
1523
        )
1524

1525
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
1526
        super(DeclarativeGenerator, self).collect_imports(models)
5✔
1527
        if any(isinstance(model, ModelClass) for model in models):
5✔
1528
            self.remove_literal_import("sqlalchemy", "MetaData")
5✔
1529
            self.add_literal_import("sqlmodel", "SQLModel")
5✔
1530
            self.add_literal_import("sqlmodel", "Field")
5✔
1531

1532
    def collect_imports_for_model(self, model: Model) -> None:
5✔
1533
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
5✔
1534
        if isinstance(model, ModelClass):
5✔
1535
            for column_attr in model.columns:
5✔
1536
                if column_attr.column.nullable:
5✔
1537
                    self.add_literal_import("typing", "Optional")
5✔
1538
                    break
5✔
1539

1540
            if model.relationships:
5✔
1541
                self.add_literal_import("sqlmodel", "Relationship")
5✔
1542

1543
    def render_module_variables(self, models: list[Model]) -> str:
5✔
1544
        declarations: list[str] = []
5✔
1545
        if any(not isinstance(model, ModelClass) for model in models):
5✔
UNCOV
1546
            if self.base.table_metadata_declaration is not None:
×
UNCOV
1547
                declarations.append(self.base.table_metadata_declaration)
×
1548

1549
        return "\n".join(declarations)
5✔
1550

1551
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1552
        if model.parent_class:
5✔
UNCOV
1553
            parent = model.parent_class.name
×
1554
        else:
1555
            parent = self.base_class_name
5✔
1556

1557
        superclass_part = f"({parent}, table=True)"
5✔
1558
        return f"class {model.name}{superclass_part}:"
5✔
1559

1560
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1561
        variables = []
5✔
1562

1563
        if model.table.name != model.name.lower():
5✔
1564
            variables.append(f"__tablename__ = {model.table.name!r}")
5✔
1565

1566
        # Render constraints and indexes as __table_args__
1567
        table_args = self.render_table_args(model.table)
5✔
1568
        if table_args:
5✔
1569
            variables.append(f"__table_args__ = {table_args}")
5✔
1570

1571
        return "\n".join(variables)
5✔
1572

1573
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1574
        column = column_attr.column
5✔
1575
        rendered_column = self.render_column(column, True)
5✔
1576
        rendered_column_python_type = self.render_column_python_type(column)
5✔
1577

1578
        kwargs: dict[str, Any] = {}
5✔
1579
        if column.nullable:
5✔
1580
            kwargs["default"] = None
5✔
1581
        kwargs["sa_column"] = f"{rendered_column}"
5✔
1582

1583
        rendered_field = render_callable("Field", kwargs=kwargs)
5✔
1584

1585
        return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
5✔
1586

1587
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1588
        rendered = super().render_relationship(relationship).partition(" = ")[2]
5✔
1589
        args = self.render_relationship_args(rendered)
5✔
1590
        kwargs: dict[str, Any] = {}
5✔
1591
        annotation = repr(relationship.target.name)
5✔
1592

1593
        if relationship.type in (
5✔
1594
            RelationshipType.ONE_TO_MANY,
1595
            RelationshipType.MANY_TO_MANY,
1596
        ):
1597
            annotation = f"list[{annotation}]"
5✔
1598
        else:
1599
            self.add_literal_import("typing", "Optional")
5✔
1600
            annotation = f"Optional[{annotation}]"
5✔
1601

1602
        rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
5✔
1603
        return f"{relationship.name}: {annotation} = {rendered_field}"
5✔
1604

1605
    def render_relationship_args(self, arguments: str) -> list[str]:
5✔
1606
        argument_list = arguments.split(",")
5✔
1607
        # delete ')' and ' ' from args
1608
        argument_list[-1] = argument_list[-1][:-1]
5✔
1609
        argument_list = [argument[1:] for argument in argument_list]
5✔
1610

1611
        rendered_args: list[str] = []
5✔
1612
        for arg in argument_list:
5✔
1613
            if "back_populates" in arg:
5✔
1614
                rendered_args.append(arg)
5✔
1615
            if "uselist=False" in arg:
5✔
1616
                rendered_args.append("sa_relationship_kwargs={'uselist': False}")
5✔
1617

1618
        return rendered_args
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc