• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 20554364289

28 Dec 2025 01:20PM UTC coverage: 96.088% (-1.3%) from 97.36%
20554364289

Pull #446

github

web-flow
Merge 5b02d70f9 into 90831a745
Pull Request #446: Support native python enum generation

89 of 112 new or added lines in 4 files covered. (79.46%)

1 existing line in 1 file now uncovered.

1572 of 1636 relevant lines covered (96.09%)

4.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.39
/src/sqlacodegen/generators.py
1
from __future__ import annotations
5✔
2

3
import inspect
5✔
4
import re
5✔
5
import sys
5✔
6
from abc import ABCMeta, abstractmethod
5✔
7
from collections import defaultdict
5✔
8
from collections.abc import Collection, Iterable, Sequence
5✔
9
from dataclasses import dataclass
5✔
10
from importlib import import_module
5✔
11
from inspect import Parameter
5✔
12
from itertools import count
5✔
13
from keyword import iskeyword
5✔
14
from pprint import pformat
5✔
15
from textwrap import indent
5✔
16
from typing import Any, ClassVar, Literal, cast
5✔
17

18
import inflect
5✔
19
import sqlalchemy
5✔
20
from sqlalchemy import (
5✔
21
    ARRAY,
22
    Boolean,
23
    CheckConstraint,
24
    Column,
25
    Computed,
26
    Constraint,
27
    DefaultClause,
28
    Enum,
29
    ForeignKey,
30
    ForeignKeyConstraint,
31
    Identity,
32
    Index,
33
    MetaData,
34
    PrimaryKeyConstraint,
35
    Table,
36
    Text,
37
    TypeDecorator,
38
    UniqueConstraint,
39
)
40
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
5✔
41
from sqlalchemy.engine import Connection, Engine
5✔
42
from sqlalchemy.exc import CompileError
5✔
43
from sqlalchemy.sql.elements import TextClause
5✔
44
from sqlalchemy.sql.type_api import UserDefinedType
5✔
45
from sqlalchemy.types import TypeEngine
5✔
46

47
from .models import (
5✔
48
    ColumnAttribute,
49
    JoinType,
50
    Model,
51
    ModelClass,
52
    RelationshipAttribute,
53
    RelationshipType,
54
)
55
from .utils import (
5✔
56
    decode_postgresql_sequence,
57
    get_column_names,
58
    get_common_fk_constraints,
59
    get_compiled_expression,
60
    get_constraint_sort_key,
61
    get_stdlib_module_names,
62
    qualified_table_name,
63
    render_callable,
64
    uses_default_name,
65
)
66

67
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
5✔
68
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
5✔
69
_re_invalid_identifier = re.compile(r"(?u)\W")
5✔
70

71

72
@dataclass
5✔
73
class LiteralImport:
5✔
74
    pkgname: str
5✔
75
    name: str
5✔
76

77

78
@dataclass
5✔
79
class Base:
5✔
80
    """Representation of MetaData for Tables, respectively Base for classes"""
81

82
    literal_imports: list[LiteralImport]
5✔
83
    declarations: list[str]
5✔
84
    metadata_ref: str
5✔
85
    decorator: str | None = None
5✔
86
    table_metadata_declaration: str | None = None
5✔
87

88

89
class CodeGenerator(metaclass=ABCMeta):
5✔
90
    valid_options: ClassVar[set[str]] = set()
5✔
91

92
    def __init__(
5✔
93
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
94
    ):
95
        self.metadata: MetaData = metadata
5✔
96
        self.bind: Connection | Engine = bind
5✔
97
        self.options: set[str] = set(options)
5✔
98

99
        # Validate options
100
        invalid_options = {opt for opt in options if opt not in self.valid_options}
5✔
101
        if invalid_options:
5✔
102
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
103

104
    @property
5✔
105
    @abstractmethod
5✔
106
    def views_supported(self) -> bool:
5✔
107
        pass
×
108

109
    @abstractmethod
5✔
110
    def generate(self) -> str:
5✔
111
        """
112
        Generate the code for the given metadata.
113
        .. note:: May modify the metadata.
114
        """
115

116

117
@dataclass(eq=False)
5✔
118
class TablesGenerator(CodeGenerator):
5✔
119
    valid_options: ClassVar[set[str]] = {
5✔
120
        "noindexes",
121
        "noconstraints",
122
        "nocomments",
123
        "noenums",
124
        "include_dialect_options",
125
        "keep_dialect_types",
126
    }
127
    stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
5✔
128

129
    def __init__(
5✔
130
        self,
131
        metadata: MetaData,
132
        bind: Connection | Engine,
133
        options: Sequence[str],
134
        *,
135
        indentation: str = "    ",
136
    ):
137
        super().__init__(metadata, bind, options)
5✔
138
        self.indentation: str = indentation
5✔
139
        self.imports: dict[str, set[str]] = defaultdict(set)
5✔
140
        self.module_imports: set[str] = set()
5✔
141

142
        # Render SchemaItem.info and dialect kwargs (Table/Column) into output
143
        self.include_dialect_options_and_info: bool = (
5✔
144
            "include_dialect_options" in self.options
145
        )
146
        # Keep dialect-specific types instead of adapting to generic SQLAlchemy types
147
        self.keep_dialect_types: bool = "keep_dialect_types" in self.options
5✔
148

149
        # Track Python enum classes: maps (table_name, column_name) -> enum_class_name
150
        self.enum_classes: dict[tuple[str, str], str] = {}
5✔
151
        # Track enum values: maps enum_class_name -> list of values
152
        self.enum_values: dict[str, list[str]] = {}
5✔
153

154
    @property
5✔
155
    def views_supported(self) -> bool:
5✔
156
        return True
×
157

158
    def generate_base(self) -> None:
5✔
159
        self.base = Base(
5✔
160
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
161
            declarations=["metadata = MetaData()"],
162
            metadata_ref="metadata",
163
        )
164

165
    def generate(self) -> str:
5✔
166
        self.generate_base()
5✔
167

168
        sections: list[str] = []
5✔
169

170
        # Remove unwanted elements from the metadata
171
        for table in list(self.metadata.tables.values()):
5✔
172
            if self.should_ignore_table(table):
5✔
173
                self.metadata.remove(table)
×
174
                continue
×
175

176
            if "noindexes" in self.options:
5✔
177
                table.indexes.clear()
5✔
178

179
            if "noconstraints" in self.options:
5✔
180
                table.constraints.clear()
5✔
181

182
            if "nocomments" in self.options:
5✔
183
                table.comment = None
5✔
184

185
            for column in table.columns:
5✔
186
                if "nocomments" in self.options:
5✔
187
                    column.comment = None
5✔
188

189
        # Use information from column constraints to figure out the intended column
190
        # types
191
        for table in self.metadata.tables.values():
5✔
192
            self.fix_column_types(table)
5✔
193

194
        # Generate the models
195
        models: list[Model] = self.generate_models()
5✔
196

197
        # Render module level variables
198
        if variables := self.render_module_variables(models):
5✔
199
            sections.append(variables + "\n")
5✔
200

201
        # Render enum classes
202
        if enum_classes := self.render_enum_classes():
5✔
203
            sections.append(enum_classes + "\n")
5✔
204

205
        # Render models
206
        if rendered_models := self.render_models(models):
5✔
207
            sections.append(rendered_models)
5✔
208

209
        # Render collected imports
210
        groups = self.group_imports()
5✔
211
        if imports := "\n\n".join(
5✔
212
            "\n".join(line for line in group) for group in groups
213
        ):
214
            sections.insert(0, imports)
5✔
215

216
        return "\n\n".join(sections) + "\n"
5✔
217

218
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
219
        for literal_import in self.base.literal_imports:
5✔
220
            self.add_literal_import(literal_import.pkgname, literal_import.name)
5✔
221

222
        for model in models:
5✔
223
            self.collect_imports_for_model(model)
5✔
224

225
    def collect_imports_for_model(self, model: Model) -> None:
5✔
226
        if model.__class__ is Model:
5✔
227
            self.add_import(Table)
5✔
228

229
        for column in model.table.c:
5✔
230
            self.collect_imports_for_column(column)
5✔
231

232
        for constraint in model.table.constraints:
5✔
233
            self.collect_imports_for_constraint(constraint)
5✔
234

235
        for index in model.table.indexes:
5✔
236
            self.collect_imports_for_constraint(index)
5✔
237

238
    def collect_imports_for_column(self, column: Column[Any]) -> None:
5✔
239
        self.add_import(column.type)
5✔
240

241
        if isinstance(column.type, ARRAY):
5✔
242
            self.add_import(column.type.item_type.__class__)
5✔
243
        elif isinstance(column.type, (JSONB, JSON)):
5✔
244
            if (
5✔
245
                not isinstance(column.type.astext_type, Text)
246
                or column.type.astext_type.length is not None
247
            ):
248
                self.add_import(column.type.astext_type)
5✔
249
        elif isinstance(column.type, DOMAIN):
5✔
250
            self.add_import(column.type.data_type.__class__)
5✔
251

252
        if column.default:
5✔
253
            self.add_import(column.default)
5✔
254

255
        if column.server_default:
5✔
256
            if isinstance(column.server_default, (Computed, Identity)):
5✔
257
                self.add_import(column.server_default)
5✔
258
            elif isinstance(column.server_default, DefaultClause):
5✔
259
                self.add_literal_import("sqlalchemy", "text")
5✔
260

261
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
5✔
262
        if isinstance(constraint, Index):
5✔
263
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
264
                self.add_literal_import("sqlalchemy", "Index")
5✔
265
        elif isinstance(constraint, PrimaryKeyConstraint):
5✔
266
            if not uses_default_name(constraint):
5✔
267
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
5✔
268
        elif isinstance(constraint, UniqueConstraint):
5✔
269
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
270
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
5✔
271
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
272
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
5✔
273
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
5✔
274
            else:
275
                self.add_import(ForeignKey)
5✔
276
        else:
277
            self.add_import(constraint)
5✔
278

279
    def add_import(self, obj: Any) -> None:
5✔
280
        # Don't store builtin imports
281
        if getattr(obj, "__module__", "builtins") == "builtins":
5✔
282
            return
×
283

284
        type_ = type(obj) if not isinstance(obj, type) else obj
5✔
285
        pkgname = type_.__module__
5✔
286

287
        # The column types have already been adapted towards generic types if possible,
288
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
289
        # use that rather than the generic sqlalchemy type as it might have different
290
        # constructor parameters.
291
        if pkgname.startswith("sqlalchemy.dialects."):
5✔
292
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
5✔
293
            dialect_pkg = import_module(dialect_pkgname)
5✔
294

295
            if type_.__name__ in dialect_pkg.__all__:
5✔
296
                pkgname = dialect_pkgname
5✔
297
        elif type_ is getattr(sqlalchemy, type_.__name__, None):
5✔
298
            pkgname = "sqlalchemy"
5✔
299
        else:
300
            pkgname = type_.__module__
5✔
301

302
        self.add_literal_import(pkgname, type_.__name__)
5✔
303

304
    def add_literal_import(self, pkgname: str, name: str) -> None:
5✔
305
        names = self.imports.setdefault(pkgname, set())
5✔
306
        names.add(name)
5✔
307

308
    def remove_literal_import(self, pkgname: str, name: str) -> None:
5✔
309
        names = self.imports.setdefault(pkgname, set())
5✔
310
        if name in names:
5✔
311
            names.remove(name)
×
312

313
    def add_module_import(self, pgkname: str) -> None:
5✔
314
        self.module_imports.add(pgkname)
5✔
315

316
    def group_imports(self) -> list[list[str]]:
5✔
317
        future_imports: list[str] = []
5✔
318
        stdlib_imports: list[str] = []
5✔
319
        thirdparty_imports: list[str] = []
5✔
320

321
        def get_collection(package: str) -> list[str]:
5✔
322
            collection = thirdparty_imports
5✔
323
            if package == "__future__":
5✔
324
                collection = future_imports
×
325
            elif package in self.stdlib_module_names:
5✔
326
                collection = stdlib_imports
5✔
327
            elif package in sys.modules:
5✔
328
                if "site-packages" not in (sys.modules[package].__file__ or ""):
5✔
329
                    collection = stdlib_imports
5✔
330
            return collection
5✔
331

332
        for package in sorted(self.imports):
5✔
333
            imports_list = sorted(self.imports[package])
5✔
334
            imports = ", ".join(imports_list)
5✔
335

336
            collection = get_collection(package)
5✔
337
            collection.append(f"from {package} import {imports}")
5✔
338

339
        for module in sorted(self.module_imports):
5✔
340
            collection = get_collection(module)
5✔
341
            collection.append(f"import {module}")
5✔
342

343
        return [
5✔
344
            group
345
            for group in (future_imports, stdlib_imports, thirdparty_imports)
346
            if group
347
        ]
348

349
    def generate_models(self) -> list[Model]:
5✔
350
        models = [Model(table) for table in self.metadata.sorted_tables]
5✔
351

352
        # Collect the imports
353
        self.collect_imports(models)
5✔
354

355
        # Generate names for models
356
        global_names = {
5✔
357
            name for namespace in self.imports.values() for name in namespace
358
        }
359
        for model in models:
5✔
360
            self.generate_model_name(model, global_names)
5✔
361
            global_names.add(model.name)
5✔
362

363
        return models
5✔
364

365
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
366
        preferred_name = f"t_{model.table.name}"
5✔
367
        model.name = self.find_free_name(preferred_name, global_names)
5✔
368

369
    def render_module_variables(self, models: list[Model]) -> str:
5✔
370
        declarations = self.base.declarations
5✔
371

372
        if any(not isinstance(model, ModelClass) for model in models):
5✔
373
            if self.base.table_metadata_declaration is not None:
5✔
374
                declarations.append(self.base.table_metadata_declaration)
×
375

376
        return "\n".join(declarations)
5✔
377

378
    def render_models(self, models: list[Model]) -> str:
5✔
379
        rendered: list[str] = []
5✔
380
        for model in models:
5✔
381
            rendered_table = self.render_table(model.table)
5✔
382
            rendered.append(f"{model.name} = {rendered_table}")
5✔
383

384
        return "\n\n".join(rendered)
5✔
385

386
    def render_table(self, table: Table) -> str:
5✔
387
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
5✔
388
        kwargs: dict[str, object] = {}
5✔
389
        for column in table.columns:
5✔
390
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
391
            # Table.columns
392
            args.append(self.render_column(column, True, is_table=True))
5✔
393

394
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
395
            if uses_default_name(constraint):
5✔
396
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
397
                    continue
5✔
398
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
5✔
399
                    if len(constraint.columns) == 1:
5✔
400
                        continue
5✔
401

402
            args.append(self.render_constraint(constraint))
5✔
403

404
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
5✔
405
            # One-column indexes should be rendered as index=True on columns
406
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
407
                args.append(self.render_index(index))
5✔
408

409
        if table.schema:
5✔
410
            kwargs["schema"] = repr(table.schema)
5✔
411

412
        table_comment = getattr(table, "comment", None)
5✔
413
        if table_comment:
5✔
414
            kwargs["comment"] = repr(table.comment)
5✔
415

416
        # add info + dialect kwargs for callable context (opt-in)
417
        if self.include_dialect_options_and_info:
5✔
418
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=False)
5✔
419

420
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
5✔
421

422
    def render_index(self, index: Index) -> str:
5✔
423
        extra_args = [repr(col.name) for col in index.columns]
5✔
424
        kwargs = {}
5✔
425
        if index.unique:
5✔
426
            kwargs["unique"] = True
5✔
427

428
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
5✔
429

430
    # TODO find better solution for is_table
431
    def render_column(
5✔
432
        self, column: Column[Any], show_name: bool, is_table: bool = False
433
    ) -> str:
434
        args = []
5✔
435
        kwargs: dict[str, Any] = {}
5✔
436
        kwarg = []
5✔
437
        is_part_of_composite_pk = (
5✔
438
            column.primary_key and len(column.table.primary_key) > 1
439
        )
440
        dedicated_fks = [
5✔
441
            c
442
            for c in column.foreign_keys
443
            if c.constraint
444
            and len(c.constraint.columns) == 1
445
            and uses_default_name(c.constraint)
446
        ]
447
        is_unique = any(
5✔
448
            isinstance(c, UniqueConstraint)
449
            and set(c.columns) == {column}
450
            and uses_default_name(c)
451
            for c in column.table.constraints
452
        )
453
        is_unique = is_unique or any(
5✔
454
            i.unique and set(i.columns) == {column} and uses_default_name(i)
455
            for i in column.table.indexes
456
        )
457
        is_primary = (
5✔
458
            any(
459
                isinstance(c, PrimaryKeyConstraint)
460
                and column.name in c.columns
461
                and uses_default_name(c)
462
                for c in column.table.constraints
463
            )
464
            or column.primary_key
465
        )
466
        has_index = any(
5✔
467
            set(i.columns) == {column} and uses_default_name(i)
468
            for i in column.table.indexes
469
        )
470

471
        if show_name:
5✔
472
            args.append(repr(column.name))
5✔
473

474
        # Render the column type if there are no foreign keys on it or any of them
475
        # points back to itself
476
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
5✔
477
            args.append(self.render_column_type(column.type, column))
5✔
478

479
        for fk in dedicated_fks:
5✔
480
            args.append(self.render_constraint(fk))
5✔
481

482
        if column.default:
5✔
483
            args.append(repr(column.default))
5✔
484

485
        if column.key != column.name:
5✔
486
            kwargs["key"] = column.key
×
487
        if is_primary:
5✔
488
            kwargs["primary_key"] = True
5✔
489
        if not column.nullable and not column.primary_key:
5✔
490
            kwargs["nullable"] = False
5✔
491
        if column.nullable and is_part_of_composite_pk:
5✔
492
            kwargs["nullable"] = True
5✔
493

494
        if is_unique:
5✔
495
            column.unique = True
5✔
496
            kwargs["unique"] = True
5✔
497
        if has_index:
5✔
498
            column.index = True
5✔
499
            kwarg.append("index")
5✔
500
            kwargs["index"] = True
5✔
501

502
        if isinstance(column.server_default, DefaultClause):
5✔
503
            kwargs["server_default"] = render_callable(
5✔
504
                "text", repr(cast(TextClause, column.server_default.arg).text)
505
            )
506
        elif isinstance(column.server_default, Computed):
5✔
507
            expression = str(column.server_default.sqltext)
5✔
508

509
            computed_kwargs = {}
5✔
510
            if column.server_default.persisted is not None:
5✔
511
                computed_kwargs["persisted"] = column.server_default.persisted
5✔
512

513
            args.append(
5✔
514
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
515
            )
516
        elif isinstance(column.server_default, Identity):
5✔
517
            args.append(repr(column.server_default))
5✔
518
        elif column.server_default:
5✔
519
            kwargs["server_default"] = repr(column.server_default)
×
520

521
        comment = getattr(column, "comment", None)
5✔
522
        if comment:
5✔
523
            kwargs["comment"] = repr(comment)
5✔
524

525
        # add column info + dialect kwargs for callable context (opt-in)
526
        if self.include_dialect_options_and_info:
5✔
527
            self._add_dialect_kwargs_and_info(column, kwargs, values_for_dict=False)
5✔
528

529
        return self.render_column_callable(is_table, *args, **kwargs)
5✔
530

531
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
5✔
532
        if is_table:
5✔
533
            self.add_import(Column)
5✔
534
            return render_callable("Column", *args, kwargs=kwargs)
5✔
535
        else:
536
            return render_callable("mapped_column", *args, kwargs=kwargs)
5✔
537

538
    def render_column_type(
5✔
539
        self, coltype: TypeEngine[Any], column: Column[Any] | None = None
540
    ) -> str:
541
        # Check if this is an enum column with a Python enum class
542
        if isinstance(coltype, Enum) and column is not None:
5✔
543
            table_name = column.table.name
5✔
544
            column_name = column.name
5✔
545
            if (table_name, column_name) in self.enum_classes:
5✔
546
                enum_class_name = self.enum_classes[(table_name, column_name)]
5✔
547
                # Import SQLAlchemy Enum (will be handled in collect_imports)
548
                self.add_import(Enum)
5✔
549
                # Return the Python enum class as the type parameter
550
                return f"Enum({enum_class_name})"
5✔
551

552
        args = []
5✔
553
        kwargs: dict[str, Any] = {}
5✔
554
        sig = inspect.signature(coltype.__class__.__init__)
5✔
555
        defaults = {param.name: param.default for param in sig.parameters.values()}
5✔
556
        missing = object()
5✔
557
        use_kwargs = False
5✔
558
        for param in list(sig.parameters.values())[1:]:
5✔
559
            # Remove annoyances like _warn_on_bytestring
560
            if param.name.startswith("_"):
5✔
561
                continue
5✔
562
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
5✔
563
                use_kwargs = True
5✔
564
                continue
5✔
565

566
            value = getattr(coltype, param.name, missing)
5✔
567

568
            if isinstance(value, (JSONB, JSON)):
5✔
569
                # Remove astext_type if it's the default
570
                if (
5✔
571
                    isinstance(value.astext_type, Text)
572
                    and value.astext_type.length is None
573
                ):
574
                    value.astext_type = None  # type: ignore[assignment]
5✔
575
                else:
576
                    self.add_import(Text)
5✔
577

578
            default = defaults.get(param.name, missing)
5✔
579
            if isinstance(value, TextClause):
5✔
580
                self.add_literal_import("sqlalchemy", "text")
5✔
581
                rendered_value = render_callable("text", repr(value.text))
5✔
582
            else:
583
                rendered_value = repr(value)
5✔
584

585
            if value is missing or value == default:
5✔
586
                use_kwargs = True
5✔
587
            elif use_kwargs:
5✔
588
                kwargs[param.name] = rendered_value
5✔
589
            else:
590
                args.append(rendered_value)
5✔
591

592
        vararg = next(
5✔
593
            (
594
                param.name
595
                for param in sig.parameters.values()
596
                if param.kind is Parameter.VAR_POSITIONAL
597
            ),
598
            None,
599
        )
600
        if vararg and hasattr(coltype, vararg):
5✔
601
            varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
5✔
602
            args.extend(varargs_repr)
5✔
603

604
        # These arguments cannot be autodetected from the Enum initializer
605
        if isinstance(coltype, Enum):
5✔
606
            for colname in "name", "schema":
5✔
607
                if (value := getattr(coltype, colname)) is not None:
5✔
608
                    kwargs[colname] = repr(value)
5✔
609

610
        if isinstance(coltype, (JSONB, JSON)):
5✔
611
            # Remove astext_type if it's the default
612
            if (
5✔
613
                isinstance(coltype.astext_type, Text)
614
                and coltype.astext_type.length is None
615
            ):
616
                del kwargs["astext_type"]
5✔
617

618
        if args or kwargs:
5✔
619
            return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
5✔
620
        else:
621
            return coltype.__class__.__name__
5✔
622

623
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
5✔
624
        def add_fk_options(*opts: Any) -> None:
5✔
625
            args.extend(repr(opt) for opt in opts)
5✔
626
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
5✔
627
                value = getattr(constraint, attr, None)
5✔
628
                if value:
5✔
629
                    kwargs[attr] = repr(value)
5✔
630

631
        args: list[str] = []
5✔
632
        kwargs: dict[str, Any] = {}
5✔
633
        if isinstance(constraint, ForeignKey):
5✔
634
            remote_column = (
5✔
635
                f"{constraint.column.table.fullname}.{constraint.column.name}"
636
            )
637
            add_fk_options(remote_column)
5✔
638
        elif isinstance(constraint, ForeignKeyConstraint):
5✔
639
            local_columns = get_column_names(constraint)
5✔
640
            remote_columns = [
5✔
641
                f"{fk.column.table.fullname}.{fk.column.name}"
642
                for fk in constraint.elements
643
            ]
644
            add_fk_options(local_columns, remote_columns)
5✔
645
        elif isinstance(constraint, CheckConstraint):
5✔
646
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
5✔
647
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
5✔
648
            args.extend(repr(col.name) for col in constraint.columns)
5✔
649
        else:
650
            raise TypeError(
×
651
                f"Cannot render constraint of type {constraint.__class__.__name__}"
652
            )
653

654
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
5✔
655
            kwargs["name"] = repr(constraint.name)
5✔
656

657
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
5✔
658

659
    def _add_dialect_kwargs_and_info(
5✔
660
        self, obj: Any, target_kwargs: dict[str, object], *, values_for_dict: bool
661
    ) -> None:
662
        """
663
        Merge SchemaItem-like object's .info and .dialect_kwargs into target_kwargs.
664
        - values_for_dict=True: keep raw values so pretty-printer emits repr() (for __table_args__ dict)
665
        - values_for_dict=False: set values to repr() strings (for callable kwargs)
666
        """
667
        info_dict = getattr(obj, "info", None)
5✔
668
        if info_dict:
5✔
669
            target_kwargs["info"] = info_dict if values_for_dict else repr(info_dict)
5✔
670

671
        dialect_keys: list[str]
672
        try:
5✔
673
            dialect_keys = sorted(getattr(obj, "dialect_kwargs"))
5✔
674
        except Exception:
×
675
            return
×
676

677
        dialect_kwargs = getattr(obj, "dialect_kwargs", {})
5✔
678
        for key in dialect_keys:
5✔
679
            try:
5✔
680
                value = dialect_kwargs[key]
5✔
681
            except Exception:
×
682
                continue
×
683

684
            # Render values:
685
            # - callable context (values_for_dict=False): produce a string expression.
686
            #   primitives use repr(value); custom objects stringify then repr().
687
            # - dict context (values_for_dict=True): pass raw primitives / str;
688
            #   custom objects become str(value) so pformat quotes them.
689
            if values_for_dict:
5✔
690
                if isinstance(value, type(None) | bool | int | float):
5✔
691
                    target_kwargs[key] = value
×
692
                elif isinstance(value, str | dict | list):
5✔
693
                    target_kwargs[key] = value
5✔
694
                else:
695
                    target_kwargs[key] = str(value)
5✔
696
            else:
697
                if isinstance(
5✔
698
                    value, type(None) | bool | int | float | str | dict | list
699
                ):
700
                    target_kwargs[key] = repr(value)
5✔
701
                else:
702
                    target_kwargs[key] = repr(str(value))
5✔
703

704
    def should_ignore_table(self, table: Table) -> bool:
5✔
705
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
706
        # tables
707
        return table.name in ("alembic_version", "migrate_version")
5✔
708

709
    def find_free_name(
5✔
710
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
711
    ) -> str:
712
        """
713
        Generate an attribute name that does not clash with other local or global names.
714
        """
715
        name = name.strip()
5✔
716
        assert name, "Identifier cannot be empty"
5✔
717
        name = _re_invalid_identifier.sub("_", name)
5✔
718
        if name[0].isdigit():
5✔
719
            name = "_" + name
5✔
720
        elif iskeyword(name) or name == "metadata":
5✔
721
            name += "_"
5✔
722

723
        original = name
5✔
724
        for i in count():
5✔
725
            if name not in global_names and name not in local_names:
5✔
726
                break
5✔
727

728
            name = original + (str(i) if i else "_")
5✔
729

730
        return name
5✔
731

732
    def _enum_name_to_class_name(self, enum_name: str) -> str:
5✔
733
        """Convert a database enum name to a Python class name (PascalCase)."""
734
        parts = []
5✔
735
        for part in enum_name.split("_"):
5✔
736
            if part:
5✔
737
                parts.append(part.capitalize())
5✔
738
        return "".join(parts)
5✔
739

740
    def _create_enum_class(
5✔
741
        self, table_name: str, column_name: str, values: list[str]
742
    ) -> str:
743
        """
744
        Create a Python enum class name and register it.
745

746
        Returns the enum class name to use in generated code.
747
        """
748
        # Generate enum class name from table and column names
749
        # Convert to PascalCase: user_status -> UserStatus
NEW
750
        parts = []
×
NEW
751
        for part in table_name.split("_"):
×
NEW
752
            if part:
×
NEW
753
                parts.append(part.capitalize())
×
NEW
754
        for part in column_name.split("_"):
×
NEW
755
            if part:
×
NEW
756
                parts.append(part.capitalize())
×
757

NEW
758
        base_name = "".join(parts)
×
759

760
        # Ensure uniqueness
NEW
761
        enum_class_name = base_name
×
NEW
762
        counter = 1
×
NEW
763
        while enum_class_name in self.enum_values:
×
764
            # Check if it's the same enum (same values)
NEW
765
            if self.enum_values[enum_class_name] == values:
×
766
                # Reuse existing enum class
NEW
767
                return enum_class_name
×
NEW
768
            enum_class_name = f"{base_name}{counter}"
×
NEW
769
            counter += 1
×
770

771
        # Register the new enum class
NEW
772
        self.enum_values[enum_class_name] = values
×
NEW
773
        return enum_class_name
×
774

775
    def render_enum_classes(self) -> str:
5✔
776
        """Render Python enum class definitions."""
777
        if not self.enum_values:
5✔
778
            return ""
5✔
779

780
        self.add_module_import("enum")
5✔
781

782
        enum_defs = []
5✔
783
        for enum_class_name, values in sorted(self.enum_values.items()):
5✔
784
            # Create enum members with valid Python identifiers
785
            members = []
5✔
786
            for value in values:
5✔
787
                # Unescape SQL escape sequences (e.g., \' -> ')
788
                # The value from the CHECK constraint has SQL escaping
789
                unescaped_value = value.replace("\\'", "'").replace("\\\\", "\\")
5✔
790

791
                # Create a valid identifier from the enum value
792
                member_name = _re_invalid_identifier.sub("_", unescaped_value).upper()
5✔
793
                if not member_name:
5✔
NEW
794
                    member_name = "EMPTY"
×
795
                elif member_name[0].isdigit():
5✔
NEW
796
                    member_name = "_" + member_name
×
797
                elif iskeyword(member_name):
5✔
NEW
798
                    member_name += "_"
×
799

800
                # Re-escape for Python string literal
801
                python_escaped = unescaped_value.replace("\\", "\\\\").replace(
5✔
802
                    "'", "\\'"
803
                )
804
                members.append(f"    {member_name} = '{python_escaped}'")
5✔
805

806
            enum_def = f"class {enum_class_name}(str, enum.Enum):\n" + "\n".join(
5✔
807
                members
808
            )
809
            enum_defs.append(enum_def)
5✔
810

811
        return "\n\n\n".join(enum_defs)
5✔
812

813
    def fix_column_types(self, table: Table) -> None:
5✔
814
        """Adjust the reflected column types."""
815
        # Detect check constraints for boolean and enum columns
816
        for constraint in table.constraints.copy():
5✔
817
            if isinstance(constraint, CheckConstraint):
5✔
818
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
5✔
819

820
                # Turn any integer-like column with a CheckConstraint like
821
                # "column IN (0, 1)" into a Boolean
822
                if match := _re_boolean_check_constraint.match(sqltext):
5✔
823
                    if colname_match := _re_column_name.match(match.group(1)):
5✔
824
                        colname = colname_match.group(3)
5✔
825
                        table.constraints.remove(constraint)
5✔
826
                        table.c[colname].type = Boolean()
5✔
827
                        continue
5✔
828

829
        for column in table.c:
5✔
830
            # Handle native database Enum types (e.g., PostgreSQL ENUM)
831
            if (
5✔
832
                "noenums" not in self.options
833
                and isinstance(column.type, Enum)
834
                and column.type.enums
835
            ):
836
                if column.type.name:
5✔
837
                    # Named enum - create shared enum class if not already created
838
                    if (table.name, column.name) not in self.enum_classes:
5✔
839
                        # Check if we've already created an enum for this name
840
                        existing_class = None
5✔
841
                        for (t, c), cls in self.enum_classes.items():
5✔
842
                            if cls == self._enum_name_to_class_name(column.type.name):
5✔
843
                                existing_class = cls
5✔
844
                                break
5✔
845

846
                        if existing_class:
5✔
847
                            enum_class_name = existing_class
5✔
848
                        else:
849
                            # Create new enum class from the enum's name
850
                            enum_class_name = self._enum_name_to_class_name(
5✔
851
                                column.type.name
852
                            )
853
                            # Register the enum values if not already registered
854
                            if enum_class_name not in self.enum_values:
5✔
855
                                self.enum_values[enum_class_name] = list(
5✔
856
                                    column.type.enums
857
                                )
858

859
                        self.enum_classes[(table.name, column.name)] = enum_class_name
5✔
860
                else:
861
                    # Unnamed enum - create enum class per column
NEW
862
                    if (table.name, column.name) not in self.enum_classes:
×
NEW
863
                        enum_class_name = self._create_enum_class(
×
864
                            table.name, column.name, list(column.type.enums)
865
                        )
NEW
866
                        self.enum_classes[(table.name, column.name)] = enum_class_name
×
867

868
            if not self.keep_dialect_types:
5✔
869
                try:
5✔
870
                    column.type = self.get_adapted_type(column.type)
5✔
871
                except CompileError:
5✔
872
                    continue
5✔
873

874
            # PostgreSQL specific fix: detect sequences from server_default
875
            if column.server_default and self.bind.dialect.name == "postgresql":
5✔
876
                if isinstance(column.server_default, DefaultClause) and isinstance(
5✔
877
                    column.server_default.arg, TextClause
878
                ):
879
                    schema, seqname = decode_postgresql_sequence(
5✔
880
                        column.server_default.arg
881
                    )
882
                    if seqname:
5✔
883
                        # Add an explicit sequence
884
                        if seqname != f"{column.table.name}_{column.name}_seq":
5✔
885
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
5✔
886

887
                        column.server_default = None
5✔
888

889
    def get_adapted_type(self, coltype: Any) -> Any:
5✔
890
        compiled_type = coltype.compile(self.bind.engine.dialect)
5✔
891
        for supercls in coltype.__class__.__mro__:
5✔
892
            if not supercls.__name__.startswith("_") and hasattr(
5✔
893
                supercls, "__visit_name__"
894
            ):
895
                # Don't try to adapt UserDefinedType as it's not a proper column type
896
                if supercls is UserDefinedType or issubclass(supercls, TypeDecorator):
5✔
897
                    return coltype
5✔
898

899
                # Hack to fix adaptation of the Enum class which is broken since
900
                # SQLAlchemy 1.2
901
                kw = {}
5✔
902
                if supercls is Enum:
5✔
903
                    kw["name"] = coltype.name
5✔
904
                    if coltype.schema:
5✔
905
                        kw["schema"] = coltype.schema
5✔
906

907
                # Hack to fix Postgres DOMAIN type adaptation, broken as of SQLAlchemy 2.0.42
908
                # For additional information - https://github.com/agronholm/sqlacodegen/issues/416#issuecomment-3417480599
909
                if supercls is DOMAIN:
5✔
910
                    if coltype.default:
5✔
911
                        kw["default"] = coltype.default
×
912
                    if coltype.constraint_name is not None:
5✔
913
                        kw["constraint_name"] = coltype.constraint_name
5✔
914
                    if coltype.not_null:
5✔
915
                        kw["not_null"] = coltype.not_null
×
916
                    if coltype.check is not None:
5✔
917
                        kw["check"] = coltype.check
5✔
918
                    if coltype.create_type:
5✔
919
                        kw["create_type"] = coltype.create_type
5✔
920

921
                try:
5✔
922
                    new_coltype = coltype.adapt(supercls)
5✔
923
                except TypeError:
5✔
924
                    # If the adaptation fails, don't try again
925
                    break
5✔
926

927
                for key, value in kw.items():
5✔
928
                    setattr(new_coltype, key, value)
5✔
929

930
                if isinstance(coltype, ARRAY):
5✔
931
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
5✔
932

933
                try:
5✔
934
                    # If the adapted column type does not render the same as the
935
                    # original, don't substitute it
936
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
5✔
937
                        break
5✔
938
                except CompileError:
5✔
939
                    # If the adapted column type can't be compiled, don't substitute it
940
                    break
5✔
941

942
                # Stop on the first valid non-uppercase column type class
943
                coltype = new_coltype
5✔
944
                if supercls.__name__ != supercls.__name__.upper():
5✔
945
                    break
5✔
946

947
        return coltype
5✔
948

949

950
class DeclarativeGenerator(TablesGenerator):
5✔
951
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
5✔
952
        "use_inflect",
953
        "nojoined",
954
        "nobidi",
955
        "noidsuffix",
956
    }
957

958
    def __init__(
5✔
959
        self,
960
        metadata: MetaData,
961
        bind: Connection | Engine,
962
        options: Sequence[str],
963
        *,
964
        indentation: str = "    ",
965
        base_class_name: str = "Base",
966
    ):
967
        super().__init__(metadata, bind, options, indentation=indentation)
5✔
968
        self.base_class_name: str = base_class_name
5✔
969
        self.inflect_engine = inflect.engine()
5✔
970

971
    def generate_base(self) -> None:
5✔
972
        self.base = Base(
5✔
973
            literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
974
            declarations=[
975
                f"class {self.base_class_name}(DeclarativeBase):",
976
                f"{self.indentation}pass",
977
            ],
978
            metadata_ref=f"{self.base_class_name}.metadata",
979
        )
980

981
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
982
        super().collect_imports(models)
5✔
983
        if any(isinstance(model, ModelClass) for model in models):
5✔
984
            self.add_literal_import("sqlalchemy.orm", "Mapped")
5✔
985
            self.add_literal_import("sqlalchemy.orm", "mapped_column")
5✔
986

987
    def collect_imports_for_model(self, model: Model) -> None:
5✔
988
        super().collect_imports_for_model(model)
5✔
989
        if isinstance(model, ModelClass):
5✔
990
            if model.relationships:
5✔
991
                self.add_literal_import("sqlalchemy.orm", "relationship")
5✔
992

993
    def generate_models(self) -> list[Model]:
5✔
994
        models_by_table_name: dict[str, Model] = {}
5✔
995

996
        # Pick association tables from the metadata into their own set, don't process
997
        # them normally
998
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
5✔
999
        for table in self.metadata.sorted_tables:
5✔
1000
            qualified_name = qualified_table_name(table)
5✔
1001

1002
            # Link tables have exactly two foreign key constraints and all columns are
1003
            # involved in them
1004
            fk_constraints = sorted(
5✔
1005
                table.foreign_key_constraints, key=get_constraint_sort_key
1006
            )
1007
            if len(fk_constraints) == 2 and all(
5✔
1008
                col.foreign_keys for col in table.columns
1009
            ):
1010
                model = models_by_table_name[qualified_name] = Model(table)
5✔
1011
                tablename = fk_constraints[0].elements[0].column.table.name
5✔
1012
                links[tablename].append(model)
5✔
1013
                continue
5✔
1014

1015
            # Only form model classes for tables that have a primary key and are not
1016
            # association tables
1017
            if not table.primary_key:
5✔
1018
                models_by_table_name[qualified_name] = Model(table)
5✔
1019
            else:
1020
                model = ModelClass(table)
5✔
1021
                models_by_table_name[qualified_name] = model
5✔
1022

1023
                # Fill in the columns
1024
                for column in table.c:
5✔
1025
                    column_attr = ColumnAttribute(model, column)
5✔
1026
                    model.columns.append(column_attr)
5✔
1027

1028
        # Add relationships
1029
        for model in models_by_table_name.values():
5✔
1030
            if isinstance(model, ModelClass):
5✔
1031
                self.generate_relationships(
5✔
1032
                    model, models_by_table_name, links[model.table.name]
1033
                )
1034

1035
        # Nest inherited classes in their superclasses to ensure proper ordering
1036
        if "nojoined" not in self.options:
5✔
1037
            for model in list(models_by_table_name.values()):
5✔
1038
                if not isinstance(model, ModelClass):
5✔
1039
                    continue
5✔
1040

1041
                pk_column_names = {col.name for col in model.table.primary_key.columns}
5✔
1042
                for constraint in model.table.foreign_key_constraints:
5✔
1043
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
1044
                        target = models_by_table_name[
5✔
1045
                            qualified_table_name(constraint.elements[0].column.table)
1046
                        ]
1047
                        if isinstance(target, ModelClass):
5✔
1048
                            model.parent_class = target
5✔
1049
                            target.children.append(model)
5✔
1050

1051
        # Change base if we only have tables
1052
        if not any(
5✔
1053
            isinstance(model, ModelClass) for model in models_by_table_name.values()
1054
        ):
1055
            super().generate_base()
5✔
1056

1057
        # Collect the imports
1058
        self.collect_imports(models_by_table_name.values())
5✔
1059

1060
        # Rename models and their attributes that conflict with imports or other
1061
        # attributes
1062
        global_names = {
5✔
1063
            name for namespace in self.imports.values() for name in namespace
1064
        }
1065
        for model in models_by_table_name.values():
5✔
1066
            self.generate_model_name(model, global_names)
5✔
1067
            global_names.add(model.name)
5✔
1068

1069
        return list(models_by_table_name.values())
5✔
1070

1071
    def generate_relationships(
5✔
1072
        self,
1073
        source: ModelClass,
1074
        models_by_table_name: dict[str, Model],
1075
        association_tables: list[Model],
1076
    ) -> list[RelationshipAttribute]:
1077
        relationships: list[RelationshipAttribute] = []
5✔
1078
        reverse_relationship: RelationshipAttribute | None
1079

1080
        # Add many-to-one (and one-to-many) relationships
1081
        pk_column_names = {col.name for col in source.table.primary_key.columns}
5✔
1082
        for constraint in sorted(
5✔
1083
            source.table.foreign_key_constraints, key=get_constraint_sort_key
1084
        ):
1085
            target = models_by_table_name[
5✔
1086
                qualified_table_name(constraint.elements[0].column.table)
1087
            ]
1088
            if isinstance(target, ModelClass):
5✔
1089
                if "nojoined" not in self.options:
5✔
1090
                    if set(get_column_names(constraint)) == pk_column_names:
5✔
1091
                        parent = models_by_table_name[
5✔
1092
                            qualified_table_name(constraint.elements[0].column.table)
1093
                        ]
1094
                        if isinstance(parent, ModelClass):
5✔
1095
                            source.parent_class = parent
5✔
1096
                            parent.children.append(source)
5✔
1097
                            continue
5✔
1098

1099
                # Add uselist=False to One-to-One relationships
1100
                column_names = get_column_names(constraint)
5✔
1101
                if any(
5✔
1102
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
1103
                    and {col.name for col in c.columns} == set(column_names)
1104
                    for c in constraint.table.constraints
1105
                ):
1106
                    r_type = RelationshipType.ONE_TO_ONE
5✔
1107
                else:
1108
                    r_type = RelationshipType.MANY_TO_ONE
5✔
1109

1110
                relationship = RelationshipAttribute(r_type, source, target, constraint)
5✔
1111
                source.relationships.append(relationship)
5✔
1112

1113
                # For self referential relationships, remote_side needs to be set
1114
                if source is target:
5✔
1115
                    relationship.remote_side = [
5✔
1116
                        source.get_column_attribute(col.name)
1117
                        for col in constraint.referred_table.primary_key
1118
                    ]
1119

1120
                # If the two tables share more than one foreign key constraint,
1121
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
1122
                # it needs
1123
                common_fk_constraints = get_common_fk_constraints(
5✔
1124
                    source.table, target.table
1125
                )
1126
                if len(common_fk_constraints) > 1:
5✔
1127
                    relationship.foreign_keys = [
5✔
1128
                        source.get_column_attribute(key)
1129
                        for key in constraint.column_keys
1130
                    ]
1131

1132
                # Generate the opposite end of the relationship in the target class
1133
                if "nobidi" not in self.options:
5✔
1134
                    if r_type is RelationshipType.MANY_TO_ONE:
5✔
1135
                        r_type = RelationshipType.ONE_TO_MANY
5✔
1136

1137
                    reverse_relationship = RelationshipAttribute(
5✔
1138
                        r_type,
1139
                        target,
1140
                        source,
1141
                        constraint,
1142
                        foreign_keys=relationship.foreign_keys,
1143
                        backref=relationship,
1144
                    )
1145
                    relationship.backref = reverse_relationship
5✔
1146
                    target.relationships.append(reverse_relationship)
5✔
1147

1148
                    # For self referential relationships, remote_side needs to be set
1149
                    if source is target:
5✔
1150
                        reverse_relationship.remote_side = [
5✔
1151
                            source.get_column_attribute(colname)
1152
                            for colname in constraint.column_keys
1153
                        ]
1154

1155
        # Add many-to-many relationships
1156
        for association_table in association_tables:
5✔
1157
            fk_constraints = sorted(
5✔
1158
                association_table.table.foreign_key_constraints,
1159
                key=get_constraint_sort_key,
1160
            )
1161
            target = models_by_table_name[
5✔
1162
                qualified_table_name(fk_constraints[1].elements[0].column.table)
1163
            ]
1164
            if isinstance(target, ModelClass):
5✔
1165
                relationship = RelationshipAttribute(
5✔
1166
                    RelationshipType.MANY_TO_MANY,
1167
                    source,
1168
                    target,
1169
                    fk_constraints[1],
1170
                    association_table,
1171
                )
1172
                source.relationships.append(relationship)
5✔
1173

1174
                # Generate the opposite end of the relationship in the target class
1175
                reverse_relationship = None
5✔
1176
                if "nobidi" not in self.options:
5✔
1177
                    reverse_relationship = RelationshipAttribute(
5✔
1178
                        RelationshipType.MANY_TO_MANY,
1179
                        target,
1180
                        source,
1181
                        fk_constraints[0],
1182
                        association_table,
1183
                        relationship,
1184
                    )
1185
                    relationship.backref = reverse_relationship
5✔
1186
                    target.relationships.append(reverse_relationship)
5✔
1187

1188
                # Add a primary/secondary join for self-referential many-to-many
1189
                # relationships
1190
                if source is target:
5✔
1191
                    both_relationships = [relationship]
5✔
1192
                    reverse_flags = [False, True]
5✔
1193
                    if reverse_relationship:
5✔
1194
                        both_relationships.append(reverse_relationship)
5✔
1195

1196
                    for relationship, reverse in zip(both_relationships, reverse_flags):
5✔
1197
                        if (
5✔
1198
                            not relationship.association_table
1199
                            or not relationship.constraint
1200
                        ):
1201
                            continue
×
1202

1203
                        constraints = sorted(
5✔
1204
                            relationship.constraint.table.foreign_key_constraints,
1205
                            key=get_constraint_sort_key,
1206
                            reverse=reverse,
1207
                        )
1208
                        pri_pairs = zip(
5✔
1209
                            get_column_names(constraints[0]), constraints[0].elements
1210
                        )
1211
                        sec_pairs = zip(
5✔
1212
                            get_column_names(constraints[1]), constraints[1].elements
1213
                        )
1214
                        relationship.primaryjoin = [
5✔
1215
                            (
1216
                                relationship.source,
1217
                                elem.column.name,
1218
                                relationship.association_table,
1219
                                col,
1220
                            )
1221
                            for col, elem in pri_pairs
1222
                        ]
1223
                        relationship.secondaryjoin = [
5✔
1224
                            (
1225
                                relationship.target,
1226
                                elem.column.name,
1227
                                relationship.association_table,
1228
                                col,
1229
                            )
1230
                            for col, elem in sec_pairs
1231
                        ]
1232

1233
        return relationships
5✔
1234

1235
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
5✔
1236
        if isinstance(model, ModelClass):
5✔
1237
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
5✔
1238
            preferred_name = "".join(
5✔
1239
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1240
            )
1241
            if "use_inflect" in self.options:
5✔
1242
                singular_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1243
                if singular_name:
5✔
1244
                    preferred_name = singular_name
5✔
1245

1246
            model.name = self.find_free_name(preferred_name, global_names)
5✔
1247

1248
            # Fill in the names for column attributes
1249
            local_names: set[str] = set()
5✔
1250
            for column_attr in model.columns:
5✔
1251
                self.generate_column_attr_name(column_attr, global_names, local_names)
5✔
1252
                local_names.add(column_attr.name)
5✔
1253

1254
            # Fill in the names for relationship attributes
1255
            for relationship in model.relationships:
5✔
1256
                self.generate_relationship_name(relationship, global_names, local_names)
5✔
1257
                local_names.add(relationship.name)
5✔
1258
        else:
1259
            super().generate_model_name(model, global_names)
5✔
1260

1261
    def generate_column_attr_name(
5✔
1262
        self,
1263
        column_attr: ColumnAttribute,
1264
        global_names: set[str],
1265
        local_names: set[str],
1266
    ) -> None:
1267
        column_attr.name = self.find_free_name(
5✔
1268
            column_attr.column.name, global_names, local_names
1269
        )
1270

1271
    def generate_relationship_name(
5✔
1272
        self,
1273
        relationship: RelationshipAttribute,
1274
        global_names: set[str],
1275
        local_names: set[str],
1276
    ) -> None:
1277
        # Self referential reverse relationships
1278
        preferred_name: str
1279
        if (
5✔
1280
            relationship.type
1281
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1282
            and relationship.source is relationship.target
1283
            and relationship.backref
1284
            and relationship.backref.name
1285
        ):
1286
            preferred_name = relationship.backref.name + "_reverse"
5✔
1287
        else:
1288
            preferred_name = relationship.target.table.name
5✔
1289

1290
            # If there's a constraint with a single column that ends with "_id", use the
1291
            # preceding part as the relationship name
1292
            if relationship.constraint and "noidsuffix" not in self.options:
5✔
1293
                is_source = relationship.source.table is relationship.constraint.table
5✔
1294
                if is_source or relationship.type not in (
5✔
1295
                    RelationshipType.ONE_TO_ONE,
1296
                    RelationshipType.ONE_TO_MANY,
1297
                ):
1298
                    column_names = [c.name for c in relationship.constraint.columns]
5✔
1299
                    if len(column_names) == 1 and column_names[0].endswith("_id"):
5✔
1300
                        preferred_name = column_names[0][:-3]
5✔
1301

1302
            if "use_inflect" in self.options:
5✔
1303
                inflected_name: str | Literal[False]
1304
                if relationship.type in (
5✔
1305
                    RelationshipType.ONE_TO_MANY,
1306
                    RelationshipType.MANY_TO_MANY,
1307
                ):
1308
                    if not self.inflect_engine.singular_noun(preferred_name):
5✔
1309
                        preferred_name = self.inflect_engine.plural_noun(preferred_name)
×
1310
                else:
1311
                    inflected_name = self.inflect_engine.singular_noun(preferred_name)
5✔
1312
                    if inflected_name:
5✔
1313
                        preferred_name = inflected_name
5✔
1314

1315
        relationship.name = self.find_free_name(
5✔
1316
            preferred_name, global_names, local_names
1317
        )
1318

1319
    def render_models(self, models: list[Model]) -> str:
5✔
1320
        rendered: list[str] = []
5✔
1321
        for model in models:
5✔
1322
            if isinstance(model, ModelClass):
5✔
1323
                rendered.append(self.render_class(model))
5✔
1324
            else:
1325
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
5✔
1326

1327
        return "\n\n\n".join(rendered)
5✔
1328

1329
    def render_class(self, model: ModelClass) -> str:
5✔
1330
        sections: list[str] = []
5✔
1331

1332
        # Render class variables / special declarations
1333
        class_vars: str = self.render_class_variables(model)
5✔
1334
        if class_vars:
5✔
1335
            sections.append(class_vars)
5✔
1336

1337
        # Render column attributes
1338
        rendered_column_attributes: list[str] = []
5✔
1339
        for nullable in (False, True):
5✔
1340
            for column_attr in model.columns:
5✔
1341
                if column_attr.column.nullable is nullable:
5✔
1342
                    rendered_column_attributes.append(
5✔
1343
                        self.render_column_attribute(column_attr)
1344
                    )
1345

1346
        if rendered_column_attributes:
5✔
1347
            sections.append("\n".join(rendered_column_attributes))
5✔
1348

1349
        # Render relationship attributes
1350
        rendered_relationship_attributes: list[str] = [
5✔
1351
            self.render_relationship(relationship)
1352
            for relationship in model.relationships
1353
        ]
1354

1355
        if rendered_relationship_attributes:
5✔
1356
            sections.append("\n".join(rendered_relationship_attributes))
5✔
1357

1358
        declaration = self.render_class_declaration(model)
5✔
1359
        rendered_sections = "\n\n".join(
5✔
1360
            indent(section, self.indentation) for section in sections
1361
        )
1362
        return f"{declaration}\n{rendered_sections}"
5✔
1363

1364
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1365
        parent_class_name = (
5✔
1366
            model.parent_class.name if model.parent_class else self.base_class_name
1367
        )
1368
        return f"class {model.name}({parent_class_name}):"
5✔
1369

1370
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1371
        variables = [f"__tablename__ = {model.table.name!r}"]
5✔
1372

1373
        # Render constraints and indexes as __table_args__
1374
        table_args = self.render_table_args(model.table)
5✔
1375
        if table_args:
5✔
1376
            variables.append(f"__table_args__ = {table_args}")
5✔
1377

1378
        return "\n".join(variables)
5✔
1379

1380
    def render_table_args(self, table: Table) -> str:
5✔
1381
        args: list[str] = []
5✔
1382
        kwargs: dict[str, object] = {}
5✔
1383

1384
        # Render constraints
1385
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
5✔
1386
            if uses_default_name(constraint):
5✔
1387
                if isinstance(constraint, PrimaryKeyConstraint):
5✔
1388
                    continue
5✔
1389
                if (
5✔
1390
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1391
                    and len(constraint.columns) == 1
1392
                ):
1393
                    continue
5✔
1394

1395
            args.append(self.render_constraint(constraint))
5✔
1396

1397
        # Render indexes
1398
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
5✔
1399
            if len(index.columns) > 1 or not uses_default_name(index):
5✔
1400
                args.append(self.render_index(index))
5✔
1401

1402
        if table.schema:
5✔
1403
            kwargs["schema"] = table.schema
5✔
1404

1405
        if table.comment:
5✔
1406
            kwargs["comment"] = table.comment
5✔
1407

1408
        # add info + dialect kwargs for dict context (__table_args__) (opt-in)
1409
        if self.include_dialect_options_and_info:
5✔
1410
            self._add_dialect_kwargs_and_info(table, kwargs, values_for_dict=True)
5✔
1411

1412
        if kwargs:
5✔
1413
            formatted_kwargs = pformat(kwargs)
5✔
1414
            if not args:
5✔
1415
                return formatted_kwargs
5✔
1416
            else:
1417
                args.append(formatted_kwargs)
5✔
1418

1419
        if args:
5✔
1420
            rendered_args = f",\n{self.indentation}".join(args)
5✔
1421
            if len(args) == 1:
5✔
1422
                rendered_args += ","
5✔
1423

1424
            return f"(\n{self.indentation}{rendered_args}\n)"
5✔
1425
        else:
1426
            return ""
5✔
1427

1428
    def render_column_python_type(self, column: Column[Any]) -> str:
5✔
1429
        def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
5✔
1430
            column_type = column.type
5✔
1431
            pre: list[str] = []
5✔
1432
            post_size = 0
5✔
1433
            if column.nullable:
5✔
1434
                self.add_literal_import("typing", "Optional")
5✔
1435
                pre.append("Optional[")
5✔
1436
                post_size += 1
5✔
1437

1438
            if isinstance(column_type, ARRAY):
5✔
1439
                dim = getattr(column_type, "dimensions", None) or 1
5✔
1440
                pre.extend("list[" for _ in range(dim))
5✔
1441
                post_size += dim
5✔
1442

1443
                column_type = column_type.item_type
5✔
1444

1445
            return "".join(pre), column_type, "]" * post_size
5✔
1446

1447
        def render_python_type(column_type: TypeEngine[Any]) -> str:
5✔
1448
            # Check if this is an enum column with a Python enum class
1449
            if isinstance(column_type, Enum):
5✔
1450
                table_name = column.table.name
5✔
1451
                column_name = column.name
5✔
1452
                if (table_name, column_name) in self.enum_classes:
5✔
1453
                    enum_class_name = self.enum_classes[(table_name, column_name)]
5✔
1454
                    return enum_class_name
5✔
1455

1456
            if isinstance(column_type, DOMAIN):
5✔
1457
                column_type = column_type.data_type
5✔
1458

1459
            try:
5✔
1460
                python_type = column_type.python_type
5✔
1461
                python_type_module = python_type.__module__
5✔
1462
                python_type_name = python_type.__name__
5✔
1463
            except NotImplementedError:
5✔
1464
                self.add_literal_import("typing", "Any")
5✔
1465
                return "Any"
5✔
1466

1467
            if python_type_module == "builtins":
5✔
1468
                return python_type_name
5✔
1469

1470
            self.add_module_import(python_type_module)
5✔
1471
            return f"{python_type_module}.{python_type_name}"
5✔
1472

1473
        pre, col_type, post = get_type_qualifiers()
5✔
1474
        column_python_type = f"{pre}{render_python_type(col_type)}{post}"
5✔
1475
        return column_python_type
5✔
1476

1477
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1478
        column = column_attr.column
5✔
1479
        rendered_column = self.render_column(column, column_attr.name != column.name)
5✔
1480
        rendered_column_python_type = self.render_column_python_type(column)
5✔
1481

1482
        return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
5✔
1483

1484
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1485
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
5✔
1486
            rendered = []
5✔
1487
            for attr in column_attrs:
5✔
1488
                if attr.model is relationship.source:
5✔
1489
                    rendered.append(attr.name)
5✔
1490
                else:
1491
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
×
1492

1493
            return "[" + ", ".join(rendered) + "]"
5✔
1494

1495
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
5✔
1496
            rendered = []
5✔
1497
            render_as_string = False
5✔
1498
            # Assume that column_attrs are all in relationship.source or none
1499
            for attr in column_attrs:
5✔
1500
                if attr.model is relationship.source:
5✔
1501
                    rendered.append(attr.name)
5✔
1502
                else:
1503
                    rendered.append(f"{attr.model.name}.{attr.name}")
5✔
1504
                    render_as_string = True
5✔
1505

1506
            if render_as_string:
5✔
1507
                return "'[" + ", ".join(rendered) + "]'"
5✔
1508
            else:
1509
                return "[" + ", ".join(rendered) + "]"
5✔
1510

1511
        def render_join(terms: list[JoinType]) -> str:
5✔
1512
            rendered_joins = []
5✔
1513
            for source, source_col, target, target_col in terms:
5✔
1514
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
5✔
1515
                if target.__class__ is Model:
5✔
1516
                    rendered += "c."
5✔
1517

1518
                rendered += str(target_col)
5✔
1519
                rendered_joins.append(rendered)
5✔
1520

1521
            if len(rendered_joins) > 1:
5✔
1522
                rendered = ", ".join(rendered_joins)
×
1523
                return f"and_({rendered})"
×
1524
            else:
1525
                return rendered_joins[0]
5✔
1526

1527
        # Render keyword arguments
1528
        kwargs: dict[str, Any] = {}
5✔
1529
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
5✔
1530
            if relationship.constraint.referred_table is relationship.source.table:
5✔
1531
                kwargs["uselist"] = False
5✔
1532

1533
        # Add the "secondary" keyword for many-to-many relationships
1534
        if relationship.association_table:
5✔
1535
            table_ref = relationship.association_table.table.name
5✔
1536
            if relationship.association_table.schema:
5✔
1537
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
5✔
1538

1539
            kwargs["secondary"] = repr(table_ref)
5✔
1540

1541
        if relationship.remote_side:
5✔
1542
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
5✔
1543

1544
        if relationship.foreign_keys:
5✔
1545
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
5✔
1546

1547
        if relationship.primaryjoin:
5✔
1548
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
5✔
1549

1550
        if relationship.secondaryjoin:
5✔
1551
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
5✔
1552

1553
        if relationship.backref:
5✔
1554
            kwargs["back_populates"] = repr(relationship.backref.name)
5✔
1555

1556
        rendered_relationship = render_callable(
5✔
1557
            "relationship", repr(relationship.target.name), kwargs=kwargs
1558
        )
1559

1560
        relationship_type: str
1561
        if relationship.type == RelationshipType.ONE_TO_MANY:
5✔
1562
            relationship_type = f"list['{relationship.target.name}']"
5✔
1563
        elif relationship.type in (
5✔
1564
            RelationshipType.ONE_TO_ONE,
1565
            RelationshipType.MANY_TO_ONE,
1566
        ):
1567
            relationship_type = f"'{relationship.target.name}'"
5✔
1568
            if relationship.constraint and any(
5✔
1569
                col.nullable for col in relationship.constraint.columns
1570
            ):
1571
                self.add_literal_import("typing", "Optional")
5✔
1572
                relationship_type = f"Optional[{relationship_type}]"
5✔
1573
        elif relationship.type == RelationshipType.MANY_TO_MANY:
5✔
1574
            relationship_type = f"list['{relationship.target.name}']"
5✔
1575
        else:
1576
            self.add_literal_import("typing", "Any")
×
1577
            relationship_type = "Any"
×
1578

1579
        return (
5✔
1580
            f"{relationship.name}: Mapped[{relationship_type}] "
1581
            f"= {rendered_relationship}"
1582
        )
1583

1584

1585
class DataclassGenerator(DeclarativeGenerator):
5✔
1586
    def __init__(
5✔
1587
        self,
1588
        metadata: MetaData,
1589
        bind: Connection | Engine,
1590
        options: Sequence[str],
1591
        *,
1592
        indentation: str = "    ",
1593
        base_class_name: str = "Base",
1594
        quote_annotations: bool = False,
1595
        metadata_key: str = "sa",
1596
    ):
1597
        super().__init__(
5✔
1598
            metadata,
1599
            bind,
1600
            options,
1601
            indentation=indentation,
1602
            base_class_name=base_class_name,
1603
        )
1604
        self.metadata_key: str = metadata_key
5✔
1605
        self.quote_annotations: bool = quote_annotations
5✔
1606

1607
    def generate_base(self) -> None:
5✔
1608
        self.base = Base(
5✔
1609
            literal_imports=[
1610
                LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1611
                LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1612
            ],
1613
            declarations=[
1614
                (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"),
1615
                f"{self.indentation}pass",
1616
            ],
1617
            metadata_ref=f"{self.base_class_name}.metadata",
1618
        )
1619

1620

1621
class SQLModelGenerator(DeclarativeGenerator):
5✔
1622
    def __init__(
5✔
1623
        self,
1624
        metadata: MetaData,
1625
        bind: Connection | Engine,
1626
        options: Sequence[str],
1627
        *,
1628
        indentation: str = "    ",
1629
        base_class_name: str = "SQLModel",
1630
    ):
1631
        super().__init__(
5✔
1632
            metadata,
1633
            bind,
1634
            options,
1635
            indentation=indentation,
1636
            base_class_name=base_class_name,
1637
        )
1638

1639
    @property
5✔
1640
    def views_supported(self) -> bool:
5✔
1641
        return False
×
1642

1643
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
5✔
1644
        self.add_import(Column)
5✔
1645
        return render_callable("Column", *args, kwargs=kwargs)
5✔
1646

1647
    def generate_base(self) -> None:
5✔
1648
        self.base = Base(
5✔
1649
            literal_imports=[],
1650
            declarations=[],
1651
            metadata_ref="",
1652
        )
1653

1654
    def collect_imports(self, models: Iterable[Model]) -> None:
5✔
1655
        super(DeclarativeGenerator, self).collect_imports(models)
5✔
1656
        if any(isinstance(model, ModelClass) for model in models):
5✔
1657
            self.remove_literal_import("sqlalchemy", "MetaData")
5✔
1658
            self.add_literal_import("sqlmodel", "SQLModel")
5✔
1659
            self.add_literal_import("sqlmodel", "Field")
5✔
1660

1661
    def collect_imports_for_model(self, model: Model) -> None:
5✔
1662
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
5✔
1663
        if isinstance(model, ModelClass):
5✔
1664
            for column_attr in model.columns:
5✔
1665
                if column_attr.column.nullable:
5✔
1666
                    self.add_literal_import("typing", "Optional")
5✔
1667
                    break
5✔
1668

1669
            if model.relationships:
5✔
1670
                self.add_literal_import("sqlmodel", "Relationship")
5✔
1671

1672
    def render_module_variables(self, models: list[Model]) -> str:
5✔
1673
        declarations: list[str] = []
5✔
1674
        if any(not isinstance(model, ModelClass) for model in models):
5✔
1675
            if self.base.table_metadata_declaration is not None:
×
1676
                declarations.append(self.base.table_metadata_declaration)
×
1677

1678
        return "\n".join(declarations)
5✔
1679

1680
    def render_class_declaration(self, model: ModelClass) -> str:
5✔
1681
        if model.parent_class:
5✔
1682
            parent = model.parent_class.name
×
1683
        else:
1684
            parent = self.base_class_name
5✔
1685

1686
        superclass_part = f"({parent}, table=True)"
5✔
1687
        return f"class {model.name}{superclass_part}:"
5✔
1688

1689
    def render_class_variables(self, model: ModelClass) -> str:
5✔
1690
        variables = []
5✔
1691

1692
        if model.table.name != model.name.lower():
5✔
1693
            variables.append(f"__tablename__ = {model.table.name!r}")
5✔
1694

1695
        # Render constraints and indexes as __table_args__
1696
        table_args = self.render_table_args(model.table)
5✔
1697
        if table_args:
5✔
1698
            variables.append(f"__table_args__ = {table_args}")
5✔
1699

1700
        return "\n".join(variables)
5✔
1701

1702
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
5✔
1703
        column = column_attr.column
5✔
1704
        rendered_column = self.render_column(column, True)
5✔
1705
        rendered_column_python_type = self.render_column_python_type(column)
5✔
1706

1707
        kwargs: dict[str, Any] = {}
5✔
1708
        if column.nullable:
5✔
1709
            kwargs["default"] = None
5✔
1710
        kwargs["sa_column"] = f"{rendered_column}"
5✔
1711

1712
        rendered_field = render_callable("Field", kwargs=kwargs)
5✔
1713

1714
        return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
5✔
1715

1716
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
5✔
1717
        rendered = super().render_relationship(relationship).partition(" = ")[2]
5✔
1718
        args = self.render_relationship_args(rendered)
5✔
1719
        kwargs: dict[str, Any] = {}
5✔
1720
        annotation = repr(relationship.target.name)
5✔
1721

1722
        if relationship.type in (
5✔
1723
            RelationshipType.ONE_TO_MANY,
1724
            RelationshipType.MANY_TO_MANY,
1725
        ):
1726
            annotation = f"list[{annotation}]"
5✔
1727
        else:
1728
            self.add_literal_import("typing", "Optional")
5✔
1729
            annotation = f"Optional[{annotation}]"
5✔
1730

1731
        rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
5✔
1732
        return f"{relationship.name}: {annotation} = {rendered_field}"
5✔
1733

1734
    def render_relationship_args(self, arguments: str) -> list[str]:
5✔
1735
        argument_list = arguments.split(",")
5✔
1736
        # delete ')' and ' ' from args
1737
        argument_list[-1] = argument_list[-1][:-1]
5✔
1738
        argument_list = [argument[1:] for argument in argument_list]
5✔
1739

1740
        rendered_args: list[str] = []
5✔
1741
        for arg in argument_list:
5✔
1742
            if "back_populates" in arg:
5✔
1743
                rendered_args.append(arg)
5✔
1744
            if "uselist=False" in arg:
5✔
1745
                rendered_args.append("sa_relationship_kwargs={'uselist': False}")
5✔
1746

1747
        return rendered_args
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc