Skip to content

Commit 36c74dd

Browse files
committed
wip, generate code for deferred using zombiezen sqlite driver
1 parent 8bf2817 commit 36c74dd

File tree

18 files changed

+473
-889
lines changed

18 files changed

+473
-889
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ __pycache__
99
.devenv*
1010
devenv.local.nix
1111

12+
# Binaries
13+
sqlc
14+

internal/codegen/golang/field.go

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,86 @@ import (
1111
)
1212

1313
type Field struct {
14-
Name string // CamelCased name for Go
15-
DBName string // Name as used in the DB
16-
Type string
17-
Tags map[string]string
18-
Comment string
19-
Column *plugin.Column
14+
Name string // CamelCased name for Go
15+
VariableForField string // Variable name for the field (including structVar.name, if part of a struct)
16+
DBName string // Name as used in the DB
17+
Type string
18+
Tags map[string]string
19+
Comment string
20+
Column *plugin.Column
2021
// EmbedFields contains the embedded fields that require scanning.
2122
EmbedFields []Field
2223
}
2324

24-
func (gf Field) Tag() string {
25+
func (gf *Field) Tag() string {
2526
return TagsToString(gf.Tags)
2627
}
2728

28-
func (gf Field) HasSqlcSlice() bool {
29-
return gf.Column.IsSqlcSlice
29+
func (gf *Field) HasSqlcSlice() bool {
30+
return gf.Column != nil && gf.Column.IsSqlcSlice
31+
}
32+
33+
func (gf *Field) IsNullable() bool {
34+
return strings.HasPrefix(gf.Type, "types.Null") || strings.HasSuffix(gf.Type, "ID") || gf.Column != nil && !gf.Column.NotNull
35+
}
36+
37+
func (gf *Field) Serialize() bool {
38+
return strings.HasPrefix(gf.Type, "*") || strings.HasSuffix(gf.Type, "ID")
39+
}
40+
41+
func (gf *Field) HasLen() bool {
42+
return gf.Type == "[]byte" || gf.Type == "string"
43+
}
44+
45+
func (gf *Field) Is64Bit() bool {
46+
return strings.HasSuffix(gf.Type, "64")
47+
}
48+
49+
func (gf *Field) BindType() string {
50+
ty := gf.Type
51+
if gf.HasSqlcSlice() {
52+
ty = ty[2:]
53+
}
54+
switch {
55+
case ty == "uint8", ty == "int8", ty == "uint32":
56+
return "int64"
57+
case ty == "float32":
58+
return "float64"
59+
case ty == "TimeOrderedID":
60+
return "[]byte"
61+
case strings.HasSuffix(ty, "Bool"):
62+
return "bool"
63+
case strings.HasSuffix(ty, "ID"):
64+
return "int64"
65+
default:
66+
return ty
67+
}
68+
}
69+
70+
func (gf *Field) BindMethod() string {
71+
bindType := gf.BindType()
72+
switch bindType {
73+
case "string":
74+
return "r.bindString"
75+
case "[]byte":
76+
return "r.bindBytes"
77+
case "float64":
78+
return "r.stmt.BindFloat"
79+
default:
80+
return "r.stmt.Bind" + toPascalCase(bindType)
81+
}
82+
}
83+
84+
func (gf *Field) FetchMethod() string {
85+
bindType := gf.BindType()
86+
switch bindType {
87+
case "[]byte":
88+
return "r.stmt.ColumnBytes"
89+
case "float64":
90+
return "r.stmt.ColumnFloat"
91+
default:
92+
return "r.stmt.Column" + toPascalCase(bindType)
93+
}
3094
}
3195

3296
func TagsToString(tags map[string]string) string {

internal/codegen/golang/gen.go

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ import (
1717
)
1818

1919
type tmplCtx struct {
20-
Q string
21-
Package string
22-
SQLDriver opts.SQLDriver
23-
Enums []Enum
24-
Structs []Struct
25-
GoQueries []Query
26-
SqlcVersion string
20+
Q string
21+
Package string
22+
SQLDriver opts.SQLDriver
23+
Enums []Enum
24+
Structs []Struct
25+
ReadQueries []Query
26+
WriteQueries []Query
27+
SqlcVersion string
2728

2829
// TODO: Race conditions
2930
SourceName string
@@ -104,6 +105,23 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
104105
}
105106
}
106107

108+
func (t *tmplCtx) AllQueries(sourceFile string) []Query {
109+
var allQueries []Query
110+
for i := range t.ReadQueries {
111+
q := &t.ReadQueries[i]
112+
if q.SourceName == sourceFile {
113+
allQueries = append(allQueries, *q)
114+
}
115+
}
116+
for i := range t.WriteQueries {
117+
q := &t.WriteQueries[i]
118+
if q.SourceName == sourceFile {
119+
allQueries = append(allQueries, *q)
120+
}
121+
}
122+
return allQueries
123+
}
124+
107125
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
108126
options, err := opts.Parse(req)
109127
if err != nil {
@@ -235,11 +253,34 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
235253
execute := func(name, templateName string) error {
236254
imports := i.Imports(name)
237255
replacedQueries := replaceConflictedArg(imports, queries)
256+
var readQueries []Query
257+
var writeQueries []Query
258+
for _, q := range replacedQueries {
259+
if q.Arg.Struct != nil {
260+
for i := range q.Arg.Struct.Fields {
261+
f := &q.Arg.Struct.Fields[i]
262+
f.VariableForField = q.Arg.VariableForField(*f)
263+
}
264+
}
265+
if q.Ret.Struct != nil {
266+
for i := range q.Ret.Struct.Fields {
267+
f := &q.Ret.Struct.Fields[i]
268+
f.VariableForField = q.Ret.VariableForField(*f)
269+
}
270+
}
271+
272+
if q.IsReadOnly() {
273+
readQueries = append(readQueries, q)
274+
} else {
275+
writeQueries = append(writeQueries, q)
276+
}
277+
}
238278

239279
var b bytes.Buffer
240280
w := bufio.NewWriter(&b)
241281
tctx.SourceName = name
242-
tctx.GoQueries = replacedQueries
282+
tctx.ReadQueries = readQueries
283+
tctx.WriteQueries = writeQueries
243284
err := tmpl.ExecuteTemplate(w, templateName, &tctx)
244285
w.Flush()
245286
if err != nil {

internal/codegen/golang/imports.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ func (i *importer) Imports(filename string) [][]ImportSpec {
119119
}
120120

121121
func (i *importer) dbImports() fileImports {
122-
var pkg []ImportSpec
123-
std := []ImportSpec{
124-
{Path: "context"},
122+
pkg := []ImportSpec{
123+
{Path: "github.com/deferred-dev/deferred/lib"},
125124
}
125+
var std []ImportSpec
126126

127127
sqlpkg := parseDriver(i.Options.SqlPackage)
128128
switch sqlpkg {
@@ -133,7 +133,7 @@ func (i *importer) dbImports() fileImports {
133133
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"})
134134
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"})
135135
default:
136-
std = append(std, ImportSpec{Path: "database/sql"})
136+
std = append(std, ImportSpec{Path: "zombiezen.com/go/sqlite"})
137137
if i.Options.EmitPreparedQueries {
138138
std = append(std, ImportSpec{Path: "fmt"})
139139
}
@@ -162,6 +162,7 @@ var pqtypeTypes = map[string]struct{}{
162162

163163
func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) {
164164
pkg := make(map[ImportSpec]struct{})
165+
pkg[ImportSpec{Path: "github.com/deferred-dev/deferred/lib"}] = struct{}{}
165166
std := make(map[string]struct{})
166167

167168
if uses("sql.Null") {
@@ -298,16 +299,12 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp
298299

299300
func (i *importer) queryImports(filename string) fileImports {
300301
var gq []Query
301-
anyNonCopyFrom := false
302302
for _, query := range i.Queries {
303303
if usesBatch([]Query{query}) {
304304
continue
305305
}
306306
if query.SourceName == filename {
307307
gq = append(gq, query)
308-
if query.Cmd != metadata.CmdCopyFrom {
309-
anyNonCopyFrom = true
310-
}
311308
}
312309
}
313310

@@ -390,10 +387,6 @@ func (i *importer) queryImports(filename string) fileImports {
390387
return false
391388
}
392389

393-
if anyNonCopyFrom {
394-
std["context"] = struct{}{}
395-
}
396-
397390
sqlpkg := parseDriver(i.Options.SqlPackage)
398391
if sqlcSliceScan() && !sqlpkg.IsPGX() {
399392
std["strings"] = struct{}{}

0 commit comments

Comments
 (0)