• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

LeanderCS / sqlalchemy-fake-model / #14

18 Sep 2025 09:38PM UTC coverage: 84.108% (-0.3%) from 84.444%
#14

push

coveralls-python

LeanderCS
Lint

13 of 17 new or added lines in 3 files covered. (76.47%)

1 existing line in 1 file now uncovered.

344 of 409 relevant lines covered (84.11%)

0.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.65
/sqlalchemy_fake_model/ModelFaker.py
1
import json
1✔
2
import logging
1✔
3
import random
1✔
4
import traceback
1✔
5
from datetime import date, datetime
1✔
6
from typing import Any, Dict, List, Optional, Union
1✔
7

8
from faker import Faker
1✔
9
from sqlalchemy import Column, ColumnDefault, Table
1✔
10
from sqlalchemy.exc import IntegrityError
1✔
11
from sqlalchemy.orm import ColumnProperty, Session
1✔
12

13
from .Enum import ModelColumnTypesEnum
1✔
14
from .Error import InvalidAmountError, UniquenessError
1✔
15
from .Model import ModelFakerConfig
1✔
16
from .SmartFieldDetector import SmartFieldDetector
1✔
17

18

19
class ModelFaker:
1✔
20
    """
21
    The ModelFaker class is a utility class that helps in generating fake data
22
    for a given SQLAlchemy model. It uses the faker library to generate fake
23
    data based on the column types of the model. It also handles relationships
24
    between models and can generate data for different relationships.
25
    """
26

27
    def __init__(
1✔
28
        self,
29
        model: Union[Table, ColumnProperty],
30
        db: Optional[Session] = None,
31
        faker: Optional[Faker] = None,
32
        config: Optional[ModelFakerConfig] = None,
33
    ) -> None:
34
        """
35
        Initializes the ModelFaker class with the given model,
36
        database session, faker instance, and configuration.
37

38
        :param model: The SQLAlchemy model for which fake data
39
            needs to be generated.
40
        :param db: Optional SQLAlchemy session to be used for
41
            creating fake data.
42
        :param faker: Optional Faker instance to be used for
43
            generating fake data.
44
        :param config: Optional ModelFakerConfig instance to be
45
            used for configuring the ModelFaker.
46
        """
47
        self.model = model
1✔
48
        self.db = db or self._get_framework_session()
1✔
49
        self.config = config or ModelFakerConfig()
1✔
50
        self.faker = (
1✔
51
            faker or self.config.faker_instance or Faker(self.config.locale)
52
        )
53
        self.logger = logging.getLogger(__name__)
1✔
54
        self._unique_values = {}
1✔
55
        self.smart_detector = (
1✔
56
            SmartFieldDetector(self.faker)
57
            if self.config.smart_detection
58
            else None
59
        )
60

61
        if self.config.seed is not None:
1✔
62
            self.faker.seed_instance(self.config.seed)
1✔
63

64
    def __enter__(self):
1✔
65
        """Context manager entry."""
66
        return self
1✔
67

68
    def __exit__(self, exc_type, exc_val, exc_tb):
1✔
69
        """Context manager exit with automatic cleanup."""
70
        if exc_type is not None:
1✔
71
            self.logger.error(f"Exception in ModelFaker context: {exc_val}")
1✔
72
            if hasattr(self.db, "rollback"):
1✔
73
                try:
1✔
74
                    self.db.rollback()
1✔
75
                    self.logger.info("Database transaction rolled back")
1✔
76
                except Exception as rollback_error:
×
77
                    self.logger.error(f"Failed to rollback: {rollback_error}")
×
78
        return False
1✔
79

80
    @staticmethod
1✔
81
    def _get_framework_session() -> Optional[Session]:
1✔
82
        """
83
        Tries to get the SQLAlchemy session from available frameworks.
84

85
        :return: The SQLAlchemy session if available.
86
        :raises RuntimeError: If no supported framework
87
            is installed or configured
88
        """
89
        try:
1✔
90
            from flask import current_app
1✔
91

92
            if "sqlalchemy" in current_app.extensions:
1✔
93
                db_ext = current_app.extensions["sqlalchemy"]
1✔
94

95
                # In Flask-SQLAlchemy >= 2.0, the db object is the extension
96
                # itself
97
                if hasattr(db_ext, "session"):
1✔
98
                    return db_ext.session
1✔
99

100
                # Some versions might have a different structure
NEW
101
                if hasattr(db_ext, "db") and hasattr(db_ext.db, "session"):
×
NEW
102
                    return db_ext.db.session
×
103

NEW
104
        except (ImportError, KeyError, AttributeError):
×
UNCOV
105
            pass
×
106

107
        try:
×
108
            from tornado.web import Application
×
109

110
            return Application().settings["db"]
×
111
        except (ImportError, KeyError):
×
112
            pass
×
113

114
        try:
×
115
            from django.conf import settings
×
116
            from sqlalchemy import create_engine
×
117
            from sqlalchemy.orm import sessionmaker
×
118

119
            engine = create_engine(settings.DATABASES["default"]["ENGINE"])
×
120
            return sessionmaker(bind=engine)()
×
121
        except (ImportError, KeyError, AttributeError):
×
122
            pass
×
123

124
        raise RuntimeError(
×
125
            "No SQLAlchemy session provided and no supported framework "
126
            "installed or configured."
127
        )
128

129
    def create(self, amount: Optional[int] = 1) -> None:
1✔
130
        """
131
        Creates the specified amount of fake data entries for the model.
132
        It handles exceptions and rolls back the session
133
        in case of any errors.
134

135
        :param amount: The number of fake data entries to create.
136
        :raises InvalidAmountError: If the amount is not an integer or
137
            negative.
138
        """
139
        if not isinstance(amount, int) or amount < 0:
1✔
140
            raise InvalidAmountError(amount)
1✔
141

142
        if amount <= self.config.bulk_size:
1✔
143
            self._create_single_batch(amount)
1✔
144
        else:
145
            self._create_bulk(amount)
1✔
146

147
    def _create_single_batch(self, amount: int) -> None:
1✔
148
        """Creates a single batch of records."""
149
        try:
1✔
150
            batch_data = []
1✔
151

152
            for _ in range(amount):
1✔
153
                data = {}
1✔
154
                for column in self.__get_table_columns():
1✔
155
                    if self.__should_skip_field(column):
1✔
156
                        continue
1✔
157
                    data[column.name] = (
1✔
158
                        self._generate_fake_data_with_overrides(column)
159
                    )
160
                batch_data.append(data)
1✔
161

162
            if self.__is_many_to_many_relation_table():
1✔
163
                self.db.execute(self.model.insert().values(batch_data))
1✔
164
            else:
165
                for data in batch_data:
1✔
166
                    self.db.add(self.model(**data))
1✔
167

168
            self.db.commit()
1✔
169
            self.logger.info(f"Successfully created {amount} records")
1✔
170

171
        except IntegrityError as e:
1✔
172
            self.db.rollback()
×
173
            self.logger.error(f"Integrity error in batch creation: {e}")
×
174
            if "unique" in str(e).lower() or "duplicate" in str(e).lower():
×
175
                raise UniquenessError("unknown_field", self.config.max_retries)
×
176
            raise
×
177
        except Exception as e:
1✔
178
            self.db.rollback()
1✔
179
            self.logger.error(f"Failed to create batch: {e}")
1✔
180
            raise RuntimeError(
1✔
181
                f"Failed to commit: {e} {traceback.format_exc()}"
182
            )
183

184
    def _create_bulk(self, amount: int) -> None:
1✔
185
        """Creates records in multiple batches for better performance."""
186
        remaining = amount
1✔
187
        created = 0
1✔
188

189
        while remaining > 0:
1✔
190
            batch_size = min(remaining, self.config.bulk_size)
1✔
191
            try:
1✔
192
                self._create_single_batch(batch_size)
1✔
193
                created += batch_size
1✔
194
                remaining -= batch_size
1✔
195
                self.logger.info(f"Created {created}/{amount} records")
1✔
196
            except Exception as e:
×
197
                self.logger.error(
×
198
                    f"Failed to create bulk batch at {created}/{amount}: {e}"
199
                )
200
                raise
×
201

202
    def _generate_fake_data(
1✔
203
        self, column: Column
204
    ) -> Optional[Union[str, int, bool, date, datetime, None]]:
205
        """
206
        Generates fake data for a given column based on its type.
207
        It handles Enum, String, Integer, Boolean, DateTime, and Date column
208
        types.
209

210
        :param column: The SQLAlchemy column for which fake data
211
            needs to be generated.
212
        :return: The fake data generated for the column.
213
        """
214
        column_type = column.type
1✔
215

216
        if column.doc:
1✔
217
            return str(self._generate_json_data(column.doc))
1✔
218

219
        # Enum has to be the first type to check, or otherwise it
220
        # uses the options of the corresponding type of the enum options
221
        if isinstance(column_type, ModelColumnTypesEnum.ENUM.value):
1✔
222
            return random.choice(column_type.enums)
×
223

224
        if column.foreign_keys:
1✔
225
            related_attribute = next(iter(column.foreign_keys)).column.name
1✔
226
            return getattr(
1✔
227
                self.__handle_relationship(column), related_attribute
228
            )
229

230
        if column.primary_key:
1✔
231
            return self._generate_primitive(column_type)
×
232

233
        if isinstance(column_type, ModelColumnTypesEnum.STRING.value):
1✔
234
            max_length = (
1✔
235
                column_type.length
236
                if hasattr(column_type, "length")
237
                and column_type.length is not None
238
                else 255
239
            )
240
            return self.faker.text(max_nb_chars=max_length)
1✔
241

242
        if isinstance(column_type, ModelColumnTypesEnum.INTEGER.value):
1✔
243
            info = column.info
1✔
244
            if not info:
1✔
245
                return self.faker.random_int()
1✔
246

247
            min_value = column.info.get("min", 1)
1✔
248
            max_value = column.info.get("max", 100)
1✔
249
            return self.faker.random_int(min=min_value, max=max_value)
1✔
250

251
        if isinstance(column_type, ModelColumnTypesEnum.FLOAT.value):
1✔
252
            precision = column_type.precision
1✔
253
            if not precision:
1✔
254
                return self.faker.pyfloat()
1✔
255

256
            max_value = 10 ** (precision[0] - precision[1]) - 1
1✔
257
            return round(
1✔
258
                self.faker.pyfloat(min_value=0, max_value=max_value),
259
                precision[1],
260
            )
261

262
        if isinstance(column_type, ModelColumnTypesEnum.BOOLEAN.value):
1✔
263
            return self.faker.boolean()
1✔
264

265
        if isinstance(column_type, ModelColumnTypesEnum.DATE.value):
1✔
266
            return self.faker.date_object()
1✔
267

268
        if isinstance(column_type, ModelColumnTypesEnum.DATETIME.value):
1✔
269
            return self.faker.date_time()
1✔
270

271
        if isinstance(column_type, ModelColumnTypesEnum.TIME.value):
1✔
272
            return self.faker.time_object()
1✔
273

274
        if isinstance(column_type, ModelColumnTypesEnum.UUID.value):
1✔
275
            return self.faker.uuid4()
×
276

277
        if isinstance(column_type, ModelColumnTypesEnum.DECIMAL.value):
1✔
278
            precision = getattr(column_type, "precision", None)
1✔
279
            scale = getattr(column_type, "scale", None)
1✔
280
            if precision and scale:
1✔
281
                max_digits = precision - scale
1✔
282
                max_value = 10**max_digits - 1
1✔
283
                return round(
1✔
284
                    self.faker.pyfloat(min_value=0, max_value=max_value), scale
285
                )
286
            return self.faker.pydecimal(
×
287
                left_digits=10, right_digits=2, positive=True
288
            )
289

290
        if isinstance(column_type, ModelColumnTypesEnum.INTERVAL.value):
1✔
291
            days = self.faker.random_int(min=1, max=365)
×
292
            return f"{days} days"
×
293

294
        if isinstance(column_type, ModelColumnTypesEnum.LARGEBINARY.value):
1✔
295
            return self.faker.binary(length=256)
1✔
296

297
        if isinstance(
×
298
            column_type,
299
            (
300
                ModelColumnTypesEnum.JSON.value,
301
                ModelColumnTypesEnum.JSONB.value,
302
            ),
303
        ):
304
            json_structure = {
×
305
                "id": "integer",
306
                "name": "string",
307
                "active": "boolean",
308
            }
309
            return self._populate_json_structure(json_structure)
×
310

311
        return None
×
312

313
    def __handle_relationship(self, column: Column) -> Optional[Table]:
1✔
314
        """
315
        Handles the relationship of a column with another model.
316
        It creates a fake data entry for the parent model and returns its id.
317
        """
318
        parent_model = self.__get_related_class(column)
1✔
319

320
        ModelFaker(parent_model, self.db).create()
1✔
321

322
        return self.db.query(parent_model).first()
1✔
323

324
    def __is_many_to_many_relation_table(self) -> bool:
1✔
325
        """
326
        Checks if the model is a many-to-many relationship table.
327
        """
328
        return not hasattr(self.model, "__table__") and not hasattr(
1✔
329
            self.model, "__mapper__"
330
        )
331

332
    def __should_skip_field(self, column: Column) -> bool:
1✔
333
        """
334
        Checks if a column is a primary key or has a default value.
335
        """
336
        return (
1✔
337
            (column.primary_key and self.__is_field_auto_increment(column))
338
            or self.__has_field_default_value(column)
339
            or self.__is_field_nullable(column)
340
        )
341

342
    @staticmethod
1✔
343
    def __is_field_auto_increment(column: Column) -> bool:
1✔
344
        """
345
        Checks if a column is autoincrement.
346
        """
347
        return column.autoincrement and isinstance(
1✔
348
            column.type, ModelColumnTypesEnum.INTEGER.value
349
        )
350

351
    def __has_field_default_value(self, column: Column) -> bool:
1✔
352
        """
353
        Checks if a column has a default value.
354
        """
355
        return (
1✔
356
            isinstance(column.default, ColumnDefault)
357
            and column.default.arg is not None
358
            and not self.config.fill_default_fields
359
        )
360

361
    def __is_field_nullable(self, column: Column) -> bool:
1✔
362
        """
363
        Checks if a column is nullable.
364
        """
365
        return (
1✔
366
            column.nullable is not None
367
            and column.nullable is True
368
            and not self.config.fill_nullable_fields
369
        )
370

371
    def __get_table_columns(self) -> List[Column]:
1✔
372
        """
373
        Returns the columns of the model's table.
374
        """
375
        return (
1✔
376
            self.model.columns
377
            if self.__is_many_to_many_relation_table()
378
            else self.model.__table__.columns
379
        )
380

381
    def __get_related_class(self, column: Column) -> Table:
1✔
382
        """
383
        Returns the related class of a column if it has
384
        a relationship with another model.
385
        """
386
        if (
1✔
387
            not self.__is_many_to_many_relation_table()
388
            and column.name in self.model.__mapper__.relationships
389
        ):
390
            return self.model.__mapper__.relationships[
×
391
                column.key
392
            ].mapper.class_
393

394
        fk = next(iter(column.foreign_keys))
1✔
395

396
        return fk.column.table
1✔
397

398
    def _generate_json_data(self, docstring: str) -> Dict[str, Any]:
1✔
399
        """
400
        Generates JSON data based on the provided docstring.
401
        """
402
        json_structure = json.loads(docstring)
1✔
403

404
        return self._populate_json_structure(json_structure)
1✔
405

406
    def _populate_json_structure(
1✔
407
        self, structure: Union[Dict[str, Any], List[Any]]
408
    ) -> Any:
409
        """
410
        Populates the JSON structure with fake data based on the defined
411
        schema.
412
        """
413
        if isinstance(structure, dict):
1✔
414
            return {
1✔
415
                key: self._populate_json_structure(value)
416
                if isinstance(value, (dict, list))
417
                else self._generate_primitive(value)
418
                for key, value in structure.items()
419
            }
420

421
        if isinstance(structure, list):
1✔
422
            return [
1✔
423
                self._populate_json_structure(item)
424
                if isinstance(item, (dict, list))
425
                else self._generate_primitive(item)
426
                for item in structure
427
            ]
428

429
        return structure
×
430

431
    def _generate_fake_data_with_overrides(self, column: Column) -> Any:
1✔
432
        """
433
        Generates fake data with custom overrides and optional smart detection.
434
        """
435
        if column.name in self.config.field_overrides:
1✔
436
            return self.config.field_overrides[column.name]()
1✔
437

438
        if self.smart_detector:
1✔
439
            smart_value = self.smart_detector.detect_and_generate(column)
1✔
440
            if smart_value is not None:
1✔
441
                return smart_value
1✔
442

443
        return self._generate_fake_data(column)
1✔
444

445
    def _generate_primitive(self, primitive_type: str) -> Any:
1✔
446
        """
447
        Generates fake data for primitive types.
448
        """
449
        if primitive_type == "boolean":
1✔
450
            return self.faker.boolean()
×
451
        if primitive_type == "datetime":
1✔
452
            return self.faker.date_time().isoformat()
1✔
453
        if primitive_type == "date":
1✔
454
            return self.faker.date()
1✔
455
        if primitive_type == "integer":
1✔
456
            return self.faker.random_int()
1✔
457
        if primitive_type == "string":
1✔
458
            return self.faker.word()
1✔
459
        if primitive_type == "float":
1✔
460
            return self.faker.pyfloat()
1✔
461
        return self.faker.word()
×
462

463
    def create_batch(self, amount: int, commit: bool = False) -> List[Any]:
1✔
464
        """
465
        Creates a batch of model instances without committing to database.
466

467
        :param amount: Number of instances to create
468
        :param commit: Whether to commit the batch to database
469
        :return: List of created model instances
470
        """
471
        if not isinstance(amount, int):
1✔
472
            raise InvalidAmountError(amount)
×
473

474
        instances = []
1✔
475
        try:
1✔
476
            for _ in range(amount):
1✔
477
                data = {}
1✔
478
                for column in self.__get_table_columns():
1✔
479
                    if self.__should_skip_field(column):
1✔
480
                        continue
1✔
481
                    data[column.name] = (
1✔
482
                        self._generate_fake_data_with_overrides(column)
483
                    )
484

485
                if not self.__is_many_to_many_relation_table():
1✔
486
                    instance = self.model(**data)
1✔
487
                    instances.append(instance)
1✔
488
                    if commit:
1✔
489
                        self.db.add(instance)
1✔
490

491
            if commit and instances:
1✔
492
                self.db.commit()
1✔
493
                self.logger.info(
1✔
494
                    f"Committed batch of {len(instances)} instances"
495
                )
496

497
            return instances
1✔
498

499
        except Exception as e:
×
500
            if commit:
×
501
                self.db.rollback()
×
502
            self.logger.error(f"Failed to create batch: {e}")
×
503
            raise
×
504

505
    def create_with(
1✔
506
        self, overrides: Dict[str, Any], amount: int = 1
507
    ) -> List[Any]:
508
        """
509
        Creates model instances with specific field overrides.
510

511
        :param overrides: Dictionary of field values to override
512
        :param amount: Number of instances to create
513
        :return: List of created model instances
514
        """
515
        if not isinstance(amount, int):
1✔
516
            raise InvalidAmountError(amount)
×
517

518
        instances = []
1✔
519
        try:
1✔
520
            for _ in range(amount):
1✔
521
                data = {}
1✔
522
                for column in self.__get_table_columns():
1✔
523
                    if self.__should_skip_field(column):
1✔
524
                        continue
1✔
525

526
                    if column.name in overrides:
1✔
527
                        data[column.name] = overrides[column.name]
1✔
528
                    else:
529
                        data[column.name] = (
1✔
530
                            self._generate_fake_data_with_overrides(column)
531
                        )
532

533
                if self.__is_many_to_many_relation_table():
1✔
534
                    self.db.execute(self.model.insert().values(**data))
×
535
                else:
536
                    instance = self.model(**data)
1✔
537
                    instances.append(instance)
1✔
538
                    self.db.add(instance)
1✔
539

540
            self.db.commit()
1✔
541
            self.logger.info(
1✔
542
                f"Created {len(instances)} instances with overrides"
543
            )
544
            return instances
1✔
545

546
        except Exception as e:
×
547
            self.db.rollback()
×
548
            self.logger.error(f"Failed to create with overrides: {e}")
×
549
            raise
×
550

551
    def reset(self, confirm: bool = False) -> int:
1✔
552
        """
553
        Removes all records from the model's table.
554

555
        :param confirm: Must be True to actually perform the deletion
556
        :return: Number of deleted records
557
        """
558
        if not confirm:
1✔
559
            raise ValueError("Must set confirm=True to delete all records")
1✔
560

561
        try:
1✔
562
            if self.__is_many_to_many_relation_table():
1✔
563
                result = self.db.execute(self.model.delete())
×
564
                deleted_count = result.rowcount
×
565
            else:
566
                deleted_count = self.db.query(self.model).count()
1✔
567
                self.db.query(self.model).delete()
1✔
568

569
            self.db.commit()
1✔
570
            self.logger.info(
1✔
571
                f"Deleted {deleted_count} records from {self.model}"
572
            )
573
            return deleted_count
1✔
574

575
        except Exception as e:
×
576
            self.db.rollback()
×
577
            self.logger.error(f"Failed to reset table: {e}")
×
578
            raise
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc