diff --git a/.gitignore b/.gitignore index 0f1d00e11..3243952a4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ *.test *~ *.swp +.idea +.vscode \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 3498c53dc..f378207f2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.13.x - 1.14.x + - 1.15.x - master sudo: true @@ -13,6 +13,7 @@ env: - PQGOSSLTESTS=1 - PQSSLCERTTEST_PATH=$PWD/certs - PGHOST=127.0.0.1 + - GODEBUG=x509ignoreCN=0 matrix: - PGVERSION=10 - PGVERSION=9.6 diff --git a/array.go b/array.go index e4933e227..405da2368 100644 --- a/array.go +++ b/array.go @@ -35,19 +35,31 @@ func Array(a interface{}) interface { return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) case []string: return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) case *[]string: return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) } return GenericArray{a} @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + var x float64 + if x, err = strconv.ParseFloat(string(v), 32); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) { return "{}", nil } +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + var x int + if x, err = strconv.Atoi(string(v)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string diff --git a/array_test.go b/array_test.go index f724bcd88..5ca9f7a55 100644 --- a/array_test.go +++ b/array_test.go @@ -104,16 +104,33 @@ func TestArrayScanner(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", s) } + s = Array(&[]float32{}) + if _, ok := s.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", s) + } + + s = Array(&[]int32{}) + if _, ok := s.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", s) + } + s = Array(&[]string{}) if _, ok := s.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", s) } + s = Array(&[][]byte{}) + if _, ok := s.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", s) + } + for _, tt := range []interface{}{ &[]sql.Scanner{}, &[][]bool{}, &[][]float64{}, &[][]int64{}, + &[][]float32{}, + &[][]int32{}, &[][]string{}, } { s = Array(tt) @@ -139,17 +156,34 @@ func TestArrayValuer(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", v) } + v = Array([]float32{}) + if _, ok := v.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", v) + } + + v = Array([]int32{}) + if _, ok := v.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", v) + } + v = Array([]string{}) if _, ok := v.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", v) } + v = Array([][]byte{}) + if _, ok := v.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", v) + } + for _, tt := range []interface{}{ nil, []driver.Value{}, [][]bool{}, [][]float64{}, [][]int64{}, + [][]float32{}, + [][]int32{}, [][]string{}, } { v = Array(tt) @@ -773,6 +807,313 @@ func BenchmarkInt64ArrayValue(b *testing.B) { } } +func TestFloat32ArrayScanUnsupported(t *testing.T) { + var arr Float32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestFloat32ArrayScanEmpty(t *testing.T) { + var arr Float32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestFloat32ArrayScanNil(t *testing.T) { + arr := Float32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Float32ArrayStringTests = []struct { + str string + arr Float32Array +}{ + {`{}`, Float32Array{}}, + {`{1.2}`, Float32Array{1.2}}, + {`{3.456,7.89}`, Float32Array{3.456, 7.89}}, + {`{3,1,2}`, Float32Array{3, 1, 2}}, +} + +func TestFloat32ArrayScanBytes(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + bytes := []byte(tt.str) + arr := Float32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat32ArrayScanBytes(b *testing.B) { + var a Float32Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float32Array{} + a.Scan(x) + } +} + +func TestFloat32ArrayScanString(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat32ArrayValue(t *testing.T) { + result, err := Float32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float32Array([]float32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float32Array([]float32{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Float32() + } + a := Float32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt32ArrayScanUnsupported(t *testing.T) { + var arr Int32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestInt32ArrayScanEmpty(t *testing.T) { + var arr Int32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestInt32ArrayScanNil(t *testing.T) { + arr := Int32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Int32ArrayStringTests = []struct { + str string + arr Int32Array +}{ + {`{}`, Int32Array{}}, + {`{12}`, Int32Array{12}}, + {`{345,678}`, Int32Array{345, 678}}, +} + +func TestInt32ArrayScanBytes(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + bytes := []byte(tt.str) + arr := Int32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt32ArrayScanBytes(b *testing.B) { + var a Int32Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int32Array{} + a.Scan(x) + } +} + +func TestInt32ArrayScanString(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt32ArrayValue(t *testing.T) { + result, err := Int32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int32Array([]int32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int32Array([]int32{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int31() + } + a := Int32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + func TestStringArrayScanUnsupported(t *testing.T) { var arr StringArray err := arr.Scan(true) diff --git a/conn.go b/conn.go index f313c1498..db0b6cef5 100644 --- a/conn.go +++ b/conn.go @@ -18,6 +18,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" "unicode" @@ -38,13 +39,18 @@ var ( errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) +// Compile time validation that our types implement the expected interfaces +var ( + _ driver.Driver = Driver{} +) + // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. -func (d *Driver) Open(name string) (driver.Conn, error) { +func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } @@ -136,7 +142,7 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. - bad bool + bad *atomic.Value // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for @@ -294,9 +300,12 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { o := c.opts + bad := &atomic.Value{} + bad.Store(false) cn = &conn{ opts: o, dialer: c.dialer, + bad: bad, } err = cn.handleDriverSettings(o) if err != nil { @@ -501,9 +510,22 @@ func (cn *conn) isInTransaction() bool { cn.txnStatus == txnStatusInFailedTransaction } +func (cn *conn) setBad() { + if cn.bad != nil { + cn.bad.Store(true) + } +} + +func (cn *conn) getBad() bool { + if cn.bad != nil { + return cn.bad.Load().(bool) + } + return false +} + func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { - cn.bad = true + cn.setBad() errorf("unexpected transaction status %v", cn.txnStatus) } } @@ -513,7 +535,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -524,11 +546,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { return nil, err } if commandTag != "BEGIN" { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil @@ -542,7 +564,7 @@ func (cn *conn) closeTxn() { func (cn *conn) Commit() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -564,12 +586,12 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } if commandTag != "COMMIT" { - cn.bad = true + cn.setBad() return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) @@ -578,7 +600,7 @@ func (cn *conn) Commit() (err error) { func (cn *conn) Rollback() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -590,7 +612,7 @@ func (cn *conn) rollback() (err error) { _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } @@ -630,7 +652,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err case 'T', 'D': // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -652,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected message %q in simple query execution", t) } if res == nil { @@ -663,8 +685,11 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' && res.colNames == nil { + if t == 'C' { res.result, res.tag = cn.parseComplete(r.string()) + if res.colNames != nil { + return + } } res.done = true case 'Z': @@ -676,7 +701,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { err = parseError(r) case 'D': if res == nil { - cn.bad = true + cn.setBad() errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next @@ -691,7 +716,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -784,7 +809,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -823,7 +848,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } if cn.inCopy { @@ -857,7 +882,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -891,9 +916,20 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } +type safeRetryError struct { + Err error +} + +func (se *safeRetryError) Error() string { + return se.Err.Error() +} + func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) + n, err := cn.c.Write(m.wrap()) if err != nil { + if n == 0 { + err = &safeRetryError{Err: err} + } panic(err) } } @@ -918,7 +954,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { - cn.bad = true + cn.setBad() errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ @@ -1288,7 +1324,7 @@ func (st *stmt) Close() (err error) { if st.closed { return nil } - if st.cn.bad { + if st.cn.getBad() { return driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1302,14 +1338,14 @@ func (st *stmt) Close() (err error) { t, _ := st.cn.recv1() if t != '3' { - st.cn.bad = true + st.cn.setBad() errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { - st.cn.bad = true + st.cn.setBad() errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) @@ -1318,7 +1354,7 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1331,7 +1367,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1418,7 +1454,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { - cn.bad = true + cn.setBad() errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] @@ -1430,7 +1466,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { - cn.bad = true + cn.setBad() errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag @@ -1497,7 +1533,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } conn := rs.cn - if conn.bad { + if conn.getBad() { return driver.ErrBadConn } defer conn.errRecover(&err) @@ -1522,7 +1558,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'D': n := rs.rb.int16() if err != nil { - conn.bad = true + conn.setBad() errorf("unexpected DataRow after error %s", err) } if n < len(dest) { @@ -1717,7 +1753,7 @@ func (cn *conn) readReadyForQuery() { cn.processReadyForQuery(r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message %q; expected ReadyForQuery", t) } } @@ -1737,7 +1773,7 @@ func (cn *conn) readParseResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Parse response %q", t) } } @@ -1762,7 +1798,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe statement response %q", t) } } @@ -1780,7 +1816,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe response %q", t) } panic("not reached") @@ -1796,7 +1832,7 @@ func (cn *conn) readBindResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Bind response %q", t) } } @@ -1823,7 +1859,7 @@ func (cn *conn) postExecuteWorkaround() { cn.saveMessage(t, r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message during extended query execution: %q", t) } } @@ -1836,7 +1872,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co switch t { case 'C': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) @@ -1850,7 +1886,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co err = parseError(r) case 'T', 'D', 'I': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected %q after error %s", t, err) } if t == 'I' { @@ -1858,7 +1894,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown %s response: %q", protocolState, t) } } diff --git a/conn_go18.go b/conn_go18.go index 09e2ea464..8cab67c9d 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "sync/atomic" "time" ) @@ -89,10 +90,21 @@ func (cn *conn) Ping(ctx context.Context) error { func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { - finished := make(chan struct{}) + finished := make(chan struct{}, 1) go func() { select { case <-done: + select { + case finished <- struct{}{}: + default: + // We raced with the finish func, let the next query handle this with the + // context. + return + } + + // Set the connection state to bad so it does not get reused. + cn.setBad() + // At this point the function level context is canceled, // so it must not be used for the additional network // request to cancel the query. @@ -101,13 +113,14 @@ func (cn *conn) watchCancel(ctx context.Context) func() { defer cancel() _ = cn.cancel(ctxCancel) - finished <- struct{}{} case <-finished: } }() return func() { select { case <-finished: + cn.setBad() + cn.Close() case finished <- struct{}{}: } } @@ -123,8 +136,11 @@ func (cn *conn) cancel(ctx context.Context) error { defer c.Close() { + bad := &atomic.Value{} + bad.Store(false) can := conn{ - c: c, + c: c, + bad: bad, } err = can.ssl(cn.opts) if err != nil { diff --git a/conn_test.go b/conn_test.go index 0d25c9554..a05d81d0b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,6 +10,7 @@ import ( "os" "reflect" "strings" + "sync/atomic" "testing" "time" ) @@ -695,7 +696,9 @@ func TestErrorDuringStartupClosesConn(t *testing.T) { func TestBadConn(t *testing.T) { var err error - cn := conn{} + bad := &atomic.Value{} + bad.Store(false) + cn := conn{bad: bad} func() { defer cn.errRecover(&err) panic(io.EOF) @@ -703,11 +706,13 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.bad { + if !cn.getBad() { t.Fatalf("expected cn.bad") } - cn = conn{} + badd := &atomic.Value{} + badd.Store(false) + cn = conn{bad: badd} func() { defer cn.errRecover(&err) e := &Error{Severity: Efatal} @@ -716,7 +721,7 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.bad { + if !cn.getBad() { t.Fatalf("expected cn.bad") } } @@ -1640,10 +1645,6 @@ func TestRowsResultTag(t *testing.T) { { query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1", }, - // Verify that an no-results query doesn't set the tag. - { - query: "CREATE TEMP TABLE t (a int); SELECT 1 WHERE FALSE; DROP TABLE t;", - }, } // If this is the only test run, this will correct the connection string. @@ -1748,6 +1749,36 @@ func TestMultipleResult(t *testing.T) { } } +func TestMultipleEmptyResult(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("select 1 where false; select 2") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + t.Fatal("unexpected row") + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + var i int + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } + } + if rows.NextResultSet() { + t.Fatal("unexpected result set") + } +} + func TestCopyInStmtAffectedRows(t *testing.T) { db := openTestConn(t) defer db.Close() diff --git a/copy.go b/copy.go index 38d5bb693..bb3cbd7b9 100644 --- a/copy.go +++ b/copy.go @@ -176,13 +176,13 @@ func (ci *copyin) resploop() { func (ci *copyin) setBad() { ci.Lock() - ci.cn.bad = true + ci.cn.setBad() ci.Unlock() } func (ci *copyin) isBad() bool { ci.Lock() - b := ci.cn.bad + b := ci.cn.getBad() ci.Unlock() return b } @@ -213,10 +213,10 @@ func (ci *copyin) setResult(result driver.Result) { func (ci *copyin) getResult() driver.Result { ci.Lock() result := ci.Result + ci.Unlock() if result == nil { return driver.RowsAffected(0) } - ci.Unlock() return result } diff --git a/copy_test.go b/copy_test.go index 2b6d6bca8..852f1be5a 100644 --- a/copy_test.go +++ b/copy_test.go @@ -4,9 +4,11 @@ import ( "bytes" "database/sql" "database/sql/driver" + "fmt" "net" "strings" "testing" + "time" ) func TestCopyInStmt(t *testing.T) { @@ -130,7 +132,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) { _, err = txn.Exec(` CREATE OR REPLACE FUNCTION pg_temp.temptest() - RETURNS trigger AS + RETURNS trigger AS $BODY$ begin raise notice 'Hello world'; return new; @@ -143,7 +145,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) { _, err = txn.Exec(` CREATE TRIGGER temptest_trigger BEFORE INSERT - ON temp + ON temp FOR EACH ROW EXECUTE PROCEDURE pg_temp.temptest()`) if err != nil { @@ -406,10 +408,13 @@ func TestCopyRespLoopConnectionError(t *testing.T) { t.Fatal(err) } } - _, err = stmt.Exec() - if err == nil { - t.Fatalf("expected error") - } + retry(t, time.Second*5, func() error { + _, err = stmt.Exec() + if err == nil { + return fmt.Errorf("expected error") + } + return nil + }) switch pge := err.(type) { case *Error: if pge.Code.Name() != "admin_shutdown" { @@ -420,6 +425,8 @@ func TestCopyRespLoopConnectionError(t *testing.T) { default: if err == driver.ErrBadConn { // likely an EPIPE + } else if err == errCopyInClosed { + // ignore } else { t.Fatalf("unexpected error, got %+#v", err) } @@ -428,6 +435,24 @@ func TestCopyRespLoopConnectionError(t *testing.T) { _ = stmt.Close() } +// retry executes f in a backoff loop until it doesn't return an error. If this +// doesn't happen within duration, t.Fatal is called with the latest error. +func retry(t *testing.T, duration time.Duration, f func() error) { + start := time.Now() + next := time.Millisecond * 100 + for { + err := f() + if err == nil { + return + } + if time.Since(start) > duration { + t.Fatal(err) + } + time.Sleep(next) + next *= 2 + } +} + func BenchmarkCopyIn(b *testing.B) { db := openTestConn(b) defer db.Close() diff --git a/error.go b/error.go index 3d66ba7c5..c19c349f1 100644 --- a/error.go +++ b/error.go @@ -484,7 +484,7 @@ func (cn *conn) errRecover(err *error) { case nil: // Do nothing case runtime.Error: - cn.bad = true + cn.setBad() panic(v) case *Error: if v.Fatal() { @@ -493,8 +493,11 @@ func (cn *conn) errRecover(err *error) { *err = v } case *net.OpError: - cn.bad = true + cn.setBad() *err = v + case *safeRetryError: + cn.setBad() + *err = driver.ErrBadConn case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn @@ -503,13 +506,13 @@ func (cn *conn) errRecover(err *error) { } default: - cn.bad = true + cn.setBad() panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - cn.bad = true + cn.setBad() } } diff --git a/go18_test.go b/go18_test.go index 72cd71fe9..95c08cd79 100644 --- a/go18_test.go +++ b/go18_test.go @@ -3,6 +3,7 @@ package pq import ( "context" "database/sql" + "database/sql/driver" "runtime" "strings" "testing" @@ -142,7 +143,7 @@ func TestContextCancelQuery(t *testing.T) { cancel() if err != nil { t.Fatal(err) - } else if err := rows.Close(); err != nil { + } else if err := rows.Close(); err != nil && err != driver.ErrBadConn { t.Fatal(err) } }() @@ -176,13 +177,26 @@ func TestIssue617(t *testing.T) { } }() } - numGoroutineFinish := runtime.NumGoroutine() - // We use N/2 and not N because the GC and other actors may increase or - // decrease the number of goroutines. - if numGoroutineFinish-numGoroutineStart >= N/2 { - t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) + // Give time for goroutines to terminate + delayTime := time.Millisecond * 50 + waitTime := time.Second + iterations := int(waitTime / delayTime) + + var numGoroutineFinish int + for i := 0; i < iterations; i++ { + time.Sleep(delayTime) + + numGoroutineFinish = runtime.NumGoroutine() + + // We use N/2 and not N because the GC and other actors may increase or + // decrease the number of goroutines. + if numGoroutineFinish-numGoroutineStart < N/2 { + return + } } + + t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) } func TestContextCancelBegin(t *testing.T) { @@ -228,7 +242,7 @@ func TestContextCancelBegin(t *testing.T) { t.Fatal(err) } else if err := tx.Rollback(); err != nil && err.Error() != "pq: canceling statement due to user request" && - err != sql.ErrTxDone { + err != sql.ErrTxDone && err != driver.ErrBadConn { t.Fatal(err) } }() diff --git a/ssl.go b/ssl.go index d90208455..881c2219b 100644 --- a/ssl.go +++ b/ssl.go @@ -59,6 +59,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { return nil, err } + // This pseudo-parameter is not recognized by the PostgreSQL server, so let's delete it after use. + delete(o, "sslinline") + // Accept renegotiation requests initiated by the backend. // // Renegotiation was deprecated then removed from PostgreSQL 9.5, but @@ -83,6 +86,19 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // in the user's home directory. The configured files must exist and have // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) error { + sslinline := o["sslinline"] + if sslinline == "true" { + cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) + // Clear out these params, in case they were to be sent to the PostgreSQL server by mistake + o["sslcert"] = "" + o["sslkey"] = "" + if err != nil { + return err + } + tlsConf.Certificates = []tls.Certificate{cert} + return nil + } + // user.Current() might fail when cross-compiling. We have to ignore the // error and continue without home directory defaults, since we wouldn't // know from where to load them. @@ -137,9 +153,19 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) error { if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - return err + sslinline := o["sslinline"] + + var cert []byte + if sslinline == "true" { + // // Clear out this param, in case it were to be sent to the PostgreSQL server by mistake + o["sslrootcert"] = "" + cert = []byte(sslrootcert) + } else { + var err error + cert, err = ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } } if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { diff --git a/ssl_test.go b/ssl_test.go index 3eafbfd20..e00522e76 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -81,7 +81,10 @@ func TestSSLVerifyFull(t *testing.T) { } _, ok := err.(x509.UnknownAuthorityError) if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) + _, ok := err.(x509.HostnameError) + if !ok { + t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err) + } } rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") diff --git a/url.go b/url.go index f4d8a7c20..aec6e95be 100644 --- a/url.go +++ b/url.go @@ -40,10 +40,10 @@ func ParseURL(url string) (string, error) { } var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) accrue := func(k, v string) { if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") } } diff --git a/url_test.go b/url_test.go index 4ff0ce034..9b6345595 100644 --- a/url_test.go +++ b/url_test.go @@ -5,7 +5,7 @@ import ( ) func TestSimpleParseURL(t *testing.T) { - expected := "host=hostname.remote" + expected := "host='hostname.remote'" str, err := ParseURL("postgres://hostname.remote") if err != nil { t.Fatal(err) @@ -17,7 +17,7 @@ func TestSimpleParseURL(t *testing.T) { } func TestIPv6LoopbackParseURL(t *testing.T) { - expected := "host=::1 port=1234" + expected := "host='::1' port='1234'" str, err := ParseURL("postgres://[::1]:1234") if err != nil { t.Fatal(err) @@ -29,7 +29,7 @@ func TestIPv6LoopbackParseURL(t *testing.T) { } func TestFullParseURL(t *testing.T) { - expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` + expected := `dbname='database' host='hostname.remote' password='top secret' port='1234' user='username'` str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") if err != nil { t.Fatal(err)