• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pycasbin / postgresql-watcher / 9854697895

09 Jul 2024 09:40AM UTC coverage: 77.619%. First build
9854697895

Pull #29

github

web-flow
Merge 09b2e8dc3 into 4b808d0b7
Pull Request #29: feat: fixed `should_reload` behaviour, close PostgreSQL connections, block until `PostgresqlWatcher` is ready, refactorings

127 of 169 new or added lines in 3 files covered. (75.15%)

163 of 210 relevant lines covered (77.62%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.21
/postgresql_watcher/watcher.py
1
from logging import Logger, getLogger
1✔
2
from multiprocessing import Process, Pipe
1✔
3
from multiprocessing.connection import Connection
1✔
4
from time import sleep, time
1✔
5
from typing import Optional, Callable
1✔
6

7
from psycopg2 import connect, extensions
1✔
8

9
from .casbin_channel_subscription import (
1✔
10
    casbin_channel_subscription,
11
    _ChannelSubscriptionMessage,
12
)
13

14

15
POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher"
1✔
16

17

18
class PostgresqlWatcher(object):
1✔
19

20
    def __init__(
1✔
21
        self,
22
        host: str,
23
        user: str,
24
        password: str,
25
        port: int = 5432,
26
        dbname: str = "postgres",
27
        channel_name: Optional[str] = None,
28
        start_listening: bool = True,
29
        sslmode: Optional[str] = None,
30
        sslrootcert: Optional[str] = None,
31
        sslcert: Optional[str] = None,
32
        sslkey: Optional[str] = None,
33
        logger: Optional[Logger] = None,
34
    ) -> None:
35
        """
36
        Initialize a PostgresqlWatcher object.
37

38
        Args:
39
            host (str): Hostname of the PostgreSQL server.
40
            user (str): PostgreSQL username.
41
            password (str): Password for the user.
42
            port (int): Post of the PostgreSQL server. Defaults to 5432.
43
            dbname (str): Database name. Defaults to "postgres".
44
            channel_name (str): The name of the channel to listen to and to send updates to. When None a default is used.
45
            start_listening (bool, optional): Flag whether to start listening to updates on the PostgreSQL channel. Defaults to True.
46
            sslmode (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
47
            sslrootcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
48
            sslcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
49
            sslkey (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
50
            logger (Optional[Logger], optional): Custom logger to use. Defaults to None.
51
        """
52
        self.update_callback = None
1✔
53
        self.parent_conn = None
1✔
54
        self.host = host
1✔
55
        self.port = port
1✔
56
        self.user = user
1✔
57
        self.password = password
1✔
58
        self.dbname = dbname
1✔
59
        self.channel_name = (
1✔
60
            channel_name if channel_name is not None else POSTGRESQL_CHANNEL_NAME
61
        )
62
        self.sslmode = sslmode
1✔
63
        self.sslrootcert = sslrootcert
1✔
64
        self.sslcert = sslcert
1✔
65
        self.sslkey = sslkey
1✔
66
        if logger is None:
1✔
67
            logger = getLogger()
×
68
        self.logger = logger
1✔
69
        self.parent_conn: Connection | None = None
1✔
70
        self.child_conn: Connection | None = None
1✔
71
        self.subscription_process: Process | None = None
1✔
72
        self._create_subscription_process(start_listening)
1✔
73
        self.update_callback: Optional[Callable[[None], None]] = None
1✔
74

75
    def __del__(self) -> None:
1✔
76
        self._cleanup_connections_and_processes()
1✔
77

78
    def _create_subscription_process(
1✔
79
        self,
80
        start_listening=True,
81
        delay: Optional[int] = 2,
82
    ) -> None:
83
        self._cleanup_connections_and_processes()
1✔
84

85
        self.parent_conn, self.child_conn = Pipe()
1✔
86
        self.subscription_proces = Process(
1✔
87
            target=casbin_channel_subscription,
88
            args=(
89
                self.child_conn,
90
                self.logger,
91
                self.host,
92
                self.user,
93
                self.password,
94
                self.channel_name,
95
                self.port,
96
                self.dbname,
97
                delay,
98
                self.sslmode,
99
                self.sslrootcert,
100
                self.sslcert,
101
                self.sslkey,
102
            ),
103
            daemon=True,
104
        )
105
        if start_listening:
1✔
106
            self.start()
1✔
107

108
    def start(self):
1✔
109
        if not self.subscription_proces.is_alive():
1✔
110
            # Start listening to messages
111
            self.subscription_proces.start()
1✔
112
            # And wait for the Process to be ready to listen for updates
113
            # from PostgreSQL
114
            while True:
1✔
115
                if self.parent_conn.poll():
1✔
116
                    message = int(self.parent_conn.recv())
1✔
117
                    if message == _ChannelSubscriptionMessage.IS_READY:
1✔
118
                        break
1✔
119
                sleep(1 / 1000)  # wait for 1 ms
1✔
120

121
    def _cleanup_connections_and_processes(self) -> None:
1✔
122
        # Clean up potentially existing Connections and Processes
123
        if self.parent_conn is not None:
1✔
124
            self.parent_conn.close()
1✔
125
            self.parent_conn = None
1✔
126
        if self.child_conn is not None:
1✔
127
            self.child_conn.close()
1✔
128
            self.child_conn = None
1✔
129
        if self.subscription_process is not None:
1✔
NEW
130
            self.subscription_process.terminate()
×
NEW
131
            self.subscription_process = None
×
132

133
    def set_update_callback(self, update_handler: Optional[Callable[[None], None]]):
1✔
134
        """
135
        Set the handler called, when the Watcher detects an update.
136
        Recommendation: `casbin_enforcer.adapter.load_policy`
137
        """
138
        self.update_callback = update_handler
1✔
139

140
    def update(self) -> None:
1✔
141
        """
142
        Called by `casbin.Enforcer` when an update to the model was made.
143
        Informs other watchers via the PostgreSQL channel.
144
        """
145
        conn = connect(
1✔
146
            host=self.host,
147
            port=self.port,
148
            user=self.user,
149
            password=self.password,
150
            dbname=self.dbname,
151
            sslmode=self.sslmode,
152
            sslrootcert=self.sslrootcert,
153
            sslcert=self.sslcert,
154
            sslkey=self.sslkey,
155
        )
156
        # Can only receive notifications when not in transaction, set this for easier usage
157
        conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
1✔
158
        curs = conn.cursor()
1✔
159
        curs.execute(f"NOTIFY {self.channel_name},'casbin policy update at {time()}'")
1✔
160
        conn.close()
1✔
161

162
    def should_reload(self) -> bool:
1✔
163
        try:
1✔
164
            if self.parent_conn.poll():
1✔
165
                message = int(self.parent_conn.recv())
1✔
166
                received_update = message == _ChannelSubscriptionMessage.RECEIVED_UPDATE
1✔
167
                if received_update and self.update_callback is not None:
1✔
168
                    self.update_callback()
1✔
169
                return received_update
1✔
170
        except EOFError:
×
171
            self.logger.warning(
×
172
                "Child casbin-watcher subscribe process has stopped, "
173
                "attempting to recreate the process in 10 seconds..."
174
            )
NEW
175
            self._create_subscription_process(delay=10)
×
176

177
        return False
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc