diff --git a/adapter.go b/adapter.go index b7dd8a5..69591bb 100644 --- a/adapter.go +++ b/adapter.go @@ -1,305 +1,304 @@ package casbin import ( + "context" "errors" "fmt" + "math" + "strings" + "github.com/casbin/casbin/v2/model" "github.com/casbin/casbin/v2/persist" "github.com/gogf/gf/database/gdb" - "runtime" + "github.com/gogf/gf/frame/g" ) const ( - defaultTableName = "casbin_policy" + defaultGroupName = "casbin_database" + defaultTableName = "casbin_policy" + createPolicyTableSql = ` +CREATE TABLE IF NOT EXISTS %s ( + ptype VARCHAR(10) NOT NULL DEFAULT '' COMMENT '', + v0 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', + v1 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', + v2 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', + v3 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', + v4 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', + v5 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '' +) ENGINE = InnoDB COMMENT = 'policy table'; +` ) var ( - ErrMissingDatabaseDriver = errors.New("missing database driver") - ErrMissingDatabaseSource = errors.New("missing database source") + ErrInvalidDatabaseLink = errors.New("invalid database link") + policyColumns = defaultPolicyColumns{ + PType: "ptype", + V0: "v0", + V1: "v1", + V2: "v2", + V3: "v3", + V4: "v4", + V5: "v5", + } ) -type Adapter struct { - db gdb.DB - TableName string - DatabaseDriver string - DatabaseSource string -} - -type Rule struct { - PType string `json:"ptype"` - V0 string `json:"v0"` - V1 string `json:"v1"` - V2 string `json:"v2"` - V3 string `json:"v3"` - V4 string `json:"v4"` - V5 string `json:"v5"` -} - -// Create a casbin adapter -func NewAdapter(a *Adapter) (*Adapter, error) { - if err := a.init(); err != nil { - return nil, err +type ( + adapter struct { + db gdb.DB + table string } - return a, nil -} - -// Init database arguments -func (a *Adapter) init() error { - if a.DatabaseDriver == "" { - return ErrMissingDatabaseDriver + defaultPolicyColumns struct { + PType string // ptype + V0 string // V0 + V1 string // V1 + V2 string // V2 + V3 string // V3 + V4 string // V4 + V5 string // V5 } - if a.DatabaseSource == "" { - return ErrMissingDatabaseSource + // policy rule entity + policyRule struct { + PType string `orm:"ptype" json:"ptype"` + V0 string `orm:"v0" json:"v0"` + V1 string `orm:"v1" json:"v1"` + V2 string `orm:"v2" json:"v2"` + V3 string `orm:"v3" json:"v3"` + V4 string `orm:"v4" json:"v4"` + V5 string `orm:"v5" json:"v5"` } +) - if a.TableName == "" { - a.TableName = defaultTableName +// Create a casbin adapter +func newAdapter(link, table string, debug bool) (*adapter, error) { + config := strings.SplitN(link, ":", 2) + + if len(config) != 2 { + return nil, ErrInvalidDatabaseLink } - gdb.SetConfigGroup("casbin", gdb.ConfigGroup{ + gdb.SetConfigGroup(defaultGroupName, gdb.ConfigGroup{ gdb.ConfigNode{ - Type: a.DatabaseDriver, - LinkInfo: a.DatabaseSource, - Role: "master", - Weight: 100, + Debug: debug, + Type: config[0], + Link: config[1], }, }) - db, err := gdb.New("casbin") - - if err != nil { - return err + if table == "" { + table = defaultTableName } - a.db = db - - if err = a.db.PingMaster(); err != nil { - return err + a := &adapter{ + db: g.DB(defaultGroupName), + table: table, } - if err = a.createTable(); err != nil { - return err + if err := a.createPolicyTable(); err != nil { + return nil, err } - runtime.SetFinalizer(a, func(a *Adapter) { - a.db = nil - }) + return a, nil +} - return nil +func (a *adapter) model() *gdb.Model { + return a.db.Model(a.table).Safe() } -// Create this policy table -func (a *Adapter) createTable() error { - sql := ` - CREATE TABLE IF NOT EXISTS %s ( - ptype VARCHAR(10) NOT NULL DEFAULT '' COMMENT '', - v0 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', - v1 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', - v2 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', - v3 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', - v4 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '', - v5 VARCHAR(256) NOT NULL DEFAULT '' COMMENT '' - ) ENGINE = InnoDB COMMENT = 'policy table'; - ` - _, err := a.db.Exec(fmt.Sprintf(sql, a.TableName)) - - return err +// create a policy table when it's not exists. +func (a *adapter) createPolicyTable() (err error) { + _, err = a.db.Exec(fmt.Sprintf(createPolicyTableSql, a.table)) + + return } -// Drop the policy table -func (a *Adapter) dropTable() error { - _, err := a.db.Exec(fmt.Sprintf("DROP TABLE %s", a.TableName)) - return err +// drop policy table from the storage. +func (a *adapter) dropPolicyTable() (err error) { + _, err = a.db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", a.table)) + + return } -// Loads all policy rules from the storage. -func (a *Adapter) LoadPolicy(model model.Model) error { - var rules []Rule +// LoadPolicy loads all policy rules from the storage. +func (a *adapter) LoadPolicy(model model.Model) (err error) { + var rules []policyRule - if err := a.db.Model(a.TableName).Scan(&rules); err != nil { - return err + if err = a.model().Scan(&rules); err != nil { + return } for _, rule := range rules { a.loadPolicyRule(rule, model) } - return nil + return } -// Saves all policy rules to the storage. -func (a *Adapter) SavePolicy(model model.Model) error { - var ( - err error - rules = make([]Rule, 0) - ) - - if err = a.dropTable(); err != nil { - return err +// SavePolicy Saves all policy rules to the storage. +func (a *adapter) SavePolicy(model model.Model) (err error) { + if err = a.dropPolicyTable(); err != nil { + return } - if err = a.createTable(); err != nil { - return err + if err = a.createPolicyTable(); err != nil { + return } + policyRules := make([]policyRule, 0) + for ptype, ast := range model["p"] { for _, rule := range ast.Policy { - rules = append(rules, a.buildPolicyRule(ptype, rule)) + policyRules = append(policyRules, a.buildPolicyRule(ptype, rule)) } } for ptype, ast := range model["g"] { for _, rule := range ast.Policy { - rules = append(rules, a.buildPolicyRule(ptype, rule)) + policyRules = append(policyRules, a.buildPolicyRule(ptype, rule)) } } - if count := len(rules); count > 0 { - _, err = a.db.Model(a.TableName).Data(&rules).Insert() - - if err != nil { - return err + if count := len(policyRules); count > 0 { + if _, err = a.model().Insert(policyRules); err != nil { + return } } - return nil + return +} + +// AddPolicy adds a policy rule to the storage. +func (a *adapter) AddPolicy(sec string, ptype string, rule []string) (err error) { + _, err = a.model().Insert(a.buildPolicyRule(ptype, rule)) + + return } -// Adds a policy rule to the storage. -func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { - line := a.buildPolicyRule(ptype, rule) - _, err := a.db.Model(a.TableName).Data(&line).Insert() - return err +// AddPolicies adds policy rules to the storage. +func (a *adapter) AddPolicies(sec string, ptype string, rules [][]string) (err error) { + if len(rules) == 0 { + return + } + + policyRules := make([]policyRule, 0, len(rules)) + + for _, rule := range rules { + policyRules = append(policyRules, a.buildPolicyRule(ptype, rule)) + } + + _, err = a.model().Insert(policyRules) + + return } -// Removes a policy rule from the storage. -func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { +// RemovePolicy removes a policy rule from the storage. +func (a *adapter) RemovePolicy(sec string, ptype string, rule []string) error { return a.deletePolicyRule(a.buildPolicyRule(ptype, rule)) } -// Removes policy rules that match the filter from the storage. -func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { - rule := Rule{} +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + rule := policyRule{PType: ptype} - rule.PType = ptype if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { rule.V0 = fieldValues[0-fieldIndex] } + if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) { rule.V1 = fieldValues[1-fieldIndex] } + if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) { rule.V2 = fieldValues[2-fieldIndex] } + if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) { rule.V3 = fieldValues[3-fieldIndex] } + if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) { rule.V4 = fieldValues[4-fieldIndex] } + if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { rule.V5 = fieldValues[5-fieldIndex] } - err := a.deletePolicyRule(rule) - return err -} -// Adds a policy rule to the storage. -func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error { - var lines []Rule - for _, rule := range rules { - lines = append(lines, a.buildPolicyRule(ptype, rule)) - } - - _, err := a.db.Model(a.TableName).Data(&lines).Insert() - return err + return a.deletePolicyRule(rule) } -// Removes a policy rule from the storage. -func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { - db := a.db.Model(a.TableName) +// RemovePolicies removes policy rules from the storage (implements the persist.BatchAdapter interface). +func (a *adapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) { + db := a.model() for _, rule := range rules { - line := a.buildPolicyRule(ptype, rule) - sql := "" - val := make([]interface{}, 0) - - sql = "(ptype = ?" - val = append(val, ptype) - if line.V0 != "" { - sql += " and v0 = ?" - val = append(val, line.V0) - } - if line.V1 != "" { - sql += " and v1 = ?" - val = append(val, line.V1) - } - if line.V2 != "" { - sql += " and v2 = ?" - val = append(val, line.V2) - } - if line.V3 != "" { - sql += " and v3 = ?" - val = append(val, line.V3) - } - if line.V4 != "" { - sql += " and v4 = ?" - val = append(val, line.V4) - } - if line.V5 != "" { - sql += " and v5 = ?" - val = append(val, line.V5) + where := map[string]interface{}{policyColumns.PType: ptype} + + for i := 0; i <= 5; i++ { + if len(rule) > i { + where[fmt.Sprintf("v%d", i)] = rule[i] + } } - sql += ")" - db.Or(sql, val...) + db = db.WhereOr(where) } - _, err := db.Delete() - return err + _, err = db.Delete() + + return } -// Updates a policy rule from storage. -func (a *Adapter) UpdatePolicy(sec string, ptype string, oldPolicy, newPolicy []string) error { - oldRule := a.buildPolicyRule(ptype, oldPolicy) - newRule := a.buildPolicyRule(ptype, newPolicy) - _, err := a.db.Model(a.TableName).Update(&oldRule, &newRule) - return err +// UpdatePolicy updates a policy rule from storage. +func (a *adapter) UpdatePolicy(sec string, ptype string, oldRule, newRule []string) (err error) { + _, err = a.model().Update(a.buildPolicyRule(ptype, newRule), a.buildPolicyRule(ptype, oldRule)) + + return } -// Updates some policy rules to storage, like db, redis. -func (a *Adapter) UpdatePolicies(sec string, ptype string, oldPolicies, newPolicies [][]string) error { - for i, oldPolicy := range oldPolicies { - oldRule := a.buildPolicyRule(ptype, oldPolicy) - newRule := a.buildPolicyRule(ptype, newPolicies[i]) - _, err := a.db.Model(a.TableName).Update(&oldRule, &newRule) - return err +// UpdatePolicies updates some policy rules to storage, like db, redis. +func (a *adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (err error) { + if len(oldRules) == 0 || len(newRules) == 0 { + return } - return nil + err = a.db.Transaction(context.TODO(), func(ctx context.Context, tx *gdb.TX) error { + for i := 0; i < int(math.Min(float64(len(oldRules)), float64(len(newRules)))); i++ { + if _, err = tx.Model(a.table).Update(a.buildPolicyRule(ptype, newRules[i]), a.buildPolicyRule(ptype, oldRules[i])); err != nil { + return err + } + } + + return nil + }) + + return } -// Load policy rules -func (a *Adapter) loadPolicyRule(rule Rule, model model.Model) { +// 加载策略规则 +func (a *adapter) loadPolicyRule(rule policyRule, model model.Model) { ruleText := rule.PType if rule.V0 != "" { ruleText += ", " + rule.V0 } + if rule.V1 != "" { ruleText += ", " + rule.V1 } + if rule.V2 != "" { ruleText += ", " + rule.V2 } + if rule.V3 != "" { ruleText += ", " + rule.V3 } + if rule.V4 != "" { ruleText += ", " + rule.V4 } + if rule.V5 != "" { ruleText += ", " + rule.V5 } @@ -307,27 +306,30 @@ func (a *Adapter) loadPolicyRule(rule Rule, model model.Model) { persist.LoadPolicyLine(ruleText, model) } -// Build policy rules -func (a *Adapter) buildPolicyRule(ptype string, data []string) Rule { - rule := Rule{} - - rule.PType = ptype +// 构建策略规则 +func (a *adapter) buildPolicyRule(ptype string, data []string) policyRule { + rule := policyRule{PType: ptype} if len(data) > 0 { rule.V0 = data[0] } + if len(data) > 1 { rule.V1 = data[1] } + if len(data) > 2 { rule.V2 = data[2] } + if len(data) > 3 { rule.V3 = data[3] } + if len(data) > 4 { rule.V4 = data[4] } + if len(data) > 5 { rule.V5 = data[5] } @@ -335,30 +337,35 @@ func (a *Adapter) buildPolicyRule(ptype string, data []string) Rule { return rule } -// Delete policy rules -func (a *Adapter) deletePolicyRule(rule Rule) error { - db := a.db.Model(a.TableName) +// deletes a policy rule. +func (a *adapter) deletePolicyRule(rule policyRule) (err error) { + where := map[string]interface{}{policyColumns.PType: rule.PType} - db.Where("ptype = ?", rule.PType) if rule.V0 != "" { - db.Where("v0 = ?", rule.V0) + where[policyColumns.V0] = rule.V0 } + if rule.V1 != "" { - db.Where("v1 = ?", rule.V1) + where[policyColumns.V1] = rule.V1 } + if rule.V2 != "" { - db.Where("v2 = ?", rule.V2) + where[policyColumns.V2] = rule.V2 } + if rule.V3 != "" { - db.Where("v3 = ?", rule.V3) + where[policyColumns.V3] = rule.V3 } + if rule.V4 != "" { - db.Where("v4 = ?", rule.V4) + where[policyColumns.V4] = rule.V4 } + if rule.V5 != "" { - db.Where("v5 = ?", rule.V5) + where[policyColumns.V5] = rule.V5 } - _, err := db.Delete() - return err + _, err = a.model().Delete(where) + + return } diff --git a/enforcer.go b/enforcer.go index ef080ae..ea6aadc 100644 --- a/enforcer.go +++ b/enforcer.go @@ -1,50 +1,39 @@ package casbin import ( - "github.com/casbin/casbin/v2" "time" + + "github.com/casbin/casbin/v2" ) type Enforcer = casbin.Enforcer -type Casbin struct { - Model string // model config file path - Debug bool // debug mode - Enable bool // enable permission - AutoLoad bool // auto load policy - Duration time.Duration // auto load duration - TableName string // policy table name - DatabaseDriver string // database driver,support MySQL, SQLite, PostgreSQL, Oracle, SQL Server - DatabaseSource string // database source url +type Options struct { + Model string // model config file path + Debug bool // debug mode + Enable bool // enable permission + AutoLoad bool // auto load policy + Duration time.Duration // auto load duration + DbTable string // policy table name + DbLink string // database source url, example: mysql:root:12345678@tcp(127.0.0.1:3306)/test } -// Create a casbin enforcer -func NewEnforcer(c *Casbin) (*Enforcer, error) { - var ( - err error - adapter *Adapter - enforcer *Enforcer - ) - - adapter, err = NewAdapter(&Adapter{ - TableName: c.TableName, - DatabaseDriver: c.DatabaseDriver, - DatabaseSource: c.DatabaseSource, - }) - - if err != nil { - return nil, err - } +// NewEnforcer create a casbin enforcer. +func NewEnforcer(opt *Options) (enforcer *Enforcer, err error) { + var adp *adapter - enforcer, err = casbin.NewEnforcer(c.Model, adapter) + if adp, err = newAdapter(opt.DbLink, opt.DbTable, opt.Debug); err != nil { + return + } - if err != nil { - return nil, err + if enforcer, err = casbin.NewEnforcer(opt.Model, adp); err != nil { + return } - enforcer.EnableLog(c.Debug) - enforcer.EnableEnforce(c.Enable) - enforcer.EnableAutoNotifyWatcher(c.AutoLoad) + enforcer.EnableLog(opt.Debug) + enforcer.EnableEnforce(opt.Enable) + enforcer.EnableAutoNotifyWatcher(opt.AutoLoad) + enforcer.EnableAutoSave(true) - return enforcer, nil + return } diff --git a/example/main.go b/example/main.go index df71065..8010f2a 100644 --- a/example/main.go +++ b/example/main.go @@ -2,21 +2,21 @@ package main import ( "fmt" - "github.com/dobyte/gf-casbin" "log" + + "github.com/dobyte/gf-casbin" ) var enforcer *casbin.Enforcer func init() { - e, err := casbin.NewEnforcer(&casbin.Casbin{ - Model: "./example/model.conf", - Debug: false, - Enable: true, - AutoLoad: true, - TableName: "casbin_policy_test", - DatabaseDriver: "mysql", - DatabaseSource: "root:123456@tcp(127.0.0.1:3306)/ftft", + e, err := casbin.NewEnforcer(&casbin.Options{ + Model: "./example/model.conf", + Debug: true, + Enable: true, + AutoLoad: true, + DbTable: "casbin_policy_test", + DbLink: "mysql:root:123456@tcp(127.0.0.1:3306)/topic1", }) if err != nil { @@ -51,48 +51,42 @@ func main() { }) // check role_1 policy - ok = enforcer.HasPolicy("role_1", "node_1") - if ok { + if ok = enforcer.HasPolicy("role_1", "node_1"); ok { fmt.Println("role_1 is allowed access node_1") } else { fmt.Println("role_1 is not allowed access node_1") } // check role_1 policy - ok = enforcer.HasPolicy("role_1", "node_2") - if ok { + if ok = enforcer.HasPolicy("role_1", "node_2"); ok { fmt.Println("role_1 is allowed access node_2") } else { fmt.Println("role_1 is not allowed access node_2") } // check user_1 policy - ok = enforcer.HasGroupingPolicy("user_1", "role_1") - if ok { + if ok = enforcer.HasGroupingPolicy("user_1", "role_1"); ok { fmt.Println("user_1 has role_1") } else { fmt.Println("user_1 has not role_1") } // check user_1 policy - ok = enforcer.HasGroupingPolicy("user_1", "role_2") - if ok { + if ok = enforcer.HasGroupingPolicy("user_1", "role_2"); ok { fmt.Println("user_1 has role_2") } else { fmt.Println("user_1 has not role_2") } // check access permission of user_1 - ok, _ = enforcer.Enforce("user_1", "node_1") - if ok { + if ok, _ = enforcer.Enforce("user_1", "node_1"); ok { fmt.Println("user_1 is allowed access node_1") } else { fmt.Println("user_1 is not allowed access node_1") } // check access permission of user_1 - ok, _ = enforcer.Enforce("user_1", "node_2") - if ok { + if ok, _ = enforcer.Enforce("user_1", "node_2"); ok { fmt.Println("user_1 is allowed access node_2") } else { fmt.Println("user_1 is not allowed access node_2") @@ -142,32 +136,28 @@ func main() { fmt.Println() // check role_1 policy - ok = enforcer.HasPolicy("role_1", "node_1") - if ok { + if ok = enforcer.HasPolicy("role_1", "node_1"); ok { fmt.Println("role_1 is allowed access node_1") } else { fmt.Println("role_1 is not allowed access node_1") } // check role_1 policy - ok = enforcer.HasPolicy("role_1", "node_2") - if ok { + if ok = enforcer.HasPolicy("role_1", "node_2"); ok { fmt.Println("role_1 is allowed access node_2") } else { fmt.Println("role_1 is not allowed access node_2") } // check user_1 policy - ok = enforcer.HasGroupingPolicy("user_1", "role_1") - if ok { + if ok = enforcer.HasGroupingPolicy("user_1", "role_1"); ok { fmt.Println("user_1 has role_1") } else { fmt.Println("user_1 has not role_1") } // check user_1 policy - ok = enforcer.HasGroupingPolicy("user_1", "role_2") - if ok { + if ok = enforcer.HasGroupingPolicy("user_1", "role_2"); ok { fmt.Println("user_1 has role_2") } else { fmt.Println("user_1 has not role_2") diff --git a/go.mod b/go.mod index a6d997f..b944503 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/dobyte/gf-casbin go 1.15 require ( - github.com/casbin/casbin/v2 v2.25.5 - github.com/gogf/gf v1.15.5 + github.com/casbin/casbin/v2 v2.41.0 + github.com/gogf/gf v1.16.6 + github.com/gomodule/redigo v2.0.0+incompatible // indirect + github.com/mattn/go-runewidth v0.0.10 // indirect )