Skip to content

Commit

Permalink
feat: code refactoring and bug fixing (#29)
Browse files Browse the repository at this point in the history
* fix: typo in docstring

* refactor: deletion of an unnecessary variable

* refactor: DRY principle and correction E712

* refactor: Adding a check for the presence of created policies and a small code reduction

* fix: correcting the variable - rows_created
  • Loading branch information
InzGIBA authored Nov 3, 2023
1 parent 395e750 commit 81557ff
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 26 deletions.
8 changes: 4 additions & 4 deletions casbin_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def save_policy(self, model):
for ptype, ast in model.model[sec].items():
for rule in ast.policy:
lines.append(self._create_policy_line(ptype, rule))
CasbinRule.objects.using(self.db_alias).bulk_create(lines)
return True
rows_created = CasbinRule.objects.using(self.db_alias).bulk_create(lines)
return len(rows_created) > 0

def add_policy(self, sec, ptype, rule):
"""adds a policy rule to the storage."""
Expand All @@ -67,7 +67,7 @@ def remove_policy(self, sec, ptype, rule):
for i, v in enumerate(rule):
query_params["v{}".format(i)] = v
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
return True if rows_deleted > 0 else False
return rows_deleted > 0

def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
Expand All @@ -81,4 +81,4 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
for i, v in enumerate(field_values):
query_params["v{}".format(i + field_index)] = v
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
return True if rows_deleted > 0 else False
return rows_deleted > 0
29 changes: 10 additions & 19 deletions casbin_adapter/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, *args, **kwargs):
logger.info("Deferring casbin enforcer initialisation until django is ready")

def _load(self):
if self._initialized == False:
if self._initialized is False:
logger.info("Performing deferred casbin enforcer initialisation")
self._initialized = True
model = getattr(settings, "CASBIN_MODEL")
Expand Down Expand Up @@ -63,24 +63,15 @@ def __getattribute__(self, name):
def initialize_enforcer(db_alias=None):
try:
row = None
if db_alias:
with connections[db_alias].cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()
else:
with connection.cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()
connect = connections[db_alias] if db_alias else connection
with connect.cursor() as cursor:
cursor.execute(
"""
SELECT app, name applied FROM django_migrations
WHERE app = 'casbin_adapter' AND name = '0001_initial';
"""
)
row = cursor.fetchone()

if row:
enforcer._load()
Expand Down
5 changes: 2 additions & 3 deletions casbin_adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

def import_class(name):
"""Import class from string
e.g. `package.module.ClassToImport` returns the `ClasToImport` class"""
e.g. `package.module.ClassToImport` returns the `ClassToImport` class"""
components = name.split(".")
module_name = ".".join(components[:-1])
class_name = components[-1]
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
return getattr(module, class_name)

0 comments on commit 81557ff

Please sign in to comment.