From db81460c3513b412ba5e75e505e9440d0ea2d1a6 Mon Sep 17 00:00:00 2001 From: pixel <303176530@qq.com> Date: Tue, 6 Apr 2021 17:12:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86gorm=E9=80=82=E9=85=8D=E5=99=A8?= =?UTF-8?q?=E8=B0=83=E6=95=B4=E8=87=B3=E6=9C=AC=E5=9C=B0=E9=81=BF=E5=85=8D?= =?UTF-8?q?windows=E6=B2=A1=E6=9C=89gcc=E7=9A=84=E6=83=85=E5=86=B5?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E7=B3=BB=E7=BB=9F=E6=97=A0=E6=B3=95=E5=90=AF?= =?UTF-8?q?=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/go.mod | 8 +- server/service/sys_casbin.go | 2 +- server/source/casbin.go | 2 +- server/utils/gormadapter/adapter.go | 655 ++++++++++++++++++++++++++++ 4 files changed, 662 insertions(+), 5 deletions(-) create mode 100644 server/utils/gormadapter/adapter.go diff --git a/server/go.mod b/server/go.mod index 8dd9857d..82f0d98c 100644 --- a/server/go.mod +++ b/server/go.mod @@ -9,7 +9,7 @@ require ( github.com/aliyun/aliyun-oss-go-sdk v2.1.6+incompatible github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect github.com/casbin/casbin/v2 v2.25.6 - github.com/casbin/gorm-adapter/v3 v3.2.5 + //github.com/casbin/gorm-adapter/v3 v3.2.5 github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect github.com/fsnotify/fsnotify v1.4.9 @@ -23,6 +23,7 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/golang/protobuf v1.4.2 // indirect github.com/gookit/color v1.3.1 + github.com/jackc/pgconn v1.8.1 github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 // indirect github.com/jordan-wright/email v0.0.0-20200824153738-3f5bafa1cd84 github.com/json-iterator/go v1.1.10 // indirect @@ -49,11 +50,12 @@ require ( github.com/tencentyun/cos-go-sdk-v5 v0.7.19 github.com/unrolled/secure v1.0.7 go.uber.org/zap v1.10.0 - golang.org/x/net v0.0.0-20201224014010-6772e930b67b // indirect golang.org/x/tools v0.0.0-20200324003944-a576cf524670 // indirect google.golang.org/protobuf v1.24.0 // indirect gopkg.in/ini.v1 v1.55.0 // indirect gopkg.in/yaml.v2 v2.3.0 // indirect gorm.io/driver/mysql v1.0.1 - gorm.io/gorm v1.20.9 + gorm.io/driver/postgres v1.0.8 + gorm.io/driver/sqlserver v1.0.7 + gorm.io/gorm v1.21.4 ) diff --git a/server/service/sys_casbin.go b/server/service/sys_casbin.go index 457be1b7..a10093e8 100644 --- a/server/service/sys_casbin.go +++ b/server/service/sys_casbin.go @@ -7,9 +7,9 @@ import ( "gin-vue-admin/model/request" "strings" + "gin-vue-admin/utils/gormadapter" "github.com/casbin/casbin/v2" "github.com/casbin/casbin/v2/util" - gormadapter "github.com/casbin/gorm-adapter/v3" _ "github.com/go-sql-driver/mysql" ) diff --git a/server/source/casbin.go b/server/source/casbin.go index 6c56f9a0..54d22134 100644 --- a/server/source/casbin.go +++ b/server/source/casbin.go @@ -3,7 +3,7 @@ package source import ( "gin-vue-admin/global" - gormadapter "github.com/casbin/gorm-adapter/v3" + "gin-vue-admin/utils/gormadapter" "github.com/gookit/color" "gorm.io/gorm" ) diff --git a/server/utils/gormadapter/adapter.go b/server/utils/gormadapter/adapter.go new file mode 100644 index 00000000..daa96f92 --- /dev/null +++ b/server/utils/gormadapter/adapter.go @@ -0,0 +1,655 @@ +// Copyright 2017 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Because SQLite3 needs CGO and windows environment does not actively have GCC +// It is necessary to comment out the SQLite3 code. +// The modification content is to comment the package about SQLite3 + +package gormadapter + +import ( + "context" + "errors" + "fmt" + "runtime" + "strings" + + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" + "github.com/jackc/pgconn" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + //"gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" +) + +const ( + defaultDatabaseName = "casbin" + defaultTableName = "casbin_rule" +) + +type customTableKey struct{} + +type CasbinRule struct { + ID uint `gorm:"primaryKey;autoIncrement"` + Ptype string `gorm:"size:100"` + V0 string `gorm:"size:100"` + V1 string `gorm:"size:100"` + V2 string `gorm:"size:100"` + V3 string `gorm:"size:100"` + V4 string `gorm:"size:100"` + V5 string `gorm:"size:100"` +} + +func (CasbinRule) TableName() string { + return "casbin_rule" +} + +type Filter struct { + PType []string + V0 []string + V1 []string + V2 []string + V3 []string + V4 []string + V5 []string +} + +// Adapter represents the Gorm adapter for policy storage. +type Adapter struct { + driverName string + dataSourceName string + databaseName string + tablePrefix string + tableName string + dbSpecified bool + db *gorm.DB + isFiltered bool +} + +// finalizer is the destructor for Adapter. +func finalizer(a *Adapter) { + sqlDB, err := a.db.DB() + if err != nil { + panic(err) + } + err = sqlDB.Close() + if err != nil { + panic(err) + } +} + +// NewAdapter is the constructor for Adapter. +// Params : databaseName,tableName,dbSpecified +// databaseName,{tableName/dbSpecified} +// {database/dbSpecified} +// databaseName and tableName are user defined. +// Their default value are "casbin" and "casbin_rule" +// +// dbSpecified is an optional bool parameter. The default value is false. +// It's up to whether you have specified an existing DB in dataSourceName. +// If dbSpecified == true, you need to make sure the DB in dataSourceName exists. +// If dbSpecified == false, the adapter will automatically create a DB named databaseName. +func NewAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) { + a := &Adapter{} + a.driverName = driverName + a.dataSourceName = dataSourceName + + a.tableName = defaultTableName + a.databaseName = defaultDatabaseName + a.dbSpecified = false + + if len(params) == 1 { + switch p1 := params[0].(type) { + case bool: + a.dbSpecified = p1 + case string: + a.databaseName = p1 + default: + return nil, errors.New("wrong format") + } + } else if len(params) == 2 { + switch p2 := params[1].(type) { + case bool: + a.dbSpecified = p2 + p1, ok := params[0].(string) + if !ok { + return nil, errors.New("wrong format") + } + a.databaseName = p1 + case string: + p1, ok := params[0].(string) + if !ok { + return nil, errors.New("wrong format") + } + a.databaseName = p1 + a.tableName = p2 + default: + return nil, errors.New("wrong format") + } + } else if len(params) == 3 { + if p3, ok := params[2].(bool); ok { + a.dbSpecified = p3 + a.databaseName = params[0].(string) + a.tableName = params[1].(string) + } else { + return nil, errors.New("wrong format") + } + } else if len(params) != 0 { + return nil, errors.New("too many parameters") + } + + // Open the DB, create it if not existed. + err := a.open() + if err != nil { + return nil, err + } + + // Call the destructor when the object is released. + runtime.SetFinalizer(a, finalizer) + + return a, nil +} + +// NewAdapterByDBUseTableName creates gorm-adapter by an existing Gorm instance and the specified table prefix and table name +// Example: gormadapter.NewAdapterByDBUseTableName(&db, "cms", "casbin") Automatically generate table name like this "cms_casbin" +func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*Adapter, error) { + if len(tableName) == 0 { + tableName = defaultTableName + } + + a := &Adapter{ + tablePrefix: prefix, + tableName: tableName, + } + + a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context}) + err := a.createTable() + if err != nil { + return nil, err + } + + return a, nil +} + +// NewFilteredAdapter is the constructor for FilteredAdapter. +// Casbin will not automatically call LoadPolicy() for a filtered adapter. +func NewFilteredAdapter(driverName string, dataSourceName string, params ...interface{}) (*Adapter, error) { + adapter, err := NewAdapter(driverName, dataSourceName, params...) + if err != nil { + return nil, err + } + adapter.isFiltered = true + return adapter, err +} + +// NewAdapterByDB creates gorm-adapter by an existing Gorm instance +func NewAdapterByDB(db *gorm.DB) (*Adapter, error) { + return NewAdapterByDBUseTableName(db, "", defaultTableName) +} + +func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}) (*Adapter, error) { + ctx := db.Statement.Context + if ctx == nil { + ctx = context.Background() + } + + ctx = context.WithValue(ctx, customTableKey{}, t) + + return NewAdapterByDBUseTableName(db.WithContext(ctx), "", defaultTableName) +} + +func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) { + var err error + var db *gorm.DB + if driverName == "postgres" { + db, err = gorm.Open(postgres.Open(dataSourceName), &gorm.Config{}) + } else if driverName == "mysql" { + db, err = gorm.Open(mysql.Open(dataSourceName), &gorm.Config{}) + } else if driverName == "sqlserver" { + db, err = gorm.Open(sqlserver.Open(dataSourceName), &gorm.Config{}) + } else { + return nil, errors.New("database dialect is not supported") + } + + // If you need SQLite, fill in the code above + /* else if driverName == "sqlite3" { + db, err = gorm.Open(sqlite.Open(dataSourceName), &gorm.Config{}) + } */ + + if err != nil { + return nil, err + } + return db, err +} + +func (a *Adapter) createDatabase() error { + var err error + db, err := openDBConnection(a.driverName, a.dataSourceName) + if err != nil { + return err + } + if a.driverName == "postgres" { + if err = db.Exec("CREATE DATABASE " + a.databaseName).Error; err != nil { + // 42P04 is duplicate_database + if err.(*pgconn.PgError).Code == "42P04" { + return nil + } + } + } + + // If you need SQLite, fill in the code above + /* else if a.driverName != "sqlite3" { + err = db.Exec("CREATE DATABASE IF NOT EXISTS " + a.databaseName).Error + } */ + if err != nil { + return err + } + return nil +} + +func (a *Adapter) open() error { + var err error + var db *gorm.DB + + if a.dbSpecified { + db, err = openDBConnection(a.driverName, a.dataSourceName) + if err != nil { + return err + } + } else { + if err = a.createDatabase(); err != nil { + return err + } + if a.driverName == "postgres" { + db, err = openDBConnection(a.driverName, a.dataSourceName+" dbname="+a.databaseName) + } else { + db, err = openDBConnection(a.driverName, a.dataSourceName+a.databaseName) + } + + // If you need SQLite, fill in the code above + /* else if a.driverName == "sqlite3" { + db, err = openDBConnection(a.driverName, a.dataSourceName) + } */ + if err != nil { + return err + } + } + + a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{}) + return a.createTable() +} + +func (a *Adapter) close() error { + a.db = nil + return nil +} + +// getTableInstance return the dynamic table name +func (a *Adapter) getTableInstance() *CasbinRule { + return &CasbinRule{} +} + +func (a *Adapter) getFullTableName() string { + if a.tablePrefix != "" { + return a.tablePrefix + "_" + a.tableName + } + return a.tableName +} + +func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + tableName := a.getFullTableName() + return db.Table(tableName) + } +} + +func (a *Adapter) createTable() error { + t := a.db.Statement.Context.Value(customTableKey{}) + + if t != nil { + return a.db.AutoMigrate(t) + } + + t = a.getTableInstance() + if err := a.db.AutoMigrate(t); err != nil { + return err + } + + tableName := a.getFullTableName() + index := "idx_" + tableName + hasIndex := a.db.Migrator().HasIndex(t, index) + if !hasIndex { + if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5)", index, tableName)).Error; err != nil { + return err + } + } + return nil +} + +func (a *Adapter) dropTable() error { + t := a.db.Statement.Context.Value(customTableKey{}) + if t == nil { + return a.db.Migrator().DropTable(a.getTableInstance()) + } + + return a.db.Migrator().DropTable(t) +} + +func loadPolicyLine(line CasbinRule, model model.Model) { + var p = []string{line.Ptype, + line.V0, line.V1, line.V2, line.V3, line.V4, line.V5} + + var lineText string + if line.V5 != "" { + lineText = strings.Join(p, ", ") + } else if line.V4 != "" { + lineText = strings.Join(p[:6], ", ") + } else if line.V3 != "" { + lineText = strings.Join(p[:5], ", ") + } else if line.V2 != "" { + lineText = strings.Join(p[:4], ", ") + } else if line.V1 != "" { + lineText = strings.Join(p[:3], ", ") + } else if line.V0 != "" { + lineText = strings.Join(p[:2], ", ") + } + + persist.LoadPolicyLine(lineText, model) +} + +// LoadPolicy loads policy from database. +func (a *Adapter) LoadPolicy(model model.Model) error { + var lines []CasbinRule + if err := a.db.Order("ID").Find(&lines).Error; err != nil { + return err + } + + for _, line := range lines { + loadPolicyLine(line, model) + } + + return nil +} + +// LoadFilteredPolicy loads only policy rules that match the filter. +func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error { + var lines []CasbinRule + + filterValue, ok := filter.(Filter) + if !ok { + return errors.New("invalid filter type") + } + + if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Order("ID").Find(&lines).Error; err != nil { + return err + } + + for _, line := range lines { + loadPolicyLine(line, model) + } + a.isFiltered = true + + return nil +} + +// IsFiltered returns true if the loaded policy has been filtered. +func (a *Adapter) IsFiltered() bool { + return a.isFiltered +} + +// filterQuery builds the gorm query to match the rule filter to use within a scope. +func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + if len(filter.PType) > 0 { + db = db.Where("ptype in (?)", filter.PType) + } + if len(filter.V0) > 0 { + db = db.Where("v0 in (?)", filter.V0) + } + if len(filter.V1) > 0 { + db = db.Where("v1 in (?)", filter.V1) + } + if len(filter.V2) > 0 { + db = db.Where("v2 in (?)", filter.V2) + } + if len(filter.V3) > 0 { + db = db.Where("v3 in (?)", filter.V3) + } + if len(filter.V4) > 0 { + db = db.Where("v4 in (?)", filter.V4) + } + if len(filter.V5) > 0 { + db = db.Where("v5 in (?)", filter.V5) + } + return db + } +} + +func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { + line := a.getTableInstance() + + line.Ptype = ptype + if len(rule) > 0 { + line.V0 = rule[0] + } + if len(rule) > 1 { + line.V1 = rule[1] + } + if len(rule) > 2 { + line.V2 = rule[2] + } + if len(rule) > 3 { + line.V3 = rule[3] + } + if len(rule) > 4 { + line.V4 = rule[4] + } + if len(rule) > 5 { + line.V5 = rule[5] + } + + return *line +} + +// SavePolicy saves policy to database. +func (a *Adapter) SavePolicy(model model.Model) error { + err := a.dropTable() + if err != nil { + return err + } + err = a.createTable() + if err != nil { + return err + } + + for ptype, ast := range model["p"] { + for _, rule := range ast.Policy { + line := a.savePolicyLine(ptype, rule) + err := a.db.Create(&line).Error + if err != nil { + return err + } + } + } + + for ptype, ast := range model["g"] { + for _, rule := range ast.Policy { + line := a.savePolicyLine(ptype, rule) + err := a.db.Create(&line).Error + if err != nil { + return err + } + } + } + + return nil +} + +// AddPolicy adds a policy rule to the storage. +func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { + line := a.savePolicyLine(ptype, rule) + err := a.db.Create(&line).Error + return err +} + +// RemovePolicy removes a policy rule from the storage. +func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error { + line := a.savePolicyLine(ptype, rule) + err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete + return err +} + +// AddPolicies adds multiple policy rules to the storage. +func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error { + return a.db.Transaction(func(tx *gorm.DB) error { + for _, rule := range rules { + line := a.savePolicyLine(ptype, rule) + if err := tx.Create(&line).Error; err != nil { + return err + } + } + return nil + }) +} + +// RemovePolicies removes multiple policy rules from the storage. +func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error { + return a.db.Transaction(func(tx *gorm.DB) error { + for _, rule := range rules { + line := a.savePolicyLine(ptype, rule) + if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete + return err + } + } + return nil + }) +} + +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + line := a.getTableInstance() + + line.Ptype = ptype + if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { + line.V0 = fieldValues[0-fieldIndex] + } + if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) { + line.V1 = fieldValues[1-fieldIndex] + } + if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) { + line.V2 = fieldValues[2-fieldIndex] + } + if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) { + line.V3 = fieldValues[3-fieldIndex] + } + if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) { + line.V4 = fieldValues[4-fieldIndex] + } + if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { + line.V5 = fieldValues[5-fieldIndex] + } + err := a.rawDelete(a.db, *line) + return err +} + +func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error { + queryArgs := []interface{}{line.Ptype} + + queryStr := "ptype = ?" + if line.V0 != "" { + queryStr += " and v0 = ?" + queryArgs = append(queryArgs, line.V0) + } + if line.V1 != "" { + queryStr += " and v1 = ?" + queryArgs = append(queryArgs, line.V1) + } + if line.V2 != "" { + queryStr += " and v2 = ?" + queryArgs = append(queryArgs, line.V2) + } + if line.V3 != "" { + queryStr += " and v3 = ?" + queryArgs = append(queryArgs, line.V3) + } + if line.V4 != "" { + queryStr += " and v4 = ?" + queryArgs = append(queryArgs, line.V4) + } + if line.V5 != "" { + queryStr += " and v5 = ?" + queryArgs = append(queryArgs, line.V5) + } + args := append([]interface{}{queryStr}, queryArgs...) + err := db.Delete(a.getTableInstance(), args...).Error + return err +} + +func appendWhere(line CasbinRule) (string, []interface{}) { + queryArgs := []interface{}{line.Ptype} + + queryStr := "ptype = ?" + if line.V0 != "" { + queryStr += " and v0 = ?" + queryArgs = append(queryArgs, line.V0) + } + if line.V1 != "" { + queryStr += " and v1 = ?" + queryArgs = append(queryArgs, line.V1) + } + if line.V2 != "" { + queryStr += " and v2 = ?" + queryArgs = append(queryArgs, line.V2) + } + if line.V3 != "" { + queryStr += " and v3 = ?" + queryArgs = append(queryArgs, line.V3) + } + if line.V4 != "" { + queryStr += " and v4 = ?" + queryArgs = append(queryArgs, line.V4) + } + if line.V5 != "" { + queryStr += " and v5 = ?" + queryArgs = append(queryArgs, line.V5) + } + return queryStr, queryArgs +} + +// UpdatePolicy updates a new policy rule to DB. +func (a *Adapter) UpdatePolicy(sec string, ptype string, oldRule, newPolicy []string) error { + oldLine := a.savePolicyLine(ptype, oldRule) + newLine := a.savePolicyLine(ptype, newPolicy) + return a.db.Model(&oldLine).Where(&oldLine).Updates(newLine).Error +} + +func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error { + oldPolicies := make([]CasbinRule, 0, len(oldRules)) + newPolicies := make([]CasbinRule, 0, len(oldRules)) + for _, oldRule := range oldRules { + oldPolicies = append(oldPolicies, a.savePolicyLine(ptype, oldRule)) + } + for _, newRule := range newRules { + newPolicies = append(newPolicies, a.savePolicyLine(ptype, newRule)) + } + tx := a.db.Begin() + for i := range oldPolicies { + if err := tx.Model(&oldPolicies[i]).Where(&oldPolicies[i]).Updates(newPolicies[i]).Error; err != nil { + tx.Rollback() + return err + } + } + return tx.Commit().Error +}