Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@ var DefaultCallback = &Callback{}
// Field `deletes` contains callbacks will be call when deleting object
// Field `queries` contains callbacks will be call when querying object with query methods like Find, First, Related, Association...
// Field `rowQueries` contains callbacks will be call when querying object with Row, Rows...
// Field `beforeSQL` contains callbacks will be call before any sql query was executed
// Field `afterSQL` contains callbacks will be call after any sql query was executed
// Field `processors` contains all callback processors, will be used to generate above callbacks in order
type Callback struct {
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
beforeSQL []*func(scope *Scope)
afterSQL []*func(scope *Scope)
processors []*CallbackProcessor
}

Expand All @@ -43,6 +47,8 @@ func (c *Callback) clone() *Callback {
queries: c.queries,
rowQueries: c.rowQueries,
processors: c.processors,
beforeSQL: c.beforeSQL,
afterSQL: c.afterSQL,
}
}

Expand Down Expand Up @@ -79,6 +85,16 @@ func (c *Callback) RowQuery() *CallbackProcessor {
return &CallbackProcessor{kind: "row_query", parent: c}
}

// Trace could be used to register callbacks before any sql queries, refer `Create` for usage
func (c *Callback) BeforeSQL() *CallbackProcessor {
return &CallbackProcessor{kind: "beforeSQL", parent: c}
}

// Trace could be used to register callbacks after any sql queries, refer `Create` for usage
func (c *Callback) AfterSQL() *CallbackProcessor {
return &CallbackProcessor{kind: "afterSQL", parent: c}
}

// After insert a new callback after callback `callbackName`, refer `Callbacks.Create`
func (cp *CallbackProcessor) After(callbackName string) *CallbackProcessor {
cp.after = callbackName
Expand Down Expand Up @@ -210,7 +226,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) {

// reorder all registered processors, and reset CURD callbacks
func (c *Callback) reorder() {
var creates, updates, deletes, queries, rowQueries []*CallbackProcessor
var creates, updates, deletes, queries, rowQueries, beforeSQL, afterSQL []*CallbackProcessor

for _, processor := range c.processors {
if processor.name != "" {
Expand All @@ -225,6 +241,10 @@ func (c *Callback) reorder() {
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
case "beforeSQL":
beforeSQL = append(beforeSQL, processor)
case "afterSQL":
afterSQL = append(afterSQL, processor)
}
}
}
Expand All @@ -234,4 +254,6 @@ func (c *Callback) reorder() {
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
c.beforeSQL = sortProcessors(beforeSQL)
c.afterSQL = sortProcessors(afterSQL)
}
3 changes: 2 additions & 1 deletion callback_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.beforeSQL)
defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL)

var (
columns, placeholders []string
Expand Down
3 changes: 2 additions & 1 deletion callback_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ func init() {

// queryCallback used to query data from database
func queryCallback(scope *Scope) {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.beforeSQL)
defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL)

var (
isSlice, isPtr bool
Expand Down
29 changes: 29 additions & 0 deletions callback_trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package gorm

import (
"time"
)

// Define callbacks for tracing
func init() {
DefaultCallback.BeforeSQL().Register("gorm:start-time", startTimeCallback)
DefaultCallback.AfterSQL().Register("gorm:log", logCallback)
}

// startTimeCallback puts time when sql started in scope
func startTimeCallback(scope *Scope) {
scope.Set("gorm:trace-start-time", NowFunc())
}

// logCallback prints sql log
func logCallback(scope *Scope) {
if len(scope.SQL) <= 0 {
return
}

t, ok := scope.Get("gorm:trace-start-time")
if !ok {
return
}
scope.db.slog(scope.SQL, t.(time.Time), scope.SQLVars...)
}
21 changes: 9 additions & 12 deletions scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"regexp"
"strconv"
"strings"
"time"

"reflect"
)
Expand Down Expand Up @@ -346,9 +345,10 @@ func (scope *Scope) Raw(sql string) *Scope {

// Exec perform generated SQL
func (scope *Scope) Exec() *Scope {
defer scope.trace(NowFunc())

if !scope.HasError() {
scope.callCallbacks(scope.db.parent.callbacks.beforeSQL)
defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL)

if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count
Expand Down Expand Up @@ -884,14 +884,18 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
}

func (scope *Scope) row() *sql.Row {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.beforeSQL)
defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL)

scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySQL()
return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...)
}

func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(NowFunc())
scope.callCallbacks(scope.db.parent.callbacks.beforeSQL)
defer scope.callCallbacks(scope.db.parent.callbacks.afterSQL)

scope.callCallbacks(scope.db.parent.callbacks.rowQueries)
scope.prepareQuerySQL()
return scope.SQLDB().Query(scope.SQL, scope.SQLVars...)
Expand Down Expand Up @@ -945,13 +949,6 @@ func (scope *Scope) typeName() string {
return typ.Name()
}

// trace print sql log
func (scope *Scope) trace(t time.Time) {
if len(scope.SQL) > 0 {
scope.db.slog(scope.SQL, t, scope.SQLVars...)
}
}

func (scope *Scope) changeableField(field *Field) bool {
if selectAttrs := scope.SelectAttrs(); len(selectAttrs) > 0 {
for _, attr := range selectAttrs {
Expand Down