• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 16617628878

30 Jul 2025 08:36AM UTC coverage: 97.595% (+0.2%) from 97.379%
16617628878

Pull #411

github

web-flow
Merge 2c0ee9067 into 240e5b712
Pull Request #411: Fixed same-name imports from wrong package

32 of 32 new or added lines in 3 files covered. (100.0%)

2 existing lines in 2 files now uncovered.

1420 of 1455 relevant lines covered (97.59%)

1.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.9
/src/sqlacodegen/generators.py
1
from __future__ import annotations
2✔
2

3
import inspect
2✔
4
import re
2✔
5
import sys
2✔
6
from abc import ABCMeta, abstractmethod
2✔
7
from collections import defaultdict
2✔
8
from collections.abc import Collection, Iterable, Sequence
2✔
9
from dataclasses import dataclass
2✔
10
from importlib import import_module
2✔
11
from inspect import Parameter
2✔
12
from itertools import count
2✔
13
from keyword import iskeyword
2✔
14
from pprint import pformat
2✔
15
from textwrap import indent
2✔
16
from typing import Any, ClassVar, Literal, cast
2✔
17

18
import inflect
2✔
19
import sqlalchemy
2✔
20
from sqlalchemy import (
2✔
21
    ARRAY,
22
    Boolean,
23
    CheckConstraint,
24
    Column,
25
    Computed,
26
    Constraint,
27
    DefaultClause,
28
    Enum,
29
    ForeignKey,
30
    ForeignKeyConstraint,
31
    Identity,
32
    Index,
33
    MetaData,
34
    PrimaryKeyConstraint,
35
    String,
36
    Table,
37
    Text,
38
    TypeDecorator,
39
    UniqueConstraint,
40
)
41
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
2✔
42
from sqlalchemy.engine import Connection, Engine
2✔
43
from sqlalchemy.exc import CompileError
2✔
44
from sqlalchemy.sql.elements import TextClause
2✔
45
from sqlalchemy.sql.type_api import UserDefinedType
2✔
46
from sqlalchemy.types import TypeEngine
2✔
47

48
from .models import (
2✔
49
    ColumnAttribute,
50
    JoinType,
51
    Model,
52
    ModelClass,
53
    RelationshipAttribute,
54
    RelationshipType,
55
)
56
from .utils import (
2✔
57
    decode_postgresql_sequence,
58
    get_column_names,
59
    get_common_fk_constraints,
60
    get_compiled_expression,
61
    get_constraint_sort_key,
62
    get_stdlib_module_names,
63
    qualified_table_name,
64
    render_callable,
65
    uses_default_name,
66
)
67

68
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
2✔
69
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
2✔
70
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
2✔
71
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
2✔
72
_re_invalid_identifier = re.compile(r"(?u)\W")
2✔
73

74

75
@dataclass
2✔
76
class LiteralImport:
2✔
77
    pkgname: str
2✔
78
    name: str
2✔
79

80

81
@dataclass
2✔
82
class Base:
2✔
83
    """Representation of MetaData for Tables, respectively Base for classes"""
84

85
    literal_imports: list[LiteralImport]
2✔
86
    declarations: list[str]
2✔
87
    metadata_ref: str
2✔
88
    decorator: str | None = None
2✔
89
    table_metadata_declaration: str | None = None
2✔
90

91

92
class CodeGenerator(metaclass=ABCMeta):
2✔
93
    valid_options: ClassVar[set[str]] = set()
2✔
94

95
    def __init__(
2✔
96
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
97
    ):
98
        self.metadata: MetaData = metadata
2✔
99
        self.bind: Connection | Engine = bind
2✔
100
        self.options: set[str] = set(options)
2✔
101

102
        # Validate options
103
        invalid_options = {opt for opt in options if opt not in self.valid_options}
2✔
104
        if invalid_options:
2✔
105
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
106

107
    @property
2✔
108
    @abstractmethod
2✔
109
    def views_supported(self) -> bool:
2✔
110
        pass
×
111

112
    @abstractmethod
2✔
113
    def generate(self) -> str:
2✔
114
        """
115
        Generate the code for the given metadata.
116
        .. note:: May modify the metadata.
117
        """
118

119

120
@dataclass(eq=False)
2✔
121
class TablesGenerator(CodeGenerator):
2✔
122
    valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
2✔
123
    stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
2✔
124

125
    def __init__(
2✔
126
        self,
127
        metadata: MetaData,
128
        bind: Connection | Engine,
129
        options: Sequence[str],
130
        *,
131
        indentation: str = "    ",
132
    ):
133
        super().__init__(metadata, bind, options)
2✔
134
        self.indentation: str = indentation
2✔
135
        self.imports: dict[str, set[str]] = defaultdict(set)
2✔
136
        self.module_imports: set[str] = set()
2✔
137

138
    @property
2✔
139
    def views_supported(self) -> bool:
2✔
140
        return True
×
141

142
    def generate_base(self) -> None:
2✔
143
        self.base = Base(
2✔
144
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
145
            declarations=["metadata = MetaData()"],
146
            metadata_ref="metadata",
147
        )
148

149
    def generate(self) -> str:
2✔
150
        self.generate_base()
2✔
151

152
        sections: list[str] = []
2✔
153

154
        # Remove unwanted elements from the metadata
155
        for table in list(self.metadata.tables.values()):
2✔
156
            if self.should_ignore_table(table):
2✔
157
                self.metadata.remove(table)
×
158
                continue
×
159

160
            if "noindexes" in self.options:
2✔
161
                table.indexes.clear()
2✔
162

163
            if "noconstraints" in self.options:
2✔
164
                table.constraints.clear()
2✔
165

166
            if "nocomments" in self.options:
2✔
167
                table.comment = None
2✔
168

169
            for column in table.columns:
2✔
170
                if "nocomments" in self.options:
2✔
171
                    column.comment = None
2✔
172

173
        # Use information from column constraints to figure out the intended column
174
        # types
175
        for table in self.metadata.tables.values():
2✔
176
            self.fix_column_types(table)
2✔
177

178
        # Generate the models
179
        models: list[Model] = self.generate_models()
2✔
180

181
        # Render module level variables
182
        variables = self.render_module_variables(models)
2✔
183
        if variables:
2✔
184
            sections.append(variables + "\n")
2✔
185

186
        # Render models
187
        rendered_models = self.render_models(models)
2✔
188
        if rendered_models:
2✔
189
            sections.append(rendered_models)
2✔
190

191
        # Render collected imports
192
        groups = self.group_imports()
2✔
193
        imports = "\n\n".join("\n".join(line for line in group) for group in groups)
2✔
194
        if imports:
2✔
195
            sections.insert(0, imports)
2✔
196

197
        return "\n\n".join(sections) + "\n"
2✔
198

199
    def collect_imports(self, models: Iterable[Model]) -> None:
2✔
200
        for literal_import in self.base.literal_imports:
2✔
201
            self.add_literal_import(literal_import.pkgname, literal_import.name)
2✔
202

203
        for model in models:
2✔
204
            self.collect_imports_for_model(model)
2✔
205

206
    def collect_imports_for_model(self, model: Model) -> None:
2✔
207
        if model.__class__ is Model:
2✔
208
            self.add_import(Table)
2✔
209

210
        for column in model.table.c:
2✔
211
            self.collect_imports_for_column(column)
2✔
212

213
        for constraint in model.table.constraints:
2✔
214
            self.collect_imports_for_constraint(constraint)
2✔
215

216
        for index in model.table.indexes:
2✔
217
            self.collect_imports_for_constraint(index)
2✔
218

219
    def collect_imports_for_column(self, column: Column[Any]) -> None:
2✔
220
        self.add_import(column.type)
2✔
221

222
        if isinstance(column.type, ARRAY):
2✔
223
            self.add_import(column.type.item_type.__class__)
2✔
224
        elif isinstance(column.type, (JSONB, JSON)):
2✔
225
            if (
2✔
226
                not isinstance(column.type.astext_type, Text)
227
                or column.type.astext_type.length is not None
228
            ):
229
                self.add_import(column.type.astext_type)
2✔
230
        elif isinstance(column.type, DOMAIN):
2✔
231
            self.add_import(column.type.data_type.__class__)
2✔
232

233
        if column.default:
2✔
234
            self.add_import(column.default)
2✔
235

236
        if column.server_default:
2✔
237
            if isinstance(column.server_default, (Computed, Identity)):
2✔
238
                self.add_import(column.server_default)
2✔
239
            elif isinstance(column.server_default, DefaultClause):
2✔
240
                self.add_literal_import("sqlalchemy", "text")
2✔
241

242
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
2✔
243
        if isinstance(constraint, Index):
2✔
244
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
2✔
245
                self.add_literal_import("sqlalchemy", "Index")
2✔
246
        elif isinstance(constraint, PrimaryKeyConstraint):
2✔
247
            if not uses_default_name(constraint):
2✔
248
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
2✔
249
        elif isinstance(constraint, UniqueConstraint):
2✔
250
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
2✔
251
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
2✔
252
        elif isinstance(constraint, ForeignKeyConstraint):
2✔
253
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
2✔
254
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
2✔
255
            else:
256
                self.add_import(ForeignKey)
2✔
257
        else:
258
            self.add_import(constraint)
2✔
259

260
    def add_import(self, obj: Any) -> None:
2✔
261
        # Don't store builtin imports
262
        if getattr(obj, "__module__", "builtins") == "builtins":
2✔
UNCOV
263
            return
×
264

265
        type_ = type(obj) if not isinstance(obj, type) else obj
2✔
266
        pkgname = type_.__module__
2✔
267

268
        # The column types have already been adapted towards generic types if possible,
269
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
270
        # use that rather than the generic sqlalchemy type as it might have different
271
        # constructor parameters.
272
        if pkgname.startswith("sqlalchemy.dialects."):
2✔
273
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
2✔
274
            dialect_pkg = import_module(dialect_pkgname)
2✔
275

276
            if type_.__name__ in dialect_pkg.__all__:
2✔
277
                pkgname = dialect_pkgname
2✔
278
        elif type_ is getattr(sqlalchemy, type_.__name__, None):
2✔
279
            pkgname = "sqlalchemy"
2✔
280
        else:
281
            pkgname = type_.__module__
2✔
282

283
        self.add_literal_import(pkgname, type_.__name__)
2✔
284

285
    def add_literal_import(self, pkgname: str, name: str) -> None:
2✔
286
        names = self.imports.setdefault(pkgname, set())
2✔
287
        names.add(name)
2✔
288

289
    def remove_literal_import(self, pkgname: str, name: str) -> None:
2✔
290
        names = self.imports.setdefault(pkgname, set())
2✔
291
        if name in names:
2✔
292
            names.remove(name)
×
293

294
    def add_module_import(self, pgkname: str) -> None:
2✔
295
        self.module_imports.add(pgkname)
2✔
296

297
    def group_imports(self) -> list[list[str]]:
2✔
298
        future_imports: list[str] = []
2✔
299
        stdlib_imports: list[str] = []
2✔
300
        thirdparty_imports: list[str] = []
2✔
301

302
        def get_collection(package: str) -> list[str]:
2✔
303
            collection = thirdparty_imports
2✔
304
            if package == "__future__":
2✔
305
                collection = future_imports
×
306
            elif package in self.stdlib_module_names:
2✔
307
                collection = stdlib_imports
2✔
308
            elif package in sys.modules:
2✔
309
                if "site-packages" not in (sys.modules[package].__file__ or ""):
2✔
310
                    collection = stdlib_imports
2✔
311
            return collection
2✔
312

313
        for package in sorted(self.imports):
2✔
314
            imports = ", ".join(sorted(self.imports[package]))
2✔
315

316
            collection = get_collection(package)
2✔
317
            collection.append(f"from {package} import {imports}")
2✔
318

319
        for module in sorted(self.module_imports):
2✔
320
            collection = get_collection(module)
2✔
321
            collection.append(f"import {module}")
2✔
322

323
        return [
2✔
324
            group
325
            for group in (future_imports, stdlib_imports, thirdparty_imports)
326
            if group
327
        ]
328

329
    def generate_models(self) -> list[Model]:
2✔
330
        models = [Model(table) for table in self.metadata.sorted_tables]
2✔
331

332
        # Collect the imports
333
        self.collect_imports(models)
2✔
334

335
        # Generate names for models
336
        global_names = {
2✔
337
            name for namespace in self.imports.values() for name in namespace
338
        }
339
        for model in models:
2✔
340
            self.generate_model_name(model, global_names)
2✔
341
            global_names.add(model.name)
2✔
342

343
        return models
2✔
344

345
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
2✔
346
        preferred_name = f"t_{model.table.name}"
2✔
347
        model.name = self.find_free_name(preferred_name, global_names)
2✔
348

349
    def render_module_variables(self, models: list[Model]) -> str:
2✔
350
        declarations = self.base.declarations
2✔
351

352
        if any(not isinstance(model, ModelClass) for model in models):
2✔
353
            if self.base.table_metadata_declaration is not None:
2✔
354
                declarations.append(self.base.table_metadata_declaration)
×
355

356
        return "\n".join(declarations)
2✔
357

358
    def render_models(self, models: list[Model]) -> str:
2✔
359
        rendered: list[str] = []
2✔
360
        for model in models:
2✔
361
            rendered_table = self.render_table(model.table)
2✔
362
            rendered.append(f"{model.name} = {rendered_table}")
2✔
363

364
        return "\n\n".join(rendered)
2✔
365

366
    def render_table(self, table: Table) -> str:
2✔
367
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
2✔
368
        kwargs: dict[str, object] = {}
2✔
369
        for column in table.columns:
2✔
370
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
371
            # Table.columns
372
            args.append(self.render_column(column, True, is_table=True))
2✔
373

374
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
2✔
375
            if uses_default_name(constraint):
2✔
376
                if isinstance(constraint, PrimaryKeyConstraint):
2✔
377
                    continue
2✔
378
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
2✔
379
                    if len(constraint.columns) == 1:
2✔
380
                        continue
2✔
381

382
            args.append(self.render_constraint(constraint))
2✔
383

384
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
2✔
385
            # One-column indexes should be rendered as index=True on columns
386
            if len(index.columns) > 1 or not uses_default_name(index):
2✔
387
                args.append(self.render_index(index))
2✔
388

389
        if table.schema:
2✔
390
            kwargs["schema"] = repr(table.schema)
2✔
391

392
        table_comment = getattr(table, "comment", None)
2✔
393
        if table_comment:
2✔
394
            kwargs["comment"] = repr(table.comment)
2✔
395

396
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
2✔
397

398
    def render_index(self, index: Index) -> str:
2✔
399
        extra_args = [repr(col.name) for col in index.columns]
2✔
400
        kwargs = {}
2✔
401
        if index.unique:
2✔
402
            kwargs["unique"] = True
2✔
403

404
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
2✔
405

406
    # TODO find better solution for is_table
407
    def render_column(
2✔
408
        self, column: Column[Any], show_name: bool, is_table: bool = False
409
    ) -> str:
410
        args = []
2✔
411
        kwargs: dict[str, Any] = {}
2✔
412
        kwarg = []
2✔
413
        is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
2✔
414
        dedicated_fks = [
2✔
415
            c
416
            for c in column.foreign_keys
417
            if c.constraint
418
            and len(c.constraint.columns) == 1
419
            and uses_default_name(c.constraint)
420
        ]
421
        is_unique = any(
2✔
422
            isinstance(c, UniqueConstraint)
423
            and set(c.columns) == {column}
424
            and uses_default_name(c)
425
            for c in column.table.constraints
426
        )
427
        is_unique = is_unique or any(
2✔
428
            i.unique and set(i.columns) == {column} and uses_default_name(i)
429
            for i in column.table.indexes
430
        )
431
        is_primary = (
2✔
432
            any(
433
                isinstance(c, PrimaryKeyConstraint)
434
                and column.name in c.columns
435
                and uses_default_name(c)
436
                for c in column.table.constraints
437
            )
438
            or column.primary_key
439
        )
440
        has_index = any(
2✔
441
            set(i.columns) == {column} and uses_default_name(i)
442
            for i in column.table.indexes
443
        )
444

445
        if show_name:
2✔
446
            args.append(repr(column.name))
2✔
447

448
        # Render the column type if there are no foreign keys on it or any of them
449
        # points back to itself
450
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
2✔
451
            args.append(self.render_column_type(column.type))
2✔
452

453
        for fk in dedicated_fks:
2✔
454
            args.append(self.render_constraint(fk))
2✔
455

456
        if column.default:
2✔
457
            args.append(repr(column.default))
2✔
458

459
        if column.key != column.name:
2✔
460
            kwargs["key"] = column.key
×
461
        if is_primary:
2✔
462
            kwargs["primary_key"] = True
2✔
463
        if not column.nullable and not is_sole_pk and is_table:
2✔
464
            kwargs["nullable"] = False
2✔
465

466
        if is_unique:
2✔
467
            column.unique = True
2✔
468
            kwargs["unique"] = True
2✔
469
        if has_index:
2✔
470
            column.index = True
2✔
471
            kwarg.append("index")
2✔
472
            kwargs["index"] = True
2✔
473

474
        if isinstance(column.server_default, DefaultClause):
2✔
475
            kwargs["server_default"] = render_callable(
2✔
476
                "text", repr(cast(TextClause, column.server_default.arg).text)
477
            )
478
        elif isinstance(column.server_default, Computed):
2✔
479
            expression = str(column.server_default.sqltext)
2✔
480

481
            computed_kwargs = {}
2✔
482
            if column.server_default.persisted is not None:
2✔
483
                computed_kwargs["persisted"] = column.server_default.persisted
2✔
484

485
            args.append(
2✔
486
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
487
            )
488
        elif isinstance(column.server_default, Identity):
2✔
489
            args.append(repr(column.server_default))
2✔
490
        elif column.server_default:
2✔
491
            kwargs["server_default"] = repr(column.server_default)
×
492

493
        comment = getattr(column, "comment", None)
2✔
494
        if comment:
2✔
495
            kwargs["comment"] = repr(comment)
2✔
496

497
        return self.render_column_callable(is_table, *args, **kwargs)
2✔
498

499
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
2✔
500
        if is_table:
2✔
501
            self.add_import(Column)
2✔
502
            return render_callable("Column", *args, kwargs=kwargs)
2✔
503
        else:
504
            return render_callable("mapped_column", *args, kwargs=kwargs)
2✔
505

506
    def render_column_type(self, coltype: TypeEngine[Any]) -> str:
2✔
507
        args = []
2✔
508
        kwargs: dict[str, Any] = {}
2✔
509
        sig = inspect.signature(coltype.__class__.__init__)
2✔
510
        defaults = {param.name: param.default for param in sig.parameters.values()}
2✔
511
        missing = object()
2✔
512
        use_kwargs = False
2✔
513
        for param in list(sig.parameters.values())[1:]:
2✔
514
            # Remove annoyances like _warn_on_bytestring
515
            if param.name.startswith("_"):
2✔
516
                continue
2✔
517
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
2✔
518
                use_kwargs = True
2✔
519
                continue
2✔
520

521
            value = getattr(coltype, param.name, missing)
2✔
522

523
            if isinstance(value, (JSONB, JSON)):
2✔
524
                # Remove astext_type if it's the default
525
                if (
2✔
526
                    isinstance(value.astext_type, Text)
527
                    and value.astext_type.length is None
528
                ):
529
                    value.astext_type = None  # type: ignore[assignment]
2✔
530
                else:
531
                    self.add_import(Text)
2✔
532

533
            default = defaults.get(param.name, missing)
2✔
534
            if isinstance(value, TextClause):
2✔
535
                self.add_literal_import("sqlalchemy", "text")
2✔
536
                rendered_value = render_callable("text", repr(value.text))
2✔
537
            else:
538
                rendered_value = repr(value)
2✔
539

540
            if value is missing or value == default:
2✔
541
                use_kwargs = True
2✔
542
            elif use_kwargs:
2✔
543
                kwargs[param.name] = rendered_value
2✔
544
            else:
545
                args.append(rendered_value)
2✔
546

547
        vararg = next(
2✔
548
            (
549
                param.name
550
                for param in sig.parameters.values()
551
                if param.kind is Parameter.VAR_POSITIONAL
552
            ),
553
            None,
554
        )
555
        if vararg and hasattr(coltype, vararg):
2✔
556
            varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
2✔
557
            args.extend(varargs_repr)
2✔
558

559
        # These arguments cannot be autodetected from the Enum initializer
560
        if isinstance(coltype, Enum):
2✔
561
            for colname in "name", "schema":
2✔
562
                if (value := getattr(coltype, colname)) is not None:
2✔
563
                    kwargs[colname] = repr(value)
2✔
564

565
        if isinstance(coltype, (JSONB, JSON)):
2✔
566
            # Remove astext_type if it's the default
567
            if (
2✔
568
                isinstance(coltype.astext_type, Text)
569
                and coltype.astext_type.length is None
570
            ):
571
                del kwargs["astext_type"]
2✔
572

573
        if args or kwargs:
2✔
574
            return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
2✔
575
        else:
576
            return coltype.__class__.__name__
2✔
577

578
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
2✔
579
        def add_fk_options(*opts: Any) -> None:
2✔
580
            args.extend(repr(opt) for opt in opts)
2✔
581
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
2✔
582
                value = getattr(constraint, attr, None)
2✔
583
                if value:
2✔
584
                    kwargs[attr] = repr(value)
2✔
585

586
        args: list[str] = []
2✔
587
        kwargs: dict[str, Any] = {}
2✔
588
        if isinstance(constraint, ForeignKey):
2✔
589
            remote_column = (
2✔
590
                f"{constraint.column.table.fullname}.{constraint.column.name}"
591
            )
592
            add_fk_options(remote_column)
2✔
593
        elif isinstance(constraint, ForeignKeyConstraint):
2✔
594
            local_columns = get_column_names(constraint)
2✔
595
            remote_columns = [
2✔
596
                f"{fk.column.table.fullname}.{fk.column.name}"
597
                for fk in constraint.elements
598
            ]
599
            add_fk_options(local_columns, remote_columns)
2✔
600
        elif isinstance(constraint, CheckConstraint):
2✔
601
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
2✔
602
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
2✔
603
            args.extend(repr(col.name) for col in constraint.columns)
2✔
604
        else:
605
            raise TypeError(
×
606
                f"Cannot render constraint of type {constraint.__class__.__name__}"
607
            )
608

609
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
2✔
610
            kwargs["name"] = repr(constraint.name)
2✔
611

612
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
2✔
613

614
    def should_ignore_table(self, table: Table) -> bool:
2✔
615
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
616
        # tables
617
        return table.name in ("alembic_version", "migrate_version")
2✔
618

619
    def find_free_name(
2✔
620
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
621
    ) -> str:
622
        """
623
        Generate an attribute name that does not clash with other local or global names.
624
        """
625
        name = name.strip()
2✔
626
        assert name, "Identifier cannot be empty"
2✔
627
        name = _re_invalid_identifier.sub("_", name)
2✔
628
        if name[0].isdigit():
2✔
629
            name = "_" + name
2✔
630
        elif iskeyword(name) or name == "metadata":
2✔
631
            name += "_"
2✔
632

633
        original = name
2✔
634
        for i in count():
2✔
635
            if name not in global_names and name not in local_names:
2✔
636
                break
2✔
637

638
            name = original + (str(i) if i else "_")
2✔
639

640
        return name
2✔
641

642
    def fix_column_types(self, table: Table) -> None:
2✔
643
        """Adjust the reflected column types."""
644
        # Detect check constraints for boolean and enum columns
645
        for constraint in table.constraints.copy():
2✔
646
            if isinstance(constraint, CheckConstraint):
2✔
647
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
2✔
648

649
                # Turn any integer-like column with a CheckConstraint like
650
                # "column IN (0, 1)" into a Boolean
651
                match = _re_boolean_check_constraint.match(sqltext)
2✔
652
                if match:
2✔
653
                    colname_match = _re_column_name.match(match.group(1))
2✔
654
                    if colname_match:
2✔
655
                        colname = colname_match.group(3)
2✔
656
                        table.constraints.remove(constraint)
2✔
657
                        table.c[colname].type = Boolean()
2✔
658
                        continue
2✔
659

660
                # Turn any string-type column with a CheckConstraint like
661
                # "column IN (...)" into an Enum
662
                match = _re_enum_check_constraint.match(sqltext)
2✔
663
                if match:
2✔
664
                    colname_match = _re_column_name.match(match.group(1))
2✔
665
                    if colname_match:
2✔
666
                        colname = colname_match.group(3)
2✔
667
                        items = match.group(2)
2✔
668
                        if isinstance(table.c[colname].type, String):
2✔
669
                            table.constraints.remove(constraint)
2✔
670
                            if not isinstance(table.c[colname].type, Enum):
2✔
671
                                options = _re_enum_item.findall(items)
2✔
672
                                table.c[colname].type = Enum(
2✔
673
                                    *options, native_enum=False
674
                                )
675

676
                            continue
2✔
677

678
        for column in table.c:
2✔
679
            try:
2✔
680
                column.type = self.get_adapted_type(column.type)
2✔
681
            except CompileError:
2✔
682
                pass
2✔
683

684
            # PostgreSQL specific fix: detect sequences from server_default
685
            if column.server_default and self.bind.dialect.name == "postgresql":
2✔
686
                if isinstance(column.server_default, DefaultClause) and isinstance(
2✔
687
                    column.server_default.arg, TextClause
688
                ):
689
                    schema, seqname = decode_postgresql_sequence(
2✔
690
                        column.server_default.arg
691
                    )
692
                    if seqname:
2✔
693
                        # Add an explicit sequence
694
                        if seqname != f"{column.table.name}_{column.name}_seq":
2✔
695
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
2✔
696

697
                        column.server_default = None
2✔
698

699
    def get_adapted_type(self, coltype: Any) -> Any:
2✔
700
        compiled_type = coltype.compile(self.bind.engine.dialect)
2✔
701
        for supercls in coltype.__class__.__mro__:
2✔
702
            if not supercls.__name__.startswith("_") and hasattr(
2✔
703
                supercls, "__visit_name__"
704
            ):
705
                # Don't try to adapt UserDefinedType as it's not a proper column type
706
                if supercls is UserDefinedType or issubclass(supercls, TypeDecorator):
2✔
707
                    return coltype
2✔
708

709
                # Hack to fix adaptation of the Enum class which is broken since
710
                # SQLAlchemy 1.2
711
                kw = {}
2✔
712
                if supercls is Enum:
2✔
713
                    kw["name"] = coltype.name
2✔
714
                    if coltype.schema:
2✔
715
                        kw["schema"] = coltype.schema
2✔
716

717
                try:
2✔
718
                    new_coltype = coltype.adapt(supercls)
2✔
719
                except TypeError:
2✔
720
                    # If the adaptation fails, don't try again
721
                    break
2✔
722

723
                for key, value in kw.items():
2✔
724
                    setattr(new_coltype, key, value)
2✔
725

726
                if isinstance(coltype, ARRAY):
2✔
727
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
2✔
728

729
                try:
2✔
730
                    # If the adapted column type does not render the same as the
731
                    # original, don't substitute it
732
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
2✔
733
                        break
2✔
734
                except CompileError:
2✔
735
                    # If the adapted column type can't be compiled, don't substitute it
736
                    break
2✔
737

738
                # Stop on the first valid non-uppercase column type class
739
                coltype = new_coltype
2✔
740
                if supercls.__name__ != supercls.__name__.upper():
2✔
741
                    break
2✔
742

743
        return coltype
2✔
744

745

746
class DeclarativeGenerator(TablesGenerator):
2✔
747
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
2✔
748
        "use_inflect",
749
        "nojoined",
750
        "nobidi",
751
    }
752

753
    def __init__(
2✔
754
        self,
755
        metadata: MetaData,
756
        bind: Connection | Engine,
757
        options: Sequence[str],
758
        *,
759
        indentation: str = "    ",
760
        base_class_name: str = "Base",
761
    ):
762
        super().__init__(metadata, bind, options, indentation=indentation)
2✔
763
        self.base_class_name: str = base_class_name
2✔
764
        self.inflect_engine = inflect.engine()
2✔
765

766
    def generate_base(self) -> None:
2✔
767
        self.base = Base(
2✔
768
            literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
769
            declarations=[
770
                f"class {self.base_class_name}(DeclarativeBase):",
771
                f"{self.indentation}pass",
772
            ],
773
            metadata_ref=f"{self.base_class_name}.metadata",
774
        )
775

776
    def collect_imports(self, models: Iterable[Model]) -> None:
2✔
777
        super().collect_imports(models)
2✔
778
        if any(isinstance(model, ModelClass) for model in models):
2✔
779
            self.add_literal_import("sqlalchemy.orm", "Mapped")
2✔
780
            self.add_literal_import("sqlalchemy.orm", "mapped_column")
2✔
781

782
    def collect_imports_for_model(self, model: Model) -> None:
2✔
783
        super().collect_imports_for_model(model)
2✔
784
        if isinstance(model, ModelClass):
2✔
785
            if model.relationships:
2✔
786
                self.add_literal_import("sqlalchemy.orm", "relationship")
2✔
787

788
    def generate_models(self) -> list[Model]:
2✔
789
        models_by_table_name: dict[str, Model] = {}
2✔
790

791
        # Pick association tables from the metadata into their own set, don't process
792
        # them normally
793
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
2✔
794
        for table in self.metadata.sorted_tables:
2✔
795
            qualified_name = qualified_table_name(table)
2✔
796

797
            # Link tables have exactly two foreign key constraints and all columns are
798
            # involved in them
799
            fk_constraints = sorted(
2✔
800
                table.foreign_key_constraints, key=get_constraint_sort_key
801
            )
802
            if len(fk_constraints) == 2 and all(
2✔
803
                col.foreign_keys for col in table.columns
804
            ):
805
                model = models_by_table_name[qualified_name] = Model(table)
2✔
806
                tablename = fk_constraints[0].elements[0].column.table.name
2✔
807
                links[tablename].append(model)
2✔
808
                continue
2✔
809

810
            # Only form model classes for tables that have a primary key and are not
811
            # association tables
812
            if not table.primary_key:
2✔
813
                models_by_table_name[qualified_name] = Model(table)
2✔
814
            else:
815
                model = ModelClass(table)
2✔
816
                models_by_table_name[qualified_name] = model
2✔
817

818
                # Fill in the columns
819
                for column in table.c:
2✔
820
                    column_attr = ColumnAttribute(model, column)
2✔
821
                    model.columns.append(column_attr)
2✔
822

823
        # Add relationships
824
        for model in models_by_table_name.values():
2✔
825
            if isinstance(model, ModelClass):
2✔
826
                self.generate_relationships(
2✔
827
                    model, models_by_table_name, links[model.table.name]
828
                )
829

830
        # Nest inherited classes in their superclasses to ensure proper ordering
831
        if "nojoined" not in self.options:
2✔
832
            for model in list(models_by_table_name.values()):
2✔
833
                if not isinstance(model, ModelClass):
2✔
834
                    continue
2✔
835

836
                pk_column_names = {col.name for col in model.table.primary_key.columns}
2✔
837
                for constraint in model.table.foreign_key_constraints:
2✔
838
                    if set(get_column_names(constraint)) == pk_column_names:
2✔
839
                        target = models_by_table_name[
2✔
840
                            qualified_table_name(constraint.elements[0].column.table)
841
                        ]
842
                        if isinstance(target, ModelClass):
2✔
843
                            model.parent_class = target
2✔
844
                            target.children.append(model)
2✔
845

846
        # Change base if we only have tables
847
        if not any(
2✔
848
            isinstance(model, ModelClass) for model in models_by_table_name.values()
849
        ):
850
            super().generate_base()
2✔
851

852
        # Collect the imports
853
        self.collect_imports(models_by_table_name.values())
2✔
854

855
        # Rename models and their attributes that conflict with imports or other
856
        # attributes
857
        global_names = {
2✔
858
            name for namespace in self.imports.values() for name in namespace
859
        }
860
        for model in models_by_table_name.values():
2✔
861
            self.generate_model_name(model, global_names)
2✔
862
            global_names.add(model.name)
2✔
863

864
        return list(models_by_table_name.values())
2✔
865

866
    def generate_relationships(
2✔
867
        self,
868
        source: ModelClass,
869
        models_by_table_name: dict[str, Model],
870
        association_tables: list[Model],
871
    ) -> list[RelationshipAttribute]:
872
        relationships: list[RelationshipAttribute] = []
2✔
873
        reverse_relationship: RelationshipAttribute | None
874

875
        # Add many-to-one (and one-to-many) relationships
876
        pk_column_names = {col.name for col in source.table.primary_key.columns}
2✔
877
        for constraint in sorted(
2✔
878
            source.table.foreign_key_constraints, key=get_constraint_sort_key
879
        ):
880
            target = models_by_table_name[
2✔
881
                qualified_table_name(constraint.elements[0].column.table)
882
            ]
883
            if isinstance(target, ModelClass):
2✔
884
                if "nojoined" not in self.options:
2✔
885
                    if set(get_column_names(constraint)) == pk_column_names:
2✔
886
                        parent = models_by_table_name[
2✔
887
                            qualified_table_name(constraint.elements[0].column.table)
888
                        ]
889
                        if isinstance(parent, ModelClass):
2✔
890
                            source.parent_class = parent
2✔
891
                            parent.children.append(source)
2✔
892
                            continue
2✔
893

894
                # Add uselist=False to One-to-One relationships
895
                column_names = get_column_names(constraint)
2✔
896
                if any(
2✔
897
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
898
                    and {col.name for col in c.columns} == set(column_names)
899
                    for c in constraint.table.constraints
900
                ):
901
                    r_type = RelationshipType.ONE_TO_ONE
2✔
902
                else:
903
                    r_type = RelationshipType.MANY_TO_ONE
2✔
904

905
                relationship = RelationshipAttribute(r_type, source, target, constraint)
2✔
906
                source.relationships.append(relationship)
2✔
907

908
                # For self referential relationships, remote_side needs to be set
909
                if source is target:
2✔
910
                    relationship.remote_side = [
2✔
911
                        source.get_column_attribute(col.name)
912
                        for col in constraint.referred_table.primary_key
913
                    ]
914

915
                # If the two tables share more than one foreign key constraint,
916
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
917
                # it needs
918
                common_fk_constraints = get_common_fk_constraints(
2✔
919
                    source.table, target.table
920
                )
921
                if len(common_fk_constraints) > 1:
2✔
922
                    relationship.foreign_keys = [
2✔
923
                        source.get_column_attribute(key)
924
                        for key in constraint.column_keys
925
                    ]
926

927
                # Generate the opposite end of the relationship in the target class
928
                if "nobidi" not in self.options:
2✔
929
                    if r_type is RelationshipType.MANY_TO_ONE:
2✔
930
                        r_type = RelationshipType.ONE_TO_MANY
2✔
931

932
                    reverse_relationship = RelationshipAttribute(
2✔
933
                        r_type,
934
                        target,
935
                        source,
936
                        constraint,
937
                        foreign_keys=relationship.foreign_keys,
938
                        backref=relationship,
939
                    )
940
                    relationship.backref = reverse_relationship
2✔
941
                    target.relationships.append(reverse_relationship)
2✔
942

943
                    # For self referential relationships, remote_side needs to be set
944
                    if source is target:
2✔
945
                        reverse_relationship.remote_side = [
2✔
946
                            source.get_column_attribute(colname)
947
                            for colname in constraint.column_keys
948
                        ]
949

950
        # Add many-to-many relationships
951
        for association_table in association_tables:
2✔
952
            fk_constraints = sorted(
2✔
953
                association_table.table.foreign_key_constraints,
954
                key=get_constraint_sort_key,
955
            )
956
            target = models_by_table_name[
2✔
957
                qualified_table_name(fk_constraints[1].elements[0].column.table)
958
            ]
959
            if isinstance(target, ModelClass):
2✔
960
                relationship = RelationshipAttribute(
2✔
961
                    RelationshipType.MANY_TO_MANY,
962
                    source,
963
                    target,
964
                    fk_constraints[1],
965
                    association_table,
966
                )
967
                source.relationships.append(relationship)
2✔
968

969
                # Generate the opposite end of the relationship in the target class
970
                reverse_relationship = None
2✔
971
                if "nobidi" not in self.options:
2✔
972
                    reverse_relationship = RelationshipAttribute(
2✔
973
                        RelationshipType.MANY_TO_MANY,
974
                        target,
975
                        source,
976
                        fk_constraints[0],
977
                        association_table,
978
                        relationship,
979
                    )
980
                    relationship.backref = reverse_relationship
2✔
981
                    target.relationships.append(reverse_relationship)
2✔
982

983
                # Add a primary/secondary join for self-referential many-to-many
984
                # relationships
985
                if source is target:
2✔
986
                    both_relationships = [relationship]
2✔
987
                    reverse_flags = [False, True]
2✔
988
                    if reverse_relationship:
2✔
989
                        both_relationships.append(reverse_relationship)
2✔
990

991
                    for relationship, reverse in zip(both_relationships, reverse_flags):
2✔
992
                        if (
2✔
993
                            not relationship.association_table
994
                            or not relationship.constraint
995
                        ):
996
                            continue
×
997

998
                        constraints = sorted(
2✔
999
                            relationship.constraint.table.foreign_key_constraints,
1000
                            key=get_constraint_sort_key,
1001
                            reverse=reverse,
1002
                        )
1003
                        pri_pairs = zip(
2✔
1004
                            get_column_names(constraints[0]), constraints[0].elements
1005
                        )
1006
                        sec_pairs = zip(
2✔
1007
                            get_column_names(constraints[1]), constraints[1].elements
1008
                        )
1009
                        relationship.primaryjoin = [
2✔
1010
                            (
1011
                                relationship.source,
1012
                                elem.column.name,
1013
                                relationship.association_table,
1014
                                col,
1015
                            )
1016
                            for col, elem in pri_pairs
1017
                        ]
1018
                        relationship.secondaryjoin = [
2✔
1019
                            (
1020
                                relationship.target,
1021
                                elem.column.name,
1022
                                relationship.association_table,
1023
                                col,
1024
                            )
1025
                            for col, elem in sec_pairs
1026
                        ]
1027

1028
        return relationships
2✔
1029

1030
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
2✔
1031
        if isinstance(model, ModelClass):
2✔
1032
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
2✔
1033
            preferred_name = "".join(
2✔
1034
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1035
            )
1036
            if "use_inflect" in self.options:
2✔
1037
                singular_name = self.inflect_engine.singular_noun(preferred_name)
2✔
1038
                if singular_name:
2✔
1039
                    preferred_name = singular_name
2✔
1040

1041
            model.name = self.find_free_name(preferred_name, global_names)
2✔
1042

1043
            # Fill in the names for column attributes
1044
            local_names: set[str] = set()
2✔
1045
            for column_attr in model.columns:
2✔
1046
                self.generate_column_attr_name(column_attr, global_names, local_names)
2✔
1047
                local_names.add(column_attr.name)
2✔
1048

1049
            # Fill in the names for relationship attributes
1050
            for relationship in model.relationships:
2✔
1051
                self.generate_relationship_name(relationship, global_names, local_names)
2✔
1052
                local_names.add(relationship.name)
2✔
1053
        else:
1054
            super().generate_model_name(model, global_names)
2✔
1055

1056
    def generate_column_attr_name(
2✔
1057
        self,
1058
        column_attr: ColumnAttribute,
1059
        global_names: set[str],
1060
        local_names: set[str],
1061
    ) -> None:
1062
        column_attr.name = self.find_free_name(
2✔
1063
            column_attr.column.name, global_names, local_names
1064
        )
1065

1066
    def generate_relationship_name(
2✔
1067
        self,
1068
        relationship: RelationshipAttribute,
1069
        global_names: set[str],
1070
        local_names: set[str],
1071
    ) -> None:
1072
        # Self referential reverse relationships
1073
        preferred_name: str
1074
        if (
2✔
1075
            relationship.type
1076
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1077
            and relationship.source is relationship.target
1078
            and relationship.backref
1079
            and relationship.backref.name
1080
        ):
1081
            preferred_name = relationship.backref.name + "_reverse"
2✔
1082
        else:
1083
            preferred_name = relationship.target.table.name
2✔
1084

1085
            # If there's a constraint with a single column that ends with "_id", use the
1086
            # preceding part as the relationship name
1087
            if relationship.constraint:
2✔
1088
                is_source = relationship.source.table is relationship.constraint.table
2✔
1089
                if is_source or relationship.type not in (
2✔
1090
                    RelationshipType.ONE_TO_ONE,
1091
                    RelationshipType.ONE_TO_MANY,
1092
                ):
1093
                    column_names = [c.name for c in relationship.constraint.columns]
2✔
1094
                    if len(column_names) == 1 and column_names[0].endswith("_id"):
2✔
1095
                        preferred_name = column_names[0][:-3]
2✔
1096

1097
            if "use_inflect" in self.options:
2✔
1098
                inflected_name: str | Literal[False]
1099
                if relationship.type in (
2✔
1100
                    RelationshipType.ONE_TO_MANY,
1101
                    RelationshipType.MANY_TO_MANY,
1102
                ):
1103
                    if not self.inflect_engine.singular_noun(preferred_name):
2✔
1104
                        preferred_name = self.inflect_engine.plural_noun(preferred_name)
×
1105
                else:
1106
                    inflected_name = self.inflect_engine.singular_noun(preferred_name)
2✔
1107
                    if inflected_name:
2✔
1108
                        preferred_name = inflected_name
2✔
1109

1110
        relationship.name = self.find_free_name(
2✔
1111
            preferred_name, global_names, local_names
1112
        )
1113

1114
    def render_models(self, models: list[Model]) -> str:
2✔
1115
        rendered: list[str] = []
2✔
1116
        for model in models:
2✔
1117
            if isinstance(model, ModelClass):
2✔
1118
                rendered.append(self.render_class(model))
2✔
1119
            else:
1120
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
2✔
1121

1122
        return "\n\n\n".join(rendered)
2✔
1123

1124
    def render_class(self, model: ModelClass) -> str:
2✔
1125
        sections: list[str] = []
2✔
1126

1127
        # Render class variables / special declarations
1128
        class_vars: str = self.render_class_variables(model)
2✔
1129
        if class_vars:
2✔
1130
            sections.append(class_vars)
2✔
1131

1132
        # Render column attributes
1133
        rendered_column_attributes: list[str] = []
2✔
1134
        for nullable in (False, True):
2✔
1135
            for column_attr in model.columns:
2✔
1136
                if column_attr.column.nullable is nullable:
2✔
1137
                    rendered_column_attributes.append(
2✔
1138
                        self.render_column_attribute(column_attr)
1139
                    )
1140

1141
        if rendered_column_attributes:
2✔
1142
            sections.append("\n".join(rendered_column_attributes))
2✔
1143

1144
        # Render relationship attributes
1145
        rendered_relationship_attributes: list[str] = [
2✔
1146
            self.render_relationship(relationship)
1147
            for relationship in model.relationships
1148
        ]
1149

1150
        if rendered_relationship_attributes:
2✔
1151
            sections.append("\n".join(rendered_relationship_attributes))
2✔
1152

1153
        declaration = self.render_class_declaration(model)
2✔
1154
        rendered_sections = "\n\n".join(
2✔
1155
            indent(section, self.indentation) for section in sections
1156
        )
1157
        return f"{declaration}\n{rendered_sections}"
2✔
1158

1159
    def render_class_declaration(self, model: ModelClass) -> str:
2✔
1160
        parent_class_name = (
2✔
1161
            model.parent_class.name if model.parent_class else self.base_class_name
1162
        )
1163
        return f"class {model.name}({parent_class_name}):"
2✔
1164

1165
    def render_class_variables(self, model: ModelClass) -> str:
2✔
1166
        variables = [f"__tablename__ = {model.table.name!r}"]
2✔
1167

1168
        # Render constraints and indexes as __table_args__
1169
        table_args = self.render_table_args(model.table)
2✔
1170
        if table_args:
2✔
1171
            variables.append(f"__table_args__ = {table_args}")
2✔
1172

1173
        return "\n".join(variables)
2✔
1174

1175
    def render_table_args(self, table: Table) -> str:
2✔
1176
        args: list[str] = []
2✔
1177
        kwargs: dict[str, str] = {}
2✔
1178

1179
        # Render constraints
1180
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
2✔
1181
            if uses_default_name(constraint):
2✔
1182
                if isinstance(constraint, PrimaryKeyConstraint):
2✔
1183
                    continue
2✔
1184
                if (
2✔
1185
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1186
                    and len(constraint.columns) == 1
1187
                ):
1188
                    continue
2✔
1189

1190
            args.append(self.render_constraint(constraint))
2✔
1191

1192
        # Render indexes
1193
        for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
2✔
1194
            if len(index.columns) > 1 or not uses_default_name(index):
2✔
1195
                args.append(self.render_index(index))
2✔
1196

1197
        if table.schema:
2✔
1198
            kwargs["schema"] = table.schema
2✔
1199

1200
        if table.comment:
2✔
1201
            kwargs["comment"] = table.comment
2✔
1202

1203
        if kwargs:
2✔
1204
            formatted_kwargs = pformat(kwargs)
2✔
1205
            if not args:
2✔
1206
                return formatted_kwargs
2✔
1207
            else:
1208
                args.append(formatted_kwargs)
2✔
1209

1210
        if args:
2✔
1211
            rendered_args = f",\n{self.indentation}".join(args)
2✔
1212
            if len(args) == 1:
2✔
1213
                rendered_args += ","
2✔
1214

1215
            return f"(\n{self.indentation}{rendered_args}\n)"
2✔
1216
        else:
1217
            return ""
2✔
1218

1219
    def render_column_python_type(self, column: Column[Any]) -> str:
2✔
1220
        def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
2✔
1221
            column_type = column.type
2✔
1222
            pre: list[str] = []
2✔
1223
            post_size = 0
2✔
1224
            if column.nullable:
2✔
1225
                self.add_literal_import("typing", "Optional")
2✔
1226
                pre.append("Optional[")
2✔
1227
                post_size += 1
2✔
1228

1229
            if isinstance(column_type, ARRAY):
2✔
1230
                dim = getattr(column_type, "dimensions", None) or 1
2✔
1231
                pre.extend("list[" for _ in range(dim))
2✔
1232
                post_size += dim
2✔
1233

1234
                column_type = column_type.item_type
2✔
1235

1236
            return "".join(pre), column_type, "]" * post_size
2✔
1237

1238
        def render_python_type(column_type: TypeEngine[Any]) -> str:
2✔
1239
            if isinstance(column_type, DOMAIN):
2✔
1240
                python_type = column_type.data_type.python_type
2✔
1241
            else:
1242
                python_type = column_type.python_type
2✔
1243

1244
            python_type_name = python_type.__name__
2✔
1245
            python_type_module = python_type.__module__
2✔
1246
            if python_type_module == "builtins":
2✔
1247
                return python_type_name
2✔
1248

1249
            try:
2✔
1250
                self.add_module_import(python_type_module)
2✔
1251
                return f"{python_type_module}.{python_type_name}"
2✔
1252
            except NotImplementedError:
×
1253
                self.add_literal_import("typing", "Any")
×
1254
                return "Any"
×
1255

1256
        pre, col_type, post = get_type_qualifiers()
2✔
1257
        column_python_type = f"{pre}{render_python_type(col_type)}{post}"
2✔
1258
        return column_python_type
2✔
1259

1260
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
2✔
1261
        column = column_attr.column
2✔
1262
        rendered_column = self.render_column(column, column_attr.name != column.name)
2✔
1263
        rendered_column_python_type = self.render_column_python_type(column)
2✔
1264

1265
        return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
2✔
1266

1267
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
2✔
1268
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
2✔
1269
            rendered = []
2✔
1270
            for attr in column_attrs:
2✔
1271
                if attr.model is relationship.source:
2✔
1272
                    rendered.append(attr.name)
2✔
1273
                else:
1274
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
×
1275

1276
            return "[" + ", ".join(rendered) + "]"
2✔
1277

1278
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
2✔
1279
            rendered = []
2✔
1280
            render_as_string = False
2✔
1281
            # Assume that column_attrs are all in relationship.source or none
1282
            for attr in column_attrs:
2✔
1283
                if attr.model is relationship.source:
2✔
1284
                    rendered.append(attr.name)
2✔
1285
                else:
1286
                    rendered.append(f"{attr.model.name}.{attr.name}")
2✔
1287
                    render_as_string = True
2✔
1288

1289
            if render_as_string:
2✔
1290
                return "'[" + ", ".join(rendered) + "]'"
2✔
1291
            else:
1292
                return "[" + ", ".join(rendered) + "]"
2✔
1293

1294
        def render_join(terms: list[JoinType]) -> str:
2✔
1295
            rendered_joins = []
2✔
1296
            for source, source_col, target, target_col in terms:
2✔
1297
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
2✔
1298
                if target.__class__ is Model:
2✔
1299
                    rendered += "c."
2✔
1300

1301
                rendered += str(target_col)
2✔
1302
                rendered_joins.append(rendered)
2✔
1303

1304
            if len(rendered_joins) > 1:
2✔
1305
                rendered = ", ".join(rendered_joins)
×
1306
                return f"and_({rendered})"
×
1307
            else:
1308
                return rendered_joins[0]
2✔
1309

1310
        # Render keyword arguments
1311
        kwargs: dict[str, Any] = {}
2✔
1312
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
2✔
1313
            if relationship.constraint.referred_table is relationship.source.table:
2✔
1314
                kwargs["uselist"] = False
2✔
1315

1316
        # Add the "secondary" keyword for many-to-many relationships
1317
        if relationship.association_table:
2✔
1318
            table_ref = relationship.association_table.table.name
2✔
1319
            if relationship.association_table.schema:
2✔
1320
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
2✔
1321

1322
            kwargs["secondary"] = repr(table_ref)
2✔
1323

1324
        if relationship.remote_side:
2✔
1325
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
2✔
1326

1327
        if relationship.foreign_keys:
2✔
1328
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
2✔
1329

1330
        if relationship.primaryjoin:
2✔
1331
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
2✔
1332

1333
        if relationship.secondaryjoin:
2✔
1334
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
2✔
1335

1336
        if relationship.backref:
2✔
1337
            kwargs["back_populates"] = repr(relationship.backref.name)
2✔
1338

1339
        rendered_relationship = render_callable(
2✔
1340
            "relationship", repr(relationship.target.name), kwargs=kwargs
1341
        )
1342

1343
        relationship_type: str
1344
        if relationship.type == RelationshipType.ONE_TO_MANY:
2✔
1345
            relationship_type = f"list['{relationship.target.name}']"
2✔
1346
        elif relationship.type in (
2✔
1347
            RelationshipType.ONE_TO_ONE,
1348
            RelationshipType.MANY_TO_ONE,
1349
        ):
1350
            relationship_type = f"'{relationship.target.name}'"
2✔
1351
            if relationship.constraint and any(
2✔
1352
                col.nullable for col in relationship.constraint.columns
1353
            ):
1354
                self.add_literal_import("typing", "Optional")
2✔
1355
                relationship_type = f"Optional[{relationship_type}]"
2✔
1356
        elif relationship.type == RelationshipType.MANY_TO_MANY:
2✔
1357
            relationship_type = f"list['{relationship.target.name}']"
2✔
1358
        else:
1359
            self.add_literal_import("typing", "Any")
×
1360
            relationship_type = "Any"
×
1361

1362
        return (
2✔
1363
            f"{relationship.name}: Mapped[{relationship_type}] "
1364
            f"= {rendered_relationship}"
1365
        )
1366

1367

1368
class DataclassGenerator(DeclarativeGenerator):
2✔
1369
    def __init__(
2✔
1370
        self,
1371
        metadata: MetaData,
1372
        bind: Connection | Engine,
1373
        options: Sequence[str],
1374
        *,
1375
        indentation: str = "    ",
1376
        base_class_name: str = "Base",
1377
        quote_annotations: bool = False,
1378
        metadata_key: str = "sa",
1379
    ):
1380
        super().__init__(
2✔
1381
            metadata,
1382
            bind,
1383
            options,
1384
            indentation=indentation,
1385
            base_class_name=base_class_name,
1386
        )
1387
        self.metadata_key: str = metadata_key
2✔
1388
        self.quote_annotations: bool = quote_annotations
2✔
1389

1390
    def generate_base(self) -> None:
2✔
1391
        self.base = Base(
2✔
1392
            literal_imports=[
1393
                LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1394
                LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1395
            ],
1396
            declarations=[
1397
                (f"class {self.base_class_name}(MappedAsDataclass, DeclarativeBase):"),
1398
                f"{self.indentation}pass",
1399
            ],
1400
            metadata_ref=f"{self.base_class_name}.metadata",
1401
        )
1402

1403

1404
class SQLModelGenerator(DeclarativeGenerator):
2✔
1405
    def __init__(
2✔
1406
        self,
1407
        metadata: MetaData,
1408
        bind: Connection | Engine,
1409
        options: Sequence[str],
1410
        *,
1411
        indentation: str = "    ",
1412
        base_class_name: str = "SQLModel",
1413
    ):
1414
        super().__init__(
2✔
1415
            metadata,
1416
            bind,
1417
            options,
1418
            indentation=indentation,
1419
            base_class_name=base_class_name,
1420
        )
1421

1422
    @property
2✔
1423
    def views_supported(self) -> bool:
2✔
1424
        return False
×
1425

1426
    def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> str:
2✔
1427
        self.add_import(Column)
2✔
1428
        return render_callable("Column", *args, kwargs=kwargs)
2✔
1429

1430
    def generate_base(self) -> None:
2✔
1431
        self.base = Base(
2✔
1432
            literal_imports=[],
1433
            declarations=[],
1434
            metadata_ref="",
1435
        )
1436

1437
    def collect_imports(self, models: Iterable[Model]) -> None:
2✔
1438
        super(DeclarativeGenerator, self).collect_imports(models)
2✔
1439
        if any(isinstance(model, ModelClass) for model in models):
2✔
1440
            self.remove_literal_import("sqlalchemy", "MetaData")
2✔
1441
            self.add_literal_import("sqlmodel", "SQLModel")
2✔
1442
            self.add_literal_import("sqlmodel", "Field")
2✔
1443

1444
    def collect_imports_for_model(self, model: Model) -> None:
2✔
1445
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
2✔
1446
        if isinstance(model, ModelClass):
2✔
1447
            for column_attr in model.columns:
2✔
1448
                if column_attr.column.nullable:
2✔
1449
                    self.add_literal_import("typing", "Optional")
2✔
1450
                    break
2✔
1451

1452
            if model.relationships:
2✔
1453
                self.add_literal_import("sqlmodel", "Relationship")
2✔
1454

1455
    def render_module_variables(self, models: list[Model]) -> str:
2✔
1456
        declarations: list[str] = []
2✔
1457
        if any(not isinstance(model, ModelClass) for model in models):
2✔
1458
            if self.base.table_metadata_declaration is not None:
×
1459
                declarations.append(self.base.table_metadata_declaration)
×
1460

1461
        return "\n".join(declarations)
2✔
1462

1463
    def render_class_declaration(self, model: ModelClass) -> str:
2✔
1464
        if model.parent_class:
2✔
1465
            parent = model.parent_class.name
×
1466
        else:
1467
            parent = self.base_class_name
2✔
1468

1469
        superclass_part = f"({parent}, table=True)"
2✔
1470
        return f"class {model.name}{superclass_part}:"
2✔
1471

1472
    def render_class_variables(self, model: ModelClass) -> str:
2✔
1473
        variables = []
2✔
1474

1475
        if model.table.name != model.name.lower():
2✔
1476
            variables.append(f"__tablename__ = {model.table.name!r}")
2✔
1477

1478
        # Render constraints and indexes as __table_args__
1479
        table_args = self.render_table_args(model.table)
2✔
1480
        if table_args:
2✔
1481
            variables.append(f"__table_args__ = {table_args}")
2✔
1482

1483
        return "\n".join(variables)
2✔
1484

1485
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
2✔
1486
        column = column_attr.column
2✔
1487
        rendered_column = self.render_column(column, True)
2✔
1488
        rendered_column_python_type = self.render_column_python_type(column)
2✔
1489

1490
        kwargs: dict[str, Any] = {}
2✔
1491
        if column.nullable:
2✔
1492
            kwargs["default"] = None
2✔
1493
        kwargs["sa_column"] = f"{rendered_column}"
2✔
1494

1495
        rendered_field = render_callable("Field", kwargs=kwargs)
2✔
1496

1497
        return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
2✔
1498

1499
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
2✔
1500
        rendered = super().render_relationship(relationship).partition(" = ")[2]
2✔
1501
        args = self.render_relationship_args(rendered)
2✔
1502
        kwargs: dict[str, Any] = {}
2✔
1503
        annotation = repr(relationship.target.name)
2✔
1504

1505
        if relationship.type in (
2✔
1506
            RelationshipType.ONE_TO_MANY,
1507
            RelationshipType.MANY_TO_MANY,
1508
        ):
1509
            annotation = f"list[{annotation}]"
2✔
1510
        else:
1511
            self.add_literal_import("typing", "Optional")
2✔
1512
            annotation = f"Optional[{annotation}]"
2✔
1513

1514
        rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
2✔
1515
        return f"{relationship.name}: {annotation} = {rendered_field}"
2✔
1516

1517
    def render_relationship_args(self, arguments: str) -> list[str]:
2✔
1518
        argument_list = arguments.split(",")
2✔
1519
        # delete ')' and ' ' from args
1520
        argument_list[-1] = argument_list[-1][:-1]
2✔
1521
        argument_list = [argument[1:] for argument in argument_list]
2✔
1522

1523
        rendered_args: list[str] = []
2✔
1524
        for arg in argument_list:
2✔
1525
            if "back_populates" in arg:
2✔
1526
                rendered_args.append(arg)
2✔
1527
            if "uselist=False" in arg:
2✔
1528
                rendered_args.append("sa_relationship_kwargs={'uselist': False}")
2✔
1529

1530
        return rendered_args
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc