• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

smileservices / data_persistence_repository / 10417671105

16 Aug 2024 09:19AM UTC coverage: 94.335% (+0.7%) from 93.643%
10417671105

push

github

vladimirgorealionstep
fix requirements.txt path in .githubactions

383 of 406 relevant lines covered (94.33%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.69
/data_persistence_repository/sql_repository_async.py
1
from typing import List, Iterable, Optional
1✔
2
import contextlib
1✔
3
import asyncio
1✔
4

5
from sqlalchemy.ext.asyncio import (
1✔
6
    create_async_engine,
7
    AsyncSession,
8
    AsyncEngine,
9
    async_sessionmaker,
10
)
11
from sqlalchemy import MetaData, select, delete, update
1✔
12
from sqlalchemy.orm import registry
1✔
13
from sqlalchemy.exc import NoResultFound
1✔
14

15
from .repository_interface import Repository
1✔
16

17
import logging
1✔
18

19
logger = logging.getLogger("sql_repository")
1✔
20

21

22
class AsyncSqlRepository(Repository):
1✔
23
    metadata_obj = MetaData()
1✔
24
    registry = registry()
1✔
25

26
    def __init__(self, url: Optional[str] = None, engine: Optional[AsyncEngine] = None):
1✔
27
        """
28
        Asynchronous SQL repository.
29
        :param url: SQL URL for the database connection.
30
        :param engine: Async engine, if already created.
31
        """
32
        if engine is None and url is None:
1✔
33
            raise ValueError("Either url or engine must be provided")
×
34

35
        self._engine = engine if engine else create_async_engine(url, echo=True, pool_recycle=3600, pool_pre_ping=True)
1✔
36
        self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False, class_=AsyncSession)
1✔
37

38
    async def get_session(self):
1✔
39
        return self._session_factory()
×
40

41
    @contextlib.asynccontextmanager
1✔
42
    async def start_session(self, rollback=True):
1✔
43
        """
44
            use this context manager mainly
45
            retries in case of bad connection
46
        """
47
        async with self._session_factory() as session:
1✔
48
            try:
1✔
49
                async with session.begin():
1✔
50
                    yield session
1✔
51
            except Exception as e:
1✔
52
                if rollback:
1✔
53
                    await session.rollback()
1✔
54
                logger.error(f"Exception during session: {str(e)}")
1✔
55
                raise
1✔
56

57
    async def sync_schema(self):
1✔
58
        """Asynchronously create tables and run migrations."""
59
        async with self._engine.begin() as conn:
×
60
            await conn.run_sync(self.metadata_obj.create_all)
×
61

62
    async def add(self, session: AsyncSession, instance: object):
1✔
63
        """Asynchronously save an object."""
64
        session.add(instance)
1✔
65

66
    async def add_bulk(self, session: AsyncSession, objects: List[object]):
1✔
67
        """Asynchronously save a list of objects."""
68
        session.add_all(objects)
1✔
69

70
    async def get(self, session: AsyncSession, model, **kwargs):
1✔
71
        """Asynchronously get an object."""
72
        try:
1✔
73
            result = await session.execute(select(model).filter_by(**kwargs))
1✔
74
            return result.scalar_one()
1✔
75
        except NoResultFound:
×
76
            return None
×
77

78
    async def exists(self, session: AsyncSession, model, **kwargs) -> bool:
1✔
79
        """Asynchronously check if an object exists."""
80
        result = await session.execute(select(model).filter_by(**kwargs))
1✔
81
        return result.scalar_one_or_none() is not None
1✔
82

83
    async def delete(self, session: AsyncSession, model, **kwargs):
1✔
84
        """Asynchronously delete an object."""
85
        await session.execute(delete(model).filter_by(**kwargs))
1✔
86

87
    async def filter(self, session: AsyncSession, model, *args, **kwargs) -> Iterable:
1✔
88
        """Asynchronously get a list of objects after applying some filtering."""
89
        if args and kwargs:
1✔
90
            raise ValueError('Cannot use filter method with both args and kwargs')
×
91

92
        if args:
1✔
93
            result = await session.execute(select(model).filter(*args))
1✔
94
        elif kwargs:
×
95
            result = await session.execute(select(model).filter_by(**kwargs))
×
96
        else:
97
            result = await session.execute(select(model))
×
98

99
        # because we're usually using lazy="joined"
100
        return result.unique().scalars().all()
1✔
101

102
    async def filter_by_list(self, session: AsyncSession, model, field: str, items_list: List) -> Iterable:
1✔
103
        """Asynchronously filter objects by a list of values in a field."""
104
        query_field = getattr(model, field)
×
105
        result = await session.execute(select(model).filter(query_field.in_(items_list)))
×
106
        return result.scalars().all()
×
107

108
    async def patch(self, session: AsyncSession, model, update_data: dict, **kwargs):
1✔
109
        """Asynchronously update specific fields of an object."""
110
        await session.execute(update(model).filter_by(**kwargs).values(update_data))
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc