Skip to content

Commit f3113d7

Browse files
committed
feat: supprt where in
1 parent 3064854 commit f3113d7

File tree

4 files changed

+89
-10
lines changed

4 files changed

+89
-10
lines changed

internal/dinosql/gen.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ type GoQueryValue struct {
6666
Name string
6767
Struct *GoStruct
6868
Typ string
69+
// LocalSQLQuery 函数内部SQL
70+
LocalSQLQuery string
6971
}
7072

7173
func (v GoQueryValue) EmitStruct() bool {
@@ -513,6 +515,7 @@ func queryImports(r Generateable, settings config.CombinedSettings, filename str
513515
}
514516

515517
if sliceScan() {
518+
std["strings"] = struct{}{}
516519
pkg["github.com/lib/pq"] = struct{}{}
517520
}
518521
_, overrideNullTime := overrideTypes["pq.NullTime"]
@@ -1233,9 +1236,10 @@ import (
12331236
{{define "queryCode"}}
12341237
{{range .GoQueries}}
12351238
{{if $.OutputQuery .SourceName}}
1239+
{{if not .Arg.LocalSQLQuery}}
12361240
// {{.ConstantName}} -- name: {{.MethodName}} {{.Cmd}}
12371241
const {{.ConstantName}} = "{{.SQL}}"
1238-
1242+
{{end}}
12391243
{{if .Arg.EmitStruct}}
12401244
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
12411245
{{.Name}} {{.Type}} {{if $.EmitJSONTags}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}}
@@ -1269,16 +1273,20 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.Ty
12691273
{{range .Comments}}//{{.}}
12701274
{{end -}}
12711275
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.Type}}, error) {
1276+
var items []{{.Ret.Type}}
1277+
{{if .Arg.LocalSQLQuery}}
1278+
querySQL := "{{.Arg.LocalSQLQuery}}"
1279+
{{end}}
12721280
{{- if $.EmitPreparedQueries}}
1273-
rows, err := q.query(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}})
1281+
rows, err := q.query(ctx, q.{{.FieldName}},{{if .Arg.LocalSQLQuery}} querySQL {{else}} {{.ConstantName}} {{end}}, {{.Arg.Params}})
12741282
{{- else}}
1275-
rows, err := q.db.QueryContext(ctx, {{.ConstantName}}, {{.Arg.Params}})
1283+
rows, err := q.db.QueryContext(ctx, {{if .Arg.LocalSQLQuery}} querySQL {{else}} {{.ConstantName}} {{end}}, {{.Arg.Params}})
12761284
{{- end}}
12771285
if err != nil {
12781286
return nil, err
12791287
}
12801288
defer rows.Close()
1281-
var items []{{.Ret.Type}}
1289+
12821290
for rows.Next() {
12831291
var {{.Ret.Name}} {{.Ret.Type}}
12841292
if err := rows.Scan({{.Ret.Scan}}); err != nil {

internal/mysql/gen.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,20 +161,33 @@ func (r *Result) GoQueries(settings config.CombinedSettings) []dinosql.GoQuery {
161161
// Comments: query.Comments,
162162
Meta: query.Meta,
163163
}
164-
165164
if len(query.Params) == 1 {
166165
p := query.Params[0]
166+
167167
gq.Arg = dinosql.GoQueryValue{
168-
Name: p.Name,
169-
Typ: p.Typ,
168+
Name: p.GetVariableName(),
169+
Typ: p.GetTypeName(),
170+
}
171+
if p.IsLikeInStmt() {
172+
old, newer := p.ReplaceLikeInStmt("\"", "")
173+
if old != "" && newer != "" {
174+
gq.Arg.LocalSQLQuery = strings.Replace(query.SQL, old, newer, 1)
175+
}
170176
}
171177
} else if len(query.Params) > 1 {
172-
173178
structInfo := make([]structParams, len(query.Params))
179+
localSQLQuery := query.SQL
174180
for i := range query.Params {
181+
item := query.Params[i]
182+
originalName := item.GetVariableName()
183+
goTypeName := item.GetTypeName()
184+
old, newer := item.ReplaceLikeInStmt("\"", "arg."+strings.Title(originalName))
185+
if old != "" && newer != "" {
186+
localSQLQuery = strings.Replace(localSQLQuery, old, newer, 1)
187+
}
175188
structInfo[i] = structParams{
176-
originalName: query.Params[i].Name,
177-
goType: query.Params[i].Typ,
189+
originalName: originalName,
190+
goType: goTypeName,
178191
}
179192
}
180193

@@ -183,6 +196,9 @@ func (r *Result) GoQueries(settings config.CombinedSettings) []dinosql.GoQuery {
183196
Name: "arg",
184197
Struct: r.columnsToStruct(gq.MethodName+"Params", structInfo, settings),
185198
}
199+
if localSQLQuery != query.SQL {
200+
gq.Arg.LocalSQLQuery = localSQLQuery
201+
}
186202
}
187203

188204
if len(query.Columns) == 1 {

internal/mysql/param.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,56 @@ type Param struct {
1515
OriginalName string
1616
Name string
1717
Typ string
18+
// ComparisonExpr 比较表达式
19+
ComparisonExpr *sqlparser.ComparisonExpr
20+
}
21+
22+
// IsLikeInStmt 是否为in/not in表达式
23+
func (p Param) IsLikeInStmt() bool {
24+
if p.ComparisonExpr == nil {
25+
return false
26+
}
27+
op := p.ComparisonExpr.Operator
28+
return op == sqlparser.InStr || op == sqlparser.NotInStr
29+
}
30+
31+
// ReplaceLikeInStmt like stmt reaplce
32+
func (p Param) ReplaceLikeInStmt(quote string, paramsName string) (string, string) {
33+
if !p.IsLikeInStmt() {
34+
return "", ""
35+
}
36+
op := p.ComparisonExpr.Operator
37+
var old string
38+
if op == sqlparser.InStr {
39+
old = fmt.Sprintf("%s in (?)", p.Name)
40+
41+
}
42+
if op == sqlparser.NotInStr {
43+
old = fmt.Sprintf("%s not in (?)", p.Name)
44+
}
45+
if old == "" {
46+
return "", ""
47+
}
48+
if paramsName == "" {
49+
paramsName = p.GetVariableName()
50+
}
51+
newer := strings.Replace(old, "?", fmt.Sprintf("?%s + strings.Repeat(\",?\",len(%s)-1)+%s", quote, paramsName, quote), 1)
52+
return old, newer
53+
}
54+
55+
// GetVariableName 获取变量名称
56+
func (p Param) GetVariableName() string {
57+
if p.IsLikeInStmt() {
58+
return p.Name + "List"
59+
}
60+
return p.Name
61+
}
62+
63+
func (p Param) GetTypeName() string {
64+
if p.IsLikeInStmt() {
65+
return "[]" + p.Typ
66+
}
67+
return p.Typ
1868
}
1969

2070
func (pGen PackageGenerator) paramsInLimitExpr(limit *sqlparser.Limit, tableAliasMap FromTables) ([]*Param, error) {
@@ -70,6 +120,7 @@ func (pGen PackageGenerator) paramsInWhereExpr(e sqlparser.SQLNode, tableAliasMa
70120
return params, nil
71121
}
72122
e = expr.Expr
123+
73124
}
74125
switch v := e.(type) {
75126
case *sqlparser.Where:
@@ -82,6 +133,9 @@ func (pGen PackageGenerator) paramsInWhereExpr(e sqlparser.SQLNode, tableAliasMa
82133
if err != nil {
83134
return nil, err
84135
}
136+
if p != nil {
137+
p.ComparisonExpr = v
138+
}
85139
if found {
86140
params = append(params, p)
87141
}

internal/mysql/parse.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ func (pGen PackageGenerator) parseSelect(tree *sqlparser.Select, query string) (
190190
parsedQuery.Columns = cols
191191

192192
whereParams, err := pGen.paramsInWhereExpr(tree.Where, tableAliasMap, defaultTableName)
193+
193194
if err != nil {
194195
return nil, err
195196
}

0 commit comments

Comments
 (0)