Skip to content

Commit

Permalink
Merge pull request #24 from yyellowsun/master
Browse files Browse the repository at this point in the history
fix: Filtered Policy Support
  • Loading branch information
leeqvip authored Jan 12, 2021
2 parents 4b3d9d9 + d1ce630 commit da97331
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 18 deletions.
81 changes: 65 additions & 16 deletions casbin_sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from casbin import persist
from sqlalchemy import Column, Integer, String
from sqlalchemy import create_engine
from sqlalchemy import create_engine, and_, or_
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

Base = declarative_base()


class CasbinRule(Base):
__tablename__ = 'casbin_rule'

Expand All @@ -31,38 +29,89 @@ def __repr__(self):
return '<CasbinRule {}: "{}">'.format(self.id, str(self))


class Filter:
ptype = []
v0 = []
v1 = []
v2 = []
v3 = []
v4 = []
v5 = []

class Adapter(persist.Adapter):
"""the interface for Casbin adapters."""

def __init__(self, engine):
def __init__(self, engine, db_class=None, filtered=False):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
self._engine = engine

if db_class is None:
db_class = CasbinRule
self._db_class = db_class
session = sessionmaker(bind=self._engine)
self._session = session()

Base.metadata.create_all(self._engine)
self._filtered = filtered

def load_policy(self, model):
"""loads all policy rules from the storage."""
lines = self._session.query(CasbinRule).all()
lines = self._session.query(self._db_class).all()
for line in lines:
persist.load_policy_line(str(line), model)
self._commit()

def _save_policy_line(self, ptype, rule):
line = CasbinRule(ptype=ptype)
def is_filtered(self):
return self._filtered

def load_filtered_policy(self, model, filter) -> None:
"""loads all policy rules from the storage."""
query = self._session.query(self._db_class)
filters = self.filter_query(query,filter)
filters = filters.all()

for line in filters:
persist.load_policy_line(str(line), model)
self._filtered = True

def filter_query(self,querydb,filter):
ret = []
if len(filter.ptype) >0:
ret = querydb.filter(CasbinRule.ptype.in_(filter.ptype)).order_by(CasbinRule.id)
return ret
if len(filter.v0) >0:
ret = querydb.filter(CasbinRule.v0.in_(filter.v0)).order_by(CasbinRule.id)
return ret
if len(filter.v1) >0:
ret = querydb.filter(CasbinRule.v1.in_(filter.v1)).order_by(CasbinRule.id)
return ret
if len(filter.v2) >0:
ret = querydb.filter(CasbinRule.v2.in_(filter.v2)).order_by(CasbinRule.id)
return ret
if len(filter.v3) >0:
ret = querydb.filter(CasbinRule.v3.in_(filter.v3)).order_by(CasbinRule.id)
return ret
if len(filter.v4) >0:
ret = querydb.filter(CasbinRule.v4.in_(filter.v4)).order_by(CasbinRule.id)
return ret
if len(filter.v5) >0:
ret = querydb.filter(CasbinRule.v5.in_(filter.v5)).order_by(CasbinRule.id)
return ret

def _save_policy_line(self,ptype,rule):
line = self._db_class(ptype=ptype)
for i, v in enumerate(rule):
setattr(line, 'v{}'.format(i), v)
setattr(line, "v{}".format(i), v)
self._session.add(line)

def _commit(self):
self._session.commit()

def save_policy(self, model):
"""saves all policy rules to the storage."""
query = self._session.query(CasbinRule)
query = self._session.query(self._db_class)
query.delete()
for sec in ["p", "g"]:
if sec not in model.model.keys():
Expand All @@ -80,10 +129,10 @@ def add_policy(self, sec, ptype, rule):

def remove_policy(self, sec, ptype, rule):
"""removes a policy rule from the storage."""
query = self._session.query(CasbinRule)
query = query.filter(CasbinRule.ptype == ptype)
query = self._session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
for i, v in enumerate(rule):
query = query.filter(getattr(CasbinRule, 'v{}'.format(i)) == v)
query = query.filter(getattr(self._db_class, "v{}".format(i)) == v)
r = query.delete()
self._commit()

Expand All @@ -93,18 +142,18 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
"""removes policy rules that match the filter from the storage.
This is part of the Auto-Save feature.
"""
query = self._session.query(CasbinRule)
query = query.filter(CasbinRule.ptype == ptype)
query = self._session.query(self._db_class)
query = query.filter(self._db_class.ptype == ptype)
if not (0 <= field_index <= 5):
return False
if not (1 <= field_index + len(field_values) <= 6):
return False
for i, v in enumerate(field_values):
query = query.filter(getattr(CasbinRule, 'v{}'.format(field_index + i)) == v)
query = query.filter(getattr(self._db_class, "v{}".format(field_index + i)) == v)
r = query.delete()
self._commit()

return True if r > 0 else False

def __del__(self):
self._session.close()
self._session.close()
136 changes: 134 additions & 2 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from casbin_sqlalchemy_adapter import Adapter
from casbin_sqlalchemy_adapter import Base
from casbin_sqlalchemy_adapter import CasbinRule
from casbin_sqlalchemy_adapter.adapter import Filter
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from unittest import TestCase
import casbin
import os


def get_fixture(path):
dir_path = os.path.split(os.path.realpath(__file__))[0] + "/"
return os.path.abspath(dir_path + path)
Expand All @@ -34,13 +34,15 @@ def get_enforcer():


class TestConfig(TestCase):

def test_enforcer_basic(self):
e = get_enforcer()

self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'read'))
self.assertTrue(e.enforce('alice', 'data2', 'write'))

Expand Down Expand Up @@ -111,6 +113,14 @@ def test_remove_filtered_policy(self):
def test_str(self):
rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')
self.assertEqual(str(rule), 'p, alice, data1, read')
rule = CasbinRule(ptype='p', v0='bob', v1='data2', v2='write')
self.assertEqual(str(rule), 'p, bob, data2, write')
rule = CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='read')
self.assertEqual(str(rule), 'p, data2_admin, data2, read')
rule = CasbinRule(ptype='p', v0='data2_admin', v1='data2', v2='write')
self.assertEqual(str(rule), 'p, data2_admin, data2, write')
rule = CasbinRule(ptype='g', v0='alice', v1 = 'data2_admin')
self.assertEqual(str(rule), 'g, alice, data2_admin')

def test_repr(self):
rule = CasbinRule(ptype='p', v0='alice', v1='data1', v2='read')
Expand All @@ -125,3 +135,125 @@ def test_repr(self):
s.commit()
self.assertRegex(repr(rule), r'<CasbinRule \d+: "p, alice, data1, read">')
s.close()

def test_filtered_policy(self):
e= get_enforcer()
filter = Filter()

filter.ptype = ['p']
e.load_filtered_policy(filter)
self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))

filter.ptype = []
filter.v0 = ['alice']
e.load_filtered_policy(filter)
self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertFalse(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
self.assertFalse(e.enforce('data2_admin', 'data2','write'))

filter.v0 = ['bob']
e.load_filtered_policy(filter)
self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
self.assertFalse(e.enforce('data2_admin', 'data2','write'))

filter.v0 = ['data2_admin']
e.load_filtered_policy(filter)
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertFalse(e.enforce('bob', 'data2', 'write'))

filter.v0 = ['alice','bob']
e.load_filtered_policy(filter)
self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
self.assertFalse(e.enforce('data2_admin', 'data2','write'))

filter.v0 = []
filter.v1 = ['data1']
e.load_filtered_policy(filter)
self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertFalse(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
self.assertFalse(e.enforce('data2_admin', 'data2','write'))

filter.v1 = ['data2']
e.load_filtered_policy(filter)
self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
self.assertTrue(e.enforce('data2_admin', 'data2','write'))

filter.v1 = []
filter.v2 = ['read']
e.load_filtered_policy(filter)
self.assertTrue(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertFalse(e.enforce('bob', 'data2', 'write'))
self.assertTrue(e.enforce('data2_admin', 'data2','read'))
self.assertFalse(e.enforce('data2_admin', 'data2','write'))

filter.v2 = ['write']
e.load_filtered_policy(filter)
self.assertFalse(e.enforce('alice', 'data1', 'read'))
self.assertFalse(e.enforce('alice','data1','write'))
self.assertFalse(e.enforce('alice', 'data2', 'read'))
self.assertFalse(e.enforce('alice', 'data2', 'write'))
self.assertFalse(e.enforce('bob', 'data1', 'read'))
self.assertFalse(e.enforce('bob', 'data1', 'write'))
self.assertFalse(e.enforce('bob', 'data2', 'read'))
self.assertTrue(e.enforce('bob', 'data2', 'write'))
self.assertFalse(e.enforce('data2_admin', 'data2','read'))
self.assertTrue(e.enforce('data2_admin', 'data2','write'))

0 comments on commit da97331

Please sign in to comment.