From 02fb382ec0b67a320fc26cdd460a70468d037779 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 15:01:02 +0800 Subject: [PATCH 01/16] Support scan into int, string data types --- finisher_api.go | 4 +++- scan.go | 2 +- tests/scan_test.go | 10 ++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 6ece0f798..f426839ab 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -384,7 +384,9 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() - tx.Error = tx.Statement.Parse(dest) + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { diff --git a/scan.go b/scan.go index 89d9a07a6..be8782ede 100644 --- a/scan.go +++ b/scan.go @@ -82,7 +82,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64: + case *int, *int64, *uint, *uint64, *float32, *float64, *string: for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/tests/scan_test.go b/tests/scan_test.go index 92e895218..785bb97e4 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -91,4 +91,14 @@ func TestScanRows(t *testing.T) { if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + } } From ed1b134e1c6d8d791fc87a7286e9c534fa2840f2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 17:33:31 +0800 Subject: [PATCH 02/16] Fix use uint to for autoCreateTime, autoUpdateTime --- schema/field.go | 8 ++++++++ tests/customize_field_test.go | 22 +++++++++++----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index 60dc80950..4b8a5a2a7 100644 --- a/schema/field.go +++ b/schema/field.go @@ -624,6 +624,14 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: return field.Set(value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(value).SetUint(i) diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index bf3c78faf..7802eb115 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) { FieldAllowSave3 string `gorm:"->:false;<-:create"` FieldReadonly string `gorm:"->"` FieldIgnore string `gorm:"-"` - AutoUnixCreateTime int64 `gorm:"autocreatetime"` - AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` + AutoUnixCreateTime int32 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` - AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` - AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` - AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) @@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid result: %#v", result) } - if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { + if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { t.Fatalf("invalid create/update unix time: %#v", result) } - if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { + if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { t.Fatalf("invalid create/update unix milli time: %#v", result) } - if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { + if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } @@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) { var createWithDefaultTimeResult CustomizeFieldStruct DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) - if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } } From 0ec10d4907762e94ac942903670184a93e7ed456 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Sep 2020 12:37:16 +0800 Subject: [PATCH 03/16] Fix format SQL log, close #3465 --- logger/sql.go | 16 ++++++++++++++-- logger/sql_test.go | 6 ++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 096b94078..69a6b10e0 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -96,9 +96,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } if numericPlaceholder == nil { - for _, v := range vars { - sql = strings.Replace(sql, "?", v, 1) + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) } + + sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { diff --git a/logger/sql_test.go b/logger/sql_test.go index 180570b8a..b78f761c3 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -29,6 +29,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), From 1d5f910b6e1a377f7f7defadb606a3e9c7a09c01 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Sep 2020 15:29:47 +0800 Subject: [PATCH 04/16] Update workflows template --- .github/labels.json | 5 +++++ .github/workflows/invalid_question.yml | 22 ++++++++++++++++++++++ .github/workflows/missing_playground.yml | 2 +- 3 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/invalid_question.yml diff --git a/.github/labels.json b/.github/labels.json index 8b1ce8490..6b9c2034e 100644 --- a/.github/labels.json +++ b/.github/labels.json @@ -10,6 +10,11 @@ "colour": "#EDEDED", "description": "general questions" }, + "invalid_question": { + "name": "type:invalid question", + "colour": "#CF2E1F", + "description": "invalid question (not related to GORM or described in document or not enough information provided)" + }, "with_playground": { "name": "type:with reproduction steps", "colour": "#00ff00", diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml new file mode 100644 index 000000000..5b0bd9813 --- /dev/null +++ b/.github/workflows/invalid_question.yml @@ -0,0 +1,22 @@ +name: "Close invalid questions issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:invalid question" + diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 422cb9f56..ea3207d68 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,7 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 2 From 06d534d6eaa7f8534e51742b9930818511aaf28c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Sep 2020 12:41:45 +0800 Subject: [PATCH 05/16] Cascade delete associations, close #3473 --- callbacks/delete.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 510dfae40..549a94e7b 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -34,8 +34,23 @@ func DeleteBeforeAssociations(db *gorm.DB) { queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{}).Model(modelValue) - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return + withoutConditions := false + + if len(db.Statement.Selects) > 0 { + tx = tx.Select(db.Statement.Selects) + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions { + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } } case schema.Many2Many: var ( From a932175ccf98130aaa3028b75daf047a32b6dca0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Sep 2020 14:28:26 +0800 Subject: [PATCH 06/16] Refactor cascade delete associations --- callbacks/delete.go | 14 +++++++++++++- tests/delete_test.go | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 549a94e7b..85f11f4ba 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -37,7 +38,18 @@ func DeleteBeforeAssociations(db *gorm.DB) { withoutConditions := false if len(db.Statement.Selects) > 0 { - tx = tx.Select(db.Statement.Selects) + var selects []string + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if strings.HasPrefix(s, column+".") { + selects = append(selects, strings.TrimPrefix(s, column+".")) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } } for _, cond := range queryConds { diff --git a/tests/delete_test.go b/tests/delete_test.go index 4945e837a..ecd5ec396 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -136,7 +136,7 @@ func TestDeleteWithAssociations(t *testing.T) { t.Fatalf("failed to create user, got error %v", err) } - if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { + if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { t.Fatalf("failed to delete user, got error %v", err) } From d002c70cf6ac6f35e4a2840606e65d84d33c5391 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Sep 2020 21:52:41 +0800 Subject: [PATCH 07/16] Support named argument for struct --- clause/expression.go | 12 ++++++++++++ clause/expression_test.go | 10 ++++++++++ tests/go.mod | 4 ++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index dde236d35..49924ef73 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -3,6 +3,7 @@ package clause import ( "database/sql" "database/sql/driver" + "go/ast" "reflect" ) @@ -89,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) { for k, v := range value { namedMap[k] = v } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + } + } + } } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 17af737da..53d79c8f1 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -37,6 +37,11 @@ func TestExpr(t *testing.T) { } func TestNamedExpr(t *testing.T) { + type NamedArgument struct { + Name1 string + Name2 string + } + results := []struct { SQL string Result string @@ -66,6 +71,11 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, }} for idx, result := range results { diff --git a/tests/go.mod b/tests/go.mod index 17a3b156a..0db879341 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,9 +8,9 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.2 + gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 - gorm.io/gorm v1.20.0 + gorm.io/gorm v1.20.1 ) replace gorm.io/gorm => ../ From 072f1de83a842a991ea76cecfd14a7e93d5e67c1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:34:44 +0800 Subject: [PATCH 08/16] Add DryRunModeUnsupported Error for Row/Rows --- errors.go | 2 ++ finisher_api.go | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/errors.go b/errors.go index 508f69571..087550832 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrInvalidField = errors.New("invalid field") // ErrEmptySlice empty slice found ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") ) diff --git a/finisher_api.go b/finisher_api.go index f426839ab..2c56d7632 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -334,13 +334,21 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Row) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row } func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Rows), tx.Error + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error } // Scan scan value to a struct From c9165fe3cafc9a66e2513caae381e6864fa0a15b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:42:27 +0800 Subject: [PATCH 09/16] Don't panic when using unmatched vars in query, close #3488 --- clause/expression.go | 4 ++-- clause/expression_test.go | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 49924ef73..6a0dde8de 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -31,7 +31,7 @@ func (expr Expr) Build(builder Builder) { ) for _, v := range []byte(expr.SQL) { - if v == '?' { + if v == '?' && len(expr.Vars) > idx { if afterParenthesis { if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) @@ -122,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) { } builder.WriteByte(v) - } else if v == '?' { + } else if v == '?' && len(expr.Vars) > idx { builder.AddVar(builder, expr.Vars[idx]) idx++ } else if inName { diff --git a/clause/expression_test.go b/clause/expression_test.go index 53d79c8f1..19e30e6c0 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -76,6 +76,10 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{}, + Result: "create table ? (? ?, ? ?)", }} for idx, result := range results { From 089939c767f89087366799e47ab24d5b7b36c5e4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:50:11 +0800 Subject: [PATCH 10/16] AutoMigrate should auto create indexes, close #3486 --- migrator/migrator.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4b069c8ad..f390ff9f2 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + + for _, idx := range stmt.Schema.ParseIndexes() { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + return nil }); err != nil { return err From 68920449f92f24c8b17d90986eb155c251ed8fc7 Mon Sep 17 00:00:00 2001 From: caelansar <31852257+caelansar@users.noreply.github.com> Date: Sat, 19 Sep 2020 13:48:34 +0800 Subject: [PATCH 11/16] Fix format sql log (#3492) --- logger/sql.go | 4 ++-- logger/sql_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 69a6b10e0..138a35ec0 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -52,9 +52,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper case driver.Valuer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && (reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) { + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { r, _ := v.Value() - vars[idx] = fmt.Sprintf("%v", r) + convertParams(r, idx) } else { vars[idx] = "NULL" } diff --git a/logger/sql_test.go b/logger/sql_test.go index b78f761c3..71aa841af 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -1,13 +1,39 @@ package logger_test import ( + "database/sql/driver" + "encoding/json" + "fmt" "regexp" + "strings" "testing" "github.com/jinzhu/now" "gorm.io/gorm/logger" ) +type JSON json.RawMessage + +func (j JSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return json.RawMessage(j).MarshalJSON() +} + +type ExampleStruct struct { + Name string + Val string +} + +func (s ExampleStruct) Value() (driver.Value, error) { + return json.Marshal(s) +} + +func format(v []byte, escaper string) string { + return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper +} + func TestExplainSQL(t *testing.T) { type role string type password []byte @@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) { tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") pwd = password([]byte("pass")) + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} ) results := []struct { @@ -53,6 +83,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { From 1a526e6802a9692a1340277551a9117644af21f0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 11:32:38 +0800 Subject: [PATCH 12/16] Fix NamingStrategy with embedded struct, close #3513 --- schema/field.go | 2 +- schema/naming.go | 2 +- schema/naming_test.go | 26 ++++++++++++++++ schema/schema.go | 3 ++ schema/schema_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++ schema/utils.go | 5 ++++ tests/go.mod | 2 +- 7 files changed, 107 insertions(+), 3 deletions(-) diff --git a/schema/field.go b/schema/field.go index 4b8a5a2a7..ce2808a8c 100644 --- a/schema/field.go +++ b/schema/field.go @@ -326,7 +326,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } diff --git a/schema/naming.go b/schema/naming.go index ecdab791a..af753ce59 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,7 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - JoinTableName(table string) string + JoinTableName(joinTable string) string RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string diff --git a/schema/naming_test.go b/schema/naming_test.go index 96b83ced9..a4600cebf 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,6 +1,7 @@ package schema import ( + "strings" "testing" ) @@ -32,3 +33,28 @@ func TestToDBName(t *testing.T) { } } } + +type NewNamingStrategy struct { + NamingStrategy +} + +func (ns NewNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} diff --git a/schema/schema.go b/schema/schema.go index c3d3f6e0b..cffc19a7b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } schema := &Schema{ Name: modelType.Name(), diff --git a/schema/schema_test.go b/schema/schema_test.go index 6ca5b2693..a426cd905 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "strings" "sync" "testing" @@ -227,3 +228,72 @@ func TestEmbeddedStruct(t *testing.T) { }) } } + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} diff --git a/schema/utils.go b/schema/utils.go index 41bd9d60e..55cbdeb4d 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa return columns, queryValues } } + +type embeddedNamer struct { + Table string + Namer +} diff --git a/tests/go.mod b/tests/go.mod index 0db879341..c92fa0cf4 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 + gorm.io/driver/postgres v1.0.1 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.1 From 52287359153b5788d95960c963f74bebcdea88c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 15:00:13 +0800 Subject: [PATCH 13/16] Don't build IN condition if value implemented Valuer interface, #3517 --- statement.go | 16 +++++++++++----- tests/query_test.go | 5 +++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index ee80f8cd3..38d359269 100644 --- a/statement.go +++ b/statement.go @@ -299,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() - } + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: key, Values: values}) + } default: conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } diff --git a/tests/query_test.go b/tests/query_test.go index d3bcbdbe4..9c9ad9f28 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -345,6 +345,11 @@ func TestNot(t *testing.T) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) From c0de3c505176b0fea74c2e09fb9cae7c595b7020 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 19:28:52 +0800 Subject: [PATCH 14/16] Support FullSaveAssociations Mode, close #3487, #3506 --- callbacks/associations.go | 61 +++++++++++++++++++-------------- callbacks/create.go | 5 ++- gorm.go | 7 ++++ logger/logger.go | 7 ++-- tests/update_belongs_to_test.go | 19 ++++++++++ tests/update_has_many_test.go | 41 ++++++++++++++++++++++ tests/update_has_one_test.go | 35 +++++++++++++++++++ tests/update_many2many_test.go | 25 ++++++++++++++ 8 files changed, 171 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 0c677f477..64d79f243 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(f.Interface()).Error) } } } @@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } } @@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) { } } -func onConflictColumns(s *schema.Schema) (columns []clause.Column) { - if s.PrioritizedPrimaryField != nil { - return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } } - for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) + if len(defaultUpdatingColumns) > 0 { + var columns []clause.Column + if s.PrioritizedPrimaryField != nil { + columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } else { + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } } - return + + return clause.OnConflict{DoNothing: true} } diff --git a/callbacks/create.go b/callbacks/create.go index c00a0a73b..8e2454e8f 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -88,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) { } case reflect.Struct: if insertID > 0 { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } else { diff --git a/gorm.go b/gorm.go index 8efd8a73d..e5c4a8a49 100644 --- a/gorm.go +++ b/gorm.go @@ -20,6 +20,8 @@ type Config struct { SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool // Logger Logger logger.Interface // NowFunc the function to be used when creating a new timestamp @@ -64,6 +66,7 @@ type Session struct { WithConditions bool SkipDefaultTransaction bool AllowGlobalUpdate bool + FullSaveAssociations bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.AllowGlobalUpdate = true } + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/logger/logger.go b/logger/logger.go index 831192fc7..e568fb247 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -20,6 +20,7 @@ const ( Magenta = "\033[35m" Cyan = "\033[36m" White = "\033[37m" + BlueBold = "\033[34;1m" MagentaBold = "\033[35;1m" RedBold = "\033[31;1m" YellowBold = "\033[33;1m" @@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface { if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset - warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" } return &logger{ diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 47076e691..736dfc5b6 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) { var user2 User DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 01ea2e3ae..9066cbacc 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) { DB.Preload("Pets").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var user = *GetUser("update-has-many", Config{}) @@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) { var user2 User DB.Preload("Toys").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) }) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 7b29f424b..545685465 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Account").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"} @@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) { var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) }) } diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index a46deeb04..d94ef4ab9 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) { var user2 User DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } From ba253982bf558543187f3eb88295b88610cdc83b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 20:08:24 +0800 Subject: [PATCH 15/16] Fix Pluck with Time and Scanner --- scan.go | 13 +++++++++++-- schema/field.go | 6 ++++-- tests/query_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/scan.go b/scan.go index be8782ede..d7cddbe6c 100644 --- a/scan.go +++ b/scan.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "reflect" "strings" + "time" "gorm.io/gorm/schema" ) @@ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64, *string: + case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: for initialized || rows.Next() { initialized = false db.RowsAffected++ @@ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } // pluck values into slice of data - isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { + isPluck = true + } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + isPluck = true + } + } + for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/schema/field.go b/schema/field.go index ce2808a8c..db516c331 100644 --- a/schema/field.go +++ b/schema/field.go @@ -18,6 +18,8 @@ type DataType string type TimeType int64 +var TimeReflectType = reflect.TypeOf(time.Time{}) + const ( UnixSecond TimeType = 1 UnixMillisecond TimeType = 2 @@ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { for i := 0; i < rv.Type().NumField(); i++ { newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { @@ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time diff --git a/tests/query_test.go b/tests/query_test.go index 9c9ad9f28..431ccce24 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "fmt" "reflect" "regexp" @@ -431,6 +432,33 @@ func TestPluck(t *testing.T) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } } func TestSelect(t *testing.T) { From 9eec6ae06638665661f9872e783a42613527e146 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Sep 2020 12:25:38 +0800 Subject: [PATCH 16/16] Fix affected rows for Scan, change affected rows count for row/rows to '-', close #3532 --- callbacks.go | 1 - callbacks/row.go | 2 ++ finisher_api.go | 8 ++++++++ logger/logger.go | 49 +++++++++++++++++++++++++++++++++++++++--------- scan.go | 1 + 5 files changed, 51 insertions(+), 10 deletions(-) diff --git a/callbacks.go b/callbacks.go index 83d103df1..fdde21e9f 100644 --- a/callbacks.go +++ b/callbacks.go @@ -74,7 +74,6 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() stmt := db.Statement - db.RowsAffected = 0 if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/callbacks/row.go b/callbacks/row.go index a36c0116c..4f985d7bb 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -16,6 +16,8 @@ func RowQuery(db *gorm.DB) { } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } + + db.RowsAffected = -1 } } } diff --git a/finisher_api.go b/finisher_api.go index 2c56d7632..630615530 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -353,7 +354,9 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { + currentLogger, newLogger := db.Logger, logger.Recorder.New() tx = db.getInstance() + tx.Logger = newLogger if rows, err := tx.Rows(); err != nil { tx.AddError(err) } else { @@ -362,6 +365,11 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger return } diff --git a/logger/logger.go b/logger/logger.go index e568fb247..b278ad5d4 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -63,6 +63,7 @@ var ( LogLevel: Warn, Colorful: true, }) + Recorder = traceRecorder{Interface: Default} ) func New(writer Writer, config Config) Interface { @@ -70,18 +71,18 @@ func New(writer Writer, config Config) Interface { infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " - traceStr = "%s\n[%.3fms] [rows:%d] %s" - traceWarnStr = "%s\n[%.3fms] [rows:%d] %s" - traceErrStr = "%s %s\n[%.3fms] [rows:%d] %s" + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" ) if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" - traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" } return &logger{ @@ -138,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i switch { case err != nil && l.LogLevel >= Error: sql, rows := fc() - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case l.LogLevel >= Info: sql, rows := fc() - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } } } } + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface} +} + +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/scan.go b/scan.go index d7cddbe6c..8d737b175 100644 --- a/scan.go +++ b/scan.go @@ -52,6 +52,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: