• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 7052416992

30 Nov 2023 09:20PM UTC coverage: 99.035% (+1.4%) from 97.636%
7052416992

push

github

agronholm
Updated GitHub actions

1129 of 1140 relevant lines covered (99.04%)

7.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.74
/src/sqlacodegen/generators.py
1
from __future__ import annotations
10✔
2

3
import inspect
10✔
4
import re
10✔
5
import sys
10✔
6
from abc import ABCMeta, abstractmethod
10✔
7
from collections import defaultdict
10✔
8
from collections.abc import Collection, Iterable, Sequence
10✔
9
from dataclasses import dataclass
10✔
10
from importlib import import_module
10✔
11
from inspect import Parameter
10✔
12
from itertools import count
10✔
13
from keyword import iskeyword
10✔
14
from pprint import pformat
10✔
15
from textwrap import indent
10✔
16
from typing import Any, ClassVar
10✔
17

18
import inflect
10✔
19
import sqlalchemy
10✔
20
from sqlalchemy import (
10✔
21
    ARRAY,
22
    Boolean,
23
    CheckConstraint,
24
    Column,
25
    Computed,
26
    Constraint,
27
    DefaultClause,
28
    Enum,
29
    Float,
30
    ForeignKey,
31
    ForeignKeyConstraint,
32
    Identity,
33
    Index,
34
    MetaData,
35
    PrimaryKeyConstraint,
36
    String,
37
    Table,
38
    Text,
39
    UniqueConstraint,
40
)
41
from sqlalchemy.dialects.postgresql import JSONB
10✔
42
from sqlalchemy.engine import Connection, Engine
10✔
43
from sqlalchemy.exc import CompileError
10✔
44
from sqlalchemy.sql.elements import TextClause
10✔
45

46
from .models import (
10✔
47
    ColumnAttribute,
48
    JoinType,
49
    Model,
50
    ModelClass,
51
    RelationshipAttribute,
52
    RelationshipType,
53
)
54
from .utils import (
10✔
55
    decode_postgresql_sequence,
56
    get_column_names,
57
    get_common_fk_constraints,
58
    get_compiled_expression,
59
    get_constraint_sort_key,
60
    qualified_table_name,
61
    render_callable,
62
    uses_default_name,
63
)
64

65
if sys.version_info < (3, 10):
10✔
66
    from importlib_metadata import version
4✔
67
else:
×
68
    from importlib.metadata import version
6✔
69

70
_sqla_version = tuple(int(x) for x in version("sqlalchemy").split(".")[:2])
10✔
71
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
10✔
72
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
10✔
73
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
10✔
74
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
10✔
75
_re_invalid_identifier = re.compile(r"(?u)\W")
10✔
76

77

78
@dataclass
10✔
79
class LiteralImport:
10✔
80
    pkgname: str
10✔
81
    name: str
10✔
82

83

84
@dataclass
10✔
85
class Base:
10✔
86
    """Representation of MetaData for Tables, respectively Base for classes"""
87

88
    literal_imports: list[LiteralImport]
89
    declarations: list[str]
90
    metadata_ref: str
91
    decorator: str | None = None
92
    table_metadata_declaration: str | None = None
93

94

95
class CodeGenerator(metaclass=ABCMeta):
96
    valid_options: ClassVar[set[str]] = set()
97

98
    def __init__(
99
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
100
    ):
101
        self.metadata: MetaData = metadata
102
        self.bind: Connection | Engine = bind
103
        self.options: set[str] = set(options)
104

105
        # Validate options
106
        invalid_options = {opt for opt in options if opt not in self.valid_options}
107
        if invalid_options:
108
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
109

110
    @abstractmethod
111
    def generate(self) -> str:
112
        """
113
        Generate the code for the given metadata.
114
        .. note:: May modify the metadata.
115
        """
116

117

118
@dataclass(eq=False)
10✔
119
class TablesGenerator(CodeGenerator):
10✔
120
    valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
10✔
121
    builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | {
10✔
122
        "dataclasses"
10✔
123
    }
124

125
    def __init__(
10✔
126
        self,
127
        metadata: MetaData,
2✔
128
        bind: Connection | Engine,
2✔
129
        options: Sequence[str],
2✔
130
        *,
131
        indentation: str = "    ",
10✔
132
    ):
133
        super().__init__(metadata, bind, options)
10✔
134
        self.indentation: str = indentation
10✔
135
        self.imports: dict[str, set[str]] = defaultdict(set)
10✔
136
        self.module_imports: set[str] = set()
10✔
137

138
    def generate_base(self) -> None:
10✔
139
        self.base = Base(
10✔
140
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
10✔
141
            declarations=["metadata = MetaData()"],
10✔
142
            metadata_ref="metadata",
10✔
143
        )
144

145
    def generate(self) -> str:
10✔
146
        self.generate_base()
10✔
147

148
        sections: list[str] = []
10✔
149

150
        # Remove unwanted elements from the metadata
151
        for table in list(self.metadata.tables.values()):
10✔
152
            if self.should_ignore_table(table):
10✔
153
                self.metadata.remove(table)
154
                continue
155

156
            if "noindexes" in self.options:
10✔
157
                table.indexes.clear()
10✔
158

159
            if "noconstraints" in self.options:
10✔
160
                table.constraints.clear()
10✔
161

162
            if "nocomments" in self.options:
10✔
163
                table.comment = None
10✔
164

165
            for column in table.columns:
10✔
166
                if "nocomments" in self.options:
10✔
167
                    column.comment = None
10✔
168

169
        # Use information from column constraints to figure out the intended column
170
        # types
171
        for table in self.metadata.tables.values():
10✔
172
            self.fix_column_types(table)
10✔
173

174
        # Generate the models
175
        models: list[Model] = self.generate_models()
10✔
176

177
        # Render module level variables
178
        variables = self.render_module_variables(models)
10✔
179
        if variables:
10✔
180
            sections.append(variables + "\n")
10✔
181

182
        # Render models
183
        rendered_models = self.render_models(models)
10✔
184
        if rendered_models:
10✔
185
            sections.append(rendered_models)
10✔
186

187
        # Render collected imports
188
        groups = self.group_imports()
10✔
189
        imports = "\n\n".join("\n".join(line for line in group) for group in groups)
10✔
190
        if imports:
10✔
191
            sections.insert(0, imports)
10✔
192

193
        return "\n\n".join(sections) + "\n"
10✔
194

195
    def collect_imports(self, models: Iterable[Model]) -> None:
10✔
196
        for literal_import in self.base.literal_imports:
10✔
197
            self.add_literal_import(literal_import.pkgname, literal_import.name)
10✔
198

199
        for model in models:
10✔
200
            self.collect_imports_for_model(model)
10✔
201

202
    def collect_imports_for_model(self, model: Model) -> None:
10✔
203
        if model.__class__ is Model:
10✔
204
            self.add_import(Table)
10✔
205

206
        for column in model.table.c:
10✔
207
            self.collect_imports_for_column(column)
10✔
208

209
        for constraint in model.table.constraints:
10✔
210
            self.collect_imports_for_constraint(constraint)
10✔
211

212
        for index in model.table.indexes:
10✔
213
            self.collect_imports_for_constraint(index)
10✔
214

215
    def collect_imports_for_column(self, column: Column[Any]) -> None:
10✔
216
        self.add_import(column.type)
10✔
217

218
        if isinstance(column.type, ARRAY):
10✔
219
            self.add_import(column.type.item_type.__class__)
10✔
220
        elif isinstance(column.type, JSONB):
10✔
221
            if (
8✔
222
                not isinstance(column.type.astext_type, Text)
10✔
223
                or column.type.astext_type.length is not None
10✔
224
            ):
225
                self.add_import(column.type.astext_type)
10✔
226

227
        if column.default:
10✔
228
            self.add_import(column.default)
10✔
229

230
        if column.server_default:
10✔
231
            if isinstance(column.server_default, (Computed, Identity)):
10✔
232
                self.add_import(column.server_default)
10✔
233
            elif isinstance(column.server_default, DefaultClause):
10✔
234
                self.add_literal_import("sqlalchemy", "text")
10✔
235

236
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
10✔
237
        if isinstance(constraint, Index):
10✔
238
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
10✔
239
                self.add_literal_import("sqlalchemy", "Index")
10✔
240
        elif isinstance(constraint, PrimaryKeyConstraint):
10✔
241
            if not uses_default_name(constraint):
10✔
242
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
10✔
243
        elif isinstance(constraint, UniqueConstraint):
10✔
244
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
10✔
245
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
10✔
246
        elif isinstance(constraint, ForeignKeyConstraint):
10✔
247
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
10✔
248
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
10✔
249
            else:
250
                self.add_import(ForeignKey)
10✔
251
        else:
252
            self.add_import(constraint)
10✔
253

254
    def add_import(self, obj: Any) -> None:
10✔
255
        # Don't store builtin imports
256
        if getattr(obj, "__module__", "builtins") == "builtins":
10✔
257
            return
5✔
258

259
        type_ = type(obj) if not isinstance(obj, type) else obj
10✔
260
        pkgname = type_.__module__
10✔
261

262
        # The column types have already been adapted towards generic types if possible,
263
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
264
        # use that rather than the generic sqlalchemy type as it might have different
265
        # constructor parameters.
266
        if pkgname.startswith("sqlalchemy.dialects."):
10✔
267
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
10✔
268
            dialect_pkg = import_module(dialect_pkgname)
10✔
269

270
            if type_.__name__ in dialect_pkg.__all__:
10✔
271
                pkgname = dialect_pkgname
10✔
272
        elif type_.__name__ in dir(sqlalchemy):
10✔
273
            pkgname = "sqlalchemy"
10✔
274
        else:
275
            pkgname = type_.__module__
10✔
276

277
        self.add_literal_import(pkgname, type_.__name__)
10✔
278

279
    def add_literal_import(self, pkgname: str, name: str) -> None:
10✔
280
        names = self.imports.setdefault(pkgname, set())
10✔
281
        names.add(name)
10✔
282

283
    def remove_literal_import(self, pkgname: str, name: str) -> None:
10✔
284
        names = self.imports.setdefault(pkgname, set())
5✔
285
        if name in names:
5✔
286
            names.remove(name)
287

288
    def add_module_import(self, pgkname: str) -> None:
10✔
289
        self.module_imports.add(pgkname)
5✔
290

291
    def group_imports(self) -> list[list[str]]:
10✔
292
        future_imports: list[str] = []
10✔
293
        stdlib_imports: list[str] = []
10✔
294
        thirdparty_imports: list[str] = []
10✔
295

296
        for package in sorted(self.imports):
10✔
297
            imports = ", ".join(sorted(self.imports[package]))
10✔
298
            collection = thirdparty_imports
10✔
299
            if package == "__future__":
10✔
300
                collection = future_imports
5✔
301
            elif package in self.builtin_module_names:
10✔
302
                collection = stdlib_imports
5✔
303
            elif package in sys.modules:
10✔
304
                if "site-packages" not in (sys.modules[package].__file__ or ""):
10✔
305
                    collection = stdlib_imports
10✔
306

307
            collection.append(f"from {package} import {imports}")
10✔
308

309
        for module in sorted(self.module_imports):
10✔
310
            thirdparty_imports.append(f"import {module}")
5✔
311

312
        return [
10✔
313
            group
10✔
314
            for group in (future_imports, stdlib_imports, thirdparty_imports)
10✔
315
            if group
10✔
316
        ]
317

318
    def generate_models(self) -> list[Model]:
10✔
319
        models = [Model(table) for table in self.metadata.sorted_tables]
10✔
320

321
        # Collect the imports
322
        self.collect_imports(models)
10✔
323

324
        # Generate names for models
325
        global_names = {
10✔
326
            name for namespace in self.imports.values() for name in namespace
10✔
327
        }
328
        for model in models:
10✔
329
            self.generate_model_name(model, global_names)
10✔
330
            global_names.add(model.name)
10✔
331

332
        return models
10✔
333

334
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
10✔
335
        preferred_name = f"t_{model.table.name}"
10✔
336
        model.name = self.find_free_name(preferred_name, global_names)
10✔
337

338
    def render_module_variables(self, models: list[Model]) -> str:
10✔
339
        declarations = self.base.declarations
10✔
340

341
        if any(not isinstance(model, ModelClass) for model in models):
10✔
342
            if self.base.table_metadata_declaration is not None:
10✔
343
                declarations.append(self.base.table_metadata_declaration)
5✔
344

345
        return "\n".join(declarations)
10✔
346

347
    def render_models(self, models: list[Model]) -> str:
10✔
348
        rendered: list[str] = []
10✔
349
        for model in models:
10✔
350
            rendered_table = self.render_table(model.table)
10✔
351
            rendered.append(f"{model.name} = {rendered_table}")
10✔
352

353
        return "\n\n".join(rendered)
10✔
354

355
    def render_table(self, table: Table) -> str:
10✔
356
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
10✔
357
        kwargs: dict[str, object] = {}
10✔
358
        for column in table.columns:
10✔
359
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
360
            # Table.columns
361
            args.append(self.render_column(column, True, is_table=True))
10✔
362

363
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
10✔
364
            if uses_default_name(constraint):
10✔
365
                if isinstance(constraint, PrimaryKeyConstraint):
10✔
366
                    continue
10✔
367
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
10✔
368
                    if len(constraint.columns) == 1:
10✔
369
                        continue
10✔
370

371
            args.append(self.render_constraint(constraint))
10✔
372

373
        for index in sorted(table.indexes, key=lambda i: i.name):
10✔
374
            # One-column indexes should be rendered as index=True on columns
375
            if len(index.columns) > 1 or not uses_default_name(index):
10✔
376
                args.append(self.render_index(index))
10✔
377

378
        if table.schema:
10✔
379
            kwargs["schema"] = repr(table.schema)
10✔
380

381
        table_comment = getattr(table, "comment", None)
10✔
382
        if table_comment:
10✔
383
            kwargs["comment"] = repr(table.comment)
10✔
384

385
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
10✔
386

387
    def render_index(self, index: Index) -> str:
10✔
388
        extra_args = [repr(col.name) for col in index.columns]
10✔
389
        kwargs = {}
10✔
390
        if index.unique:
10✔
391
            kwargs["unique"] = True
10✔
392

393
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
10✔
394

395
    # TODO find better solution for is_table
396
    def render_column(
10✔
397
        self, column: Column[Any], show_name: bool, is_table: bool = False
6✔
398
    ) -> str:
2✔
399
        args = []
10✔
400
        kwargs: dict[str, Any] = {}
10✔
401
        kwarg = []
10✔
402
        is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
10✔
403
        dedicated_fks = [
10✔
404
            c
10✔
405
            for c in column.foreign_keys
10✔
406
            if c.constraint
10✔
407
            and len(c.constraint.columns) == 1
10✔
408
            and uses_default_name(c.constraint)
10✔
409
        ]
410
        is_unique = any(
10✔
411
            isinstance(c, UniqueConstraint)
10✔
412
            and set(c.columns) == {column}
10✔
413
            and uses_default_name(c)
10✔
414
            for c in column.table.constraints
10✔
415
        )
416
        is_unique = is_unique or any(
10✔
417
            i.unique and set(i.columns) == {column} and uses_default_name(i)
10✔
418
            for i in column.table.indexes
10✔
419
        )
420
        is_primary = (
10✔
421
            any(
10✔
422
                isinstance(c, PrimaryKeyConstraint)
10✔
423
                and column.name in c.columns
10✔
424
                and uses_default_name(c)
10✔
425
                for c in column.table.constraints
10✔
426
            )
427
            or column.primary_key
10✔
428
        )
429
        has_index = any(
10✔
430
            set(i.columns) == {column} and uses_default_name(i)
10✔
431
            for i in column.table.indexes
10✔
432
        )
433

434
        if show_name:
10✔
435
            args.append(repr(column.name))
10✔
436

437
        # Render the column type if there are no foreign keys on it or any of them
438
        # points back to itself
439
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
10✔
440
            args.append(self.render_column_type(column.type))
10✔
441

442
        for fk in dedicated_fks:
10✔
443
            args.append(self.render_constraint(fk))
10✔
444

445
        if column.default:
10✔
446
            args.append(repr(column.default))
10✔
447

448
        if column.key != column.name:
10✔
449
            kwargs["key"] = column.key
450
        if is_primary:
10✔
451
            kwargs["primary_key"] = True
10✔
452
        if (
8✔
453
            not column.nullable
10✔
454
            and not is_sole_pk
10✔
455
            and (_sqla_version < (2, 0) or is_table)
10✔
456
        ):
457
            kwargs["nullable"] = False
5✔
458

459
        if is_unique:
10✔
460
            column.unique = True
10✔
461
            kwargs["unique"] = True
10✔
462
        if has_index:
10✔
463
            column.index = True
10✔
464
            kwarg.append("index")
10✔
465
            kwargs["index"] = True
10✔
466

467
        if isinstance(column.server_default, DefaultClause):
10✔
468
            kwargs["server_default"] = render_callable(
10✔
469
                "text", repr(column.server_default.arg.text)
10✔
470
            )
471
        elif isinstance(column.server_default, Computed):
10✔
472
            expression = str(column.server_default.sqltext)
10✔
473

474
            computed_kwargs = {}
10✔
475
            if column.server_default.persisted is not None:
10✔
476
                computed_kwargs["persisted"] = column.server_default.persisted
10✔
477

478
            args.append(
10✔
479
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
10✔
480
            )
481
        elif isinstance(column.server_default, Identity):
10✔
482
            args.append(repr(column.server_default))
10✔
483
        elif column.server_default:
10✔
484
            kwargs["server_default"] = repr(column.server_default)
485

486
        comment = getattr(column, "comment", None)
10✔
487
        if comment:
10✔
488
            kwargs["comment"] = repr(comment)
10✔
489

490
        if _sqla_version < (2, 0) or is_table:
10✔
491
            self.add_import(Column)
10✔
492
            return render_callable("Column", *args, kwargs=kwargs)
10✔
493
        else:
494
            return render_callable("mapped_column", *args, kwargs=kwargs)
5✔
495

496
    def render_column_type(self, coltype: object) -> str:
10✔
497
        args = []
10✔
498
        kwargs: dict[str, Any] = {}
10✔
499
        sig = inspect.signature(coltype.__class__.__init__)
10✔
500
        defaults = {param.name: param.default for param in sig.parameters.values()}
10✔
501
        missing = object()
10✔
502
        use_kwargs = False
10✔
503
        for param in list(sig.parameters.values())[1:]:
10✔
504
            # Remove annoyances like _warn_on_bytestring
505
            if param.name.startswith("_"):
10✔
506
                continue
10✔
507
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
10✔
508
                continue
10✔
509

510
            value = getattr(coltype, param.name, missing)
10✔
511
            default = defaults.get(param.name, missing)
10✔
512
            if value is missing or value == default:
10✔
513
                use_kwargs = True
10✔
514
            elif use_kwargs:
10✔
515
                kwargs[param.name] = repr(value)
10✔
516
            else:
517
                args.append(repr(value))
10✔
518

519
        vararg = next(
10✔
520
            (
10✔
521
                param.name
10✔
522
                for param in sig.parameters.values()
10✔
523
                if param.kind is Parameter.VAR_POSITIONAL
10✔
524
            ),
525
            None,
10✔
526
        )
527
        if vararg and hasattr(coltype, vararg):
10✔
528
            varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
10✔
529
            args.extend(varargs_repr)
10✔
530

531
        if isinstance(coltype, Enum) and coltype.name is not None:
10✔
532
            kwargs["name"] = repr(coltype.name)
10✔
533

534
        if isinstance(coltype, JSONB):
10✔
535
            # Remove astext_type if it's the default
536
            if (
8✔
537
                isinstance(coltype.astext_type, Text)
10✔
538
                and coltype.astext_type.length is None
10✔
539
            ):
540
                del kwargs["astext_type"]
10✔
541

542
        if args or kwargs:
10✔
543
            return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
10✔
544
        else:
545
            return coltype.__class__.__name__
10✔
546

547
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
10✔
548
        def add_fk_options(*opts: Any) -> None:
10✔
549
            args.extend(repr(opt) for opt in opts)
10✔
550
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
10✔
551
                value = getattr(constraint, attr, None)
10✔
552
                if value:
10✔
553
                    kwargs[attr] = repr(value)
10✔
554

555
        args: list[str] = []
10✔
556
        kwargs: dict[str, Any] = {}
10✔
557
        if isinstance(constraint, ForeignKey):
10✔
558
            remote_column = (
10✔
559
                f"{constraint.column.table.fullname}.{constraint.column.name}"
10✔
560
            )
561
            add_fk_options(remote_column)
10✔
562
        elif isinstance(constraint, ForeignKeyConstraint):
10✔
563
            local_columns = get_column_names(constraint)
10✔
564
            remote_columns = [
10✔
565
                f"{fk.column.table.fullname}.{fk.column.name}"
10✔
566
                for fk in constraint.elements
10✔
567
            ]
568
            add_fk_options(local_columns, remote_columns)
10✔
569
        elif isinstance(constraint, CheckConstraint):
10✔
570
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
10✔
571
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
10✔
572
            args.extend(repr(col.name) for col in constraint.columns)
10✔
573
        else:
574
            raise TypeError(
575
                f"Cannot render constraint of type {constraint.__class__.__name__}"
576
            )
577

578
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
10✔
579
            kwargs["name"] = repr(constraint.name)
10✔
580

581
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
10✔
582

583
    def should_ignore_table(self, table: Table) -> bool:
10✔
584
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
585
        # tables
586
        return table.name in ("alembic_version", "migrate_version")
10✔
587

588
    def find_free_name(
10✔
589
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
6✔
590
    ) -> str:
2✔
591
        """
592
        Generate an attribute name that does not clash with other local or global names.
593
        """
594
        name = name.strip()
10✔
595
        assert name, "Identifier cannot be empty"
10✔
596
        name = _re_invalid_identifier.sub("_", name)
10✔
597
        if name[0].isdigit():
10✔
598
            name = "_" + name
10✔
599
        elif iskeyword(name) or name == "metadata":
10✔
600
            name += "_"
10✔
601

602
        original = name
10✔
603
        for i in count():
10✔
604
            if name not in global_names and name not in local_names:
10✔
605
                break
10✔
606

607
            name = original + (str(i) if i else "_")
10✔
608

609
        return name
10✔
610

611
    def fix_column_types(self, table: Table) -> None:
10✔
612
        """Adjust the reflected column types."""
613
        # Detect check constraints for boolean and enum columns
614
        for constraint in table.constraints.copy():
615
            if isinstance(constraint, CheckConstraint):
616
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
617

618
                # Turn any integer-like column with a CheckConstraint like
619
                # "column IN (0, 1)" into a Boolean
620
                match = _re_boolean_check_constraint.match(sqltext)
621
                if match:
622
                    colname_match = _re_column_name.match(match.group(1))
623
                    if colname_match:
624
                        colname = colname_match.group(3)
625
                        table.constraints.remove(constraint)
626
                        table.c[colname].type = Boolean()
627
                        continue
628

629
                # Turn any string-type column with a CheckConstraint like
630
                # "column IN (...)" into an Enum
631
                match = _re_enum_check_constraint.match(sqltext)
632
                if match:
633
                    colname_match = _re_column_name.match(match.group(1))
634
                    if colname_match:
635
                        colname = colname_match.group(3)
636
                        items = match.group(2)
637
                        if isinstance(table.c[colname].type, String):
638
                            table.constraints.remove(constraint)
639
                            if not isinstance(table.c[colname].type, Enum):
640
                                options = _re_enum_item.findall(items)
641
                                table.c[colname].type = Enum(
642
                                    *options, native_enum=False
643
                                )
644

645
                            continue
646

647
        for column in table.c:
648
            try:
649
                column.type = self.get_adapted_type(column.type)
650
            except CompileError:
651
                pass
652

653
            # PostgreSQL specific fix: detect sequences from server_default
654
            if column.server_default and self.bind.dialect.name == "postgresql":
655
                if isinstance(column.server_default, DefaultClause) and isinstance(
656
                    column.server_default.arg, TextClause
657
                ):
658
                    schema, seqname = decode_postgresql_sequence(
659
                        column.server_default.arg
660
                    )
661
                    if seqname:
662
                        # Add an explicit sequence
663
                        if seqname != f"{column.table.name}_{column.name}_seq":
664
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
665

666
                        column.server_default = None
667

668
    def get_adapted_type(self, coltype: Any) -> Any:
669
        compiled_type = coltype.compile(self.bind.engine.dialect)
670
        for supercls in coltype.__class__.__mro__:
671
            if not supercls.__name__.startswith("_") and hasattr(
672
                supercls, "__visit_name__"
673
            ):
674
                # Hack to fix adaptation of the Enum class which is broken since
675
                # SQLAlchemy 1.2
676
                kw = {}
677
                if supercls is Enum:
678
                    kw["name"] = coltype.name
679

680
                try:
681
                    new_coltype = coltype.adapt(supercls)
682
                except TypeError:
683
                    # If the adaptation fails, don't try again
684
                    break
685

686
                for key, value in kw.items():
687
                    setattr(new_coltype, key, value)
688

689
                if isinstance(coltype, ARRAY):
690
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
691

692
                try:
693
                    # If the adapted column type does not render the same as the
694
                    # original, don't substitute it
695
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
696
                        # Make an exception to the rule for Float and arrays of Float,
697
                        # since at least on PostgreSQL, Float can accurately represent
698
                        # both REAL and DOUBLE_PRECISION
699
                        if not isinstance(new_coltype, Float) and not (
700
                            isinstance(new_coltype, ARRAY)
701
                            and isinstance(new_coltype.item_type, Float)
702
                        ):
703
                            break
704
                except CompileError:
705
                    # If the adapted column type can't be compiled, don't substitute it
706
                    break
707

708
                # Stop on the first valid non-uppercase column type class
709
                coltype = new_coltype
710
                if supercls.__name__ != supercls.__name__.upper():
711
                    break
712

713
        return coltype
714

715

716
class DeclarativeGenerator(TablesGenerator):
717
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
718
        "use_inflect",
719
        "nojoined",
720
        "nobidi",
721
    }
722

723
    def __init__(
724
        self,
725
        metadata: MetaData,
726
        bind: Connection | Engine,
727
        options: Sequence[str],
728
        *,
729
        indentation: str = "    ",
730
        base_class_name: str = "Base",
731
    ):
732
        super().__init__(metadata, bind, options, indentation=indentation)
733
        self.base_class_name: str = base_class_name
734
        self.inflect_engine = inflect.engine()
735

736
    def generate_base(self) -> None:
737
        if _sqla_version < (1, 4):
738
            table_decoration = f"metadata = {self.base_class_name}.metadata"
739
            self.base = Base(
740
                literal_imports=[
741
                    LiteralImport("sqlalchemy.ext.declarative", "declarative_base")
742
                ],
743
                declarations=[f"{self.base_class_name} = declarative_base()"],
744
                metadata_ref=self.base_class_name,
745
                table_metadata_declaration=table_decoration,
746
            )
747
        elif (1, 4) <= _sqla_version < (2, 0):
748
            table_decoration = f"metadata = {self.base_class_name}.metadata"
749
            self.base = Base(
750
                literal_imports=[LiteralImport("sqlalchemy.orm", "declarative_base")],
751
                declarations=[f"{self.base_class_name} = declarative_base()"],
752
                metadata_ref="metadata",
753
                table_metadata_declaration=table_decoration,
754
            )
755
        else:
756
            self.base = Base(
757
                literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
758
                declarations=[
759
                    f"class {self.base_class_name}(DeclarativeBase):",
760
                    f"{self.indentation}pass",
761
                ],
762
                metadata_ref=f"{self.base_class_name}.metadata",
763
            )
764

765
    def collect_imports(self, models: Iterable[Model]) -> None:
766
        super().collect_imports(models)
767
        if any(isinstance(model, ModelClass) for model in models):
768
            if _sqla_version >= (2, 0):
769
                self.add_literal_import("sqlalchemy.orm", "Mapped")
770
                self.add_literal_import("sqlalchemy.orm", "mapped_column")
771

772
    def collect_imports_for_model(self, model: Model) -> None:
773
        super().collect_imports_for_model(model)
774
        if isinstance(model, ModelClass):
775
            if model.relationships:
776
                self.add_literal_import("sqlalchemy.orm", "relationship")
777

778
    def generate_models(self) -> list[Model]:
779
        models_by_table_name: dict[str, Model] = {}
780

781
        # Pick association tables from the metadata into their own set, don't process
782
        # them normally
783
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
784
        for table in self.metadata.sorted_tables:
785
            qualified_name = qualified_table_name(table)
786

787
            # Link tables have exactly two foreign key constraints and all columns are
788
            # involved in them
789
            fk_constraints = sorted(
790
                table.foreign_key_constraints, key=get_constraint_sort_key
791
            )
792
            if len(fk_constraints) == 2 and all(
793
                col.foreign_keys for col in table.columns
794
            ):
795
                model = models_by_table_name[qualified_name] = Model(table)
796
                tablename = fk_constraints[0].elements[0].column.table.name
797
                links[tablename].append(model)
798
                continue
799

800
            # Only form model classes for tables that have a primary key and are not
801
            # association tables
802
            if not table.primary_key:
803
                models_by_table_name[qualified_name] = Model(table)
804
            else:
805
                model = ModelClass(table)
806
                models_by_table_name[qualified_name] = model
807

808
                # Fill in the columns
809
                for column in table.c:
810
                    column_attr = ColumnAttribute(model, column)
811
                    model.columns.append(column_attr)
812

813
        # Add relationships
814
        for model in models_by_table_name.values():
815
            if isinstance(model, ModelClass):
816
                self.generate_relationships(
817
                    model, models_by_table_name, links[model.table.name]
818
                )
819

820
        # Nest inherited classes in their superclasses to ensure proper ordering
821
        if "nojoined" not in self.options:
822
            for model in list(models_by_table_name.values()):
823
                if not isinstance(model, ModelClass):
824
                    continue
825

826
                pk_column_names = {col.name for col in model.table.primary_key.columns}
827
                for constraint in model.table.foreign_key_constraints:
828
                    if set(get_column_names(constraint)) == pk_column_names:
829
                        target = models_by_table_name[
830
                            qualified_table_name(constraint.elements[0].column.table)
831
                        ]
832
                        if isinstance(target, ModelClass):
833
                            model.parent_class = target
834
                            target.children.append(model)
835

836
        # Change base if we only have tables
837
        if not any(
838
            isinstance(model, ModelClass) for model in models_by_table_name.values()
839
        ):
840
            super().generate_base()
841

842
        # Collect the imports
843
        self.collect_imports(models_by_table_name.values())
844

845
        # Rename models and their attributes that conflict with imports or other
846
        # attributes
847
        global_names = {
848
            name for namespace in self.imports.values() for name in namespace
849
        }
850
        for model in models_by_table_name.values():
851
            self.generate_model_name(model, global_names)
852
            global_names.add(model.name)
853

854
        return list(models_by_table_name.values())
855

856
    def generate_relationships(
857
        self,
858
        source: ModelClass,
859
        models_by_table_name: dict[str, Model],
860
        association_tables: list[Model],
861
    ) -> list[RelationshipAttribute]:
862
        relationships: list[RelationshipAttribute] = []
863
        reverse_relationship: RelationshipAttribute | None
864

865
        # Add many-to-one (and one-to-many) relationships
866
        pk_column_names = {col.name for col in source.table.primary_key.columns}
867
        for constraint in sorted(
868
            source.table.foreign_key_constraints, key=get_constraint_sort_key
869
        ):
870
            target = models_by_table_name[
871
                qualified_table_name(constraint.elements[0].column.table)
872
            ]
873
            if isinstance(target, ModelClass):
874
                if "nojoined" not in self.options:
875
                    if set(get_column_names(constraint)) == pk_column_names:
876
                        parent = models_by_table_name[
877
                            qualified_table_name(constraint.elements[0].column.table)
878
                        ]
879
                        if isinstance(parent, ModelClass):
880
                            source.parent_class = parent
881
                            parent.children.append(source)
882
                            continue
883

884
                # Add uselist=False to One-to-One relationships
885
                column_names = get_column_names(constraint)
886
                if any(
887
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
888
                    and {col.name for col in c.columns} == set(column_names)
889
                    for c in constraint.table.constraints
890
                ):
891
                    r_type = RelationshipType.ONE_TO_ONE
892
                else:
893
                    r_type = RelationshipType.MANY_TO_ONE
894

895
                relationship = RelationshipAttribute(r_type, source, target, constraint)
896
                source.relationships.append(relationship)
897

898
                # For self referential relationships, remote_side needs to be set
899
                if source is target:
900
                    relationship.remote_side = [
901
                        source.get_column_attribute(col.name)
902
                        for col in constraint.referred_table.primary_key
903
                    ]
904

905
                # If the two tables share more than one foreign key constraint,
906
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
907
                # it needs
908
                common_fk_constraints = get_common_fk_constraints(
909
                    source.table, target.table
910
                )
911
                if len(common_fk_constraints) > 1:
912
                    relationship.foreign_keys = [
913
                        source.get_column_attribute(key)
914
                        for key in constraint.column_keys
915
                    ]
916

917
                # Generate the opposite end of the relationship in the target class
918
                if "nobidi" not in self.options:
919
                    if r_type is RelationshipType.MANY_TO_ONE:
920
                        r_type = RelationshipType.ONE_TO_MANY
921

922
                    reverse_relationship = RelationshipAttribute(
923
                        r_type,
924
                        target,
925
                        source,
926
                        constraint,
927
                        foreign_keys=relationship.foreign_keys,
928
                        backref=relationship,
929
                    )
930
                    relationship.backref = reverse_relationship
931
                    target.relationships.append(reverse_relationship)
932

933
                    # For self referential relationships, remote_side needs to be set
934
                    if source is target:
935
                        reverse_relationship.remote_side = [
936
                            source.get_column_attribute(colname)
937
                            for colname in constraint.column_keys
938
                        ]
939

940
        # Add many-to-many relationships
941
        for association_table in association_tables:
942
            fk_constraints = sorted(
943
                association_table.table.foreign_key_constraints,
944
                key=get_constraint_sort_key,
945
            )
946
            target = models_by_table_name[
947
                qualified_table_name(fk_constraints[1].elements[0].column.table)
948
            ]
949
            if isinstance(target, ModelClass):
950
                relationship = RelationshipAttribute(
951
                    RelationshipType.MANY_TO_MANY,
952
                    source,
953
                    target,
954
                    fk_constraints[1],
955
                    association_table,
956
                )
957
                source.relationships.append(relationship)
958

959
                # Generate the opposite end of the relationship in the target class
960
                reverse_relationship = None
961
                if "nobidi" not in self.options:
962
                    reverse_relationship = RelationshipAttribute(
963
                        RelationshipType.MANY_TO_MANY,
964
                        target,
965
                        source,
966
                        fk_constraints[0],
967
                        association_table,
968
                        relationship,
969
                    )
970
                    relationship.backref = reverse_relationship
971
                    target.relationships.append(reverse_relationship)
972

973
                # Add a primary/secondary join for self-referential many-to-many
974
                # relationships
975
                if source is target:
976
                    both_relationships = [relationship]
977
                    reverse_flags = [False, True]
978
                    if reverse_relationship:
979
                        both_relationships.append(reverse_relationship)
980

981
                    for relationship, reverse in zip(both_relationships, reverse_flags):
982
                        if (
983
                            not relationship.association_table
984
                            or not relationship.constraint
985
                        ):
986
                            continue
987

988
                        constraints = sorted(
989
                            relationship.constraint.table.foreign_key_constraints,
990
                            key=get_constraint_sort_key,
991
                            reverse=reverse,
992
                        )
993
                        pri_pairs = zip(
994
                            get_column_names(constraints[0]), constraints[0].elements
995
                        )
996
                        sec_pairs = zip(
997
                            get_column_names(constraints[1]), constraints[1].elements
998
                        )
999
                        relationship.primaryjoin = [
1000
                            (
1001
                                relationship.source,
1002
                                elem.column.name,
1003
                                relationship.association_table,
1004
                                col,
1005
                            )
1006
                            for col, elem in pri_pairs
1007
                        ]
1008
                        relationship.secondaryjoin = [
1009
                            (
1010
                                relationship.target,
1011
                                elem.column.name,
1012
                                relationship.association_table,
1013
                                col,
1014
                            )
1015
                            for col, elem in sec_pairs
1016
                        ]
1017

1018
        return relationships
1019

1020
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
1021
        if isinstance(model, ModelClass):
1022
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
1023
            preferred_name = "".join(
1024
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1025
            )
1026
            if "use_inflect" in self.options:
1027
                singular_name = self.inflect_engine.singular_noun(preferred_name)
1028
                if singular_name:
1029
                    preferred_name = singular_name
1030

1031
            model.name = self.find_free_name(preferred_name, global_names)
1032

1033
            # Fill in the names for column attributes
1034
            local_names: set[str] = set()
1035
            for column_attr in model.columns:
1036
                self.generate_column_attr_name(column_attr, global_names, local_names)
1037
                local_names.add(column_attr.name)
1038

1039
            # Fill in the names for relationship attributes
1040
            for relationship in model.relationships:
1041
                self.generate_relationship_name(relationship, global_names, local_names)
1042
                local_names.add(relationship.name)
1043
        else:
1044
            super().generate_model_name(model, global_names)
1045

1046
    def generate_column_attr_name(
1047
        self,
1048
        column_attr: ColumnAttribute,
1049
        global_names: set[str],
1050
        local_names: set[str],
1051
    ) -> None:
1052
        column_attr.name = self.find_free_name(
1053
            column_attr.column.name, global_names, local_names
1054
        )
1055

1056
    def generate_relationship_name(
1057
        self,
1058
        relationship: RelationshipAttribute,
1059
        global_names: set[str],
1060
        local_names: set[str],
1061
    ) -> None:
1062
        # Self referential reverse relationships
1063
        preferred_name: str
1064
        if (
1065
            relationship.type
1066
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1067
            and relationship.source is relationship.target
1068
            and relationship.backref
1069
            and relationship.backref.name
1070
        ):
1071
            preferred_name = relationship.backref.name + "_reverse"
1072
        else:
1073
            preferred_name = relationship.target.table.name
1074

1075
            # If there's a constraint with a single column that ends with "_id", use the
1076
            # preceding part as the relationship name
1077
            if relationship.constraint:
1078
                is_source = relationship.source.table is relationship.constraint.table
1079
                if is_source or relationship.type not in (
1080
                    RelationshipType.ONE_TO_ONE,
1081
                    RelationshipType.ONE_TO_MANY,
1082
                ):
1083
                    column_names = [c.name for c in relationship.constraint.columns]
1084
                    if len(column_names) == 1 and column_names[0].endswith("_id"):
1085
                        preferred_name = column_names[0][:-3]
1086

1087
            if "use_inflect" in self.options:
1088
                if relationship.type in (
1089
                    RelationshipType.ONE_TO_MANY,
1090
                    RelationshipType.MANY_TO_MANY,
1091
                ):
1092
                    inflected_name = self.inflect_engine.plural_noun(preferred_name)
1093
                    if inflected_name:
1094
                        preferred_name = inflected_name
1095
                else:
1096
                    inflected_name = self.inflect_engine.singular_noun(preferred_name)
1097
                    if inflected_name:
1098
                        preferred_name = inflected_name
1099

1100
        relationship.name = self.find_free_name(
1101
            preferred_name, global_names, local_names
1102
        )
1103

1104
    def render_models(self, models: list[Model]) -> str:
1105
        rendered: list[str] = []
1106
        for model in models:
1107
            if isinstance(model, ModelClass):
1108
                rendered.append(self.render_class(model))
1109
            else:
1110
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
1111

1112
        return "\n\n\n".join(rendered)
1113

1114
    def render_class(self, model: ModelClass) -> str:
1115
        sections: list[str] = []
1116

1117
        # Render class variables / special declarations
1118
        class_vars: str = self.render_class_variables(model)
1119
        if class_vars:
1120
            sections.append(class_vars)
1121

1122
        # Render column attributes
1123
        rendered_column_attributes: list[str] = []
1124
        for nullable in (False, True):
1125
            for column_attr in model.columns:
1126
                if column_attr.column.nullable is nullable:
1127
                    rendered_column_attributes.append(
1128
                        self.render_column_attribute(column_attr)
1129
                    )
1130

1131
        if rendered_column_attributes:
1132
            sections.append("\n".join(rendered_column_attributes))
1133

1134
        # Render relationship attributes
1135
        rendered_relationship_attributes: list[str] = [
1136
            self.render_relationship(relationship)
1137
            for relationship in model.relationships
1138
        ]
1139

1140
        if rendered_relationship_attributes:
1141
            sections.append("\n".join(rendered_relationship_attributes))
1142

1143
        declaration = self.render_class_declaration(model)
1144
        rendered_sections = "\n\n".join(
1145
            indent(section, self.indentation) for section in sections
1146
        )
1147
        return f"{declaration}\n{rendered_sections}"
1148

1149
    def render_class_declaration(self, model: ModelClass) -> str:
1150
        parent_class_name = (
1151
            model.parent_class.name if model.parent_class else self.base_class_name
1152
        )
1153
        return f"class {model.name}({parent_class_name}):"
1154

1155
    def render_class_variables(self, model: ModelClass) -> str:
1156
        variables = [f"__tablename__ = {model.table.name!r}"]
1157

1158
        # Render constraints and indexes as __table_args__
1159
        table_args = self.render_table_args(model.table)
1160
        if table_args:
1161
            variables.append(f"__table_args__ = {table_args}")
1162

1163
        return "\n".join(variables)
1164

1165
    def render_table_args(self, table: Table) -> str:
1166
        args: list[str] = []
1167
        kwargs: dict[str, str] = {}
1168

1169
        # Render constraints
1170
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
1171
            if uses_default_name(constraint):
1172
                if isinstance(constraint, PrimaryKeyConstraint):
1173
                    continue
1174
                if (
1175
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1176
                    and len(constraint.columns) == 1
1177
                ):
1178
                    continue
1179

1180
            args.append(self.render_constraint(constraint))
1181

1182
        # Render indexes
1183
        for index in sorted(table.indexes, key=lambda i: i.name):
1184
            if len(index.columns) > 1 or not uses_default_name(index):
1185
                args.append(self.render_index(index))
1186

1187
        if table.schema:
1188
            kwargs["schema"] = table.schema
1189

1190
        if table.comment:
1191
            kwargs["comment"] = table.comment
1192

1193
        if kwargs:
1194
            formatted_kwargs = pformat(kwargs)
1195
            if not args:
1196
                return formatted_kwargs
1197
            else:
1198
                args.append(formatted_kwargs)
1199

1200
        if args:
1201
            rendered_args = f",\n{self.indentation}".join(args)
1202
            if len(args) == 1:
1203
                rendered_args += ","
1204

1205
            return f"(\n{self.indentation}{rendered_args}\n)"
1206
        else:
1207
            return ""
1208

1209
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1210
        column = column_attr.column
1211
        rendered_column = self.render_column(column, column_attr.name != column.name)
1212

1213
        if _sqla_version < (2, 0):
1214
            return f"{column_attr.name} = {rendered_column}"
1215
        else:
1216
            try:
1217
                python_type = column.type.python_type
1218
                python_type_name = python_type.__name__
1219
                if python_type.__module__ == "builtins":
1220
                    column_python_type = python_type_name
1221
                else:
1222
                    python_type_module = python_type.__module__
1223
                    column_python_type = f"{python_type_module}.{python_type_name}"
1224
                    self.add_module_import(python_type_module)
1225
            except NotImplementedError:
1226
                self.add_literal_import("typing", "Any")
1227
                column_python_type = "Any"
1228

1229
            if column.nullable:
1230
                self.add_literal_import("typing", "Optional")
1231
                column_python_type = f"Optional[{column_python_type}]"
1232
            return (
1233
                f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
1234
            )
1235

1236
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
1237
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
1238
            rendered = []
1239
            for attr in column_attrs:
1240
                if attr.model is relationship.source:
1241
                    rendered.append(attr.name)
1242
                else:
1243
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
1244

1245
            return "[" + ", ".join(rendered) + "]"
1246

1247
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
1248
            rendered = []
1249
            render_as_string = False
1250
            # Assume that column_attrs are all in relationship.source or none
1251
            for attr in column_attrs:
1252
                if attr.model is relationship.source:
1253
                    rendered.append(attr.name)
1254
                else:
1255
                    rendered.append(f"{attr.model.name}.{attr.name}")
1256
                    render_as_string = True
1257

1258
            if render_as_string:
1259
                return "'[" + ", ".join(rendered) + "]'"
1260
            else:
1261
                return "[" + ", ".join(rendered) + "]"
1262

1263
        def render_join(terms: list[JoinType]) -> str:
1264
            rendered_joins = []
1265
            for source, source_col, target, target_col in terms:
1266
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
1267
                if target.__class__ is Model:
1268
                    rendered += "c."
1269

1270
                rendered += str(target_col)
1271
                rendered_joins.append(rendered)
1272

1273
            if len(rendered_joins) > 1:
1274
                rendered = ", ".join(rendered_joins)
1275
                return f"and_({rendered})"
1276
            else:
1277
                return rendered_joins[0]
1278

1279
        # Render keyword arguments
1280
        kwargs: dict[str, Any] = {}
1281
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
1282
            if relationship.constraint.referred_table is relationship.source.table:
1283
                kwargs["uselist"] = False
1284

1285
        # Add the "secondary" keyword for many-to-many relationships
1286
        if relationship.association_table:
1287
            table_ref = relationship.association_table.table.name
1288
            if relationship.association_table.schema:
1289
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
1290

1291
            kwargs["secondary"] = repr(table_ref)
1292

1293
        if relationship.remote_side:
1294
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
1295

1296
        if relationship.foreign_keys:
1297
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
1298

1299
        if relationship.primaryjoin:
1300
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
1301

1302
        if relationship.secondaryjoin:
1303
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
1304

1305
        if relationship.backref:
1306
            kwargs["back_populates"] = repr(relationship.backref.name)
1307

1308
        rendered_relationship = render_callable(
1309
            "relationship", repr(relationship.target.name), kwargs=kwargs
1310
        )
1311

1312
        if _sqla_version < (2, 0):
1313
            return f"{relationship.name} = {rendered_relationship}"
1314
        else:
1315
            relationship_type: str
1316
            if relationship.type == RelationshipType.ONE_TO_MANY:
1317
                self.add_literal_import("typing", "List")
1318
                relationship_type = f"List['{relationship.target.name}']"
1319
            elif relationship.type in (
1320
                RelationshipType.ONE_TO_ONE,
1321
                RelationshipType.MANY_TO_ONE,
1322
            ):
1323
                relationship_type = f"'{relationship.target.name}'"
1324
            elif relationship.type == RelationshipType.MANY_TO_MANY:
1325
                self.add_literal_import("typing", "List")
1326
                relationship_type = f"List['{relationship.target.name}']"
1327
            else:
1328
                self.add_literal_import("typing", "Any")
1329
                relationship_type = "Any"
1330
            return (
1331
                f"{relationship.name}: Mapped[{relationship_type}] "
1332
                f"= {rendered_relationship}"
1333
            )
1334

1335

1336
class DataclassGenerator(DeclarativeGenerator):
1337
    def __init__(
1338
        self,
1339
        metadata: MetaData,
1340
        bind: Connection | Engine,
1341
        options: Sequence[str],
1342
        *,
1343
        indentation: str = "    ",
1344
        base_class_name: str = "Base",
1345
        quote_annotations: bool = False,
1346
        metadata_key: str = "sa",
1347
    ):
1348
        super().__init__(
1349
            metadata,
1350
            bind,
1351
            options,
1352
            indentation=indentation,
1353
            base_class_name=base_class_name,
1354
        )
1355
        self.metadata_key: str = metadata_key
1356
        self.quote_annotations: bool = quote_annotations
1357

1358
    def generate_base(self) -> None:
1359
        if _sqla_version < (2, 0):
1360
            self.base = Base(
1361
                literal_imports=[LiteralImport("sqlalchemy.orm", "registry")],
1362
                declarations=["mapper_registry = registry()"],
1363
                metadata_ref="metadata",
1364
                decorator="@mapper_registry.mapped",
1365
            )
1366
        else:
1367
            self.base = Base(
1368
                literal_imports=[
1369
                    LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1370
                    LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1371
                ],
1372
                declarations=[
1373
                    (
1374
                        f"class {self.base_class_name}(MappedAsDataclass, "
1375
                        "DeclarativeBase):"
1376
                    ),
1377
                    f"{self.indentation}pass",
1378
                ],
1379
                metadata_ref=f"{self.base_class_name}.metadata",
1380
            )
1381

1382
    def collect_imports(self, models: Iterable[Model]) -> None:
1383
        super().collect_imports(models)
1384
        if _sqla_version < (2, 0):
1385
            if not self.quote_annotations:
1386
                self.add_literal_import("__future__", "annotations")
1387

1388
            if any(isinstance(model, ModelClass) for model in models):
1389
                self.remove_literal_import("sqlalchemy.orm", "declarative_base")
1390
                self.add_literal_import("dataclasses", "dataclass")
1391
                self.add_literal_import("dataclasses", "field")
1392
                self.add_literal_import("sqlalchemy.orm", "registry")
1393

1394
    def collect_imports_for_model(self, model: Model) -> None:
1395
        super().collect_imports_for_model(model)
1396
        if _sqla_version < (2, 0):
1397
            if isinstance(model, ModelClass):
1398
                for column_attr in model.columns:
1399
                    if column_attr.column.nullable:
1400
                        self.add_literal_import("typing", "Optional")
1401
                        break
1402

1403
                for relationship_attr in model.relationships:
1404
                    if relationship_attr.type in (
1405
                        RelationshipType.ONE_TO_MANY,
1406
                        RelationshipType.MANY_TO_MANY,
1407
                    ):
1408
                        self.add_literal_import("typing", "List")
1409

1410
    def collect_imports_for_column(self, column: Column[Any]) -> None:
1411
        super().collect_imports_for_column(column)
1412
        if _sqla_version < (2, 0):
1413
            try:
1414
                python_type = column.type.python_type
1415
            except NotImplementedError:
1416
                pass
1417
            else:
1418
                self.add_import(python_type)
1419

1420
    def render_module_variables(self, models: list[Model]) -> str:
1421
        if _sqla_version >= (2, 0):
1422
            return super().render_module_variables(models)
1423
        else:
1424
            if not any(isinstance(model, ModelClass) for model in models):
1425
                return super().render_module_variables(models)
1426

1427
            declarations: list[str] = ["mapper_registry = registry()"]
1428
            if any(not isinstance(model, ModelClass) for model in models):
1429
                declarations.append("metadata = mapper_registry.metadata")
1430

1431
            if not self.quote_annotations:
1432
                self.add_literal_import("__future__", "annotations")
1433

1434
            return "\n".join(declarations)
1435

1436
    def render_class_declaration(self, model: ModelClass) -> str:
1437
        if _sqla_version >= (2, 0):
1438
            return super().render_class_declaration(model)
1439
        else:
1440
            superclass_part = (
1441
                f"({model.parent_class.name})" if model.parent_class else ""
1442
            )
1443
            return (
1444
                f"@mapper_registry.mapped\n@dataclass"
1445
                f"\nclass {model.name}{superclass_part}:"
1446
            )
1447

1448
    def render_class_variables(self, model: ModelClass) -> str:
1449
        if _sqla_version >= (2, 0):
1450
            return super().render_class_variables(model)
1451
        else:
1452
            variables = [
1453
                super().render_class_variables(model),
1454
                f"__sa_dataclass_metadata_key__ = {self.metadata_key!r}",
1455
            ]
1456
            return "\n".join(variables)
1457

1458
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1459
        if _sqla_version >= (2, 0):
1460
            return super().render_column_attribute(column_attr)
1461
        else:
1462
            column = column_attr.column
1463
            try:
1464
                python_type = column.type.python_type
1465
            except NotImplementedError:
1466
                python_type_name = "Any"
1467
            else:
1468
                python_type_name = python_type.__name__
1469

1470
            kwargs: dict[str, Any] = {}
1471
            if column.autoincrement and column.name in column.table.primary_key:
1472
                kwargs["init"] = False
1473
            elif column.nullable:
1474
                self.add_literal_import("typing", "Optional")
1475
                kwargs["default"] = None
1476
                python_type_name = f"Optional[{python_type_name}]"
1477

1478
            rendered_column = self.render_column(
1479
                column, column_attr.name != column.name
1480
            )
1481
            kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered_column}}}"
1482
            rendered_field = render_callable("field", kwargs=kwargs)
1483
            return f"{column_attr.name}: {python_type_name} = {rendered_field}"
1484

1485
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
1486
        if _sqla_version >= (2, 0):
1487
            return super().render_relationship(relationship)
1488
        else:
1489
            rendered = super().render_relationship(relationship).partition(" = ")[2]
1490
            kwargs: dict[str, Any] = {}
1491

1492
            annotation = relationship.target.name
1493
            if self.quote_annotations:
1494
                annotation = repr(relationship.target.name)
1495

1496
            if relationship.type in (
1497
                RelationshipType.ONE_TO_MANY,
1498
                RelationshipType.MANY_TO_MANY,
1499
            ):
1500
                self.add_literal_import("typing", "List")
1501
                annotation = f"List[{annotation}]"
1502
                kwargs["default_factory"] = "list"
1503
            else:
1504
                self.add_literal_import("typing", "Optional")
1505
                kwargs["default"] = "None"
1506
                annotation = f"Optional[{annotation}]"
1507

1508
            kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered}}}"
1509
            rendered_field = render_callable("field", kwargs=kwargs)
1510
            return f"{relationship.name}: {annotation} = {rendered_field}"
1511

1512

1513
class SQLModelGenerator(DeclarativeGenerator):
1514
    def __init__(
1515
        self,
1516
        metadata: MetaData,
1517
        bind: Connection | Engine,
1518
        options: Sequence[str],
1519
        *,
1520
        indentation: str = "    ",
1521
        base_class_name: str = "SQLModel",
1522
    ):
1523
        super().__init__(
1524
            metadata,
1525
            bind,
1526
            options,
1527
            indentation=indentation,
1528
            base_class_name=base_class_name,
1529
        )
1530

1531
    def generate_base(self) -> None:
1532
        self.base = Base(
1533
            literal_imports=[],
1534
            declarations=[],
1535
            metadata_ref="",
1536
        )
1537

1538
    def collect_imports(self, models: Iterable[Model]) -> None:
1539
        super(DeclarativeGenerator, self).collect_imports(models)
1540
        if any(isinstance(model, ModelClass) for model in models):
1541
            self.remove_literal_import("sqlalchemy", "MetaData")
1542
            self.add_literal_import("sqlmodel", "SQLModel")
1543
            self.add_literal_import("sqlmodel", "Field")
1544

1545
    def collect_imports_for_model(self, model: Model) -> None:
1546
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
1547
        if isinstance(model, ModelClass):
1548
            for column_attr in model.columns:
1549
                if column_attr.column.nullable:
1550
                    self.add_literal_import("typing", "Optional")
1551
                    break
1552

1553
            if model.relationships:
1554
                self.add_literal_import("sqlmodel", "Relationship")
1555

1556
            for relationship_attr in model.relationships:
1557
                if relationship_attr.type in (
1558
                    RelationshipType.ONE_TO_MANY,
1559
                    RelationshipType.MANY_TO_MANY,
1560
                ):
1561
                    self.add_literal_import("typing", "List")
1562

1563
    def collect_imports_for_column(self, column: Column[Any]) -> None:
1564
        super().collect_imports_for_column(column)
1565
        try:
1566
            python_type = column.type.python_type
1567
        except NotImplementedError:
1568
            self.add_literal_import("typing", "Any")
1569
        else:
1570
            self.add_import(python_type)
1571

1572
    def render_module_variables(self, models: list[Model]) -> str:
1573
        declarations: list[str] = []
1574
        if any(not isinstance(model, ModelClass) for model in models):
1575
            if self.base.table_metadata_declaration is not None:
1576
                declarations.append(self.base.table_metadata_declaration)
1577

1578
        return "\n".join(declarations)
1579

1580
    def render_class_declaration(self, model: ModelClass) -> str:
1581
        if model.parent_class:
1582
            parent = model.parent_class.name
1583
        else:
1584
            parent = self.base_class_name
1585

1586
        superclass_part = f"({parent}, table=True)"
1587
        return f"class {model.name}{superclass_part}:"
1588

1589
    def render_class_variables(self, model: ModelClass) -> str:
1590
        variables = []
1591

1592
        if model.table.name != model.name.lower():
1593
            variables.append(f"__tablename__ = {model.table.name!r}")
1594

1595
        # Render constraints and indexes as __table_args__
1596
        table_args = self.render_table_args(model.table)
1597
        if table_args:
1598
            variables.append(f"__table_args__ = {table_args}")
1599

1600
        return "\n".join(variables)
1601

1602
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1603
        column = column_attr.column
1604
        try:
1605
            python_type = column.type.python_type
1606
        except NotImplementedError:
1607
            python_type_name = "Any"
1608
        else:
1609
            python_type_name = python_type.__name__
1610

1611
        kwargs: dict[str, Any] = {}
1612
        if (
1613
            column.autoincrement and column.name in column.table.primary_key
1614
        ) or column.nullable:
1615
            self.add_literal_import("typing", "Optional")
1616
            kwargs["default"] = None
1617
            python_type_name = f"Optional[{python_type_name}]"
1618

1619
        rendered_column = self.render_column(column, True)
1620
        kwargs["sa_column"] = f"{rendered_column}"
1621
        rendered_field = render_callable("Field", kwargs=kwargs)
1622
        return f"{column_attr.name}: {python_type_name} = {rendered_field}"
1623

1624
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
1625
        rendered = super().render_relationship(relationship).partition(" = ")[2]
1626
        args = self.render_relationship_args(rendered)
1627
        kwargs: dict[str, Any] = {}
1628
        annotation = repr(relationship.target.name)
1629

1630
        if relationship.type in (
1631
            RelationshipType.ONE_TO_MANY,
1632
            RelationshipType.MANY_TO_MANY,
1633
        ):
1634
            self.add_literal_import("typing", "List")
1635
            annotation = f"List[{annotation}]"
1636
        else:
1637
            self.add_literal_import("typing", "Optional")
1638
            annotation = f"Optional[{annotation}]"
1639

1640
        rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
1641
        return f"{relationship.name}: {annotation} = {rendered_field}"
1642

1643
    def render_relationship_args(self, arguments: str) -> list[str]:
1644
        argument_list = arguments.split(",")
1645
        # delete ')' and ' ' from args
1646
        argument_list[-1] = argument_list[-1][:-1]
1647
        argument_list = [argument[1:] for argument in argument_list]
1648

1649
        rendered_args: list[str] = []
1650
        for arg in argument_list:
1651
            if "back_populates" in arg:
1652
                rendered_args.append(arg)
1653
            if "uselist=False" in arg:
1654
                rendered_args.append("sa_relationship_kwargs={'uselist': False}")
1655

1656
        return rendered_args
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc