diff --git a/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java b/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java index 9020478..54c6095 100644 --- a/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java +++ b/src/main/java/org/casbin/adapter/JDBCBaseAdapter.java @@ -132,7 +132,7 @@ protected void migrate() throws SQLException { } stmt.executeUpdate(sql); - if (productName.equals("Oracle")) { + if ("Oracle".equals(productName)) { sql = renderActualSql("declare " + "V_NUM number;" + "BEGIN " + @@ -152,16 +152,16 @@ protected void migrate() throws SQLException { "if V_NUM > 0 then " + "null;" + "else " + - "execute immediate 'create trigger casbin_id_autoincrement before "+ - " insert on CASBIN_RULE for each row "+ - " when (new.id is null) "+ - " begin "+ - " select casbin_sequence.nextval into:new.id from dual;"+ + "execute immediate 'create trigger casbin_id_autoincrement before " + + " insert on CASBIN_RULE for each row " + + " when (new.id is null) " + + " begin " + + " select casbin_sequence.nextval into:new.id from dual;" + " end;';" + "end if;" + "END;"); stmt.executeUpdate(sql); - } else if (productName.equals("PostgreSQL")) { + } else if ("PostgreSQL".equals(productName)) { sql = renderActualSql("CREATE TABLE IF NOT EXISTS casbin_rule(id int NOT NULL PRIMARY KEY default nextval('CASBIN_SEQUENCE'::regclass), ptype VARCHAR(100) NOT NULL, v0 VARCHAR(100), v1 VARCHAR(100), v2 VARCHAR(100), v3 VARCHAR(100), v4 VARCHAR(100), v5 VARCHAR(100))"); stmt.executeUpdate(sql); } @@ -206,12 +206,12 @@ public void loadPolicy(Model model) { while (rSet.next()) { CasbinRule line = new CasbinRule(); line.ptype = rSet.getObject(1) == null ? "" : (String) rSet.getObject(1); - line.v0 = rSet.getObject(2) == null ? "" : (String) rSet.getObject(2); - line.v1 = rSet.getObject(3) == null ? "" : (String) rSet.getObject(3); - line.v2 = rSet.getObject(4) == null ? "" : (String) rSet.getObject(4); - line.v3 = rSet.getObject(5) == null ? "" : (String) rSet.getObject(5); - line.v4 = rSet.getObject(6) == null ? "" : (String) rSet.getObject(6); - line.v5 = rSet.getObject(7) == null ? "" : (String) rSet.getObject(7); + line.v0 = rSet.getObject(2) == null ? "" : (String) rSet.getObject(2); + line.v1 = rSet.getObject(3) == null ? "" : (String) rSet.getObject(3); + line.v2 = rSet.getObject(4) == null ? "" : (String) rSet.getObject(4); + line.v3 = rSet.getObject(5) == null ? "" : (String) rSet.getObject(5); + line.v4 = rSet.getObject(6) == null ? "" : (String) rSet.getObject(6); + line.v5 = rSet.getObject(7) == null ? "" : (String) rSet.getObject(7); loadPolicyLine(line, model); } } @@ -310,7 +310,7 @@ public void savePolicy(Model model) { } } - if(count!=0){ + if (count != 0) { ps.executeBatch(); } @@ -325,6 +325,7 @@ public void savePolicy(Model model) { } }); } + /** * addPolicy adds a policy rule to the storage. */ @@ -332,7 +333,7 @@ public void savePolicy(Model model) { public void addPolicy(String sec, String ptype, List rule) { List> rules = new ArrayList<>(); rules.add(rule); - this.addPolicies(sec,ptype,rules); + this.addPolicies(sec, ptype, rules); } @Override @@ -351,7 +352,7 @@ public void addPolicies(String sec, String ptype, List> rules) { conn.setAutoCommit(false); int count = 0; try (PreparedStatement ps = conn.prepareStatement(sql)) { - for(List rule:rules){ + for (List rule : rules) { CasbinRule line = savePolicyLine(ptype, rule); ps.setString(1, line.ptype); @@ -363,12 +364,12 @@ public void addPolicies(String sec, String ptype, List> rules) { ps.setString(7, line.v5); ps.addBatch(); if (++count == batchSize) { - count=0; + count = 0; ps.executeBatch(); ps.clearBatch(); } } - if(count!=0){ + if (count != 0) { ps.executeBatch(); } conn.commit(); @@ -391,7 +392,32 @@ public void removePolicy(String sec, String ptype, List rule) { if (CollectionUtils.isEmpty(rule)) { return; } - removeFilteredPolicy(sec, ptype, 0, rule.toArray(new String[0])); + + Failsafe.with(retryPolicy).run(ctx -> { + if (ctx.isRetry()) { + retry(ctx); + } + String sql = renderActualSql("DELETE FROM casbin_rule WHERE ptype = ?"); + int columnIndex = 0; + for (int i = 0; i < rule.size(); i++) { + sql = String.format("%s%s%s%s", sql, " AND v", columnIndex, " = ?"); + columnIndex++; + } + while (columnIndex <= 5) { + sql = String.format("%s%s%s%s", sql, " AND v", columnIndex, " IS NULL"); + columnIndex++; + } + try (PreparedStatement ps = conn.prepareStatement(sql)) { + ps.setString(1, ptype); + for (int j = 0; j < rule.size(); j++) { + ps.setString(j + 2, rule.get(j)); + } + int rows = ps.executeUpdate(); + if (rows < 1 && removePolicyFailed) { + throw new CasbinAdapterException(String.format("Remove policy error, remove %d rows, expect least 1 rows", rows)); + } + } + }); } @Override @@ -407,7 +433,7 @@ public void removePolicies(String sec, String ptype, List> rules) { conn.setAutoCommit(false); try { for (List rule : rules) { - removeFilteredPolicy(sec, ptype, 0, rule.toArray(new String[0])); + removePolicy(sec, ptype, rule); } conn.commit(); } catch (SQLException e) { diff --git a/src/test/java/org/casbin/adapter/JDBCAdapterTest.java b/src/test/java/org/casbin/adapter/JDBCAdapterTest.java index a813067..011553c 100644 --- a/src/test/java/org/casbin/adapter/JDBCAdapterTest.java +++ b/src/test/java/org/casbin/adapter/JDBCAdapterTest.java @@ -213,4 +213,48 @@ public void testConstructorParams() throws Exception { adapter.close(); adapterViaDataSource.close(); } + + @Test + public void testRemovePolicy() throws Exception { + JDBCAdapter adapter = new MySQLAdapterCreator().create(); + + // Because the DB is empty at first, + // so we need to load the policy from the file adapter (.CSV) first. + Enforcer e = new Enforcer("examples/rbac_model.conf", "examples/rbac_policy.csv"); + + // This is a trick to save the current policy to the DB. + // We can't call e.savePolicy() because the adapter in the enforcer is still the file adapter. + // The current policy means the policy in the jCasbin enforcer (aka in memory). + adapter.savePolicy(e.getModel()); + + e.clearPolicy(); + testGetPolicy(e, asList()); + + e = new Enforcer("examples/rbac_model.conf", adapter); + testGetPolicy(e, asList( + asList("alice", "data1", "read"), + asList("bob", "data2", "write"), + asList("data2_admin", "data2", "read"), + asList("data2_admin", "data2", "write"))); + + adapter.removePolicy("p", "p", Arrays.asList("alice", "data1", "read")); + e = new Enforcer("examples/rbac_model.conf", adapter); + testGetPolicy(e, asList( + asList("bob", "data2", "write"), + asList("data2_admin", "data2", "read"), + asList("data2_admin", "data2", "write"))); + + adapter.removePolicy("p", "p", Arrays.asList("bob", "data2")); + e = new Enforcer("examples/rbac_model.conf", adapter); + testGetPolicy(e, asList( + asList("bob", "data2", "write"), + asList("data2_admin", "data2", "read"), + asList("data2_admin", "data2", "write"))); + + adapter.removePolicy("p", "p", Arrays.asList("bob", "data2", "write")); + e = new Enforcer("examples/rbac_model.conf", adapter); + testGetPolicy(e, asList( + asList("data2_admin", "data2", "read"), + asList("data2_admin", "data2", "write"))); + } }