• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

agronholm / sqlacodegen / 6294698958

25 Sep 2023 04:00AM UTC coverage: 97.479% (-0.2%) from 97.639%
6294698958

Pull #286

github

web-flow
Merge 3a8081352 into c68d2a8a0
Pull Request #286: Fixed omission of 'collation' keyword in existing CHAR type

1 of 4 new or added lines in 1 file covered. (25.0%)

36 existing lines in 1 file now uncovered.

1779 of 1825 relevant lines covered (97.48%)

9.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.93
/src/sqlacodegen/generators.py
1
from __future__ import annotations
12✔
2

3
import inspect
12✔
4
import re
12✔
5
import sys
12✔
6
from abc import ABCMeta, abstractmethod
12✔
7
from collections import defaultdict
12✔
8
from collections.abc import Collection, Iterable, Sequence
12✔
9
from dataclasses import dataclass
12✔
10
from importlib import import_module
12✔
11
from inspect import Parameter
12✔
12
from itertools import count
12✔
13
from keyword import iskeyword
12✔
14
from pprint import pformat
12✔
15
from textwrap import indent
12✔
16
from typing import Any, ClassVar
12✔
17

18
import inflect
12✔
19
import sqlalchemy
12✔
20
from sqlalchemy import (
12✔
21
    ARRAY,
22
    CHAR,
23
    Boolean,
24
    CheckConstraint,
25
    Column,
26
    Computed,
27
    Constraint,
28
    DefaultClause,
29
    Enum,
30
    Float,
31
    ForeignKey,
32
    ForeignKeyConstraint,
33
    Identity,
34
    Index,
35
    MetaData,
36
    PrimaryKeyConstraint,
37
    String,
38
    Table,
39
    Text,
40
    UniqueConstraint,
41
)
42
from sqlalchemy.dialects.postgresql import JSONB
12✔
43
from sqlalchemy.engine import Connection, Engine
12✔
44
from sqlalchemy.exc import CompileError
12✔
45
from sqlalchemy.sql.elements import TextClause
12✔
46

47
from .models import (
12✔
48
    ColumnAttribute,
49
    JoinType,
50
    Model,
51
    ModelClass,
52
    RelationshipAttribute,
53
    RelationshipType,
54
)
55
from .utils import (
12✔
56
    decode_postgresql_sequence,
57
    get_column_names,
58
    get_common_fk_constraints,
59
    get_compiled_expression,
60
    get_constraint_sort_key,
61
    qualified_table_name,
62
    render_callable,
63
    uses_default_name,
64
)
65

66
if sys.version_info < (3, 10):
12✔
67
    from importlib_metadata import version
6✔
68
else:
69
    from importlib.metadata import version
6✔
70

71
_sqla_version = tuple(int(x) for x in version("sqlalchemy").split(".")[:2])
12✔
72
_re_boolean_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \(0, 1\)")
12✔
73
_re_column_name = re.compile(r'(?:(["`]?).*\1\.)?(["`]?)(.*)\2')
12✔
74
_re_enum_check_constraint = re.compile(r"(?:.*?\.)?(.*?) IN \((.+)\)")
12✔
75
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
12✔
76
_re_invalid_identifier = re.compile(r"(?u)\W")
12✔
77

78

79
@dataclass
12✔
80
class LiteralImport:
10✔
81
    pkgname: str
12✔
82
    name: str
12✔
83

84

85
@dataclass
12✔
86
class Base:
10✔
87
    """Representation of MetaData for Tables, respectively Base for classes"""
88

89
    literal_imports: list[LiteralImport]
12✔
90
    declarations: list[str]
12✔
91
    metadata_ref: str
12✔
92
    decorator: str | None = None
12✔
93
    table_metadata_declaration: str | None = None
12✔
94

95

96
class CodeGenerator(metaclass=ABCMeta):
12✔
97
    valid_options: ClassVar[set[str]] = set()
12✔
98

99
    def __init__(
12✔
100
        self, metadata: MetaData, bind: Connection | Engine, options: Sequence[str]
101
    ):
102
        self.metadata: MetaData = metadata
12✔
103
        self.bind: Connection | Engine = bind
12✔
104
        self.options: set[str] = set(options)
12✔
105

106
        # Validate options
107
        invalid_options = {opt for opt in options if opt not in self.valid_options}
12✔
108
        if invalid_options:
12✔
UNCOV
109
            raise ValueError("Unrecognized options: " + ", ".join(invalid_options))
×
110

111
    @abstractmethod
12✔
112
    def generate(self) -> str:
10✔
113
        """
114
        Generate the code for the given metadata.
115
        .. note:: May modify the metadata.
116
        """
117

118

119
@dataclass(eq=False)
12✔
120
class TablesGenerator(CodeGenerator):
12✔
121
    valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
12✔
122
    builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | {
12✔
123
        "dataclasses"
124
    }
125

126
    def __init__(
12✔
127
        self,
128
        metadata: MetaData,
129
        bind: Connection | Engine,
130
        options: Sequence[str],
131
        *,
132
        indentation: str = "    ",
133
    ):
134
        super().__init__(metadata, bind, options)
12✔
135
        self.indentation: str = indentation
12✔
136
        self.imports: dict[str, set[str]] = defaultdict(set)
12✔
137
        self.module_imports: set[str] = set()
12✔
138

139
    def generate_base(self) -> None:
12✔
140
        self.base = Base(
12✔
141
            literal_imports=[LiteralImport("sqlalchemy", "MetaData")],
142
            declarations=["metadata = MetaData()"],
143
            metadata_ref="metadata",
144
        )
145

146
    def generate(self) -> str:
12✔
147
        self.generate_base()
12✔
148

149
        sections: list[str] = []
12✔
150

151
        # Remove unwanted elements from the metadata
152
        for table in list(self.metadata.tables.values()):
12✔
153
            if self.should_ignore_table(table):
12✔
154
                self.metadata.remove(table)
×
UNCOV
155
                continue
×
156

157
            if "noindexes" in self.options:
12✔
158
                table.indexes.clear()
12✔
159

160
            if "noconstraints" in self.options:
12✔
161
                table.constraints.clear()
12✔
162

163
            if "nocomments" in self.options:
12✔
164
                table.comment = None
12✔
165

166
            for column in table.columns:
12✔
167
                if "nocomments" in self.options:
12✔
168
                    column.comment = None
12✔
169

170
        # Use information from column constraints to figure out the intended column
171
        # types
172
        for table in self.metadata.tables.values():
12✔
173
            self.fix_column_types(table)
12✔
174

175
        # Generate the models
176
        models: list[Model] = self.generate_models()
12✔
177

178
        # Render module level variables
179
        variables = self.render_module_variables(models)
12✔
180
        if variables:
12✔
181
            sections.append(variables + "\n")
12✔
182

183
        # Render models
184
        rendered_models = self.render_models(models)
12✔
185
        if rendered_models:
12✔
186
            sections.append(rendered_models)
12✔
187

188
        # Render collected imports
189
        groups = self.group_imports()
12✔
190
        imports = "\n\n".join("\n".join(line for line in group) for group in groups)
12✔
191
        if imports:
12✔
192
            sections.insert(0, imports)
12✔
193

194
        return "\n\n".join(sections) + "\n"
12✔
195

196
    def collect_imports(self, models: Iterable[Model]) -> None:
12✔
197
        for literal_import in self.base.literal_imports:
12✔
198
            self.add_literal_import(literal_import.pkgname, literal_import.name)
12✔
199

200
        for model in models:
12✔
201
            self.collect_imports_for_model(model)
12✔
202

203
    def collect_imports_for_model(self, model: Model) -> None:
12✔
204
        if model.__class__ is Model:
12✔
205
            self.add_import(Table)
12✔
206

207
        for column in model.table.c:
12✔
208
            self.collect_imports_for_column(column)
12✔
209

210
        for constraint in model.table.constraints:
12✔
211
            self.collect_imports_for_constraint(constraint)
12✔
212

213
        for index in model.table.indexes:
12✔
214
            self.collect_imports_for_constraint(index)
12✔
215

216
    def collect_imports_for_column(self, column: Column[Any]) -> None:
12✔
217
        self.add_import(column.type)
12✔
218

219
        if isinstance(column.type, ARRAY):
12✔
220
            self.add_import(column.type.item_type.__class__)
12✔
221
        elif isinstance(column.type, JSONB):
12✔
222
            if (
12✔
223
                not isinstance(column.type.astext_type, Text)
224
                or column.type.astext_type.length is not None
225
            ):
226
                self.add_import(column.type.astext_type)
12✔
227

228
        if column.default:
12✔
229
            self.add_import(column.default)
12✔
230

231
        if column.server_default:
12✔
232
            if isinstance(column.server_default, (Computed, Identity)):
12✔
233
                self.add_import(column.server_default)
12✔
234
            elif isinstance(column.server_default, DefaultClause):
12✔
235
                self.add_literal_import("sqlalchemy", "text")
12✔
236

237
    def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
12✔
238
        if isinstance(constraint, Index):
12✔
239
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
12✔
240
                self.add_literal_import("sqlalchemy", "Index")
12✔
241
        elif isinstance(constraint, PrimaryKeyConstraint):
12✔
242
            if not uses_default_name(constraint):
12✔
243
                self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
12✔
244
        elif isinstance(constraint, UniqueConstraint):
12✔
245
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
12✔
246
                self.add_literal_import("sqlalchemy", "UniqueConstraint")
12✔
247
        elif isinstance(constraint, ForeignKeyConstraint):
12✔
248
            if len(constraint.columns) > 1 or not uses_default_name(constraint):
12✔
249
                self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
12✔
250
            else:
251
                self.add_import(ForeignKey)
12✔
252
        else:
253
            self.add_import(constraint)
12✔
254

255
    def add_import(self, obj: Any) -> None:
12✔
256
        # Don't store builtin imports
257
        if getattr(obj, "__module__", "builtins") == "builtins":
12✔
258
            return
6✔
259

260
        type_ = type(obj) if not isinstance(obj, type) else obj
12✔
261
        pkgname = type_.__module__
12✔
262

263
        # The column types have already been adapted towards generic types if possible,
264
        # so if this is still a vendor specific type (e.g., MySQL INTEGER) be sure to
265
        # use that rather than the generic sqlalchemy type as it might have different
266
        # constructor parameters.
267
        if pkgname.startswith("sqlalchemy.dialects."):
12✔
268
            dialect_pkgname = ".".join(pkgname.split(".")[0:3])
12✔
269
            dialect_pkg = import_module(dialect_pkgname)
12✔
270

271
            if type_.__name__ in dialect_pkg.__all__:
12✔
272
                pkgname = dialect_pkgname
12✔
273
        elif type_.__name__ in dir(sqlalchemy):
12✔
274
            pkgname = "sqlalchemy"
12✔
275
        else:
276
            pkgname = type_.__module__
12✔
277

278
        self.add_literal_import(pkgname, type_.__name__)
12✔
279

280
    def add_literal_import(self, pkgname: str, name: str) -> None:
12✔
281
        names = self.imports.setdefault(pkgname, set())
12✔
282
        names.add(name)
12✔
283

284
    def remove_literal_import(self, pkgname: str, name: str) -> None:
12✔
285
        names = self.imports.setdefault(pkgname, set())
6✔
286
        if name in names:
6✔
UNCOV
287
            names.remove(name)
×
288

289
    def add_module_import(self, pgkname: str) -> None:
12✔
290
        self.module_imports.add(pgkname)
6✔
291

292
    def group_imports(self) -> list[list[str]]:
12✔
293
        future_imports: list[str] = []
12✔
294
        stdlib_imports: list[str] = []
12✔
295
        thirdparty_imports: list[str] = []
12✔
296

297
        for package in sorted(self.imports):
12✔
298
            imports = ", ".join(sorted(self.imports[package]))
12✔
299
            collection = thirdparty_imports
12✔
300
            if package == "__future__":
12✔
301
                collection = future_imports
6✔
302
            elif package in self.builtin_module_names:
12✔
303
                collection = stdlib_imports
6✔
304
            elif package in sys.modules:
12✔
305
                if "site-packages" not in (sys.modules[package].__file__ or ""):
12✔
306
                    collection = stdlib_imports
12✔
307

308
            collection.append(f"from {package} import {imports}")
12✔
309

310
        for module in sorted(self.module_imports):
12✔
311
            thirdparty_imports.append(f"import {module}")
6✔
312

313
        return [
12✔
314
            group
315
            for group in (future_imports, stdlib_imports, thirdparty_imports)
316
            if group
317
        ]
318

319
    def generate_models(self) -> list[Model]:
12✔
320
        models = [Model(table) for table in self.metadata.sorted_tables]
12✔
321

322
        # Collect the imports
323
        self.collect_imports(models)
12✔
324

325
        # Generate names for models
326
        global_names = {
12✔
327
            name for namespace in self.imports.values() for name in namespace
328
        }
329
        for model in models:
12✔
330
            self.generate_model_name(model, global_names)
12✔
331
            global_names.add(model.name)
12✔
332

333
        return models
12✔
334

335
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
12✔
336
        preferred_name = f"t_{model.table.name}"
12✔
337
        model.name = self.find_free_name(preferred_name, global_names)
12✔
338

339
    def render_module_variables(self, models: list[Model]) -> str:
12✔
340
        declarations = self.base.declarations
12✔
341

342
        if any(not isinstance(model, ModelClass) for model in models):
12✔
343
            if self.base.table_metadata_declaration is not None:
12✔
344
                declarations.append(self.base.table_metadata_declaration)
6✔
345

346
        return "\n".join(declarations)
12✔
347

348
    def render_models(self, models: list[Model]) -> str:
12✔
349
        rendered: list[str] = []
12✔
350
        for model in models:
12✔
351
            rendered_table = self.render_table(model.table)
12✔
352
            rendered.append(f"{model.name} = {rendered_table}")
12✔
353

354
        return "\n\n".join(rendered)
12✔
355

356
    def render_table(self, table: Table) -> str:
12✔
357
        args: list[str] = [f"{table.name!r}, {self.base.metadata_ref}"]
12✔
358
        kwargs: dict[str, object] = {}
12✔
359
        for column in table.columns:
12✔
360
            # Cast is required because of a bug in the SQLAlchemy stubs regarding
361
            # Table.columns
362
            args.append(self.render_column(column, True, is_table=True))
12✔
363

364
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
12✔
365
            if uses_default_name(constraint):
12✔
366
                if isinstance(constraint, PrimaryKeyConstraint):
12✔
367
                    continue
12✔
368
                elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
12✔
369
                    if len(constraint.columns) == 1:
12✔
370
                        continue
12✔
371

372
            args.append(self.render_constraint(constraint))
12✔
373

374
        for index in sorted(table.indexes, key=lambda i: i.name):
12✔
375
            # One-column indexes should be rendered as index=True on columns
376
            if len(index.columns) > 1 or not uses_default_name(index):
12✔
377
                args.append(self.render_index(index))
12✔
378

379
        if table.schema:
12✔
380
            kwargs["schema"] = repr(table.schema)
12✔
381

382
        table_comment = getattr(table, "comment", None)
12✔
383
        if table_comment:
12✔
384
            kwargs["comment"] = repr(table.comment)
12✔
385

386
        return render_callable("Table", *args, kwargs=kwargs, indentation="    ")
12✔
387

388
    def render_index(self, index: Index) -> str:
12✔
389
        extra_args = [repr(col.name) for col in index.columns]
12✔
390
        kwargs = {}
12✔
391
        if index.unique:
12✔
392
            kwargs["unique"] = True
12✔
393

394
        return render_callable("Index", repr(index.name), *extra_args, kwargs=kwargs)
12✔
395

396
    # TODO find better solution for is_table
397
    def render_column(
12✔
398
        self, column: Column[Any], show_name: bool, is_table: bool = False
399
    ) -> str:
400
        args = []
12✔
401
        kwargs: dict[str, Any] = {}
12✔
402
        kwarg = []
12✔
403
        is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
12✔
404
        dedicated_fks = [
12✔
405
            c
406
            for c in column.foreign_keys
407
            if c.constraint
408
            and len(c.constraint.columns) == 1
409
            and uses_default_name(c.constraint)
410
        ]
411
        is_unique = any(
12✔
412
            isinstance(c, UniqueConstraint)
413
            and set(c.columns) == {column}
414
            and uses_default_name(c)
415
            for c in column.table.constraints
416
        )
417
        is_unique = is_unique or any(
12✔
418
            i.unique and set(i.columns) == {column} and uses_default_name(i)
419
            for i in column.table.indexes
420
        )
421
        is_primary = (
12✔
422
            any(
423
                isinstance(c, PrimaryKeyConstraint)
424
                and column.name in c.columns
425
                and uses_default_name(c)
426
                for c in column.table.constraints
427
            )
428
            or column.primary_key
429
        )
430
        has_index = any(
12✔
431
            set(i.columns) == {column} and uses_default_name(i)
432
            for i in column.table.indexes
433
        )
434

435
        if show_name:
12✔
436
            args.append(repr(column.name))
12✔
437

438
        # Render the column type if there are no foreign keys on it or any of them
439
        # points back to itself
440
        if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
12✔
441
            args.append(self.render_column_type(column.type))
12✔
442

443
        for fk in dedicated_fks:
12✔
444
            args.append(self.render_constraint(fk))
12✔
445

446
        if column.default:
12✔
447
            args.append(repr(column.default))
12✔
448

449
        if column.key != column.name:
12✔
UNCOV
450
            kwargs["key"] = column.key
×
451
        if is_primary:
12✔
452
            kwargs["primary_key"] = True
12✔
453
        if (
12✔
454
            not column.nullable
455
            and not is_sole_pk
456
            and (_sqla_version < (2, 0) or is_table)
457
        ):
458
            kwargs["nullable"] = False
6✔
459

460
        if is_unique:
12✔
461
            column.unique = True
12✔
462
            kwargs["unique"] = True
12✔
463
        if has_index:
12✔
464
            column.index = True
12✔
465
            kwarg.append("index")
12✔
466
            kwargs["index"] = True
12✔
467

468
        if isinstance(column.server_default, DefaultClause):
12✔
469
            kwargs["server_default"] = render_callable(
12✔
470
                "text", repr(column.server_default.arg.text)
471
            )
472
        elif isinstance(column.server_default, Computed):
12✔
473
            expression = str(column.server_default.sqltext)
12✔
474

475
            computed_kwargs = {}
12✔
476
            if column.server_default.persisted is not None:
12✔
477
                computed_kwargs["persisted"] = column.server_default.persisted
12✔
478

479
            args.append(
12✔
480
                render_callable("Computed", repr(expression), kwargs=computed_kwargs)
481
            )
482
        elif isinstance(column.server_default, Identity):
12✔
483
            args.append(repr(column.server_default))
12✔
484
        elif column.server_default:
12✔
UNCOV
485
            kwargs["server_default"] = repr(column.server_default)
×
486

487
        comment = getattr(column, "comment", None)
12✔
488
        if comment:
12✔
489
            kwargs["comment"] = repr(comment)
12✔
490

491
        if _sqla_version < (2, 0) or is_table:
12✔
492
            self.add_import(Column)
12✔
493
            return render_callable("Column", *args, kwargs=kwargs)
12✔
494
        else:
495
            return render_callable("mapped_column", *args, kwargs=kwargs)
6✔
496

497
    def render_column_type(self, coltype: object) -> str:
12✔
498
        args = []
12✔
499
        kwargs: dict[str, Any] = {}
12✔
500
        sig = inspect.signature(coltype.__class__.__init__)
12✔
501
        defaults = {param.name: param.default for param in sig.parameters.values()}
12✔
502
        missing = object()
12✔
503
        use_kwargs = False
12✔
504
        for param in list(sig.parameters.values())[1:]:
12✔
505
            # Remove annoyances like _warn_on_bytestring
506
            if param.name.startswith("_"):
12✔
507
                continue
12✔
508
            elif param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
12✔
509
                continue
12✔
510

511
            value = getattr(coltype, param.name, missing)
12✔
512
            default = defaults.get(param.name, missing)
12✔
513
            if value is missing or value == default:
12✔
514
                use_kwargs = True
12✔
515
            elif use_kwargs:
12✔
516
                kwargs[param.name] = repr(value)
12✔
517
            else:
518
                args.append(repr(value))
12✔
519

520
        vararg = next(
12✔
521
            (
522
                param.name
523
                for param in sig.parameters.values()
524
                if param.kind is Parameter.VAR_POSITIONAL
525
            ),
526
            None,
527
        )
528
        if vararg and hasattr(coltype, vararg):
12✔
529
            varargs_repr = [repr(arg) for arg in getattr(coltype, vararg)]
12✔
530
            args.extend(varargs_repr)
12✔
531

532
        if isinstance(coltype, Enum) and coltype.name is not None:
12✔
533
            kwargs["name"] = repr(coltype.name)
12✔
534

535
        if isinstance(coltype, JSONB):
12✔
536
            # Remove astext_type if it's the default
537
            if (
12✔
538
                isinstance(coltype.astext_type, Text)
539
                and coltype.astext_type.length is None
540
            ):
541
                del kwargs["astext_type"]
12✔
542

543
        if isinstance(coltype, CHAR):
12✔
544
            # If the column type is CHAR and it has a non-None collation,
545
            # include the collation argument in the kwargs dictionary
546
            # and remove any corresponding argument from the args list.
NEW
UNCOV
547
            if coltype.collation is not None:
×
NEW
UNCOV
548
                kwargs["collation"] = repr(coltype.collation)
×
NEW
UNCOV
549
                args.pop()
×
550

551
        if args or kwargs:
12✔
552
            return render_callable(coltype.__class__.__name__, *args, kwargs=kwargs)
12✔
553
        else:
554
            return coltype.__class__.__name__
12✔
555

556
    def render_constraint(self, constraint: Constraint | ForeignKey) -> str:
12✔
557
        def add_fk_options(*opts: Any) -> None:
12✔
558
            args.extend(repr(opt) for opt in opts)
12✔
559
            for attr in "ondelete", "onupdate", "deferrable", "initially", "match":
12✔
560
                value = getattr(constraint, attr, None)
12✔
561
                if value:
12✔
562
                    kwargs[attr] = repr(value)
12✔
563

564
        args: list[str] = []
12✔
565
        kwargs: dict[str, Any] = {}
12✔
566
        if isinstance(constraint, ForeignKey):
12✔
567
            remote_column = (
12✔
568
                f"{constraint.column.table.fullname}.{constraint.column.name}"
569
            )
570
            add_fk_options(remote_column)
12✔
571
        elif isinstance(constraint, ForeignKeyConstraint):
12✔
572
            local_columns = get_column_names(constraint)
12✔
573
            remote_columns = [
12✔
574
                f"{fk.column.table.fullname}.{fk.column.name}"
575
                for fk in constraint.elements
576
            ]
577
            add_fk_options(local_columns, remote_columns)
12✔
578
        elif isinstance(constraint, CheckConstraint):
12✔
579
            args.append(repr(get_compiled_expression(constraint.sqltext, self.bind)))
12✔
580
        elif isinstance(constraint, (UniqueConstraint, PrimaryKeyConstraint)):
12✔
581
            args.extend(repr(col.name) for col in constraint.columns)
12✔
582
        else:
UNCOV
583
            raise TypeError(
×
584
                f"Cannot render constraint of type {constraint.__class__.__name__}"
585
            )
586

587
        if isinstance(constraint, Constraint) and not uses_default_name(constraint):
12✔
588
            kwargs["name"] = repr(constraint.name)
12✔
589

590
        return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)
12✔
591

592
    def should_ignore_table(self, table: Table) -> bool:
12✔
593
        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version
594
        # tables
595
        return table.name in ("alembic_version", "migrate_version")
12✔
596

597
    def find_free_name(
12✔
598
        self, name: str, global_names: set[str], local_names: Collection[str] = ()
599
    ) -> str:
600
        """
601
        Generate an attribute name that does not clash with other local or global names.
602
        """
603
        name = name.strip()
12✔
604
        assert name, "Identifier cannot be empty"
12✔
605
        name = _re_invalid_identifier.sub("_", name)
12✔
606
        if name[0].isdigit():
12✔
607
            name = "_" + name
12✔
608
        elif iskeyword(name) or name == "metadata":
12✔
609
            name += "_"
12✔
610

611
        original = name
12✔
612
        for i in count():
12✔
613
            if name not in global_names and name not in local_names:
12✔
614
                break
12✔
615

616
            name = original + (str(i) if i else "_")
12✔
617

618
        return name
12✔
619

620
    def fix_column_types(self, table: Table) -> None:
12✔
621
        """Adjust the reflected column types."""
622
        # Detect check constraints for boolean and enum columns
623
        for constraint in table.constraints.copy():
12✔
624
            if isinstance(constraint, CheckConstraint):
12✔
625
                sqltext = get_compiled_expression(constraint.sqltext, self.bind)
12✔
626

627
                # Turn any integer-like column with a CheckConstraint like
628
                # "column IN (0, 1)" into a Boolean
629
                match = _re_boolean_check_constraint.match(sqltext)
12✔
630
                if match:
12✔
631
                    colname_match = _re_column_name.match(match.group(1))
12✔
632
                    if colname_match:
12✔
633
                        colname = colname_match.group(3)
12✔
634
                        table.constraints.remove(constraint)
12✔
635
                        table.c[colname].type = Boolean()
12✔
636
                        continue
12✔
637

638
                # Turn any string-type column with a CheckConstraint like
639
                # "column IN (...)" into an Enum
640
                match = _re_enum_check_constraint.match(sqltext)
12✔
641
                if match:
12✔
642
                    colname_match = _re_column_name.match(match.group(1))
12✔
643
                    if colname_match:
12✔
644
                        colname = colname_match.group(3)
12✔
645
                        items = match.group(2)
12✔
646
                        if isinstance(table.c[colname].type, String):
12✔
647
                            table.constraints.remove(constraint)
12✔
648
                            if not isinstance(table.c[colname].type, Enum):
12✔
649
                                options = _re_enum_item.findall(items)
12✔
650
                                table.c[colname].type = Enum(
12✔
651
                                    *options, native_enum=False
652
                                )
653

654
                            continue
12✔
655

656
        for column in table.c:
12✔
657
            try:
12✔
658
                column.type = self.get_adapted_type(column.type)
12✔
659
            except CompileError:
12✔
660
                pass
12✔
661

662
            # PostgreSQL specific fix: detect sequences from server_default
663
            if column.server_default and self.bind.dialect.name == "postgresql":
12✔
664
                if isinstance(column.server_default, DefaultClause) and isinstance(
12✔
665
                    column.server_default.arg, TextClause
666
                ):
667
                    schema, seqname = decode_postgresql_sequence(
12✔
668
                        column.server_default.arg
669
                    )
670
                    if seqname:
12✔
671
                        # Add an explicit sequence
672
                        if seqname != f"{column.table.name}_{column.name}_seq":
12✔
673
                            column.default = sqlalchemy.Sequence(seqname, schema=schema)
12✔
674

675
                        column.server_default = None
12✔
676

677
    def get_adapted_type(self, coltype: Any) -> Any:
12✔
678
        compiled_type = coltype.compile(self.bind.engine.dialect)
12✔
679
        for supercls in coltype.__class__.__mro__:
12✔
680
            if not supercls.__name__.startswith("_") and hasattr(
12✔
681
                supercls, "__visit_name__"
682
            ):
683
                # Hack to fix adaptation of the Enum class which is broken since
684
                # SQLAlchemy 1.2
685
                kw = {}
12✔
686
                if supercls is Enum:
12✔
687
                    kw["name"] = coltype.name
12✔
688

689
                try:
12✔
690
                    new_coltype = coltype.adapt(supercls)
12✔
691
                except TypeError:
12✔
692
                    # If the adaptation fails, don't try again
693
                    break
12✔
694

695
                for key, value in kw.items():
12✔
696
                    setattr(new_coltype, key, value)
12✔
697

698
                if isinstance(coltype, ARRAY):
12✔
699
                    new_coltype.item_type = self.get_adapted_type(new_coltype.item_type)
12✔
700

701
                try:
12✔
702
                    # If the adapted column type does not render the same as the
703
                    # original, don't substitute it
704
                    if new_coltype.compile(self.bind.engine.dialect) != compiled_type:
12✔
705
                        # Make an exception to the rule for Float and arrays of Float,
706
                        # since at least on PostgreSQL, Float can accurately represent
707
                        # both REAL and DOUBLE_PRECISION
708
                        if not isinstance(new_coltype, Float) and not (
12✔
709
                            isinstance(new_coltype, ARRAY)
710
                            and isinstance(new_coltype.item_type, Float)
711
                        ):
712
                            break
12✔
713
                except CompileError:
12✔
714
                    # If the adapted column type can't be compiled, don't substitute it
715
                    break
12✔
716

717
                # Stop on the first valid non-uppercase column type class
718
                coltype = new_coltype
12✔
719
                if supercls.__name__ != supercls.__name__.upper():
12✔
720
                    break
12✔
721

722
        return coltype
12✔
723

724

725
class DeclarativeGenerator(TablesGenerator):
12✔
726
    valid_options: ClassVar[set[str]] = TablesGenerator.valid_options | {
12✔
727
        "use_inflect",
728
        "nojoined",
729
        "nobidi",
730
    }
731

732
    def __init__(
12✔
733
        self,
734
        metadata: MetaData,
735
        bind: Connection | Engine,
736
        options: Sequence[str],
737
        *,
738
        indentation: str = "    ",
739
        base_class_name: str = "Base",
740
    ):
741
        super().__init__(metadata, bind, options, indentation=indentation)
12✔
742
        self.base_class_name: str = base_class_name
12✔
743
        self.inflect_engine = inflect.engine()
12✔
744

745
    def generate_base(self) -> None:
12✔
746
        if _sqla_version < (1, 4):
12✔
UNCOV
747
            table_decoration = f"metadata = {self.base_class_name}.metadata"
×
UNCOV
748
            self.base = Base(
×
749
                literal_imports=[
750
                    LiteralImport("sqlalchemy.ext.declarative", "declarative_base")
751
                ],
752
                declarations=[f"{self.base_class_name} = declarative_base()"],
753
                metadata_ref=self.base_class_name,
754
                table_metadata_declaration=table_decoration,
755
            )
756
        elif (1, 4) <= _sqla_version < (2, 0):
12✔
757
            table_decoration = f"metadata = {self.base_class_name}.metadata"
6✔
758
            self.base = Base(
6✔
759
                literal_imports=[LiteralImport("sqlalchemy.orm", "declarative_base")],
760
                declarations=[f"{self.base_class_name} = declarative_base()"],
761
                metadata_ref="metadata",
762
                table_metadata_declaration=table_decoration,
763
            )
764
        else:
765
            self.base = Base(
6✔
766
                literal_imports=[LiteralImport("sqlalchemy.orm", "DeclarativeBase")],
767
                declarations=[
768
                    f"class {self.base_class_name}(DeclarativeBase):",
769
                    f"{self.indentation}pass",
770
                ],
771
                metadata_ref=f"{self.base_class_name}.metadata",
772
            )
773

774
    def collect_imports(self, models: Iterable[Model]) -> None:
12✔
775
        super().collect_imports(models)
12✔
776
        if any(isinstance(model, ModelClass) for model in models):
12✔
777
            if _sqla_version >= (2, 0):
12✔
778
                self.add_literal_import("sqlalchemy.orm", "Mapped")
6✔
779
                self.add_literal_import("sqlalchemy.orm", "mapped_column")
6✔
780

781
    def collect_imports_for_model(self, model: Model) -> None:
12✔
782
        super().collect_imports_for_model(model)
12✔
783
        if isinstance(model, ModelClass):
12✔
784
            if model.relationships:
12✔
785
                self.add_literal_import("sqlalchemy.orm", "relationship")
12✔
786

787
    def generate_models(self) -> list[Model]:
12✔
788
        models_by_table_name: dict[str, Model] = {}
12✔
789

790
        # Pick association tables from the metadata into their own set, don't process
791
        # them normally
792
        links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
12✔
793
        for table in self.metadata.sorted_tables:
12✔
794
            qualified_name = qualified_table_name(table)
12✔
795

796
            # Link tables have exactly two foreign key constraints and all columns are
797
            # involved in them
798
            fk_constraints = sorted(
12✔
799
                table.foreign_key_constraints, key=get_constraint_sort_key
800
            )
801
            if len(fk_constraints) == 2 and all(
12✔
802
                col.foreign_keys for col in table.columns
803
            ):
804
                model = models_by_table_name[qualified_name] = Model(table)
12✔
805
                tablename = fk_constraints[0].elements[0].column.table.name
12✔
806
                links[tablename].append(model)
12✔
807
                continue
12✔
808

809
            # Only form model classes for tables that have a primary key and are not
810
            # association tables
811
            if not table.primary_key:
12✔
812
                models_by_table_name[qualified_name] = Model(table)
12✔
813
            else:
814
                model = ModelClass(table)
12✔
815
                models_by_table_name[qualified_name] = model
12✔
816

817
                # Fill in the columns
818
                for column in table.c:
12✔
819
                    column_attr = ColumnAttribute(model, column)
12✔
820
                    model.columns.append(column_attr)
12✔
821

822
        # Add relationships
823
        for model in models_by_table_name.values():
12✔
824
            if isinstance(model, ModelClass):
12✔
825
                self.generate_relationships(
12✔
826
                    model, models_by_table_name, links[model.table.name]
827
                )
828

829
        # Nest inherited classes in their superclasses to ensure proper ordering
830
        if "nojoined" not in self.options:
12✔
831
            for model in list(models_by_table_name.values()):
12✔
832
                if not isinstance(model, ModelClass):
12✔
833
                    continue
12✔
834

835
                pk_column_names = {col.name for col in model.table.primary_key.columns}
12✔
836
                for constraint in model.table.foreign_key_constraints:
12✔
837
                    if set(get_column_names(constraint)) == pk_column_names:
12✔
838
                        target = models_by_table_name[
12✔
839
                            qualified_table_name(constraint.elements[0].column.table)
840
                        ]
841
                        if isinstance(target, ModelClass):
12✔
842
                            model.parent_class = target
12✔
843
                            target.children.append(model)
12✔
844

845
        # Change base if we only have tables
846
        if not any(
12✔
847
            isinstance(model, ModelClass) for model in models_by_table_name.values()
848
        ):
849
            super().generate_base()
12✔
850

851
        # Collect the imports
852
        self.collect_imports(models_by_table_name.values())
12✔
853

854
        # Rename models and their attributes that conflict with imports or other
855
        # attributes
856
        global_names = {
12✔
857
            name for namespace in self.imports.values() for name in namespace
858
        }
859
        for model in models_by_table_name.values():
12✔
860
            self.generate_model_name(model, global_names)
12✔
861
            global_names.add(model.name)
12✔
862

863
        return list(models_by_table_name.values())
12✔
864

865
    def generate_relationships(
12✔
866
        self,
867
        source: ModelClass,
868
        models_by_table_name: dict[str, Model],
869
        association_tables: list[Model],
870
    ) -> list[RelationshipAttribute]:
871
        relationships: list[RelationshipAttribute] = []
12✔
872
        reverse_relationship: RelationshipAttribute | None
873

874
        # Add many-to-one (and one-to-many) relationships
875
        pk_column_names = {col.name for col in source.table.primary_key.columns}
12✔
876
        for constraint in sorted(
12✔
877
            source.table.foreign_key_constraints, key=get_constraint_sort_key
878
        ):
879
            target = models_by_table_name[
12✔
880
                qualified_table_name(constraint.elements[0].column.table)
881
            ]
882
            if isinstance(target, ModelClass):
12✔
883
                if "nojoined" not in self.options:
12✔
884
                    if set(get_column_names(constraint)) == pk_column_names:
12✔
885
                        parent = models_by_table_name[
12✔
886
                            qualified_table_name(constraint.elements[0].column.table)
887
                        ]
888
                        if isinstance(parent, ModelClass):
12✔
889
                            source.parent_class = parent
12✔
890
                            parent.children.append(source)
12✔
891
                            continue
12✔
892

893
                # Add uselist=False to One-to-One relationships
894
                column_names = get_column_names(constraint)
12✔
895
                if any(
12✔
896
                    isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
897
                    and {col.name for col in c.columns} == set(column_names)
898
                    for c in constraint.table.constraints
899
                ):
900
                    r_type = RelationshipType.ONE_TO_ONE
12✔
901
                else:
902
                    r_type = RelationshipType.MANY_TO_ONE
12✔
903

904
                relationship = RelationshipAttribute(r_type, source, target, constraint)
12✔
905
                source.relationships.append(relationship)
12✔
906

907
                # For self referential relationships, remote_side needs to be set
908
                if source is target:
12✔
909
                    relationship.remote_side = [
12✔
910
                        source.get_column_attribute(col.name)
911
                        for col in constraint.referred_table.primary_key
912
                    ]
913

914
                # If the two tables share more than one foreign key constraint,
915
                # SQLAlchemy needs an explicit primaryjoin to figure out which column(s)
916
                # it needs
917
                common_fk_constraints = get_common_fk_constraints(
12✔
918
                    source.table, target.table
919
                )
920
                if len(common_fk_constraints) > 1:
12✔
921
                    relationship.foreign_keys = [
12✔
922
                        source.get_column_attribute(key)
923
                        for key in constraint.column_keys
924
                    ]
925

926
                # Generate the opposite end of the relationship in the target class
927
                if "nobidi" not in self.options:
12✔
928
                    if r_type is RelationshipType.MANY_TO_ONE:
12✔
929
                        r_type = RelationshipType.ONE_TO_MANY
12✔
930

931
                    reverse_relationship = RelationshipAttribute(
12✔
932
                        r_type,
933
                        target,
934
                        source,
935
                        constraint,
936
                        foreign_keys=relationship.foreign_keys,
937
                        backref=relationship,
938
                    )
939
                    relationship.backref = reverse_relationship
12✔
940
                    target.relationships.append(reverse_relationship)
12✔
941

942
                    # For self referential relationships, remote_side needs to be set
943
                    if source is target:
12✔
944
                        reverse_relationship.remote_side = [
12✔
945
                            source.get_column_attribute(colname)
946
                            for colname in constraint.column_keys
947
                        ]
948

949
        # Add many-to-many relationships
950
        for association_table in association_tables:
12✔
951
            fk_constraints = sorted(
12✔
952
                association_table.table.foreign_key_constraints,
953
                key=get_constraint_sort_key,
954
            )
955
            target = models_by_table_name[
12✔
956
                qualified_table_name(fk_constraints[1].elements[0].column.table)
957
            ]
958
            if isinstance(target, ModelClass):
12✔
959
                relationship = RelationshipAttribute(
12✔
960
                    RelationshipType.MANY_TO_MANY,
961
                    source,
962
                    target,
963
                    fk_constraints[1],
964
                    association_table,
965
                )
966
                source.relationships.append(relationship)
12✔
967

968
                # Generate the opposite end of the relationship in the target class
969
                reverse_relationship = None
12✔
970
                if "nobidi" not in self.options:
12✔
971
                    reverse_relationship = RelationshipAttribute(
12✔
972
                        RelationshipType.MANY_TO_MANY,
973
                        target,
974
                        source,
975
                        fk_constraints[0],
976
                        association_table,
977
                        relationship,
978
                    )
979
                    relationship.backref = reverse_relationship
12✔
980
                    target.relationships.append(reverse_relationship)
12✔
981

982
                # Add a primary/secondary join for self-referential many-to-many
983
                # relationships
984
                if source is target:
12✔
985
                    both_relationships = [relationship]
12✔
986
                    reverse_flags = [False, True]
12✔
987
                    if reverse_relationship:
12✔
988
                        both_relationships.append(reverse_relationship)
12✔
989

990
                    for relationship, reverse in zip(both_relationships, reverse_flags):
12✔
991
                        if (
12✔
992
                            not relationship.association_table
993
                            or not relationship.constraint
994
                        ):
UNCOV
995
                            continue
×
996

997
                        constraints = sorted(
12✔
998
                            relationship.constraint.table.foreign_key_constraints,
999
                            key=get_constraint_sort_key,
1000
                            reverse=reverse,
1001
                        )
1002
                        pri_pairs = zip(
12✔
1003
                            get_column_names(constraints[0]), constraints[0].elements
1004
                        )
1005
                        sec_pairs = zip(
12✔
1006
                            get_column_names(constraints[1]), constraints[1].elements
1007
                        )
1008
                        relationship.primaryjoin = [
12✔
1009
                            (
1010
                                relationship.source,
1011
                                elem.column.name,
1012
                                relationship.association_table,
1013
                                col,
1014
                            )
1015
                            for col, elem in pri_pairs
1016
                        ]
1017
                        relationship.secondaryjoin = [
12✔
1018
                            (
1019
                                relationship.target,
1020
                                elem.column.name,
1021
                                relationship.association_table,
1022
                                col,
1023
                            )
1024
                            for col, elem in sec_pairs
1025
                        ]
1026

1027
        return relationships
12✔
1028

1029
    def generate_model_name(self, model: Model, global_names: set[str]) -> None:
12✔
1030
        if isinstance(model, ModelClass):
12✔
1031
            preferred_name = _re_invalid_identifier.sub("_", model.table.name)
12✔
1032
            preferred_name = "".join(
12✔
1033
                part[:1].upper() + part[1:] for part in preferred_name.split("_")
1034
            )
1035
            if "use_inflect" in self.options:
12✔
1036
                singular_name = self.inflect_engine.singular_noun(preferred_name)
12✔
1037
                if singular_name:
12✔
1038
                    preferred_name = singular_name
12✔
1039

1040
            model.name = self.find_free_name(preferred_name, global_names)
12✔
1041

1042
            # Fill in the names for column attributes
1043
            local_names: set[str] = set()
12✔
1044
            for column_attr in model.columns:
12✔
1045
                self.generate_column_attr_name(column_attr, global_names, local_names)
12✔
1046
                local_names.add(column_attr.name)
12✔
1047

1048
            # Fill in the names for relationship attributes
1049
            for relationship in model.relationships:
12✔
1050
                self.generate_relationship_name(relationship, global_names, local_names)
12✔
1051
                local_names.add(relationship.name)
12✔
1052
        else:
1053
            super().generate_model_name(model, global_names)
12✔
1054

1055
    def generate_column_attr_name(
12✔
1056
        self,
1057
        column_attr: ColumnAttribute,
1058
        global_names: set[str],
1059
        local_names: set[str],
1060
    ) -> None:
1061
        column_attr.name = self.find_free_name(
12✔
1062
            column_attr.column.name, global_names, local_names
1063
        )
1064

1065
    def generate_relationship_name(
12✔
1066
        self,
1067
        relationship: RelationshipAttribute,
1068
        global_names: set[str],
1069
        local_names: set[str],
1070
    ) -> None:
1071
        # Self referential reverse relationships
1072
        preferred_name: str
1073
        if (
12✔
1074
            relationship.type
1075
            in (RelationshipType.ONE_TO_MANY, RelationshipType.ONE_TO_ONE)
1076
            and relationship.source is relationship.target
1077
            and relationship.backref
1078
            and relationship.backref.name
1079
        ):
1080
            preferred_name = relationship.backref.name + "_reverse"
12✔
1081
        else:
1082
            preferred_name = relationship.target.table.name
12✔
1083

1084
            # If there's a constraint with a single column that ends with "_id", use the
1085
            # preceding part as the relationship name
1086
            if relationship.constraint:
12✔
1087
                is_source = relationship.source.table is relationship.constraint.table
12✔
1088
                if is_source or relationship.type not in (
12✔
1089
                    RelationshipType.ONE_TO_ONE,
1090
                    RelationshipType.ONE_TO_MANY,
1091
                ):
1092
                    column_names = [c.name for c in relationship.constraint.columns]
12✔
1093
                    if len(column_names) == 1 and column_names[0].endswith("_id"):
12✔
1094
                        preferred_name = column_names[0][:-3]
12✔
1095

1096
            if "use_inflect" in self.options:
12✔
1097
                if relationship.type in (
12✔
1098
                    RelationshipType.ONE_TO_MANY,
1099
                    RelationshipType.MANY_TO_MANY,
1100
                ):
UNCOV
1101
                    inflected_name = self.inflect_engine.plural_noun(preferred_name)
×
UNCOV
1102
                    if inflected_name:
×
UNCOV
1103
                        preferred_name = inflected_name
×
1104
                else:
1105
                    inflected_name = self.inflect_engine.singular_noun(preferred_name)
12✔
1106
                    if inflected_name:
12✔
1107
                        preferred_name = inflected_name
12✔
1108

1109
        relationship.name = self.find_free_name(
12✔
1110
            preferred_name, global_names, local_names
1111
        )
1112

1113
    def render_models(self, models: list[Model]) -> str:
12✔
1114
        rendered: list[str] = []
12✔
1115
        for model in models:
12✔
1116
            if isinstance(model, ModelClass):
12✔
1117
                rendered.append(self.render_class(model))
12✔
1118
            else:
1119
                rendered.append(f"{model.name} = {self.render_table(model.table)}")
12✔
1120

1121
        return "\n\n\n".join(rendered)
12✔
1122

1123
    def render_class(self, model: ModelClass) -> str:
12✔
1124
        sections: list[str] = []
12✔
1125

1126
        # Render class variables / special declarations
1127
        class_vars: str = self.render_class_variables(model)
12✔
1128
        if class_vars:
12✔
1129
            sections.append(class_vars)
12✔
1130

1131
        # Render column attributes
1132
        rendered_column_attributes: list[str] = []
12✔
1133
        for nullable in (False, True):
12✔
1134
            for column_attr in model.columns:
12✔
1135
                if column_attr.column.nullable is nullable:
12✔
1136
                    rendered_column_attributes.append(
12✔
1137
                        self.render_column_attribute(column_attr)
1138
                    )
1139

1140
        if rendered_column_attributes:
12✔
1141
            sections.append("\n".join(rendered_column_attributes))
12✔
1142

1143
        # Render relationship attributes
1144
        rendered_relationship_attributes: list[str] = [
12✔
1145
            self.render_relationship(relationship)
1146
            for relationship in model.relationships
1147
        ]
1148

1149
        if rendered_relationship_attributes:
12✔
1150
            sections.append("\n".join(rendered_relationship_attributes))
12✔
1151

1152
        declaration = self.render_class_declaration(model)
12✔
1153
        rendered_sections = "\n\n".join(
12✔
1154
            indent(section, self.indentation) for section in sections
1155
        )
1156
        return f"{declaration}\n{rendered_sections}"
12✔
1157

1158
    def render_class_declaration(self, model: ModelClass) -> str:
12✔
1159
        parent_class_name = (
12✔
1160
            model.parent_class.name if model.parent_class else self.base_class_name
1161
        )
1162
        return f"class {model.name}({parent_class_name}):"
12✔
1163

1164
    def render_class_variables(self, model: ModelClass) -> str:
12✔
1165
        variables = [f"__tablename__ = {model.table.name!r}"]
12✔
1166

1167
        # Render constraints and indexes as __table_args__
1168
        table_args = self.render_table_args(model.table)
12✔
1169
        if table_args:
12✔
1170
            variables.append(f"__table_args__ = {table_args}")
12✔
1171

1172
        return "\n".join(variables)
12✔
1173

1174
    def render_table_args(self, table: Table) -> str:
12✔
1175
        args: list[str] = []
12✔
1176
        kwargs: dict[str, str] = {}
12✔
1177

1178
        # Render constraints
1179
        for constraint in sorted(table.constraints, key=get_constraint_sort_key):
12✔
1180
            if uses_default_name(constraint):
12✔
1181
                if isinstance(constraint, PrimaryKeyConstraint):
12✔
1182
                    continue
12✔
1183
                if (
12✔
1184
                    isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint))
1185
                    and len(constraint.columns) == 1
1186
                ):
1187
                    continue
12✔
1188

1189
            args.append(self.render_constraint(constraint))
12✔
1190

1191
        # Render indexes
1192
        for index in sorted(table.indexes, key=lambda i: i.name):
12✔
1193
            if len(index.columns) > 1 or not uses_default_name(index):
12✔
1194
                args.append(self.render_index(index))
12✔
1195

1196
        if table.schema:
12✔
1197
            kwargs["schema"] = table.schema
12✔
1198

1199
        if table.comment:
12✔
1200
            kwargs["comment"] = table.comment
12✔
1201

1202
        if kwargs:
12✔
1203
            formatted_kwargs = pformat(kwargs)
12✔
1204
            if not args:
12✔
1205
                return formatted_kwargs
12✔
1206
            else:
1207
                args.append(formatted_kwargs)
12✔
1208

1209
        if args:
12✔
1210
            rendered_args = f",\n{self.indentation}".join(args)
12✔
1211
            if len(args) == 1:
12✔
1212
                rendered_args += ","
12✔
1213

1214
            return f"(\n{self.indentation}{rendered_args}\n)"
12✔
1215
        else:
1216
            return ""
12✔
1217

1218
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
12✔
1219
        column = column_attr.column
12✔
1220
        rendered_column = self.render_column(column, column_attr.name != column.name)
12✔
1221

1222
        if _sqla_version < (2, 0):
12✔
1223
            return f"{column_attr.name} = {rendered_column}"
6✔
1224
        else:
1225
            try:
6✔
1226
                python_type = column.type.python_type
6✔
1227
                python_type_name = python_type.__name__
6✔
1228
                if python_type.__module__ == "builtins":
6✔
1229
                    column_python_type = python_type_name
6✔
1230
                else:
1231
                    python_type_module = python_type.__module__
6✔
1232
                    column_python_type = f"{python_type_module}.{python_type_name}"
6✔
1233
                    self.add_module_import(python_type_module)
6✔
UNCOV
1234
            except NotImplementedError:
×
UNCOV
1235
                self.add_literal_import("typing", "Any")
×
UNCOV
1236
                column_python_type = "Any"
×
1237

1238
            if column.nullable:
6✔
1239
                self.add_literal_import("typing", "Optional")
6✔
1240
                column_python_type = f"Optional[{column_python_type}]"
6✔
1241
            return (
6✔
1242
                f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
1243
            )
1244

1245
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
12✔
1246
        def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
12✔
1247
            rendered = []
12✔
1248
            for attr in column_attrs:
12✔
1249
                if attr.model is relationship.source:
12✔
1250
                    rendered.append(attr.name)
12✔
1251
                else:
UNCOV
1252
                    rendered.append(repr(f"{attr.model.name}.{attr.name}"))
×
1253

1254
            return "[" + ", ".join(rendered) + "]"
12✔
1255

1256
        def render_foreign_keys(column_attrs: list[ColumnAttribute]) -> str:
12✔
1257
            rendered = []
12✔
1258
            render_as_string = False
12✔
1259
            # Assume that column_attrs are all in relationship.source or none
1260
            for attr in column_attrs:
12✔
1261
                if attr.model is relationship.source:
12✔
1262
                    rendered.append(attr.name)
12✔
1263
                else:
1264
                    rendered.append(f"{attr.model.name}.{attr.name}")
12✔
1265
                    render_as_string = True
12✔
1266

1267
            if render_as_string:
12✔
1268
                return "'[" + ", ".join(rendered) + "]'"
12✔
1269
            else:
1270
                return "[" + ", ".join(rendered) + "]"
12✔
1271

1272
        def render_join(terms: list[JoinType]) -> str:
12✔
1273
            rendered_joins = []
12✔
1274
            for source, source_col, target, target_col in terms:
12✔
1275
                rendered = f"lambda: {source.name}.{source_col} == {target.name}."
12✔
1276
                if target.__class__ is Model:
12✔
1277
                    rendered += "c."
12✔
1278

1279
                rendered += str(target_col)
12✔
1280
                rendered_joins.append(rendered)
12✔
1281

1282
            if len(rendered_joins) > 1:
12✔
UNCOV
1283
                rendered = ", ".join(rendered_joins)
×
UNCOV
1284
                return f"and_({rendered})"
×
1285
            else:
1286
                return rendered_joins[0]
12✔
1287

1288
        # Render keyword arguments
1289
        kwargs: dict[str, Any] = {}
12✔
1290
        if relationship.type is RelationshipType.ONE_TO_ONE and relationship.constraint:
12✔
1291
            if relationship.constraint.referred_table is relationship.source.table:
12✔
1292
                kwargs["uselist"] = False
12✔
1293

1294
        # Add the "secondary" keyword for many-to-many relationships
1295
        if relationship.association_table:
12✔
1296
            table_ref = relationship.association_table.table.name
12✔
1297
            if relationship.association_table.schema:
12✔
1298
                table_ref = f"{relationship.association_table.schema}.{table_ref}"
12✔
1299

1300
            kwargs["secondary"] = repr(table_ref)
12✔
1301

1302
        if relationship.remote_side:
12✔
1303
            kwargs["remote_side"] = render_column_attrs(relationship.remote_side)
12✔
1304

1305
        if relationship.foreign_keys:
12✔
1306
            kwargs["foreign_keys"] = render_foreign_keys(relationship.foreign_keys)
12✔
1307

1308
        if relationship.primaryjoin:
12✔
1309
            kwargs["primaryjoin"] = render_join(relationship.primaryjoin)
12✔
1310

1311
        if relationship.secondaryjoin:
12✔
1312
            kwargs["secondaryjoin"] = render_join(relationship.secondaryjoin)
12✔
1313

1314
        if relationship.backref:
12✔
1315
            kwargs["back_populates"] = repr(relationship.backref.name)
12✔
1316

1317
        rendered_relationship = render_callable(
12✔
1318
            "relationship", repr(relationship.target.name), kwargs=kwargs
1319
        )
1320

1321
        if _sqla_version < (2, 0):
12✔
1322
            return f"{relationship.name} = {rendered_relationship}"
6✔
1323
        else:
1324
            relationship_type: str
1325
            if relationship.type == RelationshipType.ONE_TO_MANY:
6✔
1326
                self.add_literal_import("typing", "List")
6✔
1327
                relationship_type = f"List['{relationship.target.name}']"
6✔
1328
            elif relationship.type in (
6✔
1329
                RelationshipType.ONE_TO_ONE,
1330
                RelationshipType.MANY_TO_ONE,
1331
            ):
1332
                relationship_type = f"'{relationship.target.name}'"
6✔
1333
            elif relationship.type == RelationshipType.MANY_TO_MANY:
6✔
1334
                self.add_literal_import("typing", "List")
6✔
1335
                relationship_type = f"List['{relationship.target.name}']"
6✔
1336
            else:
UNCOV
1337
                self.add_literal_import("typing", "Any")
×
UNCOV
1338
                relationship_type = "Any"
×
1339
            return (
6✔
1340
                f"{relationship.name}: Mapped[{relationship_type}] "
1341
                f"= {rendered_relationship}"
1342
            )
1343

1344

1345
class DataclassGenerator(DeclarativeGenerator):
12✔
1346
    def __init__(
12✔
1347
        self,
1348
        metadata: MetaData,
1349
        bind: Connection | Engine,
1350
        options: Sequence[str],
1351
        *,
1352
        indentation: str = "    ",
1353
        base_class_name: str = "Base",
1354
        quote_annotations: bool = False,
1355
        metadata_key: str = "sa",
1356
    ):
1357
        super().__init__(
12✔
1358
            metadata,
1359
            bind,
1360
            options,
1361
            indentation=indentation,
1362
            base_class_name=base_class_name,
1363
        )
1364
        self.metadata_key: str = metadata_key
12✔
1365
        self.quote_annotations: bool = quote_annotations
12✔
1366

1367
    def generate_base(self) -> None:
12✔
1368
        if _sqla_version < (2, 0):
12✔
1369
            self.base = Base(
6✔
1370
                literal_imports=[LiteralImport("sqlalchemy.orm", "registry")],
1371
                declarations=["mapper_registry = registry()"],
1372
                metadata_ref="metadata",
1373
                decorator="@mapper_registry.mapped",
1374
            )
1375
        else:
1376
            self.base = Base(
6✔
1377
                literal_imports=[
1378
                    LiteralImport("sqlalchemy.orm", "DeclarativeBase"),
1379
                    LiteralImport("sqlalchemy.orm", "MappedAsDataclass"),
1380
                ],
1381
                declarations=[
1382
                    (
1383
                        f"class {self.base_class_name}(MappedAsDataclass, "
1384
                        "DeclarativeBase):"
1385
                    ),
1386
                    f"{self.indentation}pass",
1387
                ],
1388
                metadata_ref=f"{self.base_class_name}.metadata",
1389
            )
1390

1391
    def collect_imports(self, models: Iterable[Model]) -> None:
12✔
1392
        super().collect_imports(models)
12✔
1393
        if _sqla_version < (2, 0):
12✔
1394
            if not self.quote_annotations:
6✔
1395
                self.add_literal_import("__future__", "annotations")
6✔
1396

1397
            if any(isinstance(model, ModelClass) for model in models):
6✔
1398
                self.remove_literal_import("sqlalchemy.orm", "declarative_base")
6✔
1399
                self.add_literal_import("dataclasses", "dataclass")
6✔
1400
                self.add_literal_import("dataclasses", "field")
6✔
1401
                self.add_literal_import("sqlalchemy.orm", "registry")
6✔
1402

1403
    def collect_imports_for_model(self, model: Model) -> None:
12✔
1404
        super().collect_imports_for_model(model)
12✔
1405
        if _sqla_version < (2, 0):
12✔
1406
            if isinstance(model, ModelClass):
6✔
1407
                for column_attr in model.columns:
6✔
1408
                    if column_attr.column.nullable:
6✔
1409
                        self.add_literal_import("typing", "Optional")
6✔
1410
                        break
6✔
1411

1412
                for relationship_attr in model.relationships:
6✔
1413
                    if relationship_attr.type in (
6✔
1414
                        RelationshipType.ONE_TO_MANY,
1415
                        RelationshipType.MANY_TO_MANY,
1416
                    ):
1417
                        self.add_literal_import("typing", "List")
6✔
1418

1419
    def collect_imports_for_column(self, column: Column[Any]) -> None:
12✔
1420
        super().collect_imports_for_column(column)
12✔
1421
        if _sqla_version < (2, 0):
12✔
1422
            try:
6✔
1423
                python_type = column.type.python_type
6✔
UNCOV
1424
            except NotImplementedError:
×
UNCOV
1425
                pass
×
1426
            else:
1427
                self.add_import(python_type)
6✔
1428

1429
    def render_module_variables(self, models: list[Model]) -> str:
12✔
1430
        if _sqla_version >= (2, 0):
12✔
1431
            return super().render_module_variables(models)
6✔
1432
        else:
1433
            if not any(isinstance(model, ModelClass) for model in models):
6✔
UNCOV
1434
                return super().render_module_variables(models)
×
1435

1436
            declarations: list[str] = ["mapper_registry = registry()"]
6✔
1437
            if any(not isinstance(model, ModelClass) for model in models):
6✔
1438
                declarations.append("metadata = mapper_registry.metadata")
6✔
1439

1440
            if not self.quote_annotations:
6✔
1441
                self.add_literal_import("__future__", "annotations")
6✔
1442

1443
            return "\n".join(declarations)
6✔
1444

1445
    def render_class_declaration(self, model: ModelClass) -> str:
12✔
1446
        if _sqla_version >= (2, 0):
12✔
1447
            return super().render_class_declaration(model)
6✔
1448
        else:
1449
            superclass_part = (
6✔
1450
                f"({model.parent_class.name})" if model.parent_class else ""
1451
            )
1452
            return (
6✔
1453
                f"@mapper_registry.mapped\n@dataclass"
1454
                f"\nclass {model.name}{superclass_part}:"
1455
            )
1456

1457
    def render_class_variables(self, model: ModelClass) -> str:
12✔
1458
        if _sqla_version >= (2, 0):
12✔
1459
            return super().render_class_variables(model)
6✔
1460
        else:
1461
            variables = [
6✔
1462
                super().render_class_variables(model),
1463
                f"__sa_dataclass_metadata_key__ = {self.metadata_key!r}",
1464
            ]
1465
            return "\n".join(variables)
6✔
1466

1467
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
12✔
1468
        if _sqla_version >= (2, 0):
12✔
1469
            return super().render_column_attribute(column_attr)
6✔
1470
        else:
1471
            column = column_attr.column
6✔
1472
            try:
6✔
1473
                python_type = column.type.python_type
6✔
UNCOV
1474
            except NotImplementedError:
×
UNCOV
1475
                python_type_name = "Any"
×
1476
            else:
1477
                python_type_name = python_type.__name__
6✔
1478

1479
            kwargs: dict[str, Any] = {}
6✔
1480
            if column.autoincrement and column.name in column.table.primary_key:
6✔
1481
                kwargs["init"] = False
6✔
1482
            elif column.nullable:
6✔
1483
                self.add_literal_import("typing", "Optional")
6✔
1484
                kwargs["default"] = None
6✔
1485
                python_type_name = f"Optional[{python_type_name}]"
6✔
1486

1487
            rendered_column = self.render_column(
6✔
1488
                column, column_attr.name != column.name
1489
            )
1490
            kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered_column}}}"
6✔
1491
            rendered_field = render_callable("field", kwargs=kwargs)
6✔
1492
            return f"{column_attr.name}: {python_type_name} = {rendered_field}"
6✔
1493

1494
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
12✔
1495
        if _sqla_version >= (2, 0):
12✔
1496
            return super().render_relationship(relationship)
6✔
1497
        else:
1498
            rendered = super().render_relationship(relationship).partition(" = ")[2]
6✔
1499
            kwargs: dict[str, Any] = {}
6✔
1500

1501
            annotation = relationship.target.name
6✔
1502
            if self.quote_annotations:
6✔
UNCOV
1503
                annotation = repr(relationship.target.name)
×
1504

1505
            if relationship.type in (
6✔
1506
                RelationshipType.ONE_TO_MANY,
1507
                RelationshipType.MANY_TO_MANY,
1508
            ):
1509
                self.add_literal_import("typing", "List")
6✔
1510
                annotation = f"List[{annotation}]"
6✔
1511
                kwargs["default_factory"] = "list"
6✔
1512
            else:
1513
                self.add_literal_import("typing", "Optional")
6✔
1514
                kwargs["default"] = "None"
6✔
1515
                annotation = f"Optional[{annotation}]"
6✔
1516

1517
            kwargs["metadata"] = f"{{{self.metadata_key!r}: {rendered}}}"
6✔
1518
            rendered_field = render_callable("field", kwargs=kwargs)
6✔
1519
            return f"{relationship.name}: {annotation} = {rendered_field}"
6✔
1520

1521

1522
class SQLModelGenerator(DeclarativeGenerator):
12✔
1523
    def __init__(
12✔
1524
        self,
1525
        metadata: MetaData,
1526
        bind: Connection | Engine,
1527
        options: Sequence[str],
1528
        *,
1529
        indentation: str = "    ",
1530
        base_class_name: str = "SQLModel",
1531
    ):
1532
        super().__init__(
6✔
1533
            metadata,
1534
            bind,
1535
            options,
1536
            indentation=indentation,
1537
            base_class_name=base_class_name,
1538
        )
1539

1540
    def generate_base(self) -> None:
12✔
1541
        self.base = Base(
6✔
1542
            literal_imports=[],
1543
            declarations=[],
1544
            metadata_ref="",
1545
        )
1546

1547
    def collect_imports(self, models: Iterable[Model]) -> None:
12✔
1548
        super(DeclarativeGenerator, self).collect_imports(models)
6✔
1549
        if any(isinstance(model, ModelClass) for model in models):
6✔
1550
            self.remove_literal_import("sqlalchemy", "MetaData")
6✔
1551
            self.add_literal_import("sqlmodel", "SQLModel")
6✔
1552
            self.add_literal_import("sqlmodel", "Field")
6✔
1553

1554
    def collect_imports_for_model(self, model: Model) -> None:
12✔
1555
        super(DeclarativeGenerator, self).collect_imports_for_model(model)
6✔
1556
        if isinstance(model, ModelClass):
6✔
1557
            for column_attr in model.columns:
6✔
1558
                if column_attr.column.nullable:
6✔
1559
                    self.add_literal_import("typing", "Optional")
6✔
1560
                    break
6✔
1561

1562
            if model.relationships:
6✔
1563
                self.add_literal_import("sqlmodel", "Relationship")
6✔
1564

1565
            for relationship_attr in model.relationships:
6✔
1566
                if relationship_attr.type in (
6✔
1567
                    RelationshipType.ONE_TO_MANY,
1568
                    RelationshipType.MANY_TO_MANY,
1569
                ):
1570
                    self.add_literal_import("typing", "List")
6✔
1571

1572
    def collect_imports_for_column(self, column: Column[Any]) -> None:
12✔
1573
        super().collect_imports_for_column(column)
6✔
1574
        try:
6✔
1575
            python_type = column.type.python_type
6✔
UNCOV
1576
        except NotImplementedError:
×
UNCOV
1577
            self.add_literal_import("typing", "Any")
×
1578
        else:
1579
            self.add_import(python_type)
6✔
1580

1581
    def render_module_variables(self, models: list[Model]) -> str:
12✔
1582
        declarations: list[str] = []
6✔
1583
        if any(not isinstance(model, ModelClass) for model in models):
6✔
UNCOV
1584
            if self.base.table_metadata_declaration is not None:
×
UNCOV
1585
                declarations.append(self.base.table_metadata_declaration)
×
1586

1587
        return "\n".join(declarations)
6✔
1588

1589
    def render_class_declaration(self, model: ModelClass) -> str:
12✔
1590
        if model.parent_class:
6✔
UNCOV
1591
            parent = model.parent_class.name
×
1592
        else:
1593
            parent = self.base_class_name
6✔
1594

1595
        superclass_part = f"({parent}, table=True)"
6✔
1596
        return f"class {model.name}{superclass_part}:"
6✔
1597

1598
    def render_class_variables(self, model: ModelClass) -> str:
12✔
1599
        variables = []
6✔
1600

1601
        if model.table.name != model.name.lower():
6✔
1602
            variables.append(f"__tablename__ = {model.table.name!r}")
6✔
1603

1604
        # Render constraints and indexes as __table_args__
1605
        table_args = self.render_table_args(model.table)
6✔
1606
        if table_args:
6✔
1607
            variables.append(f"__table_args__ = {table_args}")
6✔
1608

1609
        return "\n".join(variables)
6✔
1610

1611
    def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
12✔
1612
        column = column_attr.column
6✔
1613
        try:
6✔
1614
            python_type = column.type.python_type
6✔
UNCOV
1615
        except NotImplementedError:
×
UNCOV
1616
            python_type_name = "Any"
×
1617
        else:
1618
            python_type_name = python_type.__name__
6✔
1619

1620
        kwargs: dict[str, Any] = {}
6✔
1621
        if (
6✔
1622
            column.autoincrement and column.name in column.table.primary_key
1623
        ) or column.nullable:
1624
            self.add_literal_import("typing", "Optional")
6✔
1625
            kwargs["default"] = None
6✔
1626
            python_type_name = f"Optional[{python_type_name}]"
6✔
1627

1628
        rendered_column = self.render_column(column, True)
6✔
1629
        kwargs["sa_column"] = f"{rendered_column}"
6✔
1630
        rendered_field = render_callable("Field", kwargs=kwargs)
6✔
1631
        return f"{column_attr.name}: {python_type_name} = {rendered_field}"
6✔
1632

1633
    def render_relationship(self, relationship: RelationshipAttribute) -> str:
12✔
1634
        rendered = super().render_relationship(relationship).partition(" = ")[2]
6✔
1635
        args = self.render_relationship_args(rendered)
6✔
1636
        kwargs: dict[str, Any] = {}
6✔
1637
        annotation = repr(relationship.target.name)
6✔
1638

1639
        if relationship.type in (
6✔
1640
            RelationshipType.ONE_TO_MANY,
1641
            RelationshipType.MANY_TO_MANY,
1642
        ):
1643
            self.add_literal_import("typing", "List")
6✔
1644
            annotation = f"List[{annotation}]"
6✔
1645
        else:
1646
            self.add_literal_import("typing", "Optional")
6✔
1647
            annotation = f"Optional[{annotation}]"
6✔
1648

1649
        rendered_field = render_callable("Relationship", *args, kwargs=kwargs)
6✔
1650
        return f"{relationship.name}: {annotation} = {rendered_field}"
6✔
1651

1652
    def render_relationship_args(self, arguments: str) -> list[str]:
12✔
1653
        argument_list = arguments.split(",")
6✔
1654
        # delete ')' and ' ' from args
1655
        argument_list[-1] = argument_list[-1][:-1]
6✔
1656
        argument_list = [argument[1:] for argument in argument_list]
6✔
1657

1658
        rendered_args: list[str] = []
6✔
1659
        for arg in argument_list:
6✔
1660
            if "back_populates" in arg:
6✔
1661
                rendered_args.append(arg)
6✔
1662
            if "uselist=False" in arg:
6✔
1663
                rendered_args.append("sa_relationship_kwargs={'uselist': False}")
6✔
1664

1665
        return rendered_args
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc