From 5e47d5c99fafe3f9e23a62920e4b09574ee6e169 Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Tue, 11 Aug 2020 16:05:18 -0600 Subject: [PATCH 01/15] test go 1.15 --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 3498c53dc..108f99ec0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ language: go go: - 1.13.x - 1.14.x + - 1.15.x - master sudo: true From e8c1c95c16249bd63907445cbbc05ee74d57e5c9 Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Tue, 11 Aug 2020 17:30:55 -0600 Subject: [PATCH 02/15] retry in TestCopyRespLoopConnectionError to prevent flakes --- copy_test.go | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/copy_test.go b/copy_test.go index 2b6d6bca8..852f1be5a 100644 --- a/copy_test.go +++ b/copy_test.go @@ -4,9 +4,11 @@ import ( "bytes" "database/sql" "database/sql/driver" + "fmt" "net" "strings" "testing" + "time" ) func TestCopyInStmt(t *testing.T) { @@ -130,7 +132,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) { _, err = txn.Exec(` CREATE OR REPLACE FUNCTION pg_temp.temptest() - RETURNS trigger AS + RETURNS trigger AS $BODY$ begin raise notice 'Hello world'; return new; @@ -143,7 +145,7 @@ func TestCopyInRaiseStmtTrigger(t *testing.T) { _, err = txn.Exec(` CREATE TRIGGER temptest_trigger BEFORE INSERT - ON temp + ON temp FOR EACH ROW EXECUTE PROCEDURE pg_temp.temptest()`) if err != nil { @@ -406,10 +408,13 @@ func TestCopyRespLoopConnectionError(t *testing.T) { t.Fatal(err) } } - _, err = stmt.Exec() - if err == nil { - t.Fatalf("expected error") - } + retry(t, time.Second*5, func() error { + _, err = stmt.Exec() + if err == nil { + return fmt.Errorf("expected error") + } + return nil + }) switch pge := err.(type) { case *Error: if pge.Code.Name() != "admin_shutdown" { @@ -420,6 +425,8 @@ func TestCopyRespLoopConnectionError(t *testing.T) { default: if err == driver.ErrBadConn { // likely an EPIPE + } else if err == errCopyInClosed { + // ignore } else { t.Fatalf("unexpected error, got %+#v", err) } @@ -428,6 +435,24 @@ func TestCopyRespLoopConnectionError(t *testing.T) { _ = stmt.Close() } +// retry executes f in a backoff loop until it doesn't return an error. If this +// doesn't happen within duration, t.Fatal is called with the latest error. +func retry(t *testing.T, duration time.Duration, f func() error) { + start := time.Now() + next := time.Millisecond * 100 + for { + err := f() + if err == nil { + return + } + if time.Since(start) > duration { + t.Fatal(err) + } + time.Sleep(next) + next *= 2 + } +} + func BenchmarkCopyIn(b *testing.B) { db := openTestConn(b) defer db.Close() From 6d207a9c1857c14aa92e2d0e721a9f9772983ff6 Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Tue, 11 Aug 2020 17:39:33 -0600 Subject: [PATCH 03/15] fix ssl test for go1.15 --- .travis.yml | 1 + ssl_test.go | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 108f99ec0..68e89e88d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ env: - PQGOSSLTESTS=1 - PQSSLCERTTEST_PATH=$PWD/certs - PGHOST=127.0.0.1 + - GODEBUG=x509ignoreCN=0 matrix: - PGVERSION=10 - PGVERSION=9.6 diff --git a/ssl_test.go b/ssl_test.go index 3eafbfd20..e00522e76 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -81,7 +81,10 @@ func TestSSLVerifyFull(t *testing.T) { } _, ok := err.(x509.UnknownAuthorityError) if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) + _, ok := err.(x509.HostnameError) + if !ok { + t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err) + } } rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") From d97e962beda231cb266a1d97afd01c007d029470 Mon Sep 17 00:00:00 2001 From: David Hill Date: Sat, 29 Aug 2020 00:29:21 -0400 Subject: [PATCH 04/15] unlock mutex earlier --- copy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/copy.go b/copy.go index 38d5bb693..9d4f850c3 100644 --- a/copy.go +++ b/copy.go @@ -213,10 +213,10 @@ func (ci *copyin) setResult(result driver.Result) { func (ci *copyin) getResult() driver.Result { ci.Lock() result := ci.Result + ci.Unlock() if result == nil { return driver.RowsAffected(0) } - ci.Unlock() return result } From 4c8c217da681ae63bbd7713e3dc975113770fcac Mon Sep 17 00:00:00 2001 From: Roel van der Goot Date: Mon, 12 Oct 2020 18:02:43 -0600 Subject: [PATCH 05/15] Handle empty result sets --- conn.go | 6 +++++- conn_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index f313c1498..02e78f8a0 100644 --- a/conn.go +++ b/conn.go @@ -663,8 +663,12 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' && res.colNames == nil { + if t == 'C' { res.result, res.tag = cn.parseComplete(r.string()) + if res.colNames != nil { + res.tag = "" + return + } } res.done = true case 'Z': diff --git a/conn_test.go b/conn_test.go index 0d25c9554..c900f66e3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1748,6 +1748,36 @@ func TestMultipleResult(t *testing.T) { } } +func TestMultipleEmptyResult(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("select 1 where false; select 2") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + t.Fatal("unexpected row") + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + var i int + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } + } + if rows.NextResultSet() { + t.Fatal("unexpected result set") + } +} + func TestCopyInStmtAffectedRows(t *testing.T) { db := openTestConn(t) defer db.Close() From 999f18fba04bbf793b27b9ec11396264086718d0 Mon Sep 17 00:00:00 2001 From: Roel van der Goot Date: Wed, 14 Oct 2020 07:31:44 -0600 Subject: [PATCH 06/15] Erroneous test? --- conn.go | 1 - conn_test.go | 4 ---- 2 files changed, 5 deletions(-) diff --git a/conn.go b/conn.go index 02e78f8a0..d7a93d6a5 100644 --- a/conn.go +++ b/conn.go @@ -666,7 +666,6 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { if t == 'C' { res.result, res.tag = cn.parseComplete(r.string()) if res.colNames != nil { - res.tag = "" return } } diff --git a/conn_test.go b/conn_test.go index c900f66e3..3b8847055 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1640,10 +1640,6 @@ func TestRowsResultTag(t *testing.T) { { query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1", }, - // Verify that an no-results query doesn't set the tag. - { - query: "CREATE TEMP TABLE t (a int); SELECT 1 WHERE FALSE; DROP TABLE t;", - }, } // If this is the only test run, this will correct the connection string. From 730d555dd5aa0f0f5bd365036f148b5dfce03e79 Mon Sep 17 00:00:00 2001 From: Igor Wiedler Date: Thu, 22 Oct 2020 14:42:20 +0200 Subject: [PATCH 07/15] Correctly implement database/sql/driver.Driver for better wrappability --- conn.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index f313c1498..1b4fc02e5 100644 --- a/conn.go +++ b/conn.go @@ -38,13 +38,18 @@ var ( errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) +// Compile time validation that our types implement the expected interfaces +var ( + _ driver.Driver = Driver{} +) + // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. -func (d *Driver) Open(name string) (driver.Conn, error) { +func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } From 25eb21e556c442fe806c846b3936e6e7fa46f4ef Mon Sep 17 00:00:00 2001 From: Danny Hermes Date: Thu, 19 Nov 2020 12:38:16 -0600 Subject: [PATCH 08/15] Mark `net.Conn` failed writes as recoverable when 0 bytes were written. --- conn.go | 13 ++++++++++++- error.go | 3 +++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index f313c1498..0d96600ed 100644 --- a/conn.go +++ b/conn.go @@ -891,9 +891,20 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } +type safeRetryError struct { + Err error +} + +func (se *safeRetryError) Error() string { + return se.Err.Error() +} + func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) + n, err := cn.c.Write(m.wrap()) if err != nil { + if n == 0 { + err = &safeRetryError{Err: err} + } panic(err) } } diff --git a/error.go b/error.go index 3d66ba7c5..a227e0831 100644 --- a/error.go +++ b/error.go @@ -495,6 +495,9 @@ func (cn *conn) errRecover(err *error) { case *net.OpError: cn.bad = true *err = v + case *safeRetryError: + cn.bad = true + *err = driver.ErrBadConn case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn From f0b8e153ff52e7d43c4934c222c96f5addfafd4e Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Thu, 1 Oct 2020 13:00:06 -0600 Subject: [PATCH 09/15] Fix query cancellation collateralizing future queries using the same connection --- conn.go | 89 +++++++++++++++++++++++++++++++--------------------- conn_go18.go | 22 +++++++++++-- conn_test.go | 13 +++++--- copy.go | 4 +-- error.go | 8 ++--- go18_test.go | 28 ++++++++++++----- 6 files changed, 108 insertions(+), 56 deletions(-) diff --git a/conn.go b/conn.go index f313c1498..a2b7d48cd 100644 --- a/conn.go +++ b/conn.go @@ -18,6 +18,7 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" "unicode" @@ -136,7 +137,7 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. - bad bool + bad *atomic.Value // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for @@ -294,9 +295,12 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { o := c.opts + bad := &atomic.Value{} + bad.Store(false) cn = &conn{ opts: o, dialer: c.dialer, + bad: bad, } err = cn.handleDriverSettings(o) if err != nil { @@ -501,9 +505,22 @@ func (cn *conn) isInTransaction() bool { cn.txnStatus == txnStatusInFailedTransaction } +func (cn *conn) setBad() { + if cn.bad != nil { + cn.bad.Store(true) + } +} + +func (cn *conn) getBad() bool { + if cn.bad != nil { + return cn.bad.Load().(bool) + } + return false +} + func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { - cn.bad = true + cn.setBad() errorf("unexpected transaction status %v", cn.txnStatus) } } @@ -513,7 +530,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -524,11 +541,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { return nil, err } if commandTag != "BEGIN" { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil @@ -542,7 +559,7 @@ func (cn *conn) closeTxn() { func (cn *conn) Commit() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -564,12 +581,12 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } if commandTag != "COMMIT" { - cn.bad = true + cn.setBad() return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) @@ -578,7 +595,7 @@ func (cn *conn) Commit() (err error) { func (cn *conn) Rollback() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -590,7 +607,7 @@ func (cn *conn) rollback() (err error) { _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } @@ -630,7 +647,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err case 'T', 'D': // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -652,7 +669,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected message %q in simple query execution", t) } if res == nil { @@ -676,7 +693,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { err = parseError(r) case 'D': if res == nil { - cn.bad = true + cn.setBad() errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next @@ -691,7 +708,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -784,7 +801,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -823,7 +840,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } if cn.inCopy { @@ -857,7 +874,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -918,7 +935,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { - cn.bad = true + cn.setBad() errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ @@ -1288,7 +1305,7 @@ func (st *stmt) Close() (err error) { if st.closed { return nil } - if st.cn.bad { + if st.cn.getBad() { return driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1302,14 +1319,14 @@ func (st *stmt) Close() (err error) { t, _ := st.cn.recv1() if t != '3' { - st.cn.bad = true + st.cn.setBad() errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { - st.cn.bad = true + st.cn.setBad() errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) @@ -1318,7 +1335,7 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1331,7 +1348,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1418,7 +1435,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { - cn.bad = true + cn.setBad() errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] @@ -1430,7 +1447,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { - cn.bad = true + cn.setBad() errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag @@ -1497,7 +1514,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } conn := rs.cn - if conn.bad { + if conn.getBad() { return driver.ErrBadConn } defer conn.errRecover(&err) @@ -1522,7 +1539,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'D': n := rs.rb.int16() if err != nil { - conn.bad = true + conn.setBad() errorf("unexpected DataRow after error %s", err) } if n < len(dest) { @@ -1717,7 +1734,7 @@ func (cn *conn) readReadyForQuery() { cn.processReadyForQuery(r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message %q; expected ReadyForQuery", t) } } @@ -1737,7 +1754,7 @@ func (cn *conn) readParseResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Parse response %q", t) } } @@ -1762,7 +1779,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe statement response %q", t) } } @@ -1780,7 +1797,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe response %q", t) } panic("not reached") @@ -1796,7 +1813,7 @@ func (cn *conn) readBindResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Bind response %q", t) } } @@ -1823,7 +1840,7 @@ func (cn *conn) postExecuteWorkaround() { cn.saveMessage(t, r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message during extended query execution: %q", t) } } @@ -1836,7 +1853,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co switch t { case 'C': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) @@ -1850,7 +1867,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co err = parseError(r) case 'T', 'D', 'I': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected %q after error %s", t, err) } if t == 'I' { @@ -1858,7 +1875,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown %s response: %q", protocolState, t) } } diff --git a/conn_go18.go b/conn_go18.go index 09e2ea464..8cab67c9d 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "sync/atomic" "time" ) @@ -89,10 +90,21 @@ func (cn *conn) Ping(ctx context.Context) error { func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { - finished := make(chan struct{}) + finished := make(chan struct{}, 1) go func() { select { case <-done: + select { + case finished <- struct{}{}: + default: + // We raced with the finish func, let the next query handle this with the + // context. + return + } + + // Set the connection state to bad so it does not get reused. + cn.setBad() + // At this point the function level context is canceled, // so it must not be used for the additional network // request to cancel the query. @@ -101,13 +113,14 @@ func (cn *conn) watchCancel(ctx context.Context) func() { defer cancel() _ = cn.cancel(ctxCancel) - finished <- struct{}{} case <-finished: } }() return func() { select { case <-finished: + cn.setBad() + cn.Close() case finished <- struct{}{}: } } @@ -123,8 +136,11 @@ func (cn *conn) cancel(ctx context.Context) error { defer c.Close() { + bad := &atomic.Value{} + bad.Store(false) can := conn{ - c: c, + c: c, + bad: bad, } err = can.ssl(cn.opts) if err != nil { diff --git a/conn_test.go b/conn_test.go index 0d25c9554..f50278525 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,6 +10,7 @@ import ( "os" "reflect" "strings" + "sync/atomic" "testing" "time" ) @@ -695,7 +696,9 @@ func TestErrorDuringStartupClosesConn(t *testing.T) { func TestBadConn(t *testing.T) { var err error - cn := conn{} + bad := &atomic.Value{} + bad.Store(false) + cn := conn{bad: bad} func() { defer cn.errRecover(&err) panic(io.EOF) @@ -703,11 +706,13 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.bad { + if !cn.getBad() { t.Fatalf("expected cn.bad") } - cn = conn{} + badd := &atomic.Value{} + badd.Store(false) + cn = conn{bad: badd} func() { defer cn.errRecover(&err) e := &Error{Severity: Efatal} @@ -716,7 +721,7 @@ func TestBadConn(t *testing.T) { if err != driver.ErrBadConn { t.Fatalf("expected driver.ErrBadConn, got: %#v", err) } - if !cn.bad { + if !cn.getBad() { t.Fatalf("expected cn.bad") } } diff --git a/copy.go b/copy.go index 9d4f850c3..bb3cbd7b9 100644 --- a/copy.go +++ b/copy.go @@ -176,13 +176,13 @@ func (ci *copyin) resploop() { func (ci *copyin) setBad() { ci.Lock() - ci.cn.bad = true + ci.cn.setBad() ci.Unlock() } func (ci *copyin) isBad() bool { ci.Lock() - b := ci.cn.bad + b := ci.cn.getBad() ci.Unlock() return b } diff --git a/error.go b/error.go index 3d66ba7c5..cb8ef5dc7 100644 --- a/error.go +++ b/error.go @@ -484,7 +484,7 @@ func (cn *conn) errRecover(err *error) { case nil: // Do nothing case runtime.Error: - cn.bad = true + cn.setBad() panic(v) case *Error: if v.Fatal() { @@ -493,7 +493,7 @@ func (cn *conn) errRecover(err *error) { *err = v } case *net.OpError: - cn.bad = true + cn.setBad() *err = v case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { @@ -503,13 +503,13 @@ func (cn *conn) errRecover(err *error) { } default: - cn.bad = true + cn.setBad() panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - cn.bad = true + cn.setBad() } } diff --git a/go18_test.go b/go18_test.go index 72cd71fe9..95c08cd79 100644 --- a/go18_test.go +++ b/go18_test.go @@ -3,6 +3,7 @@ package pq import ( "context" "database/sql" + "database/sql/driver" "runtime" "strings" "testing" @@ -142,7 +143,7 @@ func TestContextCancelQuery(t *testing.T) { cancel() if err != nil { t.Fatal(err) - } else if err := rows.Close(); err != nil { + } else if err := rows.Close(); err != nil && err != driver.ErrBadConn { t.Fatal(err) } }() @@ -176,13 +177,26 @@ func TestIssue617(t *testing.T) { } }() } - numGoroutineFinish := runtime.NumGoroutine() - // We use N/2 and not N because the GC and other actors may increase or - // decrease the number of goroutines. - if numGoroutineFinish-numGoroutineStart >= N/2 { - t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) + // Give time for goroutines to terminate + delayTime := time.Millisecond * 50 + waitTime := time.Second + iterations := int(waitTime / delayTime) + + var numGoroutineFinish int + for i := 0; i < iterations; i++ { + time.Sleep(delayTime) + + numGoroutineFinish = runtime.NumGoroutine() + + // We use N/2 and not N because the GC and other actors may increase or + // decrease the number of goroutines. + if numGoroutineFinish-numGoroutineStart < N/2 { + return + } } + + t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) } func TestContextCancelBegin(t *testing.T) { @@ -228,7 +242,7 @@ func TestContextCancelBegin(t *testing.T) { t.Fatal(err) } else if err := tx.Rollback(); err != nil && err.Error() != "pq: canceling statement due to user request" && - err != sql.ErrTxDone { + err != sql.ErrTxDone && err != driver.ErrBadConn { t.Fatal(err) } }() From 01786cbdd08acd6334ad9746fbb633d60e7b3ab5 Mon Sep 17 00:00:00 2001 From: Matt Jibson Date: Mon, 23 Nov 2020 13:35:56 -0700 Subject: [PATCH 10/15] Remove go 1.13 tests Cockroach is now in 1.15 so we can bump the version. --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 68e89e88d..f378207f2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: go go: - - 1.13.x - 1.14.x - 1.15.x - master From f373854cec99bd835c8fd49045dbe7b44479b758 Mon Sep 17 00:00:00 2001 From: splicefracture Date: Tue, 24 Nov 2020 11:52:09 -0800 Subject: [PATCH 11/15] Update error.go Missed one. --- error.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/error.go b/error.go index ee1ca6df1..c19c349f1 100644 --- a/error.go +++ b/error.go @@ -496,7 +496,7 @@ func (cn *conn) errRecover(err *error) { cn.setBad() *err = v case *safeRetryError: - cn.bad = true + cn.setBad() *err = driver.ErrBadConn case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { From d726827bf762cdd26468d69f9e3f7b4d2ee4df90 Mon Sep 17 00:00:00 2001 From: Shivam Rathore Date: Thu, 20 Jun 2019 10:07:06 +0530 Subject: [PATCH 12/15] feature: inbuilt support for scanner and valuer in pq.Array for int32/float32/[]byte slices --- .gitignore | 2 + array.go | 139 ++++++++++++++++++++ array_test.go | 341 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 482 insertions(+) diff --git a/.gitignore b/.gitignore index 0f1d00e11..3243952a4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ *.test *~ *.swp +.idea +.vscode \ No newline at end of file diff --git a/array.go b/array.go index e4933e227..405da2368 100644 --- a/array.go +++ b/array.go @@ -35,19 +35,31 @@ func Array(a interface{}) interface { return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) case []string: return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) case *[]string: return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) } return GenericArray{a} @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + var x float64 + if x, err = strconv.ParseFloat(string(v), 32); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) { return "{}", nil } +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + var x int + if x, err = strconv.Atoi(string(v)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string diff --git a/array_test.go b/array_test.go index f724bcd88..5ca9f7a55 100644 --- a/array_test.go +++ b/array_test.go @@ -104,16 +104,33 @@ func TestArrayScanner(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", s) } + s = Array(&[]float32{}) + if _, ok := s.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", s) + } + + s = Array(&[]int32{}) + if _, ok := s.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", s) + } + s = Array(&[]string{}) if _, ok := s.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", s) } + s = Array(&[][]byte{}) + if _, ok := s.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", s) + } + for _, tt := range []interface{}{ &[]sql.Scanner{}, &[][]bool{}, &[][]float64{}, &[][]int64{}, + &[][]float32{}, + &[][]int32{}, &[][]string{}, } { s = Array(tt) @@ -139,17 +156,34 @@ func TestArrayValuer(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", v) } + v = Array([]float32{}) + if _, ok := v.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", v) + } + + v = Array([]int32{}) + if _, ok := v.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", v) + } + v = Array([]string{}) if _, ok := v.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", v) } + v = Array([][]byte{}) + if _, ok := v.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", v) + } + for _, tt := range []interface{}{ nil, []driver.Value{}, [][]bool{}, [][]float64{}, [][]int64{}, + [][]float32{}, + [][]int32{}, [][]string{}, } { v = Array(tt) @@ -773,6 +807,313 @@ func BenchmarkInt64ArrayValue(b *testing.B) { } } +func TestFloat32ArrayScanUnsupported(t *testing.T) { + var arr Float32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestFloat32ArrayScanEmpty(t *testing.T) { + var arr Float32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestFloat32ArrayScanNil(t *testing.T) { + arr := Float32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Float32ArrayStringTests = []struct { + str string + arr Float32Array +}{ + {`{}`, Float32Array{}}, + {`{1.2}`, Float32Array{1.2}}, + {`{3.456,7.89}`, Float32Array{3.456, 7.89}}, + {`{3,1,2}`, Float32Array{3, 1, 2}}, +} + +func TestFloat32ArrayScanBytes(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + bytes := []byte(tt.str) + arr := Float32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat32ArrayScanBytes(b *testing.B) { + var a Float32Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float32Array{} + a.Scan(x) + } +} + +func TestFloat32ArrayScanString(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat32ArrayValue(t *testing.T) { + result, err := Float32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float32Array([]float32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float32Array([]float32{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Float32() + } + a := Float32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt32ArrayScanUnsupported(t *testing.T) { + var arr Int32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestInt32ArrayScanEmpty(t *testing.T) { + var arr Int32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestInt32ArrayScanNil(t *testing.T) { + arr := Int32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Int32ArrayStringTests = []struct { + str string + arr Int32Array +}{ + {`{}`, Int32Array{}}, + {`{12}`, Int32Array{12}}, + {`{345,678}`, Int32Array{345, 678}}, +} + +func TestInt32ArrayScanBytes(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + bytes := []byte(tt.str) + arr := Int32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt32ArrayScanBytes(b *testing.B) { + var a Int32Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int32Array{} + a.Scan(x) + } +} + +func TestInt32ArrayScanString(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt32ArrayValue(t *testing.T) { + result, err := Int32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int32Array([]int32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int32Array([]int32{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int31() + } + a := Int32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + func TestStringArrayScanUnsupported(t *testing.T) { var arr StringArray err := arr.Scan(true) From b9bb726ebf154627a21b50f9ffa4b28c6ed3f4d8 Mon Sep 17 00:00:00 2001 From: Eirik Date: Sun, 16 Dec 2018 22:34:36 +0100 Subject: [PATCH 13/15] Support inline SSL certificates Presently, pq only supports SSL connections by loading PEM certificates from files on disk. There are some situations (for example integration with HashiCorp Vault) where it's not so feasible to load certificates from a file system, but better to store them in-memory. This patch lets you set ?sslinline=true in the connection string, which changes the behavior of the paramters sslrootcert, sslcert and sslkey, so they contain the contents of the certificates directly, instead of file names pointing to the certificates on disk. --- ssl.go | 33 ++++++++++++++++++++++++++++++--- url.go | 4 ++-- url_test.go | 6 +++--- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/ssl.go b/ssl.go index d90208455..5db5af83a 100644 --- a/ssl.go +++ b/ssl.go @@ -59,6 +59,9 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { return nil, err } + // This pseudo-parameter is not recognized by the PostgreSQL server, so let's delete it after use. + delete(o, "sslinline") + // Accept renegotiation requests initiated by the backend. // // Renegotiation was deprecated then removed from PostgreSQL 9.5, but @@ -83,6 +86,20 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // in the user's home directory. The configured files must exist and have // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) error { + sslinline := o["sslinline"] + if "true" == sslinline { + cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) + // Clear out these params, in case they were to be sent to the PostgreSQL server by mistake + o["sslcert"] = "" + o["sslkey"] = "" + if err != nil { + return err + } + tlsConf.Certificates = []tls.Certificate{cert} + return nil + } + + // user.Current() might fail when cross-compiling. We have to ignore the // error and continue without home directory defaults, since we wouldn't // know from where to load them. @@ -137,9 +154,19 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) error { if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - return err + sslinline := o["sslinline"] + + var cert []byte + if "true" == sslinline { + // // Clear out this param, in case it were to be sent to the PostgreSQL server by mistake + o["sslrootcert"] = "" + cert = []byte(sslrootcert) + } else { + var err error + cert, err = ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } } if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { diff --git a/url.go b/url.go index f4d8a7c20..aec6e95be 100644 --- a/url.go +++ b/url.go @@ -40,10 +40,10 @@ func ParseURL(url string) (string, error) { } var kvs []string - escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) accrue := func(k, v string) { if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) + kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") } } diff --git a/url_test.go b/url_test.go index 4ff0ce034..9b6345595 100644 --- a/url_test.go +++ b/url_test.go @@ -5,7 +5,7 @@ import ( ) func TestSimpleParseURL(t *testing.T) { - expected := "host=hostname.remote" + expected := "host='hostname.remote'" str, err := ParseURL("postgres://hostname.remote") if err != nil { t.Fatal(err) @@ -17,7 +17,7 @@ func TestSimpleParseURL(t *testing.T) { } func TestIPv6LoopbackParseURL(t *testing.T) { - expected := "host=::1 port=1234" + expected := "host='::1' port='1234'" str, err := ParseURL("postgres://[::1]:1234") if err != nil { t.Fatal(err) @@ -29,7 +29,7 @@ func TestIPv6LoopbackParseURL(t *testing.T) { } func TestFullParseURL(t *testing.T) { - expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` + expected := `dbname='database' host='hostname.remote' password='top secret' port='1234' user='username'` str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") if err != nil { t.Fatal(err) From b7c85eeec276e594e29cd6ceec5e007965ff0d51 Mon Sep 17 00:00:00 2001 From: Eirik Date: Sun, 16 Dec 2018 23:22:54 +0100 Subject: [PATCH 14/15] remove empty line --- ssl.go | 1 - 1 file changed, 1 deletion(-) diff --git a/ssl.go b/ssl.go index 5db5af83a..3c1f1fc37 100644 --- a/ssl.go +++ b/ssl.go @@ -99,7 +99,6 @@ func sslClientCertificates(tlsConf *tls.Config, o values) error { return nil } - // user.Current() might fail when cross-compiling. We have to ignore the // error and continue without home directory defaults, since we wouldn't // know from where to load them. From 1467baf0954574fc17008b8c564938ffe3bfaa95 Mon Sep 17 00:00:00 2001 From: Eirik Sletteberg Date: Fri, 12 Feb 2021 12:47:39 +0100 Subject: [PATCH 15/15] Fix code style (no yoda conditions) --- ssl.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ssl.go b/ssl.go index 3c1f1fc37..881c2219b 100644 --- a/ssl.go +++ b/ssl.go @@ -87,7 +87,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) error { sslinline := o["sslinline"] - if "true" == sslinline { + if sslinline == "true" { cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) // Clear out these params, in case they were to be sent to the PostgreSQL server by mistake o["sslcert"] = "" @@ -156,7 +156,7 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) error { sslinline := o["sslinline"] var cert []byte - if "true" == sslinline { + if sslinline == "true" { // // Clear out this param, in case it were to be sent to the PostgreSQL server by mistake o["sslrootcert"] = "" cert = []byte(sslrootcert)