Skip to content

Commit cf79d29

Browse files
committed
fix some bugs
1 parent 9fd465c commit cf79d29

File tree

6 files changed

+115
-44
lines changed

6 files changed

+115
-44
lines changed

internal/codegen/golang/field.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ func (gf *Field) IsNullable() bool {
3636
}
3737

3838
func (gf *Field) Serialize() bool {
39-
return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "64") && !strings.HasSuffix(gf.Type, "32") && !strings.HasSuffix(gf.Type, "Type"))
39+
return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "Duration"))
40+
}
41+
42+
func (gf *Field) Deserialize() bool {
43+
return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "64") && !strings.HasSuffix(gf.Type, "32") && !strings.HasSuffix(gf.Type, "Type") && !strings.HasSuffix(gf.Type, "Duration"))
4044
}
4145

4246
func (gf *Field) HasLen() bool {
@@ -65,7 +69,7 @@ func (gf *Field) BindType() string {
6569
ty = ty[2:]
6670
}
6771
switch {
68-
case ty == "uint8", ty == "int8", ty == "uint32":
72+
case ty == "uint8", ty == "int8", ty == "uint32", strings.HasSuffix(ty, "Duration"):
6973
return "int64"
7074
case ty == "float32":
7175
return "float64"

internal/codegen/golang/imports.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ func (i *importer) dbImports() fileImports {
124124
}
125125
std := []ImportSpec{
126126
{Path: "runtime"},
127-
{Path: "context"},
128127
}
129128

130129
sqlpkg := parseDriver(i.Options.SqlPackage)

internal/codegen/golang/query.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ func (q *Query) hasRetType() bool {
326326
return scanned && !q.Ret.isEmpty()
327327
}
328328

329+
func (q *Query) FetchRowInNext() bool {
330+
return q.Cmd != metadata.CmdExec && q.Cmd != metadata.CmdIter
331+
}
332+
329333
func (q *Query) TableIdentifierAsGoSlice() string {
330334
escapedNames := make([]string, 0, 3)
331335
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {

internal/codegen/golang/result.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,6 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []
327327
func getExistingStructIfDuplicate(req *plugin.GenerateRequest, options *opts.Options, structs []Struct, query *plugin.Query, columns []*plugin.Column) (*Struct, bool) {
328328
StructLoop:
329329
for _, s := range structs {
330-
if s.Name == "Queue" && query.Name == "CreateQueue" {
331-
fmt.Println("getExistingStructIfDuplicate for Queue and CreateQueueParams", len(s.Fields), len(columns))
332-
}
333330
if len(s.Fields) != len(columns) {
334331
continue
335332
}

internal/codegen/golang/templates/sqlite/dbCode.tmpl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@ type Queries struct {
44
conn *sqlite.Conn
55
pin runtime.Pinner
66
beginRead iter
7-
beginWrite iter
87
commit iter
9-
rollback iter
108
{{- range .ReadQueries }}
119
{{.FieldName}} {{.MethodName}}Stmt
1210
{{- end}}
1311
}
1412

1513
type WriteTxn struct {
1614
Queries
15+
beginWrite iter
16+
rollback iter
1717
{{- range .WriteQueries }}
1818
{{.FieldName}} {{.MethodName}}Stmt
1919
{{- end}}
@@ -25,42 +25,46 @@ func (q *Queries) init(conn *sqlite.Conn) {
2525
}
2626
q.conn = conn
2727
q.beginRead.init(q, "begin", "begin", false)
28-
q.beginWrite.init(q, "begin immediate", "begin immediate", true)
2928
q.commit.init(q, "commit", "commit", false)
30-
q.rollback.init(q, "rollback", "rollback", false)
29+
}
30+
31+
func (txn *WriteTxn) init(conn *sqlite.Conn) {
32+
txn.Queries.init(conn)
33+
txn.beginWrite.init(&txn.Queries, "begin immediate", "begin immediate", false)
34+
txn.rollback.init(&txn.Queries, "rollback", "rollback", false)
3135
}
3236

3337
func (q *Queries) BeginRead() (int64, *lib.Error) {
3438
q.beginRead.Next()
35-
err := q.beginRead.Err()
39+
err := q.beginRead.takeErr()
3640
q.beginRead.reset()
3741
if err != nil {
3842
return 0, err
3943
}
4044
return q.GetVersion()
4145
}
4246

43-
func (q *Queries) BeginWrite() (int64, *lib.Error) {
44-
q.beginWrite.Next()
45-
err := q.beginWrite.Err()
46-
q.beginWrite.reset()
47+
func (txn *WriteTxn) BeginWrite() (int64, *lib.Error) {
48+
txn.beginWrite.Next()
49+
err := txn.beginWrite.takeErr()
50+
txn.beginWrite.reset()
4751
if err != nil {
4852
return 0, err
4953
}
50-
return q.GetVersion()
54+
return txn.GetVersion()
5155
}
5256

5357
func (q *Queries) Commit() (*lib.Error) {
5458
q.commit.Next()
55-
err := q.commit.Err()
59+
err := q.commit.takeErr()
5660
q.commit.reset()
5761
return err
5862
}
5963

60-
func (q *Queries) Rollback() (*lib.Error) {
61-
q.rollback.Next()
62-
err := q.rollback.Err()
63-
q.rollback.reset()
64+
func (txn *WriteTxn) Rollback() (*lib.Error) {
65+
txn.rollback.Next()
66+
err := txn.rollback.takeErr()
67+
txn.rollback.reset()
6468
return err
6569
}
6670

@@ -69,7 +73,7 @@ func (q *Queries) Close() *lib.Error {
6973
var err, firstErr *lib.Error
7074
{{- range .ReadQueries }}
7175
q.{{.FieldName}}.close()
72-
err = q.{{.FieldName}}.Err()
76+
err = q.{{.FieldName}}.takeErr()
7377
if firstErr == nil {
7478
firstErr = err
7579
}
@@ -85,7 +89,7 @@ func (txn *WriteTxn) Close() *lib.Error {
8589
var err, firstErr *lib.Error
8690
{{- range .WriteQueries }}
8791
txn.{{.FieldName}}.close()
88-
err = txn.{{.FieldName}}.Err()
92+
err = txn.{{.FieldName}}.takeErr()
8993
if firstErr == nil {
9094
firstErr = err
9195
}

internal/codegen/golang/templates/sqlite/queryCode.tmpl

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{{define "queryCodeStd"}}
22
{{range $.AllQueries .SourceName}}
3+
{{$query := .}}
34
const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
45
{{escape .SQL}}
56
{{$.Q}}
@@ -34,11 +35,7 @@ func (r *{{ .MethodName }}Stmt) bind(q *{{ .ReceiverType }}, {{ .Arg.Pair }}) {
3435
{{- if .Arg.HasSqlcSlices }}
3536
{{- range .Arg.Fields }}
3637
{{- if .HasSqlcSlice }}
37-
if len({{ .VariableForField }}) > 0 {
38-
query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{ .VariableForField }}))[1:], 1)
39-
} else {
40-
query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1)
41-
}
38+
query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", slicePlaceholders(len({{ .VariableForField }})), 1)
4239
{{- end }}
4340
{{- end }}
4441
{{- end }}
@@ -74,7 +71,14 @@ func (r *{{ .MethodName }}Stmt) Next() bool {
7471
if r.err != nil {
7572
return false
7673
}
77-
{{- if ne .Cmd ":exec" }}
74+
return r.Fetch()
75+
}
76+
77+
func (r *{{ .MethodName }}Stmt) Fetch() bool {
78+
if r.err != nil {
79+
return false
80+
}
81+
7882
{{ $fields := .Ret.Fields }}
7983
{{- if $fields }}
8084
col := 0
@@ -85,7 +89,7 @@ func (r *{{ .MethodName }}Stmt) Next() bool {
8589
{{- template "fetchColumn" . }}
8690
{{- end }}
8791
{{- end }}
88-
{{- end }}
92+
8993
return r.err == nil
9094
}
9195

@@ -95,14 +99,14 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De
9599
r.bind(q, {{ .Arg.Names }})
96100
defer r.reset()
97101
ok := r.Next()
98-
err = r.Err()
102+
err = r.takeErr()
99103
if err != nil {
100104
return
101105
}
102106
if !ok {
103-
err = lib.ErrorWithDepth(errNoRows, 2)
107+
err = lib.WrapWithDepth(errNoRows, 3)
104108
} else if r.Next() {
105-
err = lib.ErrorWithDepth(errTooManyRows, 2)
109+
err = lib.WrapWithDepth(errTooManyRows, 3)
106110
} else {
107111
result = r.Row
108112
}
@@ -118,7 +122,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret
118122
for r.Next() {
119123
results = append(results, r.Row)
120124
}
121-
err = r.Err()
125+
err = r.takeErr()
122126
if err != nil {
123127
results = nil
124128
}
@@ -132,7 +136,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R
132136
r.bind(q, {{ .Arg.Names }})
133137
defer r.reset()
134138
if !r.Next() {
135-
err = r.Err()
139+
err = r.takeErr()
136140
return
137141
}
138142
results = map[{{.Ret.KeyType}}]{{.Ret.DefineType}}{
@@ -141,7 +145,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R
141145
for r.Next() {
142146
results[r.Key] = r.Row
143147
}
144-
err = r.Err()
148+
err = r.takeErr()
145149
if err != nil {
146150
results = nil
147151
}
@@ -150,20 +154,79 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R
150154
{{end}}
151155

152156
{{if eq .Cmd ":iter"}}
153-
type {{.MethodName}}Iterator func(yield func(row {{.Ret.DefineType}}) bool)
157+
func (r *{{ .MethodName }}Stmt) HasNext() bool {
158+
if !r.iter.Next() {
159+
return false
160+
}
161+
if r.err != nil {
162+
return false
163+
}
164+
return r.err == nil
165+
}
154166

155-
func (q *{{ .ReceiverType }}) {{.MethodName}}Cursor({{ .Arg.Pair }}) {{.MethodName}}Iterator {
156-
r := &q.{{ .FieldName }}
157-
r.bind(q, {{ .Arg.Names }})
167+
func (r *{{ .MethodName }}Stmt) Close() *lib.Error {
168+
err := r.takeErr()
169+
r.reset()
170+
return err
171+
}
172+
173+
func (r *{{ .MethodName }}Stmt) Range() func(yield func(row {{.Ret.DefineType}}) bool) {
158174
return func(yield func(row {{.Ret.DefineType}}) bool) {
159175
defer r.reset()
160176
for r.Next() && yield(r.Row) {}
161177
}
162178
}
163179

164-
func (_ {{.MethodName}}Iterator) Close(q *{{ .ReceiverType }}) *lib.Error {
165-
return q.{{ .FieldName }}.Err()
180+
func (q *{{ .ReceiverType }}) {{.MethodName}}Cursor({{ .Arg.Pair }}) *{{ .MethodName }}Stmt {
181+
r := &q.{{ .FieldName }}
182+
r.bind(q, {{ .Arg.Names }})
183+
return r
184+
}
185+
186+
{{- range $index, $field := .Ret.Fields }}
187+
{{- if $field.IsNullable }}
188+
func (r *{{ $query.MethodName }}Stmt) Peek{{ $field.Name }}() {{ $field.BindType }} {
189+
const col = {{$index}}
190+
return {{ $field.FetchMethod }}(col)
191+
}
192+
{{- else }}
193+
func (r *{{ $query.MethodName }}Stmt) Peek{{ $field.Name }}() (val {{ $field.BindType }}, isNull bool) {
194+
const col = {{$index}}
195+
if r.stmt.ColumnIsNull(col) {
196+
isNull = true
197+
return
198+
}
199+
val = {{ $field.FetchMethod }}(col)
200+
return
201+
}
202+
{{- end }}
203+
204+
func (r *{{ $query.MethodName }}Stmt) Get{{ $field.Name }}() (val {{ $field.Type }}, ok bool) {
205+
var raw {{ $field.BindType }}
206+
{{- if $field.IsNullable }}
207+
raw = r.Peek{{ $field.Name }}()
208+
{{- else }}
209+
var isNull bool
210+
raw, isNull = r.Peek{{ $field.Name }}()
211+
if isNull {
212+
return val, true
213+
}
214+
{{- end }}
215+
{{- if $field.Deserialize }}
216+
var err *lib.Error
217+
val, err = {{ $field.DeserializeMethod }}(raw)
218+
if err != nil {
219+
r.setErr(err)
220+
return val, false
221+
}
222+
{{- else if $field.NeedsCast $field.BindType }}
223+
val = {{ $field.Type }}(raw)
224+
{{- else }}
225+
val = raw
226+
{{- end }}
227+
return val, true
166228
}
229+
{{- end }}
167230
{{end}}
168231

169232
{{if eq .Cmd ":exec"}}
@@ -172,7 +235,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error {
172235
r.bind(q, {{ .Arg.Names }})
173236
defer r.reset()
174237
r.Next()
175-
return r.Err()
238+
return r.takeErr()
176239
}
177240
{{end}}
178241

@@ -253,7 +316,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error {
253316
{
254317
{{- end }}
255318
raw := {{ .FetchMethod }}(col)
256-
{{- if .Serialize }}
319+
{{- if .Deserialize }}
257320
v, err := {{ .DeserializeMethod }}(raw)
258321
if err != nil {
259322
r.setErr(err)

0 commit comments

Comments
 (0)