From 81557ffb7c7b2756fc2676662a693cd2d684e20e Mon Sep 17 00:00:00 2001 From: Almas Baktubayev Date: Fri, 3 Nov 2023 20:03:23 +0600 Subject: [PATCH] feat: code refactoring and bug fixing (#29) * fix: typo in docstring * refactor: deletion of an unnecessary variable * refactor: DRY principle and correction E712 * refactor: Adding a check for the presence of created policies and a small code reduction * fix: correcting the variable - rows_created --- casbin_adapter/adapter.py | 8 ++++---- casbin_adapter/enforcer.py | 29 ++++++++++------------------- casbin_adapter/utils.py | 5 ++--- 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/casbin_adapter/adapter.py b/casbin_adapter/adapter.py index 477bcce..29bdd68 100644 --- a/casbin_adapter/adapter.py +++ b/casbin_adapter/adapter.py @@ -53,8 +53,8 @@ def save_policy(self, model): for ptype, ast in model.model[sec].items(): for rule in ast.policy: lines.append(self._create_policy_line(ptype, rule)) - CasbinRule.objects.using(self.db_alias).bulk_create(lines) - return True + rows_created = CasbinRule.objects.using(self.db_alias).bulk_create(lines) + return len(rows_created) > 0 def add_policy(self, sec, ptype, rule): """adds a policy rule to the storage.""" @@ -67,7 +67,7 @@ def remove_policy(self, sec, ptype, rule): for i, v in enumerate(rule): query_params["v{}".format(i)] = v rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete() - return True if rows_deleted > 0 else False + return rows_deleted > 0 def remove_filtered_policy(self, sec, ptype, field_index, *field_values): """removes policy rules that match the filter from the storage. @@ -81,4 +81,4 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values): for i, v in enumerate(field_values): query_params["v{}".format(i + field_index)] = v rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete() - return True if rows_deleted > 0 else False + return rows_deleted > 0 diff --git a/casbin_adapter/enforcer.py b/casbin_adapter/enforcer.py index 96ac280..70d2eaf 100644 --- a/casbin_adapter/enforcer.py +++ b/casbin_adapter/enforcer.py @@ -21,7 +21,7 @@ def __init__(self, *args, **kwargs): logger.info("Deferring casbin enforcer initialisation until django is ready") def _load(self): - if self._initialized == False: + if self._initialized is False: logger.info("Performing deferred casbin enforcer initialisation") self._initialized = True model = getattr(settings, "CASBIN_MODEL") @@ -63,24 +63,15 @@ def __getattribute__(self, name): def initialize_enforcer(db_alias=None): try: row = None - if db_alias: - with connections[db_alias].cursor() as cursor: - cursor.execute( - """ - SELECT app, name applied FROM django_migrations - WHERE app = 'casbin_adapter' AND name = '0001_initial'; - """ - ) - row = cursor.fetchone() - else: - with connection.cursor() as cursor: - cursor.execute( - """ - SELECT app, name applied FROM django_migrations - WHERE app = 'casbin_adapter' AND name = '0001_initial'; - """ - ) - row = cursor.fetchone() + connect = connections[db_alias] if db_alias else connection + with connect.cursor() as cursor: + cursor.execute( + """ + SELECT app, name applied FROM django_migrations + WHERE app = 'casbin_adapter' AND name = '0001_initial'; + """ + ) + row = cursor.fetchone() if row: enforcer._load() diff --git a/casbin_adapter/utils.py b/casbin_adapter/utils.py index fb6654a..1c87a24 100644 --- a/casbin_adapter/utils.py +++ b/casbin_adapter/utils.py @@ -3,10 +3,9 @@ def import_class(name): """Import class from string - e.g. `package.module.ClassToImport` returns the `ClasToImport` class""" + e.g. `package.module.ClassToImport` returns the `ClassToImport` class""" components = name.split(".") module_name = ".".join(components[:-1]) class_name = components[-1] module = importlib.import_module(module_name) - class_ = getattr(module, class_name) - return class_ + return getattr(module, class_name)