From 36c74dd02006e4c54f2022289bfe07db7f2b2f06 Mon Sep 17 00:00:00 2001 From: Eloff Date: Sat, 26 Oct 2024 19:42:41 -0600 Subject: [PATCH 01/18] wip, generate code for deferred using zombiezen sqlite driver --- .gitignore | 3 + internal/codegen/golang/field.go | 82 ++++++- internal/codegen/golang/gen.go | 57 ++++- internal/codegen/golang/imports.go | 17 +- internal/codegen/golang/query.go | 76 +++++-- internal/codegen/golang/sqlite_type.go | 53 +---- .../go-sql-driver-mysql/copyfromCopy.tmpl | 52 ----- .../golang/templates/pgx/batchCode.tmpl | 134 ----------- .../golang/templates/pgx/copyfromCopy.tmpl | 51 ----- .../codegen/golang/templates/pgx/dbCode.tmpl | 37 --- .../golang/templates/pgx/interfaceCode.tmpl | 73 ------ .../golang/templates/pgx/queryCode.tmpl | 124 ---------- .../golang/templates/sqlite/dbCode.tmpl | 57 +++++ .../golang/templates/sqlite/queryCode.tmpl | 211 ++++++++++++++++++ .../golang/templates/stdlib/dbCode.tmpl | 105 --------- .../templates/stdlib/interfaceCode.tmpl | 63 ------ .../golang/templates/stdlib/queryCode.tmpl | 155 ------------- .../codegen/golang/templates/template.tmpl | 12 +- 18 files changed, 473 insertions(+), 889 deletions(-) delete mode 100644 internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl delete mode 100644 internal/codegen/golang/templates/pgx/batchCode.tmpl delete mode 100644 internal/codegen/golang/templates/pgx/copyfromCopy.tmpl delete mode 100644 internal/codegen/golang/templates/pgx/dbCode.tmpl delete mode 100644 internal/codegen/golang/templates/pgx/interfaceCode.tmpl delete mode 100644 internal/codegen/golang/templates/pgx/queryCode.tmpl create mode 100644 internal/codegen/golang/templates/sqlite/dbCode.tmpl create mode 100644 internal/codegen/golang/templates/sqlite/queryCode.tmpl delete mode 100644 internal/codegen/golang/templates/stdlib/dbCode.tmpl delete mode 100644 internal/codegen/golang/templates/stdlib/interfaceCode.tmpl delete mode 100644 internal/codegen/golang/templates/stdlib/queryCode.tmpl diff --git a/.gitignore b/.gitignore index 39961ebb02..e1920d9165 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ __pycache__ .devenv* devenv.local.nix +# Binaries +sqlc + diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 2a63b6d342..2959bd1c68 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -11,22 +11,86 @@ import ( ) type Field struct { - Name string // CamelCased name for Go - DBName string // Name as used in the DB - Type string - Tags map[string]string - Comment string - Column *plugin.Column + Name string // CamelCased name for Go + VariableForField string // Variable name for the field (including structVar.name, if part of a struct) + DBName string // Name as used in the DB + Type string + Tags map[string]string + Comment string + Column *plugin.Column // EmbedFields contains the embedded fields that require scanning. EmbedFields []Field } -func (gf Field) Tag() string { +func (gf *Field) Tag() string { return TagsToString(gf.Tags) } -func (gf Field) HasSqlcSlice() bool { - return gf.Column.IsSqlcSlice +func (gf *Field) HasSqlcSlice() bool { + return gf.Column != nil && gf.Column.IsSqlcSlice +} + +func (gf *Field) IsNullable() bool { + return strings.HasPrefix(gf.Type, "types.Null") || strings.HasSuffix(gf.Type, "ID") || gf.Column != nil && !gf.Column.NotNull +} + +func (gf *Field) Serialize() bool { + return strings.HasPrefix(gf.Type, "*") || strings.HasSuffix(gf.Type, "ID") +} + +func (gf *Field) HasLen() bool { + return gf.Type == "[]byte" || gf.Type == "string" +} + +func (gf *Field) Is64Bit() bool { + return strings.HasSuffix(gf.Type, "64") +} + +func (gf *Field) BindType() string { + ty := gf.Type + if gf.HasSqlcSlice() { + ty = ty[2:] + } + switch { + case ty == "uint8", ty == "int8", ty == "uint32": + return "int64" + case ty == "float32": + return "float64" + case ty == "TimeOrderedID": + return "[]byte" + case strings.HasSuffix(ty, "Bool"): + return "bool" + case strings.HasSuffix(ty, "ID"): + return "int64" + default: + return ty + } +} + +func (gf *Field) BindMethod() string { + bindType := gf.BindType() + switch bindType { + case "string": + return "r.bindString" + case "[]byte": + return "r.bindBytes" + case "float64": + return "r.stmt.BindFloat" + default: + return "r.stmt.Bind" + toPascalCase(bindType) + } +} + +func (gf *Field) FetchMethod() string { + bindType := gf.BindType() + switch bindType { + case "[]byte": + return "r.stmt.ColumnBytes" + case "float64": + return "r.stmt.ColumnFloat" + default: + return "r.stmt.Column" + toPascalCase(bindType) + } } func TagsToString(tags map[string]string) string { diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 5b7977f500..749bfaa9e1 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -17,13 +17,14 @@ import ( ) type tmplCtx struct { - Q string - Package string - SQLDriver opts.SQLDriver - Enums []Enum - Structs []Struct - GoQueries []Query - SqlcVersion string + Q string + Package string + SQLDriver opts.SQLDriver + Enums []Enum + Structs []Struct + ReadQueries []Query + WriteQueries []Query + SqlcVersion string // TODO: Race conditions SourceName string @@ -104,6 +105,23 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) { } } +func (t *tmplCtx) AllQueries(sourceFile string) []Query { + var allQueries []Query + for i := range t.ReadQueries { + q := &t.ReadQueries[i] + if q.SourceName == sourceFile { + allQueries = append(allQueries, *q) + } + } + for i := range t.WriteQueries { + q := &t.WriteQueries[i] + if q.SourceName == sourceFile { + allQueries = append(allQueries, *q) + } + } + return allQueries +} + func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) { options, err := opts.Parse(req) if err != nil { @@ -235,11 +253,34 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, execute := func(name, templateName string) error { imports := i.Imports(name) replacedQueries := replaceConflictedArg(imports, queries) + var readQueries []Query + var writeQueries []Query + for _, q := range replacedQueries { + if q.Arg.Struct != nil { + for i := range q.Arg.Struct.Fields { + f := &q.Arg.Struct.Fields[i] + f.VariableForField = q.Arg.VariableForField(*f) + } + } + if q.Ret.Struct != nil { + for i := range q.Ret.Struct.Fields { + f := &q.Ret.Struct.Fields[i] + f.VariableForField = q.Ret.VariableForField(*f) + } + } + + if q.IsReadOnly() { + readQueries = append(readQueries, q) + } else { + writeQueries = append(writeQueries, q) + } + } var b bytes.Buffer w := bufio.NewWriter(&b) tctx.SourceName = name - tctx.GoQueries = replacedQueries + tctx.ReadQueries = readQueries + tctx.WriteQueries = writeQueries err := tmpl.ExecuteTemplate(w, templateName, &tctx) w.Flush() if err != nil { diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 9e7819e4b1..ebd8e864b2 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -119,10 +119,10 @@ func (i *importer) Imports(filename string) [][]ImportSpec { } func (i *importer) dbImports() fileImports { - var pkg []ImportSpec - std := []ImportSpec{ - {Path: "context"}, + pkg := []ImportSpec{ + {Path: "github.com/deferred-dev/deferred/lib"}, } + var std []ImportSpec sqlpkg := parseDriver(i.Options.SqlPackage) switch sqlpkg { @@ -133,7 +133,7 @@ func (i *importer) dbImports() fileImports { pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}) pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"}) default: - std = append(std, ImportSpec{Path: "database/sql"}) + std = append(std, ImportSpec{Path: "zombiezen.com/go/sqlite"}) if i.Options.EmitPreparedQueries { std = append(std, ImportSpec{Path: "fmt"}) } @@ -162,6 +162,7 @@ var pqtypeTypes = map[string]struct{}{ func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) { pkg := make(map[ImportSpec]struct{}) + pkg[ImportSpec{Path: "github.com/deferred-dev/deferred/lib"}] = struct{}{} std := make(map[string]struct{}) if uses("sql.Null") { @@ -298,16 +299,12 @@ func sortedImports(std map[string]struct{}, pkg map[ImportSpec]struct{}) fileImp func (i *importer) queryImports(filename string) fileImports { var gq []Query - anyNonCopyFrom := false for _, query := range i.Queries { if usesBatch([]Query{query}) { continue } if query.SourceName == filename { gq = append(gq, query) - if query.Cmd != metadata.CmdCopyFrom { - anyNonCopyFrom = true - } } } @@ -390,10 +387,6 @@ func (i *importer) queryImports(filename string) fileImports { return false } - if anyNonCopyFrom { - std["context"] = struct{}{} - } - sqlpkg := parseDriver(i.Options.SqlPackage) if sqlcSliceScan() && !sqlpkg.IsPGX() { std["strings"] = struct{}{} diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 3b4fb2fa1a..71a172187b 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -23,19 +23,19 @@ type QueryValue struct { Column *plugin.Column } -func (v QueryValue) EmitStruct() bool { +func (v *QueryValue) EmitStruct() bool { return v.Emit } -func (v QueryValue) IsStruct() bool { +func (v *QueryValue) IsStruct() bool { return v.Struct != nil } -func (v QueryValue) IsPointer() bool { +func (v *QueryValue) IsPointer() bool { return v.EmitPointer && v.Struct != nil } -func (v QueryValue) isEmpty() bool { +func (v *QueryValue) isEmpty() bool { return v.Typ == "" && v.Name == "" && v.Struct == nil } @@ -44,7 +44,15 @@ type Argument struct { Type string } -func (v QueryValue) Pair() string { +func (v *QueryValue) Names() string { + var out []string + for _, arg := range v.Pairs() { + out = append(out, arg.Name) + } + return strings.Join(out, ",") +} + +func (v *QueryValue) Pair() string { var out []string for _, arg := range v.Pairs() { out = append(out, arg.Name+" "+arg.Type) @@ -54,7 +62,7 @@ func (v QueryValue) Pair() string { // Return the argument name and type for query methods. Should only be used in // the context of method arguments. -func (v QueryValue) Pairs() []Argument { +func (v *QueryValue) Pairs() []Argument { if v.isEmpty() { return nil } @@ -76,14 +84,14 @@ func (v QueryValue) Pairs() []Argument { } } -func (v QueryValue) SlicePair() string { +func (v *QueryValue) SlicePair() string { if v.isEmpty() { return "" } return v.Name + " []" + v.DefineType() } -func (v QueryValue) Type() string { +func (v *QueryValue) Type() string { if v.Typ != "" { return v.Typ } @@ -108,7 +116,7 @@ func (v *QueryValue) ReturnName() string { return escape(v.Name) } -func (v QueryValue) UniqueFields() []Field { +func (v *QueryValue) UniqueFields() []Field { seen := map[string]struct{}{} fields := make([]Field, 0, len(v.Struct.Fields)) @@ -123,7 +131,25 @@ func (v QueryValue) UniqueFields() []Field { return fields } -func (v QueryValue) Params() string { +func (v *QueryValue) Fields() []Field { + if v.Struct == nil { + if v.Typ == "" { + return nil + } + return []Field{ + { + Name: v.Name, + VariableForField: v.Name, + DBName: v.DBName, + Type: v.Type(), + Column: v.Column, + }, + } + } + return v.Struct.Fields +} + +func (v *QueryValue) Params() string { if v.isEmpty() { return "" } @@ -150,7 +176,7 @@ func (v QueryValue) Params() string { return "\n" + strings.Join(out, ",\n") } -func (v QueryValue) ColumnNames() []string { +func (v *QueryValue) ColumnNames() []string { if v.Struct == nil { return []string{v.DBName} } @@ -161,7 +187,7 @@ func (v QueryValue) ColumnNames() []string { return names } -func (v QueryValue) ColumnNamesAsGoSlice() string { +func (v *QueryValue) ColumnNamesAsGoSlice() string { if v.Struct == nil { return fmt.Sprintf("[]string{%q}", v.DBName) } @@ -178,7 +204,7 @@ func (v QueryValue) ColumnNamesAsGoSlice() string { // When true, we have to build the arguments to q.db.QueryContext in addition to // munging the SQL -func (v QueryValue) HasSqlcSlices() bool { +func (v *QueryValue) HasSqlcSlices() bool { if v.Struct == nil { return v.Column != nil && v.Column.IsSqlcSlice } @@ -190,7 +216,7 @@ func (v QueryValue) HasSqlcSlices() bool { return false } -func (v QueryValue) Scan() string { +func (v *QueryValue) Scan() string { var out []string if v.Struct == nil { if strings.HasPrefix(v.Typ, "[]") && v.Typ != "[]byte" && !v.SQLDriver.IsPGX() { @@ -230,7 +256,7 @@ func (v QueryValue) Scan() string { // Deprecated: This method does not respect the Emit field set on the // QueryValue. It's used by the go-sql-driver-mysql/copyfromCopy.tmpl and should // not be used other places. -func (v QueryValue) CopyFromMySQLFields() []Field { +func (v *QueryValue) CopyFromMySQLFields() []Field { // fmt.Printf("%#v\n", v) if v.Struct != nil { return v.Struct.Fields @@ -240,11 +266,12 @@ func (v QueryValue) CopyFromMySQLFields() []Field { Name: v.Name, DBName: v.DBName, Type: v.Typ, + Column: v.Column, }, } } -func (v QueryValue) VariableForField(f Field) string { +func (v *QueryValue) VariableForField(f Field) string { if !v.IsStruct() { return v.Name } @@ -269,13 +296,24 @@ type Query struct { Table *plugin.Identifier } -func (q Query) hasRetType() bool { +func (q *Query) IsReadOnly() bool { + return q.Cmd != metadata.CmdExec && strings.EqualFold(q.SQL[:6], "select") +} + +func (q *Query) ReceiverType() string { + if q.IsReadOnly() { + return "Queries" + } + return "WriteTx" +} + +func (q *Query) hasRetType() bool { scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne return scanned && !q.Ret.isEmpty() } -func (q Query) TableIdentifierAsGoSlice() string { +func (q *Query) TableIdentifierAsGoSlice() string { escapedNames := make([]string, 0, 3) for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { if p != "" { @@ -285,7 +323,7 @@ func (q Query) TableIdentifierAsGoSlice() string { return "[]string{" + strings.Join(escapedNames, ", ") + "}" } -func (q Query) TableIdentifierForMySQL() string { +func (q *Query) TableIdentifierForMySQL() string { escapedNames := make([]string, 0, 3) for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { if p != "" { diff --git a/internal/codegen/golang/sqlite_type.go b/internal/codegen/golang/sqlite_type.go index 92c13557f6..7eb1215329 100644 --- a/internal/codegen/golang/sqlite_type.go +++ b/internal/codegen/golang/sqlite_type.go @@ -10,90 +10,61 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) -func sqliteType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string { +func sqliteType(_req *plugin.GenerateRequest, _options *opts.Options, col *plugin.Column) string { dt := strings.ToLower(sdk.DataType(col.Type)) notNull := col.NotNull || col.IsArray - emitPointersForNull := options.EmitPointersForNullTypes switch dt { - - case "int", "integer", "tinyint", "smallint", "mediumint", "bigint", "unsignedbigint", "int2", "int8": + case "int", "integer", "tinyint", "smallint", "mediumint", "int2": if notNull { - return "int64" + return "uint32" } - if emitPointersForNull { - return "*int64" + return "types.NullUint32" + case "bigint", "unsignedbigint", "int8": + if notNull { + return "int64" } - return "sql.NullInt64" - + return "types.NullInt64" case "blob": return "[]byte" - case "real", "double", "doubleprecision", "float": if notNull { return "float64" } - if emitPointersForNull { - return "*float64" - } return "sql.NullFloat64" - case "boolean", "bool": if notNull { return "bool" } - if emitPointersForNull { - return "*bool" - } return "sql.NullBool" - case "date", "datetime", "timestamp": if notNull { return "time.Time" } - if emitPointersForNull { - return "*time.Time" - } return "sql.NullTime" - case "any": return "interface{}" - } switch { - - case strings.HasPrefix(dt, "character"), + case dt == "text", + strings.HasPrefix(dt, "character"), strings.HasPrefix(dt, "varchar"), strings.HasPrefix(dt, "varyingcharacter"), strings.HasPrefix(dt, "nchar"), strings.HasPrefix(dt, "nativecharacter"), strings.HasPrefix(dt, "nvarchar"), - dt == "text", dt == "clob": - if notNull { - return "string" - } - if emitPointersForNull { - return "*string" - } - return "sql.NullString" - + return "string" case strings.HasPrefix(dt, "decimal"), dt == "numeric": if notNull { return "float64" } - if emitPointersForNull { - return "*float64" - } - return "sql.NullFloat64" - + return "types.NullFloat64" default: if debug.Active { log.Printf("unknown SQLite type: %s\n", dt) } - return "interface{}" - } } diff --git a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl deleted file mode 100644 index e21475b148..0000000000 --- a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl +++ /dev/null @@ -1,52 +0,0 @@ -{{define "copyfromCodeGoSqlDriver"}} -{{range .GoQueries}} -{{if eq .Cmd ":copyfrom" }} -var readerHandlerSequenceFor{{.MethodName}} uint32 = 1 - -func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) { - e := mysqltsv.NewEncoder(w, {{ len .Arg.CopyFromMySQLFields }}, nil) - for _, row := range {{.Arg.Name}} { -{{- with $arg := .Arg }} -{{- range $arg.CopyFromMySQLFields}} -{{- if eq .Type "string"}} - e.AppendString({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- else if or (eq .Type "[]byte") (eq .Type "json.RawMessage")}} - e.AppendBytes({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- else}} - e.AppendValue({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) -{{- end}} -{{- end}} -{{- end}} - } - w.CloseWithError(e.Close()) -} - -{{range .Comments}}//{{.}} -{{end -}} -// {{.MethodName}} uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. -// -// Errors and duplicate keys are treated as warnings and insertion will -// continue, even without an error for some cases. Use this in a transaction -// and use SHOW WARNINGS to check for any problems and roll back if you want to. -// -// Check the documentation for more information: -// https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling -func (q *Queries) {{.MethodName}}(ctx context.Context{{if $.EmitMethodsWithDBArgument}}, db DBTX{{end}}, {{.Arg.SlicePair}}) (int64, error) { - pr, pw := io.Pipe() - defer pr.Close() - rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1)) - mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) - defer mysql.DeregisterReaderHandler(rh) - go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}}) - // The string interpolation is necessary because LOAD DATA INFILE requires - // the file name to be given as a literal string. - result, err := {{if (not $.EmitMethodsWithDBArgument)}}q.{{end}}db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl deleted file mode 100644 index 35bd701bd3..0000000000 --- a/internal/codegen/golang/templates/pgx/batchCode.tmpl +++ /dev/null @@ -1,134 +0,0 @@ -{{define "batchCodePgx"}} - -var ( - ErrBatchAlreadyClosed = errors.New("batch already closed") -) - -{{range .GoQueries}} -{{if eq (hasPrefix .Cmd ":batch") true }} -const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} -{{escape .SQL}} -{{$.Q}} - -type {{.MethodName}}BatchResults struct { - br pgx.BatchResults - tot int - closed bool -} - -{{if .Arg.Struct}} -type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if .Ret.EmitStruct}} -type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDBArgument}}db DBTX,{{end}} {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults { - batch := &pgx.Batch{} - for _, a := range {{index .Arg.Name}} { - vals := []interface{}{ - {{- if .Arg.Struct }} - {{- range .Arg.Struct.Fields }} - a.{{.Name}}, - {{- end }} - {{- else }} - a, - {{- end }} - } - batch.Queue({{.ConstantName}}, vals...) - } - br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) - return &{{.MethodName}}BatchResults{br,len({{.Arg.Name}}),false} -} - -{{if eq .Cmd ":batchexec"}} -func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { - defer b.br.Close() - for t := 0; t < b.tot; t++ { - if b.closed { - if f != nil { - f(t, ErrBatchAlreadyClosed) - } - continue - } - _, err := b.br.Exec() - if f != nil { - f(t, err) - } - } -} -{{end}} - -{{if eq .Cmd ":batchmany"}} -func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { - defer b.br.Close() - for t := 0; t < b.tot; t++ { - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - if b.closed { - if f != nil { - f(t, items, ErrBatchAlreadyClosed) - } - continue - } - err := func() error { - rows, err := b.br.Query() - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - return err - } - items = append(items, {{.Ret.ReturnName}}) - } - return rows.Err() - }() - if f != nil { - f(t, items, err) - } - } -} -{{end}} - -{{if eq .Cmd ":batchone"}} -func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { - defer b.br.Close() - for t := 0; t < b.tot; t++ { - var {{.Ret.Name}} {{.Ret.Type}} - if b.closed { - if f != nil { - f(t, {{if .Ret.IsPointer}}nil{{else}}{{.Ret.Name}}{{end}}, ErrBatchAlreadyClosed) - } - continue - } - row := b.br.QueryRow() - err := row.Scan({{.Ret.Scan}}) - if f != nil { - f(t, {{.Ret.ReturnName}}, err) - } - } -} -{{end}} - -func (b *{{.MethodName}}BatchResults) Close() error { - b.closed = true - return b.br.Close() -} -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl b/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl deleted file mode 100644 index c1cfa68d1d..0000000000 --- a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl +++ /dev/null @@ -1,51 +0,0 @@ -{{define "copyfromCodePgx"}} -{{range .GoQueries}} -{{if eq .Cmd ":copyfrom" }} -// iteratorFor{{.MethodName}} implements pgx.CopyFromSource. -type iteratorFor{{.MethodName}} struct { - rows []{{.Arg.DefineType}} - skippedFirstNextCall bool -} - -func (r *iteratorFor{{.MethodName}}) Next() bool { - if len(r.rows) == 0 { - return false - } - if !r.skippedFirstNextCall { - r.skippedFirstNextCall = true - return true - } - r.rows = r.rows[1:] - return len(r.rows) > 0 -} - -func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) { - return []interface{}{ -{{- if .Arg.Struct }} -{{- range .Arg.Struct.Fields }} - r.rows[0].{{.Name}}, -{{- end }} -{{- else }} - r.rows[0], -{{- end }} - }, nil -} - -func (r iteratorFor{{.MethodName}}) Err() error { - return nil -} - -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { - return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { - return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) -{{- end}} -} - -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl deleted file mode 100644 index 236554d9f2..0000000000 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ /dev/null @@ -1,37 +0,0 @@ -{{define "dbCodeTemplatePgx"}} - -type DBTX interface { - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) - Query(context.Context, string, ...interface{}) (pgx.Rows, error) - QueryRow(context.Context, string, ...interface{}) pgx.Row -{{- if .UsesCopyFrom }} - CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) -{{- end }} -{{- if .UsesBatch }} - SendBatch(context.Context, *pgx.Batch) pgx.BatchResults -{{- end }} -} - -{{ if .EmitMethodsWithDBArgument}} -func New() *Queries { - return &Queries{} -{{- else -}} -func New(db DBTX) *Queries { - return &Queries{db: db} -{{- end}} -} - -type Queries struct { - {{if not .EmitMethodsWithDBArgument}} - db DBTX - {{end}} -} - -{{if not .EmitMethodsWithDBArgument}} -func (q *Queries) WithTx(tx pgx.Tx) *Queries { - return &Queries{ - db: tx, - } -} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl deleted file mode 100644 index cf7cd36cb9..0000000000 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ /dev/null @@ -1,73 +0,0 @@ -{{define "interfaceCodePgx"}} - type Querier interface { - {{- $dbtxParam := .EmitMethodsWithDBArgument -}} - {{- range .GoQueries}} - {{- if and (eq .Cmd ":one") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":one" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":many") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":many" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":exec") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error - {{- else if eq .Cmd ":exec" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error - {{- end}} - {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":execrows" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) - {{- end}} - {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) - {{- else if eq .Cmd ":execresult" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) - {{- end}} - {{- if and (eq .Cmd ":copyfrom") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) - {{- else if eq .Cmd ":copyfrom" }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) - {{- end}} - {{- if and (or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone")) ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults - {{- else if or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone") }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults - {{- end}} - - {{- end}} - } - - var _ Querier = (*Queries)(nil) -{{end}} diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl deleted file mode 100644 index 18de5db2ba..0000000000 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ /dev/null @@ -1,124 +0,0 @@ -{{define "queryCodePgx"}} -{{range .GoQueries}} -{{if $.OutputQuery .SourceName}} -{{if and (ne .Cmd ":copyfrom") (ne (hasPrefix .Cmd ":batch") true)}} -const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} -{{escape .SQL}} -{{$.Q}} -{{end}} - -{{if ne (hasPrefix .Cmd ":batch") true}} -{{if .Arg.EmitStruct}} -type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if .Ret.EmitStruct}} -type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} -{{end}} - -{{if eq .Cmd ":one"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - row := q.db.QueryRow(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} - var {{.Ret.Name}} {{.Ret.Type}} - {{- end}} - err := row.Scan({{.Ret.Scan}}) - return {{.Ret.ReturnName}}, err -} -{{end}} - -{{if eq .Cmd ":many"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { - rows, err := db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { - rows, err := q.db.Query(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - if err != nil { - return nil, err - } - defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - return nil, err - } - items = append(items, {{.Ret.ReturnName}}) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} -{{end}} - -{{if eq .Cmd ":exec"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error { - _, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error { - _, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - return err -} -{{end}} - -{{if eq .Cmd ":execrows"}} -{{range .Comments}}//{{.}} -{{end -}} -{{if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) { - result, err := db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) { - result, err := q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} - if err != nil { - return 0, err - } - return result.RowsAffected(), nil -} -{{end}} - -{{if eq .Cmd ":execresult"}} -{{range .Comments}}//{{.}} -{{end -}} -{{- if $.EmitMethodsWithDBArgument -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (pgconn.CommandTag, error) { - return db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- else -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error) { - return q.db.Exec(ctx, {{.ConstantName}}, {{.Arg.Params}}) -{{- end}} -} -{{end}} - - -{{end}} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl new file mode 100644 index 0000000000..b428f42db9 --- /dev/null +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -0,0 +1,57 @@ +{{define "dbCodeTemplateStd"}} + +type Queries struct { + conn *sqlite.Conn + {{- range .ReadQueries }} + {{.FieldName}} {{.MethodName}}Stmt + {{- end}} +} + +type WriteTx struct { + Queries + finish func(*error) + {{- range .WriteQueries }} + {{.FieldName}} {{.MethodName}}Stmt + {{- end}} +} + +func (q *Queries) init(conn *sqlite.Conn) { + if q.conn != nil { + panic("Queries already initialized") + } + q.conn = conn +} + +func (q *Queries) Close() *lib.Error { + var err, firstErr *lib.Error + {{- range .ReadQueries }} + q.{{.FieldName}}.Close() + err = q.{{.FieldName}}.Err() + if firstErr == nil { + firstErr = err + } + {{- end}} + extErr := q.conn.Close() + if firstErr == nil && extErr != nil { + firstErr = lib.Wrap(extErr) + } + return firstErr +} + +func (txn *WriteTx) Close() *lib.Error { + var err, firstErr *lib.Error + {{- range .WriteQueries }} + txn.{{.FieldName}}.Close() + err = txn.{{.FieldName}}.Err() + if firstErr == nil { + firstErr = err + } + {{- end}} + err = txn.Queries.Close() + if firstErr == nil { + firstErr = err + } + return firstErr +} + +{{end}} diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl new file mode 100644 index 0000000000..1cc2819961 --- /dev/null +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -0,0 +1,211 @@ +{{define "queryCodeStd"}} +{{range $.AllQueries .SourceName}} +const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} +{{escape .SQL}} +{{$.Q}} + +type {{ .MethodName }}Stmt struct { + iter + {{- if ne .Cmd ":exec" }} + Row {{.Ret.DefineType}} + {{- end }} +} + +{{if .Arg.EmitStruct}} +type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} + {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} + {{- end}} +} +{{end}} + +{{if .Ret.EmitStruct}} +type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} + {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} + {{- end}} +} +{{end}} + +func (r *{{ .MethodName }}Stmt) bind({{ .Arg.Pair }}) { + {{- $arg := .Arg }} + query := {{.ConstantName}} + {{- if .Arg.HasSqlcSlices }} + {{- range .Arg.Fields }} + {{- if .HasSqlcSlice }} + length := len({{ .VariableForField }}) + if length > 0 { + query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", length)[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) + } + {{- end }} + {{- end }} + {{- end }} + + r.init(r.queries, "{{.MethodName}}", query, {{ .Arg.HasSqlcSlices }}) + + param := 1 + {{- range .Arg.Fields }} + {{- if .HasSqlcSlice }} + for _, v := range {{ .VariableForField }} { + {{- template "bindField" . }} + } + {{- else }} + {{- template "bindField" . }} + {{- end }} + {{- end }} +} + +func (r *{{ .MethodName }}Stmt) Next() bool { + if !r.iter.Next() { + return false + } + if r.err != nil { + return false + } + {{- if ne .Cmd ":exec" }} + {{ $fields := .Ret.Fields }} + {{- if $fields }} + col := 1 + {{- range $fields }} + {{- template "fetchColumn" . }} + {{- end }} + {{- end }} + {{- end }} + return r.err == nil +} + +{{if eq .Cmd ":one"}} +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.DefineType}}, err *lib.Error) { + r := &q.{{ .FieldName }} + r.bind({{ .Arg.Names }}) + defer r.Reset() + ok := r.Next() + err = r.Err() + if err != nil { + return + } + if !ok { + return ErrNoRows + } + if r.Next() { + return ErrTooManyRows + } + result = r.Row + return +} +{{end}} + +{{if eq .Cmd ":many"}} +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret.DefineType}}, err *lib.Error) { + r := &q.{{ .FieldName }} + r.bind({{ .Arg.Names }}) + defer r.Reset() + for r.Next() { + results = append(results, r.Row) + } + err = r.Err() + if err != nil { + results = nil + } + return +} +{{end}} + +{{if eq .Cmd ":iter"}} +func (q *{{ .ReceiverType }}) Begin{{.MethodName}}({{ .Arg.Pair }}) func(yield func(row {{.Ret.DefineType}}) bool) { + r := &q.{{ .FieldName }} + r.bind({{ .Arg.Names }}) + return func(yield func(row {{.Ret.DefineType}}) bool) { + defer r.Reset() + for r.Next() && yield(r.Row) {} + } +} + +func (q *{{ .ReceiverType }}) End{{.MethodName}}() *lib.Error { + return q.{{ .FieldName }}.Err() +} +{{end}} + +{{if eq .Cmd ":exec"}} +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { + r := &q.{{ .FieldName }} + r.bind({{ .Arg.Names }}) + defer r.Reset() + r.Next() + return r.Err() +} +{{end}} + +{{end}} +{{end}} + +{{define "bindField"}} + { + {{- $bindType := .BindType }} + var v {{ $bindType }} + isNull := false + {{- if .Serialize }} + var err *lib.Error + {{- if eq $bindType "string" }} + v, isNull, err = {{ .VariableForField }}.SerializeString() + if err != nil { + r.setErr(err) + } + {{- else if eq $bindType "bool" }} + v, isNull = {{ .VariableForField }}.SerializeBool() + {{- else if eq $bindType "float64" }} + v, isNull = {{ .VariableForField }}.SerializeFloat() + {{- else if eq $bindType "int64" }} + v, isNull = {{ .VariableForField }}.SerializeInt() + {{- else }} + v, isNull, err = {{ .VariableForField }}.SerializeBytes() + if err != nil { + r.setErr(err) + } + {{- end }} + {{- else if .HasLen }} + v = {{ .VariableForField }} + {{- if .IsNullable }} + isNull = len(v) == 0 + {{- end }} + {{- else }} + {{- if eq $bindType "bool" }} + v = {{ .VariableForField }} + {{- else if not .Is64Bit }} + v = {{ $bindType }}({{ .VariableForField }}) + {{- else }} + v = {{ .VariableForField }} + {{- end }} + {{- end }} + if isNull { + r.stmt.BindNull(param) + } else { + {{ .BindMethod }}(param, v) + } + } + param++ +{{end}} + +{{define "fetchColumn"}} + {{- if .IsNullable }} + if r.stmt.ColumnIsNull(col) { + var v {{ .Type }} + r.Row.{{ .Name }} = v + } else { + {{- else }} + { + {{- end }} + raw := {{ .FetchMethod }}(col) + {{- if .Serialize }} + res, err := Deserialize{{ .Type }}(raw) + if err != nil { + r.setErr(err) + } + {{- else if not (or .Is64Bit (eq .BindType "bool")) }} + r.Row.{{ .Name }} = {{ .Type }}(v) + {{- else }} + r.Row.{{ .Name }} = v + {{- end }} + } + col++ +{{end}} \ No newline at end of file diff --git a/internal/codegen/golang/templates/stdlib/dbCode.tmpl b/internal/codegen/golang/templates/stdlib/dbCode.tmpl deleted file mode 100644 index 7433d522f6..0000000000 --- a/internal/codegen/golang/templates/stdlib/dbCode.tmpl +++ /dev/null @@ -1,105 +0,0 @@ -{{define "dbCodeTemplateStd"}} -type DBTX interface { - ExecContext(context.Context, string, ...interface{}) (sql.Result, error) - PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...interface{}) *sql.Row -} - -{{ if .EmitMethodsWithDBArgument}} -func New() *Queries { - return &Queries{} -{{- else -}} -func New(db DBTX) *Queries { - return &Queries{db: db} -{{- end}} -} - -{{if .EmitPreparedQueries}} -func Prepare(ctx context.Context, db DBTX) (*Queries, error) { - q := Queries{db: db} - var err error - {{- if eq (len .GoQueries) 0 }} - _ = err - {{- end }} - {{- range .GoQueries }} - if q.{{.FieldName}}, err = db.PrepareContext(ctx, {{.ConstantName}}); err != nil { - return nil, fmt.Errorf("error preparing query {{.MethodName}}: %w", err) - } - {{- end}} - return &q, nil -} - -func (q *Queries) Close() error { - var err error - {{- range .GoQueries }} - if q.{{.FieldName}} != nil { - if cerr := q.{{.FieldName}}.Close(); cerr != nil { - err = fmt.Errorf("error closing {{.FieldName}}: %w", cerr) - } - } - {{- end}} - return err -} - -func (q *Queries) exec(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (sql.Result, error) { - switch { - case stmt != nil && q.tx != nil: - return q.tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) - case stmt != nil: - return stmt.ExecContext(ctx, args...) - default: - return q.db.ExecContext(ctx, query, args...) - } -} - -func (q *Queries) query(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Rows, error) { - switch { - case stmt != nil && q.tx != nil: - return q.tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) - case stmt != nil: - return stmt.QueryContext(ctx, args...) - default: - return q.db.QueryContext(ctx, query, args...) - } -} - -func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, args ...interface{}) (*sql.Row) { - switch { - case stmt != nil && q.tx != nil: - return q.tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) - case stmt != nil: - return stmt.QueryRowContext(ctx, args...) - default: - return q.db.QueryRowContext(ctx, query, args...) - } -} -{{end}} - -type Queries struct { - {{- if not .EmitMethodsWithDBArgument}} - db DBTX - {{- end}} - - {{- if .EmitPreparedQueries}} - tx *sql.Tx - {{- range .GoQueries}} - {{.FieldName}} *sql.Stmt - {{- end}} - {{- end}} -} - -{{if not .EmitMethodsWithDBArgument}} -func (q *Queries) WithTx(tx *sql.Tx) *Queries { - return &Queries{ - db: tx, - {{- if .EmitPreparedQueries}} - tx: tx, - {{- range .GoQueries}} - {{.FieldName}}: q.{{.FieldName}}, - {{- end}} - {{- end}} - } -} -{{end}} -{{end}} diff --git a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl b/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl deleted file mode 100644 index 3cbefe6df4..0000000000 --- a/internal/codegen/golang/templates/stdlib/interfaceCode.tmpl +++ /dev/null @@ -1,63 +0,0 @@ -{{define "interfaceCodeStd"}} - type Querier interface { - {{- $dbtxParam := .EmitMethodsWithDBArgument -}} - {{- range .GoQueries}} - {{- if and (eq .Cmd ":one") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":one"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":many") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- else if eq .Cmd ":many"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) - {{- end}} - {{- if and (eq .Cmd ":exec") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) error - {{- else if eq .Cmd ":exec"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) error - {{- end}} - {{- if and (eq .Cmd ":execrows") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":execrows"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) - {{- end}} - {{- if and (eq .Cmd ":execlastid") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error) - {{- else if eq .Cmd ":execlastid"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error) - {{- end}} - {{- if and (eq .Cmd ":execresult") ($dbtxParam) }} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (sql.Result, error) - {{- else if eq .Cmd ":execresult"}} - {{range .Comments}}//{{.}} - {{end -}} - {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (sql.Result, error) - {{- end}} - {{- end}} - } - - var _ Querier = (*Queries)(nil) -{{end}} diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl deleted file mode 100644 index cf56000ec6..0000000000 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ /dev/null @@ -1,155 +0,0 @@ -{{define "queryCodeStd"}} -{{range .GoQueries}} -{{if $.OutputQuery .SourceName}} -const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} -{{escape .SQL}} -{{$.Q}} - -{{if .Arg.EmitStruct}} -type {{.Arg.Type}} struct { {{- range .Arg.UniqueFields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if .Ret.EmitStruct}} -type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} - {{.Name}} {{.Type}} {{if .Tag}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} - {{- end}} -} -{{end}} - -{{if eq .Cmd ":one"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { - {{- template "queryCodeStdExec" . }} - {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} - var {{.Ret.Name}} {{.Ret.Type}} - {{- end}} - err := row.Scan({{.Ret.Scan}}) - return {{.Ret.ReturnName}}, err -} -{{end}} - -{{if eq .Cmd ":many"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { - {{- template "queryCodeStdExec" . }} - if err != nil { - return nil, err - } - defer rows.Close() - {{- if $.EmitEmptySlices}} - items := []{{.Ret.DefineType}}{} - {{else}} - var items []{{.Ret.DefineType}} - {{end -}} - for rows.Next() { - var {{.Ret.Name}} {{.Ret.Type}} - if err := rows.Scan({{.Ret.Scan}}); err != nil { - return nil, err - } - items = append(items, {{.Ret.ReturnName}}) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} -{{end}} - -{{if eq .Cmd ":exec"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { - {{- template "queryCodeStdExec" . }} - return err -} -{{end}} - -{{if eq .Cmd ":execrows"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { - {{- template "queryCodeStdExec" . }} - if err != nil { - return 0, err - } - return result.RowsAffected() -} -{{end}} - -{{if eq .Cmd ":execlastid"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { - {{- template "queryCodeStdExec" . }} - if err != nil { - return 0, err - } - return result.LastInsertId() -} -{{end}} - -{{if eq .Cmd ":execresult"}} -{{range .Comments}}//{{.}} -{{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { - {{- template "queryCodeStdExec" . }} -} -{{end}} - -{{end}} -{{end}} -{{end}} - -{{define "queryCodeStdExec"}} - {{- if .Arg.HasSqlcSlices }} - query := {{.ConstantName}} - var queryParams []interface{} - {{- if .Arg.Struct }} - {{- $arg := .Arg }} - {{- range .Arg.Struct.Fields }} - {{- if .HasSqlcSlice }} - if len({{$arg.VariableForField .}}) > 0 { - for _, v := range {{$arg.VariableForField .}} { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{$arg.VariableForField .}}))[1:], 1) - } else { - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) - } - {{- else }} - queryParams = append(queryParams, {{$arg.VariableForField .}}) - {{- end }} - {{- end }} - {{- else }} - {{- /* Single argument parameter to this goroutine (they are not packed - in a struct), because .Arg.HasSqlcSlices further up above was true, - this section is 100% a slice (impossible to get here otherwise). - */}} - if len({{.Arg.Name}}) > 0 { - for _, v := range {{.Arg.Name}} { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", strings.Repeat(",?", len({{.Arg.Name}}))[1:], 1) - } else { - query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL", 1) - } - {{- end }} - {{- if emitPreparedQueries }} - {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) - {{- else}} - {{ queryRetval . }} {{ queryMethod . }}(ctx, query, queryParams...) - {{- end -}} - {{- else if emitPreparedQueries }} - {{- queryRetval . }} {{ queryMethod . }}(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) - {{- else}} - {{- queryRetval . }} {{ queryMethod . }}(ctx, {{.ConstantName}}, {{.Arg.Params}}) - {{- end -}} -{{end}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index afd50c01ac..702b8f0624 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -2,7 +2,7 @@ {{if .BuildTags}} //go:build {{.BuildTags}} -{{end}}// Code generated by sqlc. DO NOT EDIT. +{{end}}// Code generated by deferred-dev fork of sqlc. DO NOT EDIT. {{if not .OmitSqlcVersion}}// versions: // sqlc {{.SqlcVersion}} {{end}} @@ -35,7 +35,7 @@ import ( {{if .BuildTags}} //go:build {{.BuildTags}} -{{end}}// Code generated by sqlc. DO NOT EDIT. +{{end}}// Code generated by deferred-dev fork of sqlc. DO NOT EDIT. {{if not .OmitSqlcVersion}}// versions: // sqlc {{.SqlcVersion}} {{end}} @@ -66,7 +66,7 @@ import ( {{if .BuildTags}} //go:build {{.BuildTags}} -{{end}}// Code generated by sqlc. DO NOT EDIT. +{{end}}// Code generated by deferred-dev fork of sqlc. DO NOT EDIT. {{if not .OmitSqlcVersion}}// versions: // sqlc {{.SqlcVersion}} {{end}} @@ -166,7 +166,7 @@ type {{.Name}} struct { {{- range .Fields}} {{if .BuildTags}} //go:build {{.BuildTags}} -{{end}}// Code generated by sqlc. DO NOT EDIT. +{{end}}// Code generated by deferred-dev fork of sqlc. DO NOT EDIT. {{if not .OmitSqlcVersion}}// versions: // sqlc {{.SqlcVersion}} {{end}}// source: {{.SourceName}} @@ -197,7 +197,7 @@ import ( {{if .BuildTags}} //go:build {{.BuildTags}} -{{end}}// Code generated by sqlc. DO NOT EDIT. +{{end}}// Code generated by deferred-dev fork of sqlc. DO NOT EDIT. {{if not .OmitSqlcVersion}}// versions: // sqlc {{.SqlcVersion}} {{end}}// source: {{.SourceName}} @@ -228,7 +228,7 @@ import ( {{if .BuildTags}} //go:build {{.BuildTags}} -{{end}}// Code generated by sqlc. DO NOT EDIT. +{{end}}// Code generated by deferred-dev fork of sqlc. DO NOT EDIT. {{if not .OmitSqlcVersion}}// versions: // sqlc {{.SqlcVersion}} {{end}}// source: {{.SourceName}} From 62ab9f701fa1eab749a4ee0992c849c44820ed2f Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 27 Oct 2024 11:54:10 -0600 Subject: [PATCH 02/18] get the new code generation working with existing deferred code base --- internal/codegen/golang/field.go | 64 +++++++++- internal/codegen/golang/gen.go | 60 +++++---- internal/codegen/golang/imports.go | 7 +- internal/codegen/golang/query.go | 6 +- .../golang/templates/sqlite/dbCode.tmpl | 2 + .../golang/templates/sqlite/queryCode.tmpl | 117 +++++++++++------- 6 files changed, 184 insertions(+), 72 deletions(-) diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 2959bd1c68..ce3e678b31 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -35,7 +35,7 @@ func (gf *Field) IsNullable() bool { } func (gf *Field) Serialize() bool { - return strings.HasPrefix(gf.Type, "*") || strings.HasSuffix(gf.Type, "ID") + return gf.IsPointer() || strings.HasSuffix(gf.Type, "ID") || strings.HasSuffix(gf.Type, "Blob") } func (gf *Field) HasLen() bool { @@ -46,6 +46,18 @@ func (gf *Field) Is64Bit() bool { return strings.HasSuffix(gf.Type, "64") } +func (gf *Field) IsPointer() bool { + return strings.HasPrefix(gf.Type, "*") +} + +func (gf *Field) NeedsCast(toType string) bool { + ty := gf.Type + if gf.HasSqlcSlice() { + ty = ty[2:] + } + return ty != toType +} + func (gf *Field) BindType() string { ty := gf.Type if gf.HasSqlcSlice() { @@ -56,12 +68,16 @@ func (gf *Field) BindType() string { return "int64" case ty == "float32": return "float64" - case ty == "TimeOrderedID": + case strings.HasSuffix(ty, "TimeOrderedID"): return "[]byte" case strings.HasSuffix(ty, "Bool"): return "bool" - case strings.HasSuffix(ty, "ID"): + case strings.Contains(ty, "Null"): + return "int64" + case strings.HasSuffix(ty, "ID"), strings.HasSuffix(ty, "Type"), strings.HasSuffix(ty, "Priority"): return "int64" + case strings.IndexByte(ty, '.') >= 0: + return "[]byte" default: return ty } @@ -84,8 +100,14 @@ func (gf *Field) BindMethod() string { func (gf *Field) FetchMethod() string { bindType := gf.BindType() switch bindType { + case "string": + return "r.stmt.ColumnText" case "[]byte": - return "r.stmt.ColumnBytes" + if gf.Serialize() { + // We can use zero-copy bytes to deserialize the field + return "r.columnPeekBytes" + } + return "r.columnBytes" case "float64": return "r.stmt.ColumnFloat" default: @@ -93,6 +115,40 @@ func (gf *Field) FetchMethod() string { } } +func (gf *Field) FetchInto() string { + if gf.Name == "" { + return "r.Row" + } + return "r.Row." + gf.Name +} + +func (gf *Field) DeserializeMethod() string { + ty := gf.Type + if gf.HasSqlcSlice() { + ty = ty[2:] + } else if strings.HasPrefix(ty, "*") { + ty = ty[1:] + } + i := strings.IndexByte(ty, '.') + if i >= 0 { + return ty[:i] + ".Deserialize" + toPascalCase(ty[i+1:]) + } + panic("not a deserializable type: " + ty) +} + +func (gf *Field) WithVariable(v string) *Field { + return &Field{ + Name: gf.Name, + VariableForField: v, + DBName: gf.DBName, + Type: gf.Type, + Tags: gf.Tags, + Comment: gf.Comment, + Column: gf.Column, + EmbedFields: gf.EmbedFields, + } +} + func TagsToString(tags map[string]string) string { if len(tags) == 0 { return "" diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 749bfaa9e1..c028019e88 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -250,37 +250,45 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, output := map[string]string{} - execute := func(name, templateName string) error { - imports := i.Imports(name) - replacedQueries := replaceConflictedArg(imports, queries) - var readQueries []Query - var writeQueries []Query - for _, q := range replacedQueries { - if q.Arg.Struct != nil { - for i := range q.Arg.Struct.Fields { - f := &q.Arg.Struct.Fields[i] - f.VariableForField = q.Arg.VariableForField(*f) + var readQueries []Query + var writeQueries []Query + for _, q := range queries { + if q.Arg.Struct != nil { + for i := range q.Arg.Struct.Fields { + f := &q.Arg.Struct.Fields[i] + f.VariableForField = q.Arg.VariableForField(*f) + if len(f.EmbedFields) != 0 { + fixEmbedFieldsNames(f.EmbedFields, f.Name) } } - if q.Ret.Struct != nil { - for i := range q.Ret.Struct.Fields { - f := &q.Ret.Struct.Fields[i] - f.VariableForField = q.Ret.VariableForField(*f) + } + if q.Ret.Struct != nil { + for i := range q.Ret.Struct.Fields { + f := &q.Ret.Struct.Fields[i] + f.VariableForField = q.Ret.VariableForField(*f) + if len(f.EmbedFields) != 0 { + fixEmbedFieldsNames(f.EmbedFields, f.Name) } } + } - if q.IsReadOnly() { - readQueries = append(readQueries, q) - } else { - writeQueries = append(writeQueries, q) - } + if q.IsReadOnly() { + readQueries = append(readQueries, q) + } else { + writeQueries = append(writeQueries, q) } + } + + execute := func(name, templateName string) error { + imports := i.Imports(name) + replacedReadQueries := replaceConflictedArg(imports, readQueries) + replacedWriteQueries := replaceConflictedArg(imports, writeQueries) var b bytes.Buffer w := bufio.NewWriter(&b) tctx.SourceName = name - tctx.ReadQueries = readQueries - tctx.WriteQueries = writeQueries + tctx.ReadQueries = replacedReadQueries + tctx.WriteQueries = replacedWriteQueries err := tmpl.ExecuteTemplate(w, templateName, &tctx) w.Flush() if err != nil { @@ -369,6 +377,16 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, return &resp, nil } +func fixEmbedFieldsNames(fields []Field, name string) { + for i := range fields { + f := &fields[i] + f.Name = name + "." + f.Name + if len(f.EmbedFields) != 0 { + fixEmbedFieldsNames(f.EmbedFields, f.Name) + } + } +} + func usesCopyFrom(queries []Query) bool { for _, q := range queries { if q.Cmd == metadata.CmdCopyFrom { diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index ebd8e864b2..e6e0400c19 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -122,7 +122,9 @@ func (i *importer) dbImports() fileImports { pkg := []ImportSpec{ {Path: "github.com/deferred-dev/deferred/lib"}, } - var std []ImportSpec + std := []ImportSpec{ + {Path: "runtime"}, + } sqlpkg := parseDriver(i.Options.SqlPackage) switch sqlpkg { @@ -162,7 +164,6 @@ var pqtypeTypes = map[string]struct{}{ func buildImports(options *opts.Options, queries []Query, uses func(string) bool) (map[string]struct{}, map[ImportSpec]struct{}) { pkg := make(map[ImportSpec]struct{}) - pkg[ImportSpec{Path: "github.com/deferred-dev/deferred/lib"}] = struct{}{} std := make(map[string]struct{}) if uses("sql.Null") { @@ -340,6 +341,8 @@ func (i *importer) queryImports(filename string) fileImports { return false }) + pkg[ImportSpec{Path: "github.com/deferred-dev/deferred/lib"}] = struct{}{} + sliceScan := func() bool { for _, q := range gq { if q.hasRetType() { diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 71a172187b..d32f86146d 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -136,9 +136,13 @@ func (v *QueryValue) Fields() []Field { if v.Typ == "" { return nil } + name := v.Name + if !v.IsStruct() { + name = "" + } return []Field{ { - Name: v.Name, + Name: name, VariableForField: v.Name, DBName: v.DBName, Type: v.Type(), diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index b428f42db9..b88fe6af59 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -2,6 +2,7 @@ type Queries struct { conn *sqlite.Conn + pin runtime.Pinner {{- range .ReadQueries }} {{.FieldName}} {{.MethodName}}Stmt {{- end}} @@ -23,6 +24,7 @@ func (q *Queries) init(conn *sqlite.Conn) { } func (q *Queries) Close() *lib.Error { + q.pin.Unpin() var err, firstErr *lib.Error {{- range .ReadQueries }} q.{{.FieldName}}.Close() diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 1cc2819961..f2f0bca20d 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -25,7 +25,7 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} } {{end}} -func (r *{{ .MethodName }}Stmt) bind({{ .Arg.Pair }}) { +func (r *{{ .MethodName }}Stmt) bind(q *{{ .ReceiverType }}, {{ .Arg.Pair }}) { {{- $arg := .Arg }} query := {{.ConstantName}} {{- if .Arg.HasSqlcSlices }} @@ -41,18 +41,28 @@ func (r *{{ .MethodName }}Stmt) bind({{ .Arg.Pair }}) { {{- end }} {{- end }} - r.init(r.queries, "{{.MethodName}}", query, {{ .Arg.HasSqlcSlices }}) + {{- if .IsReadOnly }} + r.init(q, "{{.MethodName}}", query, {{ .Arg.HasSqlcSlices }}) + {{- else }} + r.init(&q.Queries, "{{.MethodName}}", query, {{ .Arg.HasSqlcSlices }}) + {{- end }} + {{- $fields := .Arg.Fields }} + {{- if $fields }} param := 1 - {{- range .Arg.Fields }} + {{- range $fields }} {{- if .HasSqlcSlice }} - for _, v := range {{ .VariableForField }} { - {{- template "bindField" . }} - } + for _, raw := range {{ .VariableForField }} { + {{ $field := .WithVariable "raw" }} + {{- template "bindField" $field }} + } {{- else }} + { {{- template "bindField" . }} + } {{- end }} {{- end }} + {{- end }} } func (r *{{ .MethodName }}Stmt) Next() bool { @@ -77,7 +87,7 @@ func (r *{{ .MethodName }}Stmt) Next() bool { {{if eq .Cmd ":one"}} func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.DefineType}}, err *lib.Error) { r := &q.{{ .FieldName }} - r.bind({{ .Arg.Names }}) + r.bind(q, {{ .Arg.Names }}) defer r.Reset() ok := r.Next() err = r.Err() @@ -85,12 +95,12 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De return } if !ok { - return ErrNoRows - } - if r.Next() { - return ErrTooManyRows + err = ErrNoRows + } else if r.Next() { + err = ErrTooManyRows + } else { + result = r.Row } - result = r.Row return } {{end}} @@ -98,7 +108,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De {{if eq .Cmd ":many"}} func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret.DefineType}}, err *lib.Error) { r := &q.{{ .FieldName }} - r.bind({{ .Arg.Names }}) + r.bind(q, {{ .Arg.Names }}) defer r.Reset() for r.Next() { results = append(results, r.Row) @@ -114,7 +124,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret {{if eq .Cmd ":iter"}} func (q *{{ .ReceiverType }}) Begin{{.MethodName}}({{ .Arg.Pair }}) func(yield func(row {{.Ret.DefineType}}) bool) { r := &q.{{ .FieldName }} - r.bind({{ .Arg.Names }}) + r.bind(q, {{ .Arg.Names }}) return func(yield func(row {{.Ret.DefineType}}) bool) { defer r.Reset() for r.Next() && yield(r.Row) {} @@ -129,7 +139,7 @@ func (q *{{ .ReceiverType }}) End{{.MethodName}}() *lib.Error { {{if eq .Cmd ":exec"}} func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { r := &q.{{ .FieldName }} - r.bind({{ .Arg.Names }}) + r.bind(q, {{ .Arg.Names }}) defer r.Reset() r.Next() return r.Err() @@ -140,29 +150,40 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { {{end}} {{define "bindField"}} - { + {{- if .EmbedFields }} + {{- range .EmbedFields }} + {{ template "bindField" . }} + {{- end }} + {{- else }} {{- $bindType := .BindType }} var v {{ $bindType }} + {{- if and .IsPointer .IsNullable }} + isNull := {{ .VariableForField }} == nil + if !isNull { + {{- else }} isNull := false + { + {{- end }} {{- if .Serialize }} + {{- if eq $bindType "string" }} var err *lib.Error - {{- if eq $bindType "string" }} - v, isNull, err = {{ .VariableForField }}.SerializeString() - if err != nil { - r.setErr(err) - } - {{- else if eq $bindType "bool" }} - v, isNull = {{ .VariableForField }}.SerializeBool() - {{- else if eq $bindType "float64" }} - v, isNull = {{ .VariableForField }}.SerializeFloat() - {{- else if eq $bindType "int64" }} - v, isNull = {{ .VariableForField }}.SerializeInt() - {{- else }} - v, isNull, err = {{ .VariableForField }}.SerializeBytes() - if err != nil { - r.setErr(err) - } - {{- end }} + v, isNull, err = {{ .VariableForField }}.SerializeString() + if err != nil { + r.setErr(err) + } + {{- else if eq $bindType "bool" }} + v, isNull = {{ .VariableForField }}.SerializeBool() + {{- else if eq $bindType "float64" }} + v, isNull = {{ .VariableForField }}.SerializeFloat() + {{- else if eq $bindType "int64" }} + v, isNull = {{ .VariableForField }}.SerializeInt() + {{- else }} + var err *lib.Error + v, isNull, err = {{ .VariableForField }}.SerializeBytes() + if err != nil { + r.setErr(err) + } + {{- end }} {{- else if .HasLen }} v = {{ .VariableForField }} {{- if .IsNullable }} @@ -171,41 +192,49 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { {{- else }} {{- if eq $bindType "bool" }} v = {{ .VariableForField }} - {{- else if not .Is64Bit }} + {{- else if .NeedsCast $bindType }} v = {{ $bindType }}({{ .VariableForField }}) {{- else }} v = {{ .VariableForField }} {{- end }} {{- end }} - if isNull { - r.stmt.BindNull(param) - } else { - {{ .BindMethod }}(param, v) - } + } + if isNull { + r.stmt.BindNull(param) + } else { + {{ .BindMethod }}(param, v) } param++ + {{- end }} {{end}} {{define "fetchColumn"}} + {{- if .EmbedFields }} + {{- range .EmbedFields }} + {{ template "fetchColumn" . }} + {{- end }} + {{- else }} {{- if .IsNullable }} if r.stmt.ColumnIsNull(col) { var v {{ .Type }} - r.Row.{{ .Name }} = v + {{ .FetchInto }} = v } else { {{- else }} { {{- end }} raw := {{ .FetchMethod }}(col) {{- if .Serialize }} - res, err := Deserialize{{ .Type }}(raw) + v, err := {{ .DeserializeMethod }}(raw) if err != nil { r.setErr(err) } - {{- else if not (or .Is64Bit (eq .BindType "bool")) }} - r.Row.{{ .Name }} = {{ .Type }}(v) + {{- else if .NeedsCast .BindType }} + v := {{ .Type }}(raw) {{- else }} - r.Row.{{ .Name }} = v + v := raw {{- end }} + {{ .FetchInto }} = v } col++ + {{- end }} {{end}} \ No newline at end of file From 5bf19bf917392e4c27a13d80d74f339958cc2da1 Mon Sep 17 00:00:00 2001 From: abdullah2993 Date: Wed, 10 Jan 2024 02:34:16 +0500 Subject: [PATCH 03/18] Add support for omitempty in JSON tags Adds a new configuration option json_tags_omit_empty --- docs/reference/config.md | 2 + internal/codegen/golang/field.go | 11 ++++-- internal/codegen/golang/opts/options.go | 1 + internal/config/v_one.go | 2 + .../omit_empty/postgresql/pgx/v4/go/db.go | 32 ++++++++++++++++ .../omit_empty/postgresql/pgx/v4/go/models.go | 15 ++++++++ .../postgresql/pgx/v4/go/query.sql.go | 34 +++++++++++++++++ .../omit_empty/postgresql/pgx/v4/query.sql | 2 + .../omit_empty/postgresql/pgx/v4/schema.sql | 6 +++ .../omit_empty/postgresql/pgx/v4/sqlc.json | 15 ++++++++ .../omit_empty/postgresql/pgx/v5/go/db.go | 32 ++++++++++++++++ .../omit_empty/postgresql/pgx/v5/go/models.go | 15 ++++++++ .../postgresql/pgx/v5/go/query.sql.go | 34 +++++++++++++++++ .../omit_empty/postgresql/pgx/v5/query.sql | 2 + .../omit_empty/postgresql/pgx/v5/schema.sql | 6 +++ .../omit_empty/postgresql/pgx/v5/sqlc.json | 15 ++++++++ .../omit_empty/postgresql/stdlib/go/db.go | 31 ++++++++++++++++ .../omit_empty/postgresql/stdlib/go/models.go | 15 ++++++++ .../postgresql/stdlib/go/query.sql.go | 37 +++++++++++++++++++ .../omit_empty/postgresql/stdlib/query.sql | 2 + .../omit_empty/postgresql/stdlib/schema.sql | 6 +++ .../omit_empty/postgresql/stdlib/sqlc.json | 14 +++++++ 22 files changed, 325 insertions(+), 4 deletions(-) create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/db.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/models.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/query.sql.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/query.sql create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/schema.sql create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/sqlc.json create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/db.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/models.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/query.sql.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/query.sql create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/schema.sql create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/sqlc.json create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/db.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/models.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/query.sql.go create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/query.sql create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/schema.sql create mode 100644 internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/sqlc.json diff --git a/docs/reference/config.md b/docs/reference/config.md index 2629babf4b..7f2afe08e4 100644 --- a/docs/reference/config.md +++ b/docs/reference/config.md @@ -173,6 +173,8 @@ The `gen` mapping supports the following keys: - If true, "Id" in json tags will be uppercase. If false, will be camelcase. Defaults to `false` - `json_tags_case_style`: - `camel` for camelCase, `pascal` for PascalCase, `snake` for snake_case or `none` to use the column name in the DB. Defaults to `none`. +- `json_tags_omit_empty`: + - If true will add `omitempty` to JSON tags. Defaults to `false`. - `omit_unused_structs`: - If `true`, sqlc won't generate table and enum structs that aren't used in queries for a given package. Defaults to `false`. - `output_batch_file_name`: diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index ce3e678b31..87e202cc09 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -164,11 +164,14 @@ func TagsToString(tags map[string]string) string { func JSONTagName(name string, options *opts.Options) string { style := options.JsonTagsCaseStyle idUppercase := options.JsonTagsIdUppercase - if style == "" || style == "none" { - return name - } else { - return SetJSONCaseStyle(name, style, idUppercase) + addOmitEmpty := options.JsonTagsOmitEmpty + if style != "" && style != "none" { + name = SetJSONCaseStyle(name, style, idUppercase) } + if addOmitEmpty { + name = name + ",omitempty" + } + return name } func SetCaseStyle(name string, style string) string { diff --git a/internal/codegen/golang/opts/options.go b/internal/codegen/golang/opts/options.go index 30a6c2246c..cc75415bbc 100644 --- a/internal/codegen/golang/opts/options.go +++ b/internal/codegen/golang/opts/options.go @@ -26,6 +26,7 @@ type Options struct { EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"` EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"` JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` + JsonTagsOmitEmpty bool `json:"json_tags_omit_empty,omitempty" yaml:"json_tags_omit_empty"` Package string `json:"package" yaml:"package"` Out string `json:"out" yaml:"out"` Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` diff --git a/internal/config/v_one.go b/internal/config/v_one.go index 8efa9f42fc..cf7b22e519 100644 --- a/internal/config/v_one.go +++ b/internal/config/v_one.go @@ -43,6 +43,7 @@ type v1PackageSettings struct { EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"` EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"` JSONTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` + JSONTagsOmitEmpty bool `json:"json_tags_omit_empty,omitempty" yaml:"json_tags_omit_empty"` SQLPackage string `json:"sql_package" yaml:"sql_package"` SQLDriver string `json:"sql_driver" yaml:"sql_driver"` Overrides []golang.Override `json:"overrides" yaml:"overrides"` @@ -158,6 +159,7 @@ func (c *V1GenerateSettings) Translate() Config { SqlDriver: pkg.SQLDriver, Overrides: pkg.Overrides, JsonTagsCaseStyle: pkg.JSONTagsCaseStyle, + JsonTagsOmitEmpty: pkg.JSONTagsOmitEmpty, OutputBatchFileName: pkg.OutputBatchFileName, OutputDbFileName: pkg.OutputDBFileName, OutputModelsFileName: pkg.OutputModelsFileName, diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/db.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/db.go new file mode 100644 index 0000000000..64a2dde21d --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/models.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/models.go new file mode 100644 index 0000000000..334fc1d3fc --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package querytest + +import ( + "database/sql" +) + +type User struct { + FirstName sql.NullString `json:"first_name,omitempty"` + LastName sql.NullString `json:"last_name,omitempty"` + Age sql.NullInt16 `json:"age,omitempty"` +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/query.sql.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/query.sql.go new file mode 100644 index 0000000000..3f6d25f6d6 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/go/query.sql.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan(&i.FirstName, &i.LastName, &i.Age); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/query.sql b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/query.sql new file mode 100644 index 0000000000..237b20193b --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/query.sql @@ -0,0 +1,2 @@ +-- name: GetAll :many +SELECT * FROM users; diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/schema.sql b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/schema.sql new file mode 100644 index 0000000000..8ca8f26d32 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + first_name varchar(255), + last_name varchar(255), + age smallint +); + diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/sqlc.json b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/sqlc.json new file mode 100644 index 0000000000..a6cd874015 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v4/sqlc.json @@ -0,0 +1,15 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_json_tags": true, + "json_tags_omit_empty": true + } + ] +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/db.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/db.go new file mode 100644 index 0000000000..4b7184a242 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/models.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/models.go new file mode 100644 index 0000000000..e6909e9e0d --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package querytest + +import ( + "github.com/jackc/pgx/v5/pgtype" +) + +type User struct { + FirstName pgtype.Text `json:"first_name,omitempty"` + LastName pgtype.Text `json:"last_name,omitempty"` + Age pgtype.Int2 `json:"age,omitempty"` +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/query.sql.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/query.sql.go new file mode 100644 index 0000000000..3f6d25f6d6 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/go/query.sql.go @@ -0,0 +1,34 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan(&i.FirstName, &i.LastName, &i.Age); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/query.sql b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/query.sql new file mode 100644 index 0000000000..237b20193b --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/query.sql @@ -0,0 +1,2 @@ +-- name: GetAll :many +SELECT * FROM users; diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/schema.sql b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/schema.sql new file mode 100644 index 0000000000..8ca8f26d32 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + first_name varchar(255), + last_name varchar(255), + age smallint +); + diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/sqlc.json b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/sqlc.json new file mode 100644 index 0000000000..5f27d5fb58 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/pgx/v5/sqlc.json @@ -0,0 +1,15 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v5", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_json_tags": true, + "json_tags_omit_empty": true + } + ] +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..df0488f428 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..334fc1d3fc --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/models.go @@ -0,0 +1,15 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 + +package querytest + +import ( + "database/sql" +) + +type User struct { + FirstName sql.NullString `json:"first_name,omitempty"` + LastName sql.NullString `json:"last_name,omitempty"` + Age sql.NullInt16 `json:"age,omitempty"` +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..c936e761b6 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.25.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const getAll = `-- name: GetAll :many +SELECT first_name, last_name, age FROM users +` + +func (q *Queries) GetAll(ctx context.Context) ([]User, error) { + rows, err := q.db.QueryContext(ctx, getAll) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan(&i.FirstName, &i.LastName, &i.Age); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/query.sql b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..237b20193b --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/query.sql @@ -0,0 +1,2 @@ +-- name: GetAll :many +SELECT * FROM users; diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..8ca8f26d32 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE users ( + first_name varchar(255), + last_name varchar(255), + age smallint +); + diff --git a/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..c5da08c629 --- /dev/null +++ b/internal/endtoend/testdata/json_tags/omit_empty/postgresql/stdlib/sqlc.json @@ -0,0 +1,14 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql", + "emit_json_tags": true, + "json_tags_omit_empty": true + } + ] +} From 3233b3eeb4ce8a05ef5ba357a49f880753a6b2d3 Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 27 Oct 2024 14:11:17 -0600 Subject: [PATCH 04/18] make errNoRows and errTooManyRows private to package --- internal/codegen/golang/imports.go | 1 + internal/codegen/golang/query.go | 2 +- internal/codegen/golang/templates/sqlite/dbCode.tmpl | 5 +++-- internal/codegen/golang/templates/sqlite/queryCode.tmpl | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index e6e0400c19..fffe511767 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -124,6 +124,7 @@ func (i *importer) dbImports() fileImports { } std := []ImportSpec{ {Path: "runtime"}, + {Path: "context"}, } sqlpkg := parseDriver(i.Options.SqlPackage) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index d32f86146d..ea2ba8cd66 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -308,7 +308,7 @@ func (q *Query) ReceiverType() string { if q.IsReadOnly() { return "Queries" } - return "WriteTx" + return "WriteTxn" } func (q *Query) hasRetType() bool { diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index b88fe6af59..31fcc2eabb 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -1,6 +1,7 @@ {{define "dbCodeTemplateStd"}} type Queries struct { + Ctx context.Context conn *sqlite.Conn pin runtime.Pinner {{- range .ReadQueries }} @@ -8,7 +9,7 @@ type Queries struct { {{- end}} } -type WriteTx struct { +type WriteTxn struct { Queries finish func(*error) {{- range .WriteQueries }} @@ -40,7 +41,7 @@ func (q *Queries) Close() *lib.Error { return firstErr } -func (txn *WriteTx) Close() *lib.Error { +func (txn *WriteTxn) Close() *lib.Error { var err, firstErr *lib.Error {{- range .WriteQueries }} txn.{{.FieldName}}.Close() diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index f2f0bca20d..95db07dc4d 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -95,9 +95,9 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De return } if !ok { - err = ErrNoRows + err = errNoRows } else if r.Next() { - err = ErrTooManyRows + err = errTooManyRows } else { result = r.Row } From b4241bc8379856cc7c422d882529da60a03ab8a2 Mon Sep 17 00:00:00 2001 From: Eloff Date: Mon, 28 Oct 2024 09:09:24 -0600 Subject: [PATCH 05/18] bind columns start at 1, result columns start at 0 (wtf) --- internal/codegen/golang/templates/sqlite/queryCode.tmpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 95db07dc4d..cb50fedd48 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -75,7 +75,7 @@ func (r *{{ .MethodName }}Stmt) Next() bool { {{- if ne .Cmd ":exec" }} {{ $fields := .Ret.Fields }} {{- if $fields }} - col := 1 + col := 0 {{- range $fields }} {{- template "fetchColumn" . }} {{- end }} From ba1a0f7c562fcbcbc379035e0bf2fca6232aa4b4 Mon Sep 17 00:00:00 2001 From: Eloff Date: Mon, 28 Oct 2024 18:05:52 -0600 Subject: [PATCH 06/18] keep Column around for code generation --- internal/codegen/golang/field.go | 2 +- internal/codegen/golang/result.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 87e202cc09..be95c81a5d 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -31,7 +31,7 @@ func (gf *Field) HasSqlcSlice() bool { } func (gf *Field) IsNullable() bool { - return strings.HasPrefix(gf.Type, "types.Null") || strings.HasSuffix(gf.Type, "ID") || gf.Column != nil && !gf.Column.NotNull + return !gf.Column.GetNotNull() } func (gf *Field) Serialize() bool { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 515d0a654f..367632333b 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -99,6 +99,7 @@ func buildStructs(req *plugin.GenerateRequest, options *opts.Options) []Struct { Type: goType(req, options, column), Tags: tags, Comment: column.Comment, + Column: column, }) } structs = append(structs, s) From 6b1aea3dd2051fe6262898def598c85f0a81a46e Mon Sep 17 00:00:00 2001 From: Eloff Date: Tue, 29 Oct 2024 07:54:46 -0600 Subject: [PATCH 07/18] some tweaks to codegen --- internal/codegen/golang/field.go | 2 +- internal/codegen/golang/sqlite_type.go | 4 ++-- internal/codegen/golang/templates/sqlite/queryCode.tmpl | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index be95c81a5d..950655bfac 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -35,7 +35,7 @@ func (gf *Field) IsNullable() bool { } func (gf *Field) Serialize() bool { - return gf.IsPointer() || strings.HasSuffix(gf.Type, "ID") || strings.HasSuffix(gf.Type, "Blob") + return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "64") && !strings.HasSuffix(gf.Type, "32") && !strings.HasSuffix(gf.Type, "Type")) } func (gf *Field) HasLen() bool { diff --git a/internal/codegen/golang/sqlite_type.go b/internal/codegen/golang/sqlite_type.go index 7eb1215329..47ad6f0a2e 100644 --- a/internal/codegen/golang/sqlite_type.go +++ b/internal/codegen/golang/sqlite_type.go @@ -43,7 +43,7 @@ func sqliteType(_req *plugin.GenerateRequest, _options *opts.Options, col *plugi } return "sql.NullTime" case "any": - return "interface{}" + return "any" } switch { @@ -65,6 +65,6 @@ func sqliteType(_req *plugin.GenerateRequest, _options *opts.Options, col *plugi if debug.Active { log.Printf("unknown SQLite type: %s\n", dt) } - return "interface{}" + return "any" } } diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index cb50fedd48..08efc3526e 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -95,9 +95,9 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De return } if !ok { - err = errNoRows + err = lib.ErrorWithDepth(errNoRows, 2) } else if r.Next() { - err = errTooManyRows + err = lib.ErrorWithDepth(errTooManyRows, 2) } else { result = r.Row } From 108907d85261407cc9bbd80b1334a969a2960dfc Mon Sep 17 00:00:00 2001 From: Eloff Date: Tue, 29 Oct 2024 08:01:46 -0600 Subject: [PATCH 08/18] make close and reset package private --- internal/codegen/golang/templates/sqlite/dbCode.tmpl | 4 ++-- internal/codegen/golang/templates/sqlite/queryCode.tmpl | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index 31fcc2eabb..ce1133e8c6 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -28,7 +28,7 @@ func (q *Queries) Close() *lib.Error { q.pin.Unpin() var err, firstErr *lib.Error {{- range .ReadQueries }} - q.{{.FieldName}}.Close() + q.{{.FieldName}}.close() err = q.{{.FieldName}}.Err() if firstErr == nil { firstErr = err @@ -44,7 +44,7 @@ func (q *Queries) Close() *lib.Error { func (txn *WriteTxn) Close() *lib.Error { var err, firstErr *lib.Error {{- range .WriteQueries }} - txn.{{.FieldName}}.Close() + txn.{{.FieldName}}.close() err = txn.{{.FieldName}}.Err() if firstErr == nil { firstErr = err diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 08efc3526e..3420580872 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -88,7 +88,7 @@ func (r *{{ .MethodName }}Stmt) Next() bool { func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.DefineType}}, err *lib.Error) { r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) - defer r.Reset() + defer r.reset() ok := r.Next() err = r.Err() if err != nil { @@ -109,7 +109,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret.DefineType}}, err *lib.Error) { r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) - defer r.Reset() + defer r.reset() for r.Next() { results = append(results, r.Row) } @@ -126,7 +126,7 @@ func (q *{{ .ReceiverType }}) Begin{{.MethodName}}({{ .Arg.Pair }}) func(yield f r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) return func(yield func(row {{.Ret.DefineType}}) bool) { - defer r.Reset() + defer r.reset() for r.Next() && yield(r.Row) {} } } @@ -140,7 +140,7 @@ func (q *{{ .ReceiverType }}) End{{.MethodName}}() *lib.Error { func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) - defer r.Reset() + defer r.reset() r.Next() return r.Err() } From 08619e8cd95bea4a834acce3f91264b74a867f2e Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 3 Nov 2024 15:58:25 -0700 Subject: [PATCH 09/18] add iter command type --- internal/metadata/meta.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 97ff36dbd2..bee2f1a61a 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -32,6 +32,7 @@ const ( CmdBatchExec = ":batchexec" CmdBatchMany = ":batchmany" CmdBatchOne = ":batchone" + CmdIter = ":iter" ) // A query name must be a valid Go identifier @@ -101,7 +102,7 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string queryName := part[2] queryType := strings.TrimSpace(part[3]) switch queryType { - case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne: + case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne, CmdIter: default: return "", "", fmt.Errorf("invalid query type: %s", queryType) } From 71cda54632bf4e3f4ba25e89797e779b234feead Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 3 Nov 2024 16:08:20 -0700 Subject: [PATCH 10/18] fix some more issues with :iter support --- internal/codegen/golang/query.go | 3 ++- internal/codegen/golang/result.go | 1 + internal/sql/validate/cmd.go | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index ea2ba8cd66..ad391b2182 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -313,7 +313,8 @@ func (q *Query) ReceiverType() string { func (q *Query) hasRetType() bool { scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || - q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne + q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne || + q.Cmd == metadata.CmdIter return scanned && !q.Ret.isEmpty() } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 367632333b..4bc997e3c6 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -337,6 +337,7 @@ var cmdReturnsData = map[string]struct{}{ metadata.CmdBatchOne: {}, metadata.CmdMany: {}, metadata.CmdOne: {}, + metadata.CmdIter: {}, } func putOutColumns(query *plugin.Query) bool { diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index 66e849de6c..b83882ac15 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -71,7 +71,7 @@ func Cmd(n ast.Node, name, cmd string) error { return err } } - if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne) { + if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne || cmd == metadata.CmdIter) { return nil } var list *ast.List From b37da2f5274c7b96ee9c5f357a4f25c065d4a032 Mon Sep 17 00:00:00 2001 From: Eloff Date: Tue, 26 Nov 2024 08:28:28 -0700 Subject: [PATCH 11/18] some codegen tweaks --- .../golang/templates/sqlite/dbCode.tmpl | 44 ++++++++++++++++++- .../golang/templates/sqlite/queryCode.tmpl | 6 ++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index ce1133e8c6..5d3c32a889 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -1,9 +1,12 @@ {{define "dbCodeTemplateStd"}} type Queries struct { - Ctx context.Context conn *sqlite.Conn pin runtime.Pinner + beginRead iter + beginWrite iter + commit iter + rollback iter {{- range .ReadQueries }} {{.FieldName}} {{.MethodName}}Stmt {{- end}} @@ -11,7 +14,6 @@ type Queries struct { type WriteTxn struct { Queries - finish func(*error) {{- range .WriteQueries }} {{.FieldName}} {{.MethodName}}Stmt {{- end}} @@ -22,6 +24,44 @@ func (q *Queries) init(conn *sqlite.Conn) { panic("Queries already initialized") } q.conn = conn + q.beginRead.init(q, "begin", "begin", false) + q.beginWrite.init(q, "begin immediate", "begin immediate", true) + q.commit.init(q, "commit", "commit", false) + q.rollback.init(q, "rollback", "rollback", false) +} + +func (q *Queries) BeginRead() (int64, *lib.Error) { + q.beginRead.Next(q) + err := q.beginRead.Err() + q.beginRead.Reset() + if err != nil { + return 0, err + } + return q.GetVersion() +} + +func (q *Queries) BeginWrite() (int64, *lib.Error) { + q.beginWrite.Next(q) + err := q.beginWrite.Err() + q.beginWrite.Reset() + if err != nil { + return 0, err + } + return q.GetVersion() +} + +func (q *Queries) Commit() (*lib.Error) { + q.commit.Next() + err := q.commit.Err() + q.commit.Reset() + return err +} + +func (q *Queries) Rollback() (*lib.Error) { + q.rollback.Next() + err := q.rollback.Err() + q.rollback.Reset() + return err } func (q *Queries) Close() *lib.Error { diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 3420580872..f8da65840f 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -122,7 +122,9 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret {{end}} {{if eq .Cmd ":iter"}} -func (q *{{ .ReceiverType }}) Begin{{.MethodName}}({{ .Arg.Pair }}) func(yield func(row {{.Ret.DefineType}}) bool) { +type {{.MethodName}}Iterator func(yield func(row {{.Ret.DefineType}}) bool) + +func (q *{{ .ReceiverType }}) {{.MethodName}}Cursor({{ .Arg.Pair }}) {{.MethodName}}Iterator { r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) return func(yield func(row {{.Ret.DefineType}}) bool) { @@ -131,7 +133,7 @@ func (q *{{ .ReceiverType }}) Begin{{.MethodName}}({{ .Arg.Pair }}) func(yield f } } -func (q *{{ .ReceiverType }}) End{{.MethodName}}() *lib.Error { +func (_ {{.MethodName}}Iterator) Close(q *{{ .ReceiverType }}) *lib.Error { return q.{{ .FieldName }}.Err() } {{end}} From 9fd465c1d1af23fca4e4992c97423961ca9057f3 Mon Sep 17 00:00:00 2001 From: Eloff Date: Sat, 30 Nov 2024 06:00:52 -0700 Subject: [PATCH 12/18] added :map command --- internal/codegen/golang/field.go | 4 ++ internal/codegen/golang/query.go | 10 ++- internal/codegen/golang/result.go | 68 ++++++++++++------- .../golang/templates/sqlite/dbCode.tmpl | 12 ++-- .../golang/templates/sqlite/queryCode.tmpl | 34 +++++++++- internal/metadata/meta.go | 3 +- internal/sql/validate/cmd.go | 2 +- 7 files changed, 97 insertions(+), 36 deletions(-) diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 950655bfac..0a753737f7 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -13,6 +13,7 @@ import ( type Field struct { Name string // CamelCased name for Go VariableForField string // Variable name for the field (including structVar.name, if part of a struct) + IsKeyField bool // If this field is being output as the key for a map DBName string // Name as used in the DB Type string Tags map[string]string @@ -116,6 +117,9 @@ func (gf *Field) FetchMethod() string { } func (gf *Field) FetchInto() string { + if gf.IsKeyField { + return "r.Key" + } if gf.Name == "" { return "r.Row" } diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index ad391b2182..a0d0dd2186 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -16,6 +16,7 @@ type QueryValue struct { DBName string // The name of the field in the database. Only set if Struct==nil. Struct *Struct Typ string + KeyField *Field // Only set if Struct!=nil && Cmd==metadata.CmdMap SQLDriver opts.SQLDriver // Column is kept so late in the generation process around to differentiate @@ -109,6 +110,13 @@ func (v *QueryValue) DefineType() string { return t } +func (v *QueryValue) KeyType() string { + if v.KeyField == nil { + return "" + } + return v.KeyField.Type +} + func (v *QueryValue) ReturnName() string { if v.IsPointer() { return "&" + escape(v.Name) @@ -314,7 +322,7 @@ func (q *Query) ReceiverType() string { func (q *Query) hasRetType() bool { scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne || - q.Cmd == metadata.CmdIter + q.Cmd == metadata.CmdIter || q.Cmd == metadata.CmdMap return scanned && !q.Ret.isEmpty() } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 4bc997e3c6..53ea9057bd 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -248,14 +248,14 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] Column: p.Column, }) } - s, err := columnsToStruct(req, options, gq.MethodName+"Params", cols, false) + gs, err := columnsToStruct(req, options, gq.MethodName+"Params", cols, false) if err != nil { return nil, err } gq.Arg = QueryValue{ Emit: true, Name: "arg", - Struct: s, + Struct: gs, SQLDriver: sqlpkg, EmitPointer: options.EmitParamsStructPointers, } @@ -278,29 +278,21 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] SQLDriver: sqlpkg, } } else if putOutColumns(query) { - var gs *Struct - var emit bool - - for _, s := range structs { - if len(s.Fields) != len(query.Columns) { - continue - } - same := true - for i, f := range s.Fields { - c := query.Columns[i] - sameName := f.Name == StructName(columnName(c, i), options) - sameType := f.Type == goType(req, options, c) - sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema) - if !sameName || !sameType || !sameTable { - same = false - } - } - if same { - gs = &s - break + var keyField *Field + if query.Cmd == metadata.CmdMap { + c := query.Columns[0] + query.Columns = query.Columns[1:] + colName := columnName(c, 0) + keyField = &Field{ + Name: StructName(colName, options), + DBName: colName, + IsKeyField: true, + Type: goType(req, options, c), + Column: c, } } + gs, noEmit := getExistingStructIfDuplicate(req, options, structs, query, query.Columns) if gs == nil { var columns []goColumn for i, c := range query.Columns { @@ -315,12 +307,12 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] if err != nil { return nil, err } - emit = true } gq.Ret = QueryValue{ - Emit: emit, + Emit: !noEmit, Name: "i", Struct: gs, + KeyField: keyField, SQLDriver: sqlpkg, EmitPointer: options.EmitResultStructPointers, } @@ -332,12 +324,40 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] return qs, nil } +func getExistingStructIfDuplicate(req *plugin.GenerateRequest, options *opts.Options, structs []Struct, query *plugin.Query, columns []*plugin.Column) (*Struct, bool) { +StructLoop: + for _, s := range structs { + if s.Name == "Queue" && query.Name == "CreateQueue" { + fmt.Println("getExistingStructIfDuplicate for Queue and CreateQueueParams", len(s.Fields), len(columns)) + } + if len(s.Fields) != len(columns) { + continue + } + for i, f := range s.Fields { + c := columns[i] + // Identical column + if c == f.Column { + continue + } + sameName := f.Name == StructName(columnName(c, i), options) + sameType := f.Type == goType(req, options, c) + sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema) + if !sameName || !sameType || !sameTable { + continue StructLoop + } + } + return &s, true + } + return nil, false +} + var cmdReturnsData = map[string]struct{}{ metadata.CmdBatchMany: {}, metadata.CmdBatchOne: {}, metadata.CmdMany: {}, metadata.CmdOne: {}, metadata.CmdIter: {}, + metadata.CmdMap: {}, } func putOutColumns(query *plugin.Query) bool { diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index 5d3c32a889..3b359f4af5 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -31,9 +31,9 @@ func (q *Queries) init(conn *sqlite.Conn) { } func (q *Queries) BeginRead() (int64, *lib.Error) { - q.beginRead.Next(q) + q.beginRead.Next() err := q.beginRead.Err() - q.beginRead.Reset() + q.beginRead.reset() if err != nil { return 0, err } @@ -41,9 +41,9 @@ func (q *Queries) BeginRead() (int64, *lib.Error) { } func (q *Queries) BeginWrite() (int64, *lib.Error) { - q.beginWrite.Next(q) + q.beginWrite.Next() err := q.beginWrite.Err() - q.beginWrite.Reset() + q.beginWrite.reset() if err != nil { return 0, err } @@ -53,14 +53,14 @@ func (q *Queries) BeginWrite() (int64, *lib.Error) { func (q *Queries) Commit() (*lib.Error) { q.commit.Next() err := q.commit.Err() - q.commit.Reset() + q.commit.reset() return err } func (q *Queries) Rollback() (*lib.Error) { q.rollback.Next() err := q.rollback.Err() - q.rollback.Reset() + q.rollback.reset() return err } diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index f8da65840f..c513635a24 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -8,6 +8,9 @@ type {{ .MethodName }}Stmt struct { iter {{- if ne .Cmd ":exec" }} Row {{.Ret.DefineType}} + {{- if eq .Cmd ":map" }} + Key {{.Ret.KeyType}} + {{- end }} {{- end }} } @@ -31,9 +34,8 @@ func (r *{{ .MethodName }}Stmt) bind(q *{{ .ReceiverType }}, {{ .Arg.Pair }}) { {{- if .Arg.HasSqlcSlices }} {{- range .Arg.Fields }} {{- if .HasSqlcSlice }} - length := len({{ .VariableForField }}) - if length > 0 { - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", length)[1:], 1) + if len({{ .VariableForField }}) > 0 { + query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{ .VariableForField }}))[1:], 1) } else { query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) } @@ -76,6 +78,9 @@ func (r *{{ .MethodName }}Stmt) Next() bool { {{ $fields := .Ret.Fields }} {{- if $fields }} col := 0 + {{- if eq .Cmd ":map" }} + {{- template "fetchColumn" .Ret.KeyField }} + {{- end }} {{- range $fields }} {{- template "fetchColumn" . }} {{- end }} @@ -121,6 +126,29 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret } {{end}} +{{if eq .Cmd ":map"}} +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.Ret.KeyType}}]{{.Ret.DefineType}}, err *lib.Error) { + r := &q.{{ .FieldName }} + r.bind(q, {{ .Arg.Names }}) + defer r.reset() + if !r.Next() { + err = r.Err() + return + } + results = map[{{.Ret.KeyType}}]{{.Ret.DefineType}}{ + r.Key: r.Row, + } + for r.Next() { + results[r.Key] = r.Row + } + err = r.Err() + if err != nil { + results = nil + } + return +} +{{end}} + {{if eq .Cmd ":iter"}} type {{.MethodName}}Iterator func(yield func(row {{.Ret.DefineType}}) bool) diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index bee2f1a61a..c7fa6d969a 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -33,6 +33,7 @@ const ( CmdBatchMany = ":batchmany" CmdBatchOne = ":batchone" CmdIter = ":iter" + CmdMap = ":map" ) // A query name must be a valid Go identifier @@ -102,7 +103,7 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string queryName := part[2] queryType := strings.TrimSpace(part[3]) switch queryType { - case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne, CmdIter: + case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne, CmdIter, CmdMap: default: return "", "", fmt.Errorf("invalid query type: %s", queryType) } diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index b83882ac15..8aca1f9840 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -71,7 +71,7 @@ func Cmd(n ast.Node, name, cmd string) error { return err } } - if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne || cmd == metadata.CmdIter) { + if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne || cmd == metadata.CmdIter || cmd == metadata.CmdMap) { return nil } var list *ast.List From cf79d29ff726b2894418da96ebe7822811d95f7e Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 15 Dec 2024 10:29:09 -0700 Subject: [PATCH 13/18] fix some bugs --- internal/codegen/golang/field.go | 8 +- internal/codegen/golang/imports.go | 1 - internal/codegen/golang/query.go | 4 + internal/codegen/golang/result.go | 3 - .../golang/templates/sqlite/dbCode.tmpl | 38 ++++--- .../golang/templates/sqlite/queryCode.tmpl | 105 ++++++++++++++---- 6 files changed, 115 insertions(+), 44 deletions(-) diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index 0a753737f7..2712555f58 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -36,7 +36,11 @@ func (gf *Field) IsNullable() bool { } func (gf *Field) Serialize() bool { - return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "64") && !strings.HasSuffix(gf.Type, "32") && !strings.HasSuffix(gf.Type, "Type")) + return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "Duration")) +} + +func (gf *Field) Deserialize() bool { + return gf.IsPointer() || (strings.IndexByte(gf.Type, '.') > 0 && !strings.HasSuffix(gf.Type, "64") && !strings.HasSuffix(gf.Type, "32") && !strings.HasSuffix(gf.Type, "Type") && !strings.HasSuffix(gf.Type, "Duration")) } func (gf *Field) HasLen() bool { @@ -65,7 +69,7 @@ func (gf *Field) BindType() string { ty = ty[2:] } switch { - case ty == "uint8", ty == "int8", ty == "uint32": + case ty == "uint8", ty == "int8", ty == "uint32", strings.HasSuffix(ty, "Duration"): return "int64" case ty == "float32": return "float64" diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index fffe511767..e6e0400c19 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -124,7 +124,6 @@ func (i *importer) dbImports() fileImports { } std := []ImportSpec{ {Path: "runtime"}, - {Path: "context"}, } sqlpkg := parseDriver(i.Options.SqlPackage) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index a0d0dd2186..ddc7601c1d 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -326,6 +326,10 @@ func (q *Query) hasRetType() bool { return scanned && !q.Ret.isEmpty() } +func (q *Query) FetchRowInNext() bool { + return q.Cmd != metadata.CmdExec && q.Cmd != metadata.CmdIter +} + func (q *Query) TableIdentifierAsGoSlice() string { escapedNames := make([]string, 0, 3) for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 53ea9057bd..b7aa91dcff 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -327,9 +327,6 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] func getExistingStructIfDuplicate(req *plugin.GenerateRequest, options *opts.Options, structs []Struct, query *plugin.Query, columns []*plugin.Column) (*Struct, bool) { StructLoop: for _, s := range structs { - if s.Name == "Queue" && query.Name == "CreateQueue" { - fmt.Println("getExistingStructIfDuplicate for Queue and CreateQueueParams", len(s.Fields), len(columns)) - } if len(s.Fields) != len(columns) { continue } diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index 3b359f4af5..656da75312 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -4,9 +4,7 @@ type Queries struct { conn *sqlite.Conn pin runtime.Pinner beginRead iter - beginWrite iter commit iter - rollback iter {{- range .ReadQueries }} {{.FieldName}} {{.MethodName}}Stmt {{- end}} @@ -14,6 +12,8 @@ type Queries struct { type WriteTxn struct { Queries + beginWrite iter + rollback iter {{- range .WriteQueries }} {{.FieldName}} {{.MethodName}}Stmt {{- end}} @@ -25,14 +25,18 @@ func (q *Queries) init(conn *sqlite.Conn) { } q.conn = conn q.beginRead.init(q, "begin", "begin", false) - q.beginWrite.init(q, "begin immediate", "begin immediate", true) q.commit.init(q, "commit", "commit", false) - q.rollback.init(q, "rollback", "rollback", false) +} + +func (txn *WriteTxn) init(conn *sqlite.Conn) { + txn.Queries.init(conn) + txn.beginWrite.init(&txn.Queries, "begin immediate", "begin immediate", false) + txn.rollback.init(&txn.Queries, "rollback", "rollback", false) } func (q *Queries) BeginRead() (int64, *lib.Error) { q.beginRead.Next() - err := q.beginRead.Err() + err := q.beginRead.takeErr() q.beginRead.reset() if err != nil { return 0, err @@ -40,27 +44,27 @@ func (q *Queries) BeginRead() (int64, *lib.Error) { return q.GetVersion() } -func (q *Queries) BeginWrite() (int64, *lib.Error) { - q.beginWrite.Next() - err := q.beginWrite.Err() - q.beginWrite.reset() +func (txn *WriteTxn) BeginWrite() (int64, *lib.Error) { + txn.beginWrite.Next() + err := txn.beginWrite.takeErr() + txn.beginWrite.reset() if err != nil { return 0, err } - return q.GetVersion() + return txn.GetVersion() } func (q *Queries) Commit() (*lib.Error) { q.commit.Next() - err := q.commit.Err() + err := q.commit.takeErr() q.commit.reset() return err } -func (q *Queries) Rollback() (*lib.Error) { - q.rollback.Next() - err := q.rollback.Err() - q.rollback.reset() +func (txn *WriteTxn) Rollback() (*lib.Error) { + txn.rollback.Next() + err := txn.rollback.takeErr() + txn.rollback.reset() return err } @@ -69,7 +73,7 @@ func (q *Queries) Close() *lib.Error { var err, firstErr *lib.Error {{- range .ReadQueries }} q.{{.FieldName}}.close() - err = q.{{.FieldName}}.Err() + err = q.{{.FieldName}}.takeErr() if firstErr == nil { firstErr = err } @@ -85,7 +89,7 @@ func (txn *WriteTxn) Close() *lib.Error { var err, firstErr *lib.Error {{- range .WriteQueries }} txn.{{.FieldName}}.close() - err = txn.{{.FieldName}}.Err() + err = txn.{{.FieldName}}.takeErr() if firstErr == nil { firstErr = err } diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index c513635a24..082cbd6bfb 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -1,5 +1,6 @@ {{define "queryCodeStd"}} {{range $.AllQueries .SourceName}} +{{$query := .}} const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{escape .SQL}} {{$.Q}} @@ -34,11 +35,7 @@ func (r *{{ .MethodName }}Stmt) bind(q *{{ .ReceiverType }}, {{ .Arg.Pair }}) { {{- if .Arg.HasSqlcSlices }} {{- range .Arg.Fields }} {{- if .HasSqlcSlice }} - if len({{ .VariableForField }}) > 0 { - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", strings.Repeat(",?", len({{ .VariableForField }}))[1:], 1) - } else { - query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", "NULL", 1) - } + query = strings.Replace(query, "/*SLICE:{{.Column.Name}}*/?", slicePlaceholders(len({{ .VariableForField }})), 1) {{- end }} {{- end }} {{- end }} @@ -74,7 +71,14 @@ func (r *{{ .MethodName }}Stmt) Next() bool { if r.err != nil { return false } - {{- if ne .Cmd ":exec" }} + return r.Fetch() +} + +func (r *{{ .MethodName }}Stmt) Fetch() bool { + if r.err != nil { + return false + } + {{ $fields := .Ret.Fields }} {{- if $fields }} col := 0 @@ -85,7 +89,7 @@ func (r *{{ .MethodName }}Stmt) Next() bool { {{- template "fetchColumn" . }} {{- end }} {{- end }} - {{- end }} + return r.err == nil } @@ -95,14 +99,14 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De r.bind(q, {{ .Arg.Names }}) defer r.reset() ok := r.Next() - err = r.Err() + err = r.takeErr() if err != nil { return } if !ok { - err = lib.ErrorWithDepth(errNoRows, 2) + err = lib.WrapWithDepth(errNoRows, 3) } else if r.Next() { - err = lib.ErrorWithDepth(errTooManyRows, 2) + err = lib.WrapWithDepth(errTooManyRows, 3) } else { result = r.Row } @@ -118,7 +122,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret for r.Next() { results = append(results, r.Row) } - err = r.Err() + err = r.takeErr() if err != nil { results = nil } @@ -132,7 +136,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R r.bind(q, {{ .Arg.Names }}) defer r.reset() if !r.Next() { - err = r.Err() + err = r.takeErr() return } results = map[{{.Ret.KeyType}}]{{.Ret.DefineType}}{ @@ -141,7 +145,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R for r.Next() { results[r.Key] = r.Row } - err = r.Err() + err = r.takeErr() if err != nil { results = nil } @@ -150,20 +154,79 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R {{end}} {{if eq .Cmd ":iter"}} -type {{.MethodName}}Iterator func(yield func(row {{.Ret.DefineType}}) bool) +func (r *{{ .MethodName }}Stmt) HasNext() bool { + if !r.iter.Next() { + return false + } + if r.err != nil { + return false + } + return r.err == nil +} -func (q *{{ .ReceiverType }}) {{.MethodName}}Cursor({{ .Arg.Pair }}) {{.MethodName}}Iterator { - r := &q.{{ .FieldName }} - r.bind(q, {{ .Arg.Names }}) +func (r *{{ .MethodName }}Stmt) Close() *lib.Error { + err := r.takeErr() + r.reset() + return err +} + +func (r *{{ .MethodName }}Stmt) Range() func(yield func(row {{.Ret.DefineType}}) bool) { return func(yield func(row {{.Ret.DefineType}}) bool) { defer r.reset() for r.Next() && yield(r.Row) {} } } -func (_ {{.MethodName}}Iterator) Close(q *{{ .ReceiverType }}) *lib.Error { - return q.{{ .FieldName }}.Err() +func (q *{{ .ReceiverType }}) {{.MethodName}}Cursor({{ .Arg.Pair }}) *{{ .MethodName }}Stmt { + r := &q.{{ .FieldName }} + r.bind(q, {{ .Arg.Names }}) + return r +} + +{{- range $index, $field := .Ret.Fields }} +{{- if $field.IsNullable }} +func (r *{{ $query.MethodName }}Stmt) Peek{{ $field.Name }}() {{ $field.BindType }} { + const col = {{$index}} + return {{ $field.FetchMethod }}(col) +} +{{- else }} +func (r *{{ $query.MethodName }}Stmt) Peek{{ $field.Name }}() (val {{ $field.BindType }}, isNull bool) { + const col = {{$index}} + if r.stmt.ColumnIsNull(col) { + isNull = true + return + } + val = {{ $field.FetchMethod }}(col) + return +} +{{- end }} + +func (r *{{ $query.MethodName }}Stmt) Get{{ $field.Name }}() (val {{ $field.Type }}, ok bool) { + var raw {{ $field.BindType }} +{{- if $field.IsNullable }} + raw = r.Peek{{ $field.Name }}() +{{- else }} + var isNull bool + raw, isNull = r.Peek{{ $field.Name }}() + if isNull { + return val, true + } +{{- end }} + {{- if $field.Deserialize }} + var err *lib.Error + val, err = {{ $field.DeserializeMethod }}(raw) + if err != nil { + r.setErr(err) + return val, false + } + {{- else if $field.NeedsCast $field.BindType }} + val = {{ $field.Type }}(raw) + {{- else }} + val = raw + {{- end }} + return val, true } +{{- end }} {{end}} {{if eq .Cmd ":exec"}} @@ -172,7 +235,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { r.bind(q, {{ .Arg.Names }}) defer r.reset() r.Next() - return r.Err() + return r.takeErr() } {{end}} @@ -253,7 +316,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) *lib.Error { { {{- end }} raw := {{ .FetchMethod }}(col) - {{- if .Serialize }} + {{- if .Deserialize }} v, err := {{ .DeserializeMethod }}(raw) if err != nil { r.setErr(err) From 502669909bfe57523870ee879fcb0fed286e3888 Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 12 Jan 2025 10:15:41 -0600 Subject: [PATCH 14/18] add finalizers to db connections --- .../golang/templates/sqlite/dbCode.tmpl | 37 ++++++++++--------- .../golang/templates/sqlite/queryCode.tmpl | 11 +----- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/internal/codegen/golang/templates/sqlite/dbCode.tmpl b/internal/codegen/golang/templates/sqlite/dbCode.tmpl index 656da75312..4fb07bd627 100644 --- a/internal/codegen/golang/templates/sqlite/dbCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/dbCode.tmpl @@ -19,19 +19,23 @@ type WriteTxn struct { {{- end}} } -func (q *Queries) init(conn *sqlite.Conn) { +func (q *Queries) init(conn *sqlite.Conn, isWriteTxn bool) { if q.conn != nil { panic("Queries already initialized") } q.conn = conn q.beginRead.init(q, "begin", "begin", false) q.commit.init(q, "commit", "commit", false) + if !isWriteTxn { + runtime.SetFinalizer(q, (*Queries).Close) + } } func (txn *WriteTxn) init(conn *sqlite.Conn) { - txn.Queries.init(conn) + txn.Queries.init(conn, true) txn.beginWrite.init(&txn.Queries, "begin immediate", "begin immediate", false) txn.rollback.init(&txn.Queries, "rollback", "rollback", false) + runtime.SetFinalizer(txn, (*WriteTxn).Close) } func (q *Queries) BeginRead() (int64, *lib.Error) { @@ -55,6 +59,7 @@ func (txn *WriteTxn) BeginWrite() (int64, *lib.Error) { } func (q *Queries) Commit() (*lib.Error) { + q.pin.Unpin() q.commit.Next() err := q.commit.takeErr() q.commit.reset() @@ -62,6 +67,7 @@ func (q *Queries) Commit() (*lib.Error) { } func (txn *WriteTxn) Rollback() (*lib.Error) { + txn.Queries.pin.Unpin() txn.rollback.Next() err := txn.rollback.takeErr() txn.rollback.reset() @@ -69,16 +75,17 @@ func (txn *WriteTxn) Rollback() (*lib.Error) { } func (q *Queries) Close() *lib.Error { + if q.conn == nil { + return nil + } q.pin.Unpin() - var err, firstErr *lib.Error + var firstErr *lib.Error {{- range .ReadQueries }} q.{{.FieldName}}.close() - err = q.{{.FieldName}}.takeErr() - if firstErr == nil { - firstErr = err - } + firstErr = firstErr.Or(q.{{.FieldName}}.takeErr()) {{- end}} extErr := q.conn.Close() + q.conn = nil if firstErr == nil && extErr != nil { firstErr = lib.Wrap(extErr) } @@ -86,19 +93,15 @@ func (q *Queries) Close() *lib.Error { } func (txn *WriteTxn) Close() *lib.Error { - var err, firstErr *lib.Error + if txn.conn == nil { + return nil + } + var firstErr *lib.Error {{- range .WriteQueries }} txn.{{.FieldName}}.close() - err = txn.{{.FieldName}}.takeErr() - if firstErr == nil { - firstErr = err - } + firstErr = firstErr.Or(txn.{{.FieldName}}.takeErr()) {{- end}} - err = txn.Queries.Close() - if firstErr == nil { - firstErr = err - } - return firstErr + return firstErr.Or(txn.Queries.Close()) } {{end}} diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 082cbd6bfb..deb90e48ba 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -68,9 +68,6 @@ func (r *{{ .MethodName }}Stmt) Next() bool { if !r.iter.Next() { return false } - if r.err != nil { - return false - } return r.Fetch() } @@ -155,13 +152,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R {{if eq .Cmd ":iter"}} func (r *{{ .MethodName }}Stmt) HasNext() bool { - if !r.iter.Next() { - return false - } - if r.err != nil { - return false - } - return r.err == nil + return r.iter.Next() } func (r *{{ .MethodName }}Stmt) Close() *lib.Error { From 86cb4cb806fa0296c0cb7fb9d161a5677e534054 Mon Sep 17 00:00:00 2001 From: Eloff Date: Sun, 12 Jan 2025 12:50:31 -0600 Subject: [PATCH 15/18] take three, don't return ErrNoRows for :one queries with a returning clause --- internal/codegen/golang/query.go | 4 ++++ internal/codegen/golang/templates/sqlite/queryCode.tmpl | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index ddc7601c1d..116f2a31c5 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -312,6 +312,10 @@ func (q *Query) IsReadOnly() bool { return q.Cmd != metadata.CmdExec && strings.EqualFold(q.SQL[:6], "select") } +func (q *Query) HasReturning() bool { + return !q.IsReadOnly() && !q.Ret.isEmpty() +} + func (q *Query) ReceiverType() string { if q.IsReadOnly() { return "Queries" diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index deb90e48ba..a42befd4c1 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -98,6 +98,11 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De ok := r.Next() err = r.takeErr() if err != nil { + {{if .HasReturning}} + if IsErrNoRows(err) { + err = nil + } + {{end}} return } if !ok { From d6135d15d0ab9c86b70a5947cfde7ccafbf3e66a Mon Sep 17 00:00:00 2001 From: Eloff Date: Sat, 18 Jan 2025 12:36:58 -0700 Subject: [PATCH 16/18] modify :many to use a slice passed into the query method, more flexible / efficient --- internal/codegen/golang/templates/sqlite/queryCode.tmpl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index a42befd4c1..316f478080 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -117,7 +117,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De {{end}} {{if eq .Cmd ":many"}} -func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret.DefineType}}, err *lib.Error) { +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}, results []{{.Ret.DefineType}}) ([]{{.Ret.DefineType}}, err *lib.Error) { r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) defer r.reset() @@ -126,9 +126,9 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results []{{.Ret } err = r.takeErr() if err != nil { - results = nil + return nil, err } - return + return results, nil } {{end}} From a9d5e06347c8509980351b50cc6795840663b1b9 Mon Sep 17 00:00:00 2001 From: Eloff Date: Sat, 18 Jan 2025 21:45:57 -0700 Subject: [PATCH 17/18] fix codegen --- internal/codegen/golang/templates/sqlite/queryCode.tmpl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 316f478080..7a3033e132 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -117,14 +117,14 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (result {{.Ret.De {{end}} {{if eq .Cmd ":many"}} -func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}, results []{{.Ret.DefineType}}) ([]{{.Ret.DefineType}}, err *lib.Error) { +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}, results []{{.Ret.DefineType}}) ([]{{.Ret.DefineType}}, *lib.Error) { r := &q.{{ .FieldName }} r.bind(q, {{ .Arg.Names }}) defer r.reset() for r.Next() { results = append(results, r.Row) } - err = r.takeErr() + err := r.takeErr() if err != nil { return nil, err } From 9ea542d158083767f6d128ff772b864fc9853da4 Mon Sep 17 00:00:00 2001 From: Eloff Date: Fri, 24 Jan 2025 09:14:57 -0700 Subject: [PATCH 18/18] add support for :set return type --- internal/codegen/golang/query.go | 4 +-- internal/codegen/golang/result.go | 15 ++++++++++- .../golang/templates/sqlite/queryCode.tmpl | 25 +++++++++++++++++++ internal/metadata/meta.go | 3 ++- internal/sql/validate/cmd.go | 2 +- 5 files changed, 44 insertions(+), 5 deletions(-) diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 116f2a31c5..1210f934d0 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -16,7 +16,7 @@ type QueryValue struct { DBName string // The name of the field in the database. Only set if Struct==nil. Struct *Struct Typ string - KeyField *Field // Only set if Struct!=nil && Cmd==metadata.CmdMap + KeyField *Field // Only set if Struct!=nil && Cmd==metadata.CmdMap || Cmd==metadata.CmdSet SQLDriver opts.SQLDriver // Column is kept so late in the generation process around to differentiate @@ -326,7 +326,7 @@ func (q *Query) ReceiverType() string { func (q *Query) hasRetType() bool { scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany || q.Cmd == metadata.CmdBatchMany || q.Cmd == metadata.CmdBatchOne || - q.Cmd == metadata.CmdIter || q.Cmd == metadata.CmdMap + q.Cmd == metadata.CmdIter || q.Cmd == metadata.CmdMap || q.Cmd == metadata.CmdSet return scanned && !q.Ret.isEmpty() } diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index b7aa91dcff..712faeb6ce 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -271,15 +271,27 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] c := query.Columns[0] name := columnName(c, 0) name = strings.Replace(name, "$", "_", -1) + var keyField *Field + if query.Cmd == metadata.CmdSet { + query.Columns = query.Columns[1:] + keyField = &Field{ + Name: StructName(name, options), + DBName: name, + IsKeyField: true, + Type: goType(req, options, c), + Column: c, + } + } gq.Ret = QueryValue{ Name: escape(name), DBName: name, Typ: goType(req, options, c), + KeyField: keyField, SQLDriver: sqlpkg, } } else if putOutColumns(query) { var keyField *Field - if query.Cmd == metadata.CmdMap { + if query.Cmd == metadata.CmdMap || query.Cmd == metadata.CmdSet { c := query.Columns[0] query.Columns = query.Columns[1:] colName := columnName(c, 0) @@ -355,6 +367,7 @@ var cmdReturnsData = map[string]struct{}{ metadata.CmdOne: {}, metadata.CmdIter: {}, metadata.CmdMap: {}, + metadata.CmdSet: {}, } func putOutColumns(query *plugin.Query) bool { diff --git a/internal/codegen/golang/templates/sqlite/queryCode.tmpl b/internal/codegen/golang/templates/sqlite/queryCode.tmpl index 7a3033e132..b69e54d3e5 100644 --- a/internal/codegen/golang/templates/sqlite/queryCode.tmpl +++ b/internal/codegen/golang/templates/sqlite/queryCode.tmpl @@ -139,6 +139,7 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R defer r.reset() if !r.Next() { err = r.takeErr() + results = map[{{.Ret.KeyType}}]{{.Ret.DefineType}}{} return } results = map[{{.Ret.KeyType}}]{{.Ret.DefineType}}{ @@ -155,6 +156,30 @@ func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.R } {{end}} +{{if eq .Cmd ":set"}} +func (q *{{ .ReceiverType }}) {{.MethodName}}({{ .Arg.Pair }}) (results map[{{.Ret.KeyType}}]struct{}, err *lib.Error) { + r := &q.{{ .FieldName }} + r.bind(q, {{ .Arg.Names }}) + defer r.reset() + if !r.Next() { + err = r.takeErr() + results = map[{{.Ret.KeyType}}]struct{}{} + return + } + results = map[{{.Ret.KeyType}}]struct{}{ + r.Row: {}, + } + for r.Next() { + results[r.Row] = struct{}{} + } + err = r.takeErr() + if err != nil { + results = nil + } + return +} +{{end}} + {{if eq .Cmd ":iter"}} func (r *{{ .MethodName }}Stmt) HasNext() bool { return r.iter.Next() diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index c7fa6d969a..909b5c988d 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -34,6 +34,7 @@ const ( CmdBatchOne = ":batchone" CmdIter = ":iter" CmdMap = ":map" + CmdSet = ":set" ) // A query name must be a valid Go identifier @@ -103,7 +104,7 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string queryName := part[2] queryType := strings.TrimSpace(part[3]) switch queryType { - case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne, CmdIter, CmdMap: + case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdExecLastId, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne, CmdIter, CmdMap, CmdSet: default: return "", "", fmt.Errorf("invalid query type: %s", queryType) } diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index 8aca1f9840..6e1a0def4a 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -71,7 +71,7 @@ func Cmd(n ast.Node, name, cmd string) error { return err } } - if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne || cmd == metadata.CmdIter || cmd == metadata.CmdMap) { + if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne || cmd == metadata.CmdBatchMany || cmd == metadata.CmdBatchOne || cmd == metadata.CmdIter || cmd == metadata.CmdMap || cmd == metadata.CmdSet) { return nil } var list *ast.List