From 3edbf373bff6677ec28bbed3c2c76b885d932eb8 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 12:17:51 +0200 Subject: [PATCH 01/33] chore: updated dev requirements --- dev_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 8a3a4d0..b4f398b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1 +1 @@ -black==20.8b1 \ No newline at end of file +black==24.4.2 From 8101bcd046fdf605f16e3c8e8c34cdf4deaf5c3f Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 12:18:54 +0200 Subject: [PATCH 02/33] chore: format code with black --- postgresql_watcher/watcher.py | 4 ++-- setup.py | 6 +++--- tests/test_postgresql_watcher.py | 10 ++++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 95cd9b8..6710949 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -35,7 +35,7 @@ def casbin_subscription( sslmode=sslmode, sslrootcert=sslrootcert, sslcert=sslcert, - sslkey=sslkey + sslkey=sslkey, ) # Can only receive notifications when not in transaction, set this for easier usage conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) @@ -130,7 +130,7 @@ def update(self): sslmode=self.sslmode, sslrootcert=self.sslrootcert, sslcert=self.sslcert, - sslkey=self.sslkey + sslkey=self.sslkey, ) # Can only receive notifications when not in transaction, set this for easier usage conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) diff --git a/setup.py b/setup.py index 03a55e3..fe4faff 100644 --- a/setup.py +++ b/setup.py @@ -14,10 +14,10 @@ long_description_content_type="text/markdown", url="https://github.com/pycasbin/postgresql-watcher", packages=find_packages(), - install_requires=open('requirements.txt').read().splitlines(), + install_requires=open("requirements.txt").read().splitlines(), extras_require={ - 'dev': [ - open('dev_requirements.txt').read().splitlines(), + "dev": [ + open("dev_requirements.txt").read().splitlines(), ] }, classifiers=[ diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index fa0356a..b66e440 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -14,7 +14,9 @@ def get_watcher(): - return PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME) + return PostgresqlWatcher( + host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME + ) pg_watcher = get_watcher() @@ -22,9 +24,9 @@ def get_watcher(): try: import _winapi from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE -except ImportError: - if sys.platform == 'win32': - raise +except ImportError as e: + if sys.platform == "win32": + raise e _winapi = None From 98ed5687651cd5a98869f10702c80d4ee02a43a1 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 12:20:46 +0200 Subject: [PATCH 03/33] chore: updated .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8acb2f2..54431a9 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,4 @@ dmypy.json .idea/ *.iml +.vscode From bb999173f5c083967d809c69e433722427c1c2f7 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 13:58:02 +0200 Subject: [PATCH 04/33] fix: type hint, multiprocessing.Pipe is a Callable and not a type --- postgresql_watcher/watcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 6710949..3e216ec 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -1,6 +1,7 @@ from typing import Optional, Callable from psycopg2 import connect, extensions from multiprocessing import Process, Pipe +from multiprocessing.connection import Connection import time from select import select from logging import Logger, getLogger @@ -10,7 +11,7 @@ def casbin_subscription( - process_conn: Pipe, + process_conn: Connection, logger: Logger, host: str, user: str, From df1b8040dd49e012796c74d7335201b6e02f894b Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:07:16 +0200 Subject: [PATCH 05/33] fix: make Watcher.should_reload return value consistent --- postgresql_watcher/watcher.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 3e216ec..2f0327d 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -142,7 +142,7 @@ def update(self): conn.close() return True - def should_reload(self): + def should_reload(self) -> bool: try: if self.parent_conn.poll(None): message = self.parent_conn.recv() @@ -153,7 +153,6 @@ def should_reload(self): "Child casbin-watcher subscribe process has stopped, " "attempting to recreate the process in 10 seconds..." ) - self.subscribed_process, self.parent_conn = self.create_subscriber_process( - delay=10 - ) - return False + self.create_subscription_process(delay=10) + + return False From c6fb08751b499068d076f27f81ad99503b5d335e Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:08:59 +0200 Subject: [PATCH 06/33] fix: Handle Connection and Process objects consistenly and close them before creating new ones --- postgresql_watcher/watcher.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 2f0327d..8922652 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -84,20 +84,32 @@ def __init__( if logger is None: logger = getLogger() self.logger = logger - self.subscribed_process = self.create_subscriber_process(start_process) + self.parent_conn: Connection = None + self.child_conn: Connection = None + self.subscription_process: Process = None + self.create_subscription_process(start_process) - def create_subscriber_process( + def create_subscription_process( self, start_process: Optional[bool] = True, delay: Optional[int] = 2, - ): - parent_conn, child_conn = Pipe() - if not self.parent_conn: - self.parent_conn = parent_conn - p = Process( + ) -> None: + # Clean up potentially existing Connections and Processes + if self.parent_conn is not None: + self.parent_conn.close() + self.parent_conn = None + if self.child_conn is not None: + self.child_conn.close() + self.child_conn = None + if self.subscription_process is not None: + self.subscription_process.terminate() + self.subscription_process = None + + self.parent_conn, self.child_conn = Pipe() + self.subscribed_process = Process( target=casbin_subscription, args=( - child_conn, + self.child_conn, self.logger, self.host, self.user, @@ -114,8 +126,7 @@ def create_subscriber_process( daemon=True, ) if start_process: - p.start() - return p + self.subscribed_process.start() def set_update_callback(self, fn_name: Callable): self.logger.debug(f"runtime is set update callback {fn_name}") From 434a84c1bfa21f8be16734b00d3e79718f7ab493 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:12:25 +0200 Subject: [PATCH 07/33] feat: Customize the postgres channel name --- postgresql_watcher/watcher.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 8922652..2d35fc4 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -54,6 +54,18 @@ def casbin_subscription( class PostgresqlWatcher(object): + + @staticmethod + def set_channel_name(channel_name: str) -> None: + """ + Customize the Postgres channel name. This have to be done before initializing a PostgresqlWatcher object. + + Args: + channel_name (str): New channel name + """ + global POSTGRESQL_CHANNEL_NAME + POSTGRESQL_CHANNEL_NAME = channel_name + def __init__( self, host: str, @@ -165,5 +177,5 @@ def should_reload(self) -> bool: "attempting to recreate the process in 10 seconds..." ) self.create_subscription_process(delay=10) - + return False From 90689ef151a4cc796bb6f6061fc91b9c61d30621 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:19:11 +0200 Subject: [PATCH 08/33] chore: Some code reorg - Make PostgresqlWatcher.create_subscription_process a private method - Rename casbin_subscription to _casbin_channel_subscription --- postgresql_watcher/watcher.py | 95 ++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 2d35fc4..5cd058a 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -10,49 +10,6 @@ POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher" -def casbin_subscription( - process_conn: Connection, - logger: Logger, - host: str, - user: str, - password: str, - port: Optional[int] = 5432, - dbname: Optional[str] = "postgres", - delay: Optional[int] = 2, - channel_name: Optional[str] = POSTGRESQL_CHANNEL_NAME, - sslmode: Optional[str] = None, - sslrootcert: Optional[str] = None, - sslcert: Optional[str] = None, - sslkey: Optional[str] = None, -): - # delay connecting to postgresql (postgresql connection failure) - time.sleep(delay) - conn = connect( - host=host, - port=port, - user=user, - password=password, - dbname=dbname, - sslmode=sslmode, - sslrootcert=sslrootcert, - sslcert=sslcert, - sslkey=sslkey, - ) - # Can only receive notifications when not in transaction, set this for easier usage - conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) - curs = conn.cursor() - curs.execute(f"LISTEN {channel_name};") - logger.debug("Waiting for casbin policy update") - while True and not curs.closed: - if not select([conn], [], [], 5) == ([], [], []): - logger.debug("Casbin policy update identified..") - conn.poll() - while conn.notifies: - notify = conn.notifies.pop(0) - logger.debug(f"Notify: {notify.payload}") - process_conn.send(notify.payload) - - class PostgresqlWatcher(object): @staticmethod @@ -99,9 +56,9 @@ def __init__( self.parent_conn: Connection = None self.child_conn: Connection = None self.subscription_process: Process = None - self.create_subscription_process(start_process) + self._create_subscription_process(start_process) - def create_subscription_process( + def _create_subscription_process( self, start_process: Optional[bool] = True, delay: Optional[int] = 2, @@ -119,7 +76,7 @@ def create_subscription_process( self.parent_conn, self.child_conn = Pipe() self.subscribed_process = Process( - target=casbin_subscription, + target=_casbin_channel_subscription, args=( self.child_conn, self.logger, @@ -145,6 +102,7 @@ def set_update_callback(self, fn_name: Callable): self.update_callback = fn_name def update(self): + conn = connect( host=self.host, port=self.port, @@ -176,6 +134,49 @@ def should_reload(self) -> bool: "Child casbin-watcher subscribe process has stopped, " "attempting to recreate the process in 10 seconds..." ) - self.create_subscription_process(delay=10) + self._create_subscription_process(delay=10) return False + + +def _casbin_channel_subscription( + process_conn: Connection, + logger: Logger, + host: str, + user: str, + password: str, + port: Optional[int] = 5432, + dbname: Optional[str] = "postgres", + delay: Optional[int] = 2, + channel_name: Optional[str] = POSTGRESQL_CHANNEL_NAME, + sslmode: Optional[str] = None, + sslrootcert: Optional[str] = None, + sslcert: Optional[str] = None, + sslkey: Optional[str] = None, +): + # delay connecting to postgresql (postgresql connection failure) + time.sleep(delay) + conn = connect( + host=host, + port=port, + user=user, + password=password, + dbname=dbname, + sslmode=sslmode, + sslrootcert=sslrootcert, + sslcert=sslcert, + sslkey=sslkey, + ) + # Can only receive notifications when not in transaction, set this for easier usage + conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) + curs = conn.cursor() + curs.execute(f"LISTEN {channel_name};") + logger.debug("Waiting for casbin policy update") + while True and not curs.closed: + if not select([conn], [], [], 5) == ([], [], []): + logger.debug("Casbin policy update identified..") + conn.poll() + while conn.notifies: + notify = conn.notifies.pop(0) + logger.debug(f"Notify: {notify.payload}") + process_conn.send(notify.payload) From 772a2617345b82c74e39476966d93596b3259479 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:30:30 +0200 Subject: [PATCH 09/33] docs: added doc string for PostgresqlWatcher.update --- postgresql_watcher/watcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 5cd058a..eb9e149 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -101,8 +101,11 @@ def set_update_callback(self, fn_name: Callable): self.logger.debug(f"runtime is set update callback {fn_name}") self.update_callback = fn_name - def update(self): - + def update(self) -> None: + """ + Called by `casbin.Enforcer` when an update to the model was made. + Informs other watchers via the PostgreSQL channel. + """ conn = connect( host=self.host, port=self.port, From f94294e09c6454724aef871502b1a450d796618a Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:36:01 +0200 Subject: [PATCH 10/33] refactor: PostgresqlWatcher.set_update_callback --- postgresql_watcher/watcher.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index eb9e149..11bb7b6 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -57,6 +57,7 @@ def __init__( self.child_conn: Connection = None self.subscription_process: Process = None self._create_subscription_process(start_process) + self.update_callback: Optional[Callable] = None def _create_subscription_process( self, @@ -97,15 +98,18 @@ def _create_subscription_process( if start_process: self.subscribed_process.start() - def set_update_callback(self, fn_name: Callable): - self.logger.debug(f"runtime is set update callback {fn_name}") - self.update_callback = fn_name + def set_update_callback(self, update_handler: Callable): + """ + Set the handler called, when the Watcher detects an update. + Recommendation: `casbin_enforcer.adapter.load_policy` + """ + self.update_callback = update_handler def update(self) -> None: """ Called by `casbin.Enforcer` when an update to the model was made. Informs other watchers via the PostgreSQL channel. - """ + """ conn = connect( host=self.host, port=self.port, From c26887c0866a222530b6f938c9edfe71f4d6162c Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:54:30 +0200 Subject: [PATCH 11/33] refactor!: Rename 'start_process' flag to 'start_listening' --- postgresql_watcher/watcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 11bb7b6..5a88195 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -31,7 +31,7 @@ def __init__( port: Optional[int] = 5432, dbname: Optional[str] = "postgres", channel_name: Optional[str] = POSTGRESQL_CHANNEL_NAME, - start_process: Optional[bool] = True, + start_listening: bool = True, sslmode: Optional[str] = None, sslrootcert: Optional[str] = None, sslcert: Optional[str] = None, @@ -56,12 +56,12 @@ def __init__( self.parent_conn: Connection = None self.child_conn: Connection = None self.subscription_process: Process = None - self._create_subscription_process(start_process) + self._create_subscription_process(start_listening) self.update_callback: Optional[Callable] = None def _create_subscription_process( self, - start_process: Optional[bool] = True, + start_listening=True, delay: Optional[int] = 2, ) -> None: # Clean up potentially existing Connections and Processes @@ -95,7 +95,7 @@ def _create_subscription_process( ), daemon=True, ) - if start_process: + if start_listening: self.subscribed_process.start() def set_update_callback(self, update_handler: Callable): From bc33cf5977c57392ff2288da57c4075e473e3d6b Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 15:54:55 +0200 Subject: [PATCH 12/33] docs: Added doc string to PostgresqlWatcher.__init__ --- postgresql_watcher/watcher.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 5a88195..d5eaed6 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -37,7 +37,24 @@ def __init__( sslcert: Optional[str] = None, sslkey: Optional[str] = None, logger: Optional[Logger] = None, - ): + ) -> None: + """ + Initialize a PostgresqlWatcher object. + + Args: + host (str): Hostname of the PostgreSQL server. + user (str): PostgreSQL username. + password (str): Password for the user. + port (Optional[int], optional): Post of the PostgreSQL server. Defaults to 5432. + dbname (Optional[str], optional): Database name. Defaults to "postgres". + channel_name (Optional[str], optional): The name of the channel to listen to and to send updates to. Defaults to 'casbin_role_watcher'. + start_listening (bool, optional): Flag whether to start listening to updates on the PostgreSQL channel. Defaults to True. + sslmode (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None. + sslrootcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None. + sslcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None. + sslkey (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None. + logger (Optional[Logger], optional): Custom logger to use. Defaults to None. + """ self.update_callback = None self.parent_conn = None self.host = host From 9aba002949eb828aff6b0d8dcf640960494d4e40 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 16:02:45 +0200 Subject: [PATCH 13/33] fix: Added proper destructor for PostgresqlWatcher --- postgresql_watcher/watcher.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index d5eaed6..b5e2757 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -70,27 +70,21 @@ def __init__( if logger is None: logger = getLogger() self.logger = logger - self.parent_conn: Connection = None - self.child_conn: Connection = None - self.subscription_process: Process = None + self.parent_conn: Connection | None = None + self.child_conn: Connection | None = None + self.subscription_process: Process | None = None self._create_subscription_process(start_listening) self.update_callback: Optional[Callable] = None + def __del__(self) -> None: + self._cleanup_connections_and_processes() + def _create_subscription_process( self, start_listening=True, delay: Optional[int] = 2, ) -> None: - # Clean up potentially existing Connections and Processes - if self.parent_conn is not None: - self.parent_conn.close() - self.parent_conn = None - if self.child_conn is not None: - self.child_conn.close() - self.child_conn = None - if self.subscription_process is not None: - self.subscription_process.terminate() - self.subscription_process = None + self._cleanup_connections_and_processes() self.parent_conn, self.child_conn = Pipe() self.subscribed_process = Process( @@ -115,6 +109,18 @@ def _create_subscription_process( if start_listening: self.subscribed_process.start() + def _cleanup_connections_and_processes(self) -> None: + # Clean up potentially existing Connections and Processes + if self.parent_conn is not None: + self.parent_conn.close() + self.parent_conn = None + if self.child_conn is not None: + self.child_conn.close() + self.child_conn = None + if self.subscription_process is not None: + self.subscription_process.terminate() + self.subscription_process = None + def set_update_callback(self, update_handler: Callable): """ Set the handler called, when the Watcher detects an update. From eb5c8f3ead910b4f1875af9e35d04e41f281e4d0 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 16:16:18 +0200 Subject: [PATCH 14/33] chore: fix type hints and proper handling of the channel_name argument and its default value --- postgresql_watcher/watcher.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index b5e2757..eea7c41 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -12,25 +12,14 @@ class PostgresqlWatcher(object): - @staticmethod - def set_channel_name(channel_name: str) -> None: - """ - Customize the Postgres channel name. This have to be done before initializing a PostgresqlWatcher object. - - Args: - channel_name (str): New channel name - """ - global POSTGRESQL_CHANNEL_NAME - POSTGRESQL_CHANNEL_NAME = channel_name - def __init__( self, host: str, user: str, password: str, - port: Optional[int] = 5432, - dbname: Optional[str] = "postgres", - channel_name: Optional[str] = POSTGRESQL_CHANNEL_NAME, + port: int = 5432, + dbname: str = "postgres", + channel_name: Optional[str] = None, start_listening: bool = True, sslmode: Optional[str] = None, sslrootcert: Optional[str] = None, @@ -45,9 +34,9 @@ def __init__( host (str): Hostname of the PostgreSQL server. user (str): PostgreSQL username. password (str): Password for the user. - port (Optional[int], optional): Post of the PostgreSQL server. Defaults to 5432. - dbname (Optional[str], optional): Database name. Defaults to "postgres". - channel_name (Optional[str], optional): The name of the channel to listen to and to send updates to. Defaults to 'casbin_role_watcher'. + port (int): Post of the PostgreSQL server. Defaults to 5432. + dbname (str): Database name. Defaults to "postgres". + channel_name (str): The name of the channel to listen to and to send updates to. When None a default is used. start_listening (bool, optional): Flag whether to start listening to updates on the PostgreSQL channel. Defaults to True. sslmode (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None. sslrootcert (Optional[str], optional): See `psycopg2.connect` for details. Defaults to None. @@ -62,7 +51,7 @@ def __init__( self.user = user self.password = password self.dbname = dbname - self.channel_name = channel_name + self.channel_name = channel_name if channel_name is not None else POSTGRESQL_CHANNEL_NAME self.sslmode = sslmode self.sslrootcert = sslrootcert self.sslcert = sslcert From 94e23d7446b10b2d1d7c66c85a4ffb942c342c28 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 17:34:03 +0200 Subject: [PATCH 15/33] test: fix tests decrease select timeout to one second in child Process remove infinite timout in PostgresqlWatcher.should_reload create a new watcher instance for every test case --- postgresql_watcher/watcher.py | 8 ++++---- tests/test_postgresql_watcher.py | 16 +++++++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index eea7c41..18b10a9 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -8,6 +8,7 @@ POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher" +CASBIN_CHANNEL_SELECT_TIMEOUT = 1 # seconds class PostgresqlWatcher(object): @@ -140,11 +141,10 @@ def update(self) -> None: f"NOTIFY {self.channel_name},'casbin policy update at {time.time()}'" ) conn.close() - return True def should_reload(self) -> bool: try: - if self.parent_conn.poll(None): + if self.parent_conn.poll(): message = self.parent_conn.recv() self.logger.debug(f"message:{message}") return True @@ -191,8 +191,8 @@ def _casbin_channel_subscription( curs = conn.cursor() curs.execute(f"LISTEN {channel_name};") logger.debug("Waiting for casbin policy update") - while True and not curs.closed: - if not select([conn], [], [], 5) == ([], [], []): + while not curs.closed: + if not select([conn], [], [], CASBIN_CHANNEL_SELECT_TIMEOUT) == ([], [], []): logger.debug("Casbin policy update identified..") conn.poll() while conn.notifies: diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index b66e440..e2ecf75 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -1,8 +1,10 @@ import sys import unittest from multiprocessing.connection import Pipe +from time import sleep from postgresql_watcher import PostgresqlWatcher +from postgresql_watcher.watcher import CASBIN_CHANNEL_SELECT_TIMEOUT from multiprocessing import connection, context # Warning!!! , Please setup yourself config @@ -18,9 +20,6 @@ def get_watcher(): host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME ) - -pg_watcher = get_watcher() - try: import _winapi from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE @@ -32,6 +31,7 @@ def get_watcher(): class TestConfig(unittest.TestCase): def test_pg_watcher_init(self): + pg_watcher = get_watcher() if _winapi: assert isinstance(pg_watcher.parent_conn, connection.PipeConnection) else: @@ -39,12 +39,18 @@ def test_pg_watcher_init(self): assert isinstance(pg_watcher.subscribed_process, context.Process) def test_update_pg_watcher(self): - assert pg_watcher.update() is True + pg_watcher = get_watcher() + sleep(5) # Wait for casbin_channel_subscription initialization + pg_watcher.update() + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + self.assertTrue(pg_watcher.should_reload()) def test_default_update_callback(self): - assert pg_watcher.update_callback() is None + pg_watcher = get_watcher() + assert pg_watcher.update_callback is None def test_add_update_callback(self): + pg_watcher = get_watcher() def _test_callback(): pass From 1db8e23ce73483dbf0024da6bc7cddd82b66484e Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 17:39:38 +0200 Subject: [PATCH 16/33] feat: Setup logging module for unit tests --- tests/test_postgresql_watcher.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index e2ecf75..bce1f19 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -2,6 +2,7 @@ import unittest from multiprocessing.connection import Pipe from time import sleep +import logging from postgresql_watcher import PostgresqlWatcher from postgresql_watcher.watcher import CASBIN_CHANNEL_SELECT_TIMEOUT @@ -14,10 +15,19 @@ PASSWORD = "123456" DBNAME = "postgres" +logger = logging.getLogger() +logger.level = logging.DEBUG +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) def get_watcher(): return PostgresqlWatcher( - host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME + host=HOST, + port=PORT, + user=USER, + password=PASSWORD, + dbname=DBNAME, + logger=logger, ) try: From f0e647919746d7d9ea6af51cafce410223060de0 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 18:17:06 +0200 Subject: [PATCH 17/33] fix: typo --- postgresql_watcher/watcher.py | 4 ++-- tests/test_postgresql_watcher.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 18b10a9..0db5534 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -77,7 +77,7 @@ def _create_subscription_process( self._cleanup_connections_and_processes() self.parent_conn, self.child_conn = Pipe() - self.subscribed_process = Process( + self.subscription_proces = Process( target=_casbin_channel_subscription, args=( self.child_conn, @@ -97,7 +97,7 @@ def _create_subscription_process( daemon=True, ) if start_listening: - self.subscribed_process.start() + self.subscription_proces.start() def _cleanup_connections_and_processes(self) -> None: # Clean up potentially existing Connections and Processes diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index bce1f19..6ce4f64 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -46,7 +46,7 @@ def test_pg_watcher_init(self): assert isinstance(pg_watcher.parent_conn, connection.PipeConnection) else: assert isinstance(pg_watcher.parent_conn, connection.Connection) - assert isinstance(pg_watcher.subscribed_process, context.Process) + assert isinstance(pg_watcher.subscription_proces, context.Process) def test_update_pg_watcher(self): pg_watcher = get_watcher() From 1640e75cb8fc2c793314a84e2b882eb396f42a13 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 19:01:39 +0200 Subject: [PATCH 18/33] feat: channel subscription with proper resource cleanup Moved channel subscription function to separate file and added context manager for the connection, that handles SIGINT, SIGTERM for proper resource cleanup --- .../casbin_channel_subscription.py | 93 +++++++++++++++++++ postgresql_watcher/watcher.py | 51 +--------- tests/test_postgresql_watcher.py | 2 +- 3 files changed, 98 insertions(+), 48 deletions(-) create mode 100644 postgresql_watcher/casbin_channel_subscription.py diff --git a/postgresql_watcher/casbin_channel_subscription.py b/postgresql_watcher/casbin_channel_subscription.py new file mode 100644 index 0000000..49efc9b --- /dev/null +++ b/postgresql_watcher/casbin_channel_subscription.py @@ -0,0 +1,93 @@ +from logging import Logger +from multiprocessing.connection import Connection +from select import select +from signal import signal, SIGINT, SIGTERM +from time import sleep +from typing import Optional + +from psycopg2 import connect, extensions, InterfaceError + + +CASBIN_CHANNEL_SELECT_TIMEOUT = 1 # seconds + + +def casbin_channel_subscription( + process_conn: Connection, + logger: Logger, + host: str, + user: str, + password: str, + channel_name: str, + port: int = 5432, + dbname: str = "postgres", + delay: int = 2, + sslmode: Optional[str] = None, + sslrootcert: Optional[str] = None, + sslcert: Optional[str] = None, + sslkey: Optional[str] = None, +): + # delay connecting to postgresql (postgresql connection failure) + sleep(delay) + db_connection = connect( + host=host, + port=port, + user=user, + password=password, + dbname=dbname, + sslmode=sslmode, + sslrootcert=sslrootcert, + sslcert=sslcert, + sslkey=sslkey, + ) + # Can only receive notifications when not in transaction, set this for easier usage + db_connection.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) + db_cursor = db_connection.cursor() + context_manager = _ConnectionManager(db_connection, db_cursor) + + db_cursor.execute(f"LISTEN {channel_name};") + logger.debug("Waiting for casbin policy update") + with context_manager: + while not db_cursor.closed: + try: + if not select([db_connection], [], [], CASBIN_CHANNEL_SELECT_TIMEOUT) == ([], [], []): + logger.debug("Casbin policy update identified..") + db_connection.poll() + while db_connection.notifies: + notify = db_connection.notifies.pop(0) + logger.debug(f"Notify: {notify.payload}") + process_conn.send(notify.payload) + except (InterfaceError, OSError) as e: + # Log an exception if these errors occurred without the context beeing closed + if not context_manager.connections_were_closed: + logger.critical(e, exc_info=True) + break + +class _ConnectionManager: + """ + You can not use 'with' and a connection / cursor directly in this setup. + For more details see this issue: https://github.com/psycopg/psycopg2/issues/941#issuecomment-864025101. + As a workaround this connection manager / context manager class is used, that also handles SIGINT and SIGTERM and + closes the database connection. + """ + + def __init__(self, connection, cursor) -> None: + self.connection = connection + self.cursor = cursor + self.connections_were_closed = False + + def __enter__(self): + signal(SIGINT, self._close_connections) + signal(SIGTERM, self._close_connections) + return self + + def _close_connections(self, *_): + if self.cursor is not None: + self.cursor.close() + self.cursor = None + if self.connection is not None: + self.connection.close() + self.connection = None + self.connections_were_closed = True + + def __exit__(self, *_): + self._close_connections() diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 0db5534..af65ad8 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -5,11 +5,10 @@ import time from select import select from logging import Logger, getLogger +from .casbin_channel_subscription import casbin_channel_subscription POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher" -CASBIN_CHANNEL_SELECT_TIMEOUT = 1 # seconds - class PostgresqlWatcher(object): @@ -78,17 +77,17 @@ def _create_subscription_process( self.parent_conn, self.child_conn = Pipe() self.subscription_proces = Process( - target=_casbin_channel_subscription, + target=casbin_channel_subscription, args=( self.child_conn, self.logger, self.host, self.user, self.password, + self.channel_name, self.port, self.dbname, delay, - self.channel_name, self.sslmode, self.sslrootcert, self.sslcert, @@ -96,6 +95,7 @@ def _create_subscription_process( ), daemon=True, ) + if start_listening: self.subscription_proces.start() @@ -156,46 +156,3 @@ def should_reload(self) -> bool: self._create_subscription_process(delay=10) return False - - -def _casbin_channel_subscription( - process_conn: Connection, - logger: Logger, - host: str, - user: str, - password: str, - port: Optional[int] = 5432, - dbname: Optional[str] = "postgres", - delay: Optional[int] = 2, - channel_name: Optional[str] = POSTGRESQL_CHANNEL_NAME, - sslmode: Optional[str] = None, - sslrootcert: Optional[str] = None, - sslcert: Optional[str] = None, - sslkey: Optional[str] = None, -): - # delay connecting to postgresql (postgresql connection failure) - time.sleep(delay) - conn = connect( - host=host, - port=port, - user=user, - password=password, - dbname=dbname, - sslmode=sslmode, - sslrootcert=sslrootcert, - sslcert=sslcert, - sslkey=sslkey, - ) - # Can only receive notifications when not in transaction, set this for easier usage - conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) - curs = conn.cursor() - curs.execute(f"LISTEN {channel_name};") - logger.debug("Waiting for casbin policy update") - while not curs.closed: - if not select([conn], [], [], CASBIN_CHANNEL_SELECT_TIMEOUT) == ([], [], []): - logger.debug("Casbin policy update identified..") - conn.poll() - while conn.notifies: - notify = conn.notifies.pop(0) - logger.debug(f"Notify: {notify.payload}") - process_conn.send(notify.payload) diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 6ce4f64..7783c6a 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -5,7 +5,7 @@ import logging from postgresql_watcher import PostgresqlWatcher -from postgresql_watcher.watcher import CASBIN_CHANNEL_SELECT_TIMEOUT +from postgresql_watcher.casbin_channel_subscription import CASBIN_CHANNEL_SELECT_TIMEOUT from multiprocessing import connection, context # Warning!!! , Please setup yourself config From b86a400a0ee36b5ae56fbd1ad363971889b4ee1e Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 19:10:19 +0200 Subject: [PATCH 19/33] chore: removed unnecessary tests --- tests/test_postgresql_watcher.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 7783c6a..1104026 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -55,18 +55,5 @@ def test_update_pg_watcher(self): sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) self.assertTrue(pg_watcher.should_reload()) - def test_default_update_callback(self): - pg_watcher = get_watcher() - assert pg_watcher.update_callback is None - - def test_add_update_callback(self): - pg_watcher = get_watcher() - def _test_callback(): - pass - - pg_watcher.set_update_callback(_test_callback) - assert pg_watcher.update_callback == _test_callback - - if __name__ == "__main__": unittest.main() From e0a6337bd83c103de09a3bbcc5df70fcebbfaee7 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 21:47:46 +0200 Subject: [PATCH 20/33] feat: Wait for Process to be ready to receive messages from PostgreSQL --- .../casbin_channel_subscription.py | 20 +++++++++---- postgresql_watcher/watcher.py | 29 +++++++++++++++---- tests/test_postgresql_watcher.py | 4 ++- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/postgresql_watcher/casbin_channel_subscription.py b/postgresql_watcher/casbin_channel_subscription.py index 49efc9b..a7bb276 100644 --- a/postgresql_watcher/casbin_channel_subscription.py +++ b/postgresql_watcher/casbin_channel_subscription.py @@ -1,3 +1,4 @@ +from enum import IntEnum from logging import Logger from multiprocessing.connection import Connection from select import select @@ -8,7 +9,7 @@ from psycopg2 import connect, extensions, InterfaceError -CASBIN_CHANNEL_SELECT_TIMEOUT = 1 # seconds +CASBIN_CHANNEL_SELECT_TIMEOUT = 1 # seconds def casbin_channel_subscription( @@ -43,25 +44,34 @@ def casbin_channel_subscription( db_connection.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) db_cursor = db_connection.cursor() context_manager = _ConnectionManager(db_connection, db_cursor) - + db_cursor.execute(f"LISTEN {channel_name};") logger.debug("Waiting for casbin policy update") + process_conn.send(_ChannelSubscriptionMessage.IS_READY) with context_manager: while not db_cursor.closed: try: - if not select([db_connection], [], [], CASBIN_CHANNEL_SELECT_TIMEOUT) == ([], [], []): + if not select( + [db_connection], [], [], CASBIN_CHANNEL_SELECT_TIMEOUT + ) == ([], [], []): logger.debug("Casbin policy update identified..") db_connection.poll() while db_connection.notifies: notify = db_connection.notifies.pop(0) logger.debug(f"Notify: {notify.payload}") - process_conn.send(notify.payload) + process_conn.send(_ChannelSubscriptionMessage.RECEIVED_UPDATE) except (InterfaceError, OSError) as e: # Log an exception if these errors occurred without the context beeing closed if not context_manager.connections_were_closed: logger.critical(e, exc_info=True) break + +class _ChannelSubscriptionMessage(IntEnum): + IS_READY = 1 + RECEIVED_UPDATE = 2 + + class _ConnectionManager: """ You can not use 'with' and a connection / cursor directly in this setup. @@ -79,7 +89,7 @@ def __enter__(self): signal(SIGINT, self._close_connections) signal(SIGTERM, self._close_connections) return self - + def _close_connections(self, *_): if self.cursor is not None: self.cursor.close() diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index af65ad8..ec11526 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -5,11 +5,15 @@ import time from select import select from logging import Logger, getLogger -from .casbin_channel_subscription import casbin_channel_subscription +from .casbin_channel_subscription import ( + casbin_channel_subscription, + _ChannelSubscriptionMessage, +) POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher" + class PostgresqlWatcher(object): def __init__( @@ -51,7 +55,9 @@ def __init__( self.user = user self.password = password self.dbname = dbname - self.channel_name = channel_name if channel_name is not None else POSTGRESQL_CHANNEL_NAME + self.channel_name = ( + channel_name if channel_name is not None else POSTGRESQL_CHANNEL_NAME + ) self.sslmode = sslmode self.sslrootcert = sslrootcert self.sslcert = sslcert @@ -95,9 +101,21 @@ def _create_subscription_process( ), daemon=True, ) - if start_listening: + self.start() + + def start(self): + if not self.subscription_proces.is_alive(): + # Start listening to messages self.subscription_proces.start() + # And wait for the Process to be ready to listen for updates + # from PostgreSQL + while True: + if self.parent_conn.poll(): + message = int(self.parent_conn.recv()) + if message == _ChannelSubscriptionMessage.IS_READY: + break + time.sleep(1 / 1000) # wait for 1 ms def _cleanup_connections_and_processes(self) -> None: # Clean up potentially existing Connections and Processes @@ -145,9 +163,8 @@ def update(self) -> None: def should_reload(self) -> bool: try: if self.parent_conn.poll(): - message = self.parent_conn.recv() - self.logger.debug(f"message:{message}") - return True + message = int(self.parent_conn.recv()) + return message == _ChannelSubscriptionMessage.RECEIVED_UPDATE except EOFError: self.logger.warning( "Child casbin-watcher subscribe process has stopped, " diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 1104026..72b7435 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -20,6 +20,7 @@ stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) + def get_watcher(): return PostgresqlWatcher( host=HOST, @@ -30,6 +31,7 @@ def get_watcher(): logger=logger, ) + try: import _winapi from _winapi import WAIT_OBJECT_0, WAIT_ABANDONED_0, WAIT_TIMEOUT, INFINITE @@ -50,10 +52,10 @@ def test_pg_watcher_init(self): def test_update_pg_watcher(self): pg_watcher = get_watcher() - sleep(5) # Wait for casbin_channel_subscription initialization pg_watcher.update() sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) self.assertTrue(pg_watcher.should_reload()) + if __name__ == "__main__": unittest.main() From 87056d6a06576da6dd5ec47dcd84c78040a486b3 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 21:59:28 +0200 Subject: [PATCH 21/33] test: multiple instances of the watcher --- tests/test_postgresql_watcher.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 72b7435..2339b04 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -50,12 +50,21 @@ def test_pg_watcher_init(self): assert isinstance(pg_watcher.parent_conn, connection.Connection) assert isinstance(pg_watcher.subscription_proces, context.Process) - def test_update_pg_watcher(self): + def test_update_single_pg_watcher(self): pg_watcher = get_watcher() pg_watcher.update() sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) self.assertTrue(pg_watcher.should_reload()) + def test_update_mutiple_pg_watcher(self): + main_watcher = get_watcher() + + other_watchers = [get_watcher() for _ in range(5)] + main_watcher.update() + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + for watcher in other_watchers: + self.assertTrue(watcher.should_reload()) + if __name__ == "__main__": unittest.main() From af31dffacbb045fd6a981f905492d78b340faf27 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 22:07:14 +0200 Subject: [PATCH 22/33] test: make sure every test case uses its own channel --- tests/test_postgresql_watcher.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 2339b04..f1ab098 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -21,7 +21,7 @@ logger.addHandler(stream_handler) -def get_watcher(): +def get_watcher(channel_name): return PostgresqlWatcher( host=HOST, port=PORT, @@ -29,6 +29,7 @@ def get_watcher(): password=PASSWORD, dbname=DBNAME, logger=logger, + channel_name=channel_name, ) @@ -43,7 +44,7 @@ def get_watcher(): class TestConfig(unittest.TestCase): def test_pg_watcher_init(self): - pg_watcher = get_watcher() + pg_watcher = get_watcher("test_pg_watcher_init") if _winapi: assert isinstance(pg_watcher.parent_conn, connection.PipeConnection) else: @@ -51,15 +52,16 @@ def test_pg_watcher_init(self): assert isinstance(pg_watcher.subscription_proces, context.Process) def test_update_single_pg_watcher(self): - pg_watcher = get_watcher() + pg_watcher = get_watcher("test_update_single_pg_watcher") pg_watcher.update() sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) self.assertTrue(pg_watcher.should_reload()) def test_update_mutiple_pg_watcher(self): - main_watcher = get_watcher() + channel_name = "test_update_mutiple_pg_watcher" + main_watcher = get_watcher(channel_name) - other_watchers = [get_watcher() for _ in range(5)] + other_watchers = [get_watcher(channel_name) for _ in range(5)] main_watcher.update() sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) for watcher in other_watchers: From f31ba260a0aa9aa507fa0f70956a78843b056800 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 22:10:25 +0200 Subject: [PATCH 23/33] test: no update --- tests/test_postgresql_watcher.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index f1ab098..4b6d87b 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -57,6 +57,11 @@ def test_update_single_pg_watcher(self): sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) self.assertTrue(pg_watcher.should_reload()) + def test_no_update_single_pg_watcher(self): + pg_watcher = get_watcher("test_no_update_single_pg_watcher") + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + self.assertFalse(pg_watcher.should_reload()) + def test_update_mutiple_pg_watcher(self): channel_name = "test_update_mutiple_pg_watcher" main_watcher = get_watcher(channel_name) @@ -67,6 +72,15 @@ def test_update_mutiple_pg_watcher(self): for watcher in other_watchers: self.assertTrue(watcher.should_reload()) + def test_no_update_mutiple_pg_watcher(self): + channel_name = "test_no_update_mutiple_pg_watcher" + main_watcher = get_watcher(channel_name) + + other_watchers = [get_watcher(channel_name) for _ in range(5)] + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + for watcher in other_watchers: + self.assertFalse(watcher.should_reload()) + if __name__ == "__main__": unittest.main() From b464dcb6081dc296dfa5220ffe1cb936e17a6831 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 22:13:12 +0200 Subject: [PATCH 24/33] refactor: moved code into with block --- postgresql_watcher/casbin_channel_subscription.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/postgresql_watcher/casbin_channel_subscription.py b/postgresql_watcher/casbin_channel_subscription.py index a7bb276..0a5e19e 100644 --- a/postgresql_watcher/casbin_channel_subscription.py +++ b/postgresql_watcher/casbin_channel_subscription.py @@ -45,10 +45,11 @@ def casbin_channel_subscription( db_cursor = db_connection.cursor() context_manager = _ConnectionManager(db_connection, db_cursor) - db_cursor.execute(f"LISTEN {channel_name};") - logger.debug("Waiting for casbin policy update") - process_conn.send(_ChannelSubscriptionMessage.IS_READY) with context_manager: + db_cursor.execute(f"LISTEN {channel_name};") + logger.debug("Waiting for casbin policy update") + process_conn.send(_ChannelSubscriptionMessage.IS_READY) + while not db_cursor.closed: try: if not select( From b13046309d4363b4aecc00b09dd58e5cde9fa46e Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 22:32:06 +0200 Subject: [PATCH 25/33] feat: automaticall call the update handler if it is provided --- postgresql_watcher/watcher.py | 9 ++++++--- tests/test_postgresql_watcher.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index ec11526..986b5c3 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -69,7 +69,7 @@ def __init__( self.child_conn: Connection | None = None self.subscription_process: Process | None = None self._create_subscription_process(start_listening) - self.update_callback: Optional[Callable] = None + self.update_callback: Optional[Callable[[None], None]] = None def __del__(self) -> None: self._cleanup_connections_and_processes() @@ -129,7 +129,7 @@ def _cleanup_connections_and_processes(self) -> None: self.subscription_process.terminate() self.subscription_process = None - def set_update_callback(self, update_handler: Callable): + def set_update_callback(self, update_handler: Optional[Callable[[None], None]]): """ Set the handler called, when the Watcher detects an update. Recommendation: `casbin_enforcer.adapter.load_policy` @@ -164,7 +164,10 @@ def should_reload(self) -> bool: try: if self.parent_conn.poll(): message = int(self.parent_conn.recv()) - return message == _ChannelSubscriptionMessage.RECEIVED_UPDATE + received_update = message == _ChannelSubscriptionMessage.RECEIVED_UPDATE + if received_update and self.update_callback is not None: + self.update_callback() + return received_update except EOFError: self.logger.warning( "Child casbin-watcher subscribe process has stopped, " diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 4b6d87b..dadc251 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -1,5 +1,6 @@ import sys import unittest +from unittest.mock import MagicMock from multiprocessing.connection import Pipe from time import sleep import logging @@ -80,6 +81,26 @@ def test_no_update_mutiple_pg_watcher(self): sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) for watcher in other_watchers: self.assertFalse(watcher.should_reload()) + self.assertFalse(main_watcher.should_reload()) + + def test_update_handler_called(self): + channel_name = "test_update_handler_called" + main_watcher = get_watcher(channel_name) + handler = MagicMock() + main_watcher.set_update_callback(handler) + main_watcher.update() + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + self.assertTrue(main_watcher.should_reload()) + self.assertTrue(handler.call_count == 1) + + def test_update_handler_not_called(self): + channel_name = "test_update_handler_not_called" + main_watcher = get_watcher(channel_name) + handler = MagicMock() + main_watcher.set_update_callback(handler) + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * 2) + self.assertFalse(main_watcher.should_reload()) + self.assertTrue(handler.call_count == 0) if __name__ == "__main__": From 5ffe939647dd551fdedd648372f490214a27f68a Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 22:37:43 +0200 Subject: [PATCH 26/33] refactor: sorted imports --- postgresql_watcher/watcher.py | 17 ++++++++--------- tests/test_postgresql_watcher.py | 22 +++++++++++----------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 986b5c3..bc33db0 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -1,10 +1,11 @@ -from typing import Optional, Callable -from psycopg2 import connect, extensions +from logging import Logger, getLogger from multiprocessing import Process, Pipe from multiprocessing.connection import Connection -import time -from select import select -from logging import Logger, getLogger +from time import sleep, time +from typing import Optional, Callable + +from psycopg2 import connect, extensions + from .casbin_channel_subscription import ( casbin_channel_subscription, _ChannelSubscriptionMessage, @@ -115,7 +116,7 @@ def start(self): message = int(self.parent_conn.recv()) if message == _ChannelSubscriptionMessage.IS_READY: break - time.sleep(1 / 1000) # wait for 1 ms + sleep(1 / 1000) # wait for 1 ms def _cleanup_connections_and_processes(self) -> None: # Clean up potentially existing Connections and Processes @@ -155,9 +156,7 @@ def update(self) -> None: # Can only receive notifications when not in transaction, set this for easier usage conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) curs = conn.cursor() - curs.execute( - f"NOTIFY {self.channel_name},'casbin policy update at {time.time()}'" - ) + curs.execute(f"NOTIFY {self.channel_name},'casbin policy update at {time()}'") conn.close() def should_reload(self) -> bool: diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index dadc251..10ef81e 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -1,13 +1,13 @@ -import sys -import unittest -from unittest.mock import MagicMock -from multiprocessing.connection import Pipe +from logging import DEBUG, getLogger, StreamHandler +from multiprocessing import connection, context from time import sleep -import logging +from unittest import TestCase, main +from unittest.mock import MagicMock +import sys from postgresql_watcher import PostgresqlWatcher from postgresql_watcher.casbin_channel_subscription import CASBIN_CHANNEL_SELECT_TIMEOUT -from multiprocessing import connection, context + # Warning!!! , Please setup yourself config HOST = "127.0.0.1" @@ -16,9 +16,9 @@ PASSWORD = "123456" DBNAME = "postgres" -logger = logging.getLogger() -logger.level = logging.DEBUG -stream_handler = logging.StreamHandler(sys.stdout) +logger = getLogger() +logger.level = DEBUG +stream_handler = StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -43,7 +43,7 @@ def get_watcher(channel_name): _winapi = None -class TestConfig(unittest.TestCase): +class TestConfig(TestCase): def test_pg_watcher_init(self): pg_watcher = get_watcher("test_pg_watcher_init") if _winapi: @@ -104,4 +104,4 @@ def test_update_handler_not_called(self): if __name__ == "__main__": - unittest.main() + main() From e6a7bdc31599d1f4c7a863f914d1138de73611fe Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Mon, 8 Jul 2024 23:05:03 +0200 Subject: [PATCH 27/33] docs: updated README --- README.md | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index a78cc00..acd58b2 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ pip install casbin-postgresql-watcher ``` ## Basic Usage Example -### With Flask-authz + ```python from flask_authz import CasbinEnforcer from postgresql_watcher import PostgresqlWatcher @@ -25,15 +25,22 @@ from casbin.persist.adapters import FileAdapter casbin_enforcer = CasbinEnforcer(app, adapter) watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME) -watcher.set_update_callback(casbin_enforcer.e.load_policy) +watcher.set_update_callback(casbin_enforcer.adapter.load_policy) casbin_enforcer.set_watcher(watcher) -``` -## Basic Usage Example With SSL Enabled +# Call should_reload before every call of enforce to make sure +# the policy is update to date +watcher.should_reload() +if casbin_enforcer.enforce("alice", "data1", "read"): + # permit alice to read data1 + pass +else: + # deny the request, show an error + pass +``` -See [PostgresQL documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) for full details of SSL parameters. +alternatively, if you need more control -### With Flask-authz ```python from flask_authz import CasbinEnforcer from postgresql_watcher import PostgresqlWatcher @@ -41,7 +48,28 @@ from flask import Flask from casbin.persist.adapters import FileAdapter casbin_enforcer = CasbinEnforcer(app, adapter) -watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME, sslmode="verify_full", sslcert=SSLCERT, sslrootcert=SSLROOTCERT, sslkey=SSLKEY) -watcher.set_update_callback(casbin_enforcer.e.load_policy) +watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME) casbin_enforcer.set_watcher(watcher) + +# Call should_reload before every call of enforce to make sure +# the policy is update to date +if watcher.should_reload(): + adapter.load_policy() + +if casbin_enforcer.enforce("alice", "data1", "read"): + # permit alice to read data1 + pass +else: + # deny the request, show an error + pass +``` + +## Basic Usage Example With SSL Enabled + +See [PostgresQL documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS) for full details of SSL parameters. + +```python +... +watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME, sslmode="verify_full", sslcert=SSLCERT, sslrootcert=SSLROOTCERT, sslkey=SSLKEY) +... ``` From 09b2e8dc369b02f38cd01fe12550f0cb7339da03 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Tue, 9 Jul 2024 11:40:38 +0200 Subject: [PATCH 28/33] refactor: improved readibility --- postgresql_watcher/casbin_channel_subscription.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/postgresql_watcher/casbin_channel_subscription.py b/postgresql_watcher/casbin_channel_subscription.py index 0a5e19e..f568de8 100644 --- a/postgresql_watcher/casbin_channel_subscription.py +++ b/postgresql_watcher/casbin_channel_subscription.py @@ -52,10 +52,14 @@ def casbin_channel_subscription( while not db_cursor.closed: try: - if not select( - [db_connection], [], [], CASBIN_CHANNEL_SELECT_TIMEOUT - ) == ([], [], []): - logger.debug("Casbin policy update identified..") + select_result = select( + [db_connection], + [], + [], + CASBIN_CHANNEL_SELECT_TIMEOUT, + ) + if select_result != ([], [], []): + logger.debug("Casbin policy update identified") db_connection.poll() while db_connection.notifies: notify = db_connection.notifies.pop(0) From 6f3c031bbe7e59ec4ffc210db5abb2084a48e127 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Tue, 9 Jul 2024 13:02:34 +0200 Subject: [PATCH 29/33] refactor: resolve a potential infinite loop with a custom Exception --- postgresql_watcher/__init__.py | 2 +- postgresql_watcher/watcher.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/postgresql_watcher/__init__.py b/postgresql_watcher/__init__.py index 5f86668..c40d1dc 100644 --- a/postgresql_watcher/__init__.py +++ b/postgresql_watcher/__init__.py @@ -1 +1 @@ -from .watcher import PostgresqlWatcher +from .watcher import PostgresqlWatcher, PostgresqlWatcherChannelSubscriptionTimeoutError diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index bc33db0..95f0aff 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -111,11 +111,15 @@ def start(self): self.subscription_proces.start() # And wait for the Process to be ready to listen for updates # from PostgreSQL + timeout = 20 # seconds + timeout_time = time() + timeout while True: if self.parent_conn.poll(): message = int(self.parent_conn.recv()) if message == _ChannelSubscriptionMessage.IS_READY: break + if time() > timeout_time: + raise PostgresqlWatcherChannelSubscriptionTimeoutError(timeout) sleep(1 / 1000) # wait for 1 ms def _cleanup_connections_and_processes(self) -> None: @@ -175,3 +179,13 @@ def should_reload(self) -> bool: self._create_subscription_process(delay=10) return False + + +class PostgresqlWatcherChannelSubscriptionTimeoutError(RuntimeError): + """ + Raised if the channel subscription could not be established within a given timeout. + """ + + def __init__(self, timeout_in_seconds: float) -> None: + msg = f"The channel subscription could not be established within {timeout_in_seconds:.0f} seconds." + super().__init__(msg) From 241ac5c3dffec4570635c1f684c73897b2c5ccb6 Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Tue, 9 Jul 2024 13:06:54 +0200 Subject: [PATCH 30/33] refactor: make timeout configurable by the user --- postgresql_watcher/watcher.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 95f0aff..bdd4f21 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -105,13 +105,15 @@ def _create_subscription_process( if start_listening: self.start() - def start(self): + def start( + self, + timeout=20, # seconds + ): if not self.subscription_proces.is_alive(): # Start listening to messages self.subscription_proces.start() # And wait for the Process to be ready to listen for updates # from PostgreSQL - timeout = 20 # seconds timeout_time = time() + timeout while True: if self.parent_conn.poll(): From c053715c67f98ac23cfec6c4acf701f270c94e4b Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Tue, 9 Jul 2024 16:26:02 +0200 Subject: [PATCH 31/33] fix: docs --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index acd58b2..33736e5 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ from casbin.persist.adapters import FileAdapter casbin_enforcer = CasbinEnforcer(app, adapter) watcher = PostgresqlWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME) -watcher.set_update_callback(casbin_enforcer.adapter.load_policy) +watcher.set_update_callback(casbin_enforcer.load_policy) casbin_enforcer.set_watcher(watcher) # Call should_reload before every call of enforce to make sure @@ -54,7 +54,7 @@ casbin_enforcer.set_watcher(watcher) # Call should_reload before every call of enforce to make sure # the policy is update to date if watcher.should_reload(): - adapter.load_policy() + casbin_enforcer.load_policy() if casbin_enforcer.enforce("alice", "data1", "read"): # permit alice to read data1 From 63a78c38042916610db3e26e39bf5518787be9cf Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Wed, 10 Jul 2024 10:30:36 +0200 Subject: [PATCH 32/33] fix: ensure type hint compatibility with Python 3.9 --- postgresql_watcher/watcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index bdd4f21..8ff3b55 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -66,9 +66,9 @@ def __init__( if logger is None: logger = getLogger() self.logger = logger - self.parent_conn: Connection | None = None - self.child_conn: Connection | None = None - self.subscription_process: Process | None = None + self.parent_conn: Optional[Connection] = None + self.child_conn: Optional[Connection] = None + self.subscription_process: Optional[Process] = None self._create_subscription_process(start_listening) self.update_callback: Optional[Callable[[None], None]] = None From 89395315dc999cad5e0a74a025034530189b517c Mon Sep 17 00:00:00 2001 From: Thore Bartholomaeus Date: Wed, 10 Jul 2024 11:19:06 +0200 Subject: [PATCH 33/33] feat: make sure multiple calls of update() get resolved by one call of should_reload() thanks to @pradeepranwa1 --- postgresql_watcher/watcher.py | 13 +++++++++---- tests/test_postgresql_watcher.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/postgresql_watcher/watcher.py b/postgresql_watcher/watcher.py index 8ff3b55..818b38a 100644 --- a/postgresql_watcher/watcher.py +++ b/postgresql_watcher/watcher.py @@ -167,12 +167,17 @@ def update(self) -> None: def should_reload(self) -> bool: try: - if self.parent_conn.poll(): + should_reload_flag = False + while self.parent_conn.poll(): message = int(self.parent_conn.recv()) received_update = message == _ChannelSubscriptionMessage.RECEIVED_UPDATE - if received_update and self.update_callback is not None: - self.update_callback() - return received_update + if received_update: + should_reload_flag = True + + if should_reload_flag and self.update_callback is not None: + self.update_callback() + + return should_reload_flag except EOFError: self.logger.warning( "Child casbin-watcher subscribe process has stopped, " diff --git a/tests/test_postgresql_watcher.py b/tests/test_postgresql_watcher.py index 10ef81e..d3f9d70 100644 --- a/tests/test_postgresql_watcher.py +++ b/tests/test_postgresql_watcher.py @@ -93,6 +93,19 @@ def test_update_handler_called(self): self.assertTrue(main_watcher.should_reload()) self.assertTrue(handler.call_count == 1) + def test_update_handler_called_multiple_channel_messages(self): + channel_name = "test_update_handler_called_multiple_channel_messages" + main_watcher = get_watcher(channel_name) + handler = MagicMock() + main_watcher.set_update_callback(handler) + number_of_updates = 5 + for _ in range(number_of_updates): + main_watcher.update() + sleep(CASBIN_CHANNEL_SELECT_TIMEOUT * (number_of_updates + 1)) + while main_watcher.should_reload(): + pass + self.assertTrue(handler.call_count == 1) + def test_update_handler_not_called(self): channel_name = "test_update_handler_not_called" main_watcher = get_watcher(channel_name)