diff --git a/README.md b/README.md index c1f547f..5c7be62 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,15 @@ func main() { } ``` ## Turn off AutoMigrate -New an adapter will use ``AutoMigrate`` by default for create table, if you want to turn it off, please use API ``NewAdapterWithoutAutoMigrate(db *gorm.DB) (*Adapter, error)``. Find out more at [gorm-adapter#160](https://github.com/casbin/gorm-adapter/pull/160) +New an adapter will use ``AutoMigrate`` by default for create table, if you want to turn it off, please use API ``TurnOffAutoMigrate(db *gorm.DB) *gorm.DB``. See example: +```go +db, err := gorm.Open(mysql.Open("root:@tcp(127.0.0.1:3306)/casbin"), &gorm.Config{}) +db = TurnOffAutoMigrate(db) +// a,_ := NewAdapterByDB(...) +// a,_ := NewAdapterByDBUseTableName(...) +a,_ := NewAdapterByDBWithCustomTable(...) +``` +Find out more details at [gorm-adapter#162](https://github.com/casbin/gorm-adapter/issues/162) ## Customize table columns example You can change the gorm struct tags, but the table structure must stay the same. ```go diff --git a/adapter.go b/adapter.go index 4f7bbb8..d7badc2 100755 --- a/adapter.go +++ b/adapter.go @@ -198,12 +198,9 @@ func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (* a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context}) - disableMigrate := a.db.Statement.Context.Value(disableMigrateKey) - if disableMigrate == nil { - err := a.createTable() - if err != nil { - return nil, err - } + err := a.createTable() + if err != nil { + return nil, err } return a, nil @@ -255,7 +252,7 @@ func NewAdapterByDB(db *gorm.DB) (*Adapter, error) { return NewAdapterByDBUseTableName(db, "", defaultTableName) } -func NewAdapterWithoutAutoMigrate(db *gorm.DB) (*Adapter, error) { +func TurnOffAutoMigrate(db *gorm.DB) *gorm.DB { ctx := db.Statement.Context if ctx == nil { ctx = context.Background() @@ -263,7 +260,7 @@ func NewAdapterWithoutAutoMigrate(db *gorm.DB) (*Adapter, error) { ctx = context.WithValue(ctx, disableMigrateKey, false) - return NewAdapterByDBUseTableName(db.WithContext(ctx), "", defaultTableName) + return db.WithContext(ctx) } func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) { @@ -383,6 +380,11 @@ func (a *Adapter) casbinRuleTable() func(db *gorm.DB) *gorm.DB { } func (a *Adapter) createTable() error { + disableMigrate := a.db.Statement.Context.Value(disableMigrateKey) + if disableMigrate != nil { + return nil + } + t := a.db.Statement.Context.Value(customTableKey{}) if t != nil { @@ -414,6 +416,13 @@ func (a *Adapter) dropTable() error { return a.db.Migrator().DropTable(t) } +func (a *Adapter) truncateTable() error { + if a.db.Config.Name() == sqlite.DriverName { + return a.db.Exec(fmt.Sprintf("delete from %s", a.getFullTableName())).Error + } + return a.db.Exec(fmt.Sprintf("truncate table %s", a.getFullTableName())).Error +} + func loadPolicyLine(line CasbinRule, model model.Model) { var p = []string{line.Ptype, line.V0, line.V1, line.V2, @@ -538,11 +547,7 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { // SavePolicy saves policy to database. func (a *Adapter) SavePolicy(model model.Model) error { - err := a.dropTable() - if err != nil { - return err - } - err = a.createTable() + err := a.truncateTable() if err != nil { return err } diff --git a/adapter_test.go b/adapter_test.go index 955dfb3..c191853 100755 --- a/adapter_test.go +++ b/adapter_test.go @@ -194,21 +194,36 @@ func initAdapterWithGormInstanceByName(t *testing.T, db *gorm.DB, name string) * func initAdapterWithoutAutoMigrate(t *testing.T, db *gorm.DB) *Adapter { var err error - hasTable := db.Migrator().HasTable(defaultTableName) + var customTableName = "without_auto_migrate_custom_table" + hasTable := db.Migrator().HasTable(customTableName) if hasTable { - err = db.Migrator().DropTable(defaultTableName) + err = db.Migrator().DropTable(customTableName) if err != nil { panic(err) } } - a, err := NewAdapterWithoutAutoMigrate(db) + db = TurnOffAutoMigrate(db) + + type CustomCasbinRule struct { + ID uint `gorm:"primaryKey;autoIncrement"` + Ptype string `gorm:"size:50"` + V0 string `gorm:"size:50"` + V1 string `gorm:"size:50"` + V2 string `gorm:"size:50"` + V3 string `gorm:"size:50"` + V4 string `gorm:"size:50"` + V5 string `gorm:"size:50"` + V6 string `gorm:"size:50"` + V7 string `gorm:"size:50"` + } + a, err := NewAdapterByDBWithCustomTable(db, &CustomCasbinRule{}, customTableName) hasTable = a.db.Migrator().HasTable(a.getFullTableName()) if hasTable { t.Fatal("AutoMigration has been disabled but tables are still created in NewAdapterWithoutAutoMigrate method") } - err = a.db.Migrator().CreateTable(&CasbinRule{}) + err = a.db.Migrator().CreateTable(&CustomCasbinRule{}) if err != nil { panic(err) } @@ -391,7 +406,7 @@ func TestAdapterWithoutAutoMigrate(t *testing.T) { a = initAdapterWithoutAutoMigrate(t, db) testFilteredPolicy(t, a) - db, err = gorm.Open(postgres.Open("user=postgres password=postgres host=127.0.0.1 port=5432 sslmode=disable"), &gorm.Config{}) + db, err = gorm.Open(postgres.Open("user=postgres password=postgres host=localhost port=5432 sslmode=disable TimeZone=Asia/Shanghai"), &gorm.Config{}) if err != nil { panic(err) } @@ -414,6 +429,18 @@ func TestAdapterWithoutAutoMigrate(t *testing.T) { a = initAdapterWithoutAutoMigrate(t, db) testFilteredPolicy(t, a) + + db, err = gorm.Open(sqlite.Open("casbin.db"), &gorm.Config{}) + if err != nil { + panic(err) + } + + a = initAdapterWithoutAutoMigrate(t, db) + testAutoSave(t, a) + testSaveLoad(t, a) + + a = initAdapterWithoutAutoMigrate(t, db) + testFilteredPolicy(t, a) } func TestAdapterWithMulDb(t *testing.T) {