• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pycasbin / postgresql-watcher / 9952906225

16 Jul 2024 07:54AM UTC coverage: 78.017%. First build
9952906225

Pull #27

github

web-flow
Merge df0c4e750 into 11466e0d6
Pull Request #27: feat: calling update_callback fun if set and updated docs

181 of 232 relevant lines covered (78.02%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.51
/postgresql_watcher/watcher.py
1
from logging import Logger, getLogger
1✔
2
from multiprocessing import Process, Pipe
1✔
3
from multiprocessing.connection import Connection
1✔
4
from time import sleep, time
1✔
5
from typing import Optional, Callable
1✔
6

7
from psycopg2 import connect, extensions
1✔
8

9
from .casbin_channel_subscription import (
1✔
10
    casbin_channel_subscription,
11
    _ChannelSubscriptionMessage,
12
)
13

14

15
POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher"
1✔
16

17

18
class PostgresqlWatcher(object):
1✔
19

20
    def __init__(
1✔
21
        self,
22
        host: str,
23
        user: str,
24
        password: str,
25
        port: int = 5432,
26
        dbname: str = "postgres",
27
        channel_name: Optional[str] = None,
28
        start_listening: bool = True,
29
        sslmode: Optional[str] = None,
30
        sslrootcert: Optional[str] = None,
31
        sslcert: Optional[str] = None,
32
        sslkey: Optional[str] = None,
33
        logger: Optional[Logger] = None,
34
    ) -> None:
35
        """
36
        Initialize a PostgresqlWatcher object.
37

38
        Args:
39
            host (str): Hostname of the PostgreSQL server.
40
            user (str): PostgreSQL username.
41
            password (str): Password for the user.
42
            port (int): Post of the PostgreSQL server. Defaults to 5432.
43
            dbname (str): Database name. Defaults to "postgres".
44
            channel_name (str): The name of the channel to listen to and to send updates to. When None a default is used.
45
            start_listening (bool, optional): Flag whether to start listening to updates on the PostgreSQL channel. Defaults to True.
46
            sslmode (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
47
            sslrootcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
48
            sslcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
49
            sslkey (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None.
50
            logger (Optional[Logger], optional): Custom logger to use. Defaults to None.
51
        """
52
        self.update_callback = None
1✔
53
        self.parent_conn = None
1✔
54
        self.host = host
1✔
55
        self.port = port
1✔
56
        self.user = user
1✔
57
        self.password = password
1✔
58
        self.dbname = dbname
1✔
59
        self.channel_name = (
1✔
60
            channel_name if channel_name is not None else POSTGRESQL_CHANNEL_NAME
61
        )
62
        self.sslmode = sslmode
1✔
63
        self.sslrootcert = sslrootcert
1✔
64
        self.sslcert = sslcert
1✔
65
        self.sslkey = sslkey
1✔
66
        if logger is None:
1✔
67
            logger = getLogger()
×
68
        self.logger = logger
1✔
69
        self.parent_conn: Optional[Connection] = None
1✔
70
        self.child_conn: Optional[Connection] = None
1✔
71
        self.subscription_process: Optional[Process] = None
1✔
72
        self._create_subscription_process(start_listening)
1✔
73
        self.update_callback: Optional[Callable[[None], None]] = None
1✔
74

75
    def __del__(self) -> None:
1✔
76
        self._cleanup_connections_and_processes()
1✔
77

78
    def _create_subscription_process(
1✔
79
        self,
80
        start_listening=True,
81
        delay: Optional[int] = 2,
82
    ) -> None:
83
        self._cleanup_connections_and_processes()
1✔
84

85
        self.parent_conn, self.child_conn = Pipe()
1✔
86
        self.subscription_proces = Process(
1✔
87
            target=casbin_channel_subscription,
88
            args=(
89
                self.child_conn,
90
                self.logger,
91
                self.host,
92
                self.user,
93
                self.password,
94
                self.channel_name,
95
                self.port,
96
                self.dbname,
97
                delay,
98
                self.sslmode,
99
                self.sslrootcert,
100
                self.sslcert,
101
                self.sslkey,
102
            ),
103
            daemon=True,
104
        )
105
        if start_listening:
1✔
106
            self.start()
1✔
107

108
    def start(
1✔
109
        self,
110
        timeout=20, # seconds
111
    ):
112
        if not self.subscription_proces.is_alive():
1✔
113
            # Start listening to messages
114
            self.subscription_proces.start()
1✔
115
            # And wait for the Process to be ready to listen for updates
116
            # from PostgreSQL
117
            timeout_time = time() + timeout
1✔
118
            while True:
1✔
119
                if self.parent_conn.poll():
1✔
120
                    message = int(self.parent_conn.recv())
1✔
121
                    if message == _ChannelSubscriptionMessage.IS_READY:
1✔
122
                        break
1✔
123
                if time() > timeout_time:
1✔
124
                    raise PostgresqlWatcherChannelSubscriptionTimeoutError(timeout)
×
125
                sleep(1 / 1000)  # wait for 1 ms
1✔
126

127
    def _cleanup_connections_and_processes(self) -> None:
1✔
128
        # Clean up potentially existing Connections and Processes
129
        if self.parent_conn is not None:
1✔
130
            self.parent_conn.close()
1✔
131
            self.parent_conn = None
1✔
132
        if self.child_conn is not None:
1✔
133
            self.child_conn.close()
1✔
134
            self.child_conn = None
1✔
135
        if self.subscription_process is not None:
1✔
136
            self.subscription_process.terminate()
×
137
            self.subscription_process = None
×
138

139
    def set_update_callback(self, update_handler: Optional[Callable[[None], None]]):
1✔
140
        """
141
        Set the handler called, when the Watcher detects an update.
142
        Recommendation: `casbin_enforcer.adapter.load_policy`
143
        """
144
        self.update_callback = update_handler
1✔
145

146
    def update(self) -> None:
1✔
147
        """
148
        Called by `casbin.Enforcer` when an update to the model was made.
149
        Informs other watchers via the PostgreSQL channel.
150
        """
151
        conn = connect(
1✔
152
            host=self.host,
153
            port=self.port,
154
            user=self.user,
155
            password=self.password,
156
            dbname=self.dbname,
157
            sslmode=self.sslmode,
158
            sslrootcert=self.sslrootcert,
159
            sslcert=self.sslcert,
160
            sslkey=self.sslkey,
161
        )
162
        # Can only receive notifications when not in transaction, set this for easier usage
163
        conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT)
1✔
164
        curs = conn.cursor()
1✔
165
        curs.execute(f"NOTIFY {self.channel_name},'casbin policy update at {time()}'")
1✔
166
        conn.close()
1✔
167

168
    def should_reload(self) -> bool:
1✔
169
        try:
1✔
170
            should_reload_flag = False
1✔
171
            while self.parent_conn.poll():
1✔
172
                message = int(self.parent_conn.recv())
1✔
173
                received_update = message == _ChannelSubscriptionMessage.RECEIVED_UPDATE
1✔
174
                if received_update:
1✔
175
                    should_reload_flag = True
1✔
176

177
            if should_reload_flag and self.update_callback is not None:
1✔
178
                self.update_callback()
1✔
179

180
            return should_reload_flag
1✔
181
        except EOFError:
×
182
            self.logger.warning(
×
183
                "Child casbin-watcher subscribe process has stopped, "
184
                "attempting to recreate the process in 10 seconds..."
185
            )
186
            self._create_subscription_process(delay=10)
×
187

188
        return False
×
189

190

191
class PostgresqlWatcherChannelSubscriptionTimeoutError(RuntimeError):
1✔
192
    """
193
    Raised if the channel subscription could not be established within a given timeout.
194
    """
195

196
    def __init__(self, timeout_in_seconds: float) -> None:
1✔
197
        msg = f"The channel subscription could not be established within {timeout_in_seconds:.0f} seconds."
×
198
        super().__init__(msg)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc