Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 53 additions & 36 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"time"
"unicode"

Expand Down Expand Up @@ -136,7 +137,7 @@ type conn struct {

// If true, this connection is bad and all public-facing functions should
// return ErrBadConn.
bad bool
bad *atomic.Value

// If set, this connection should never use the binary format when
// receiving query results from prepared statements. Only provided for
Expand Down Expand Up @@ -294,9 +295,12 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) {

o := c.opts

bad := &atomic.Value{}
bad.Store(false)
cn = &conn{
opts: o,
dialer: c.dialer,
bad: bad,
}
err = cn.handleDriverSettings(o)
if err != nil {
Expand Down Expand Up @@ -501,9 +505,22 @@ func (cn *conn) isInTransaction() bool {
cn.txnStatus == txnStatusInFailedTransaction
}

func (cn *conn) setBad() {
if cn.bad != nil {
cn.bad.Store(true)
}
}

func (cn *conn) getBad() bool {
if cn.bad != nil {
return cn.bad.Load().(bool)
}
return false
}

func (cn *conn) checkIsInTransaction(intxn bool) {
if cn.isInTransaction() != intxn {
cn.bad = true
cn.setBad()
errorf("unexpected transaction status %v", cn.txnStatus)
}
}
Expand All @@ -513,7 +530,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) {
}

func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand All @@ -524,11 +541,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
return nil, err
}
if commandTag != "BEGIN" {
cn.bad = true
cn.setBad()
return nil, fmt.Errorf("unexpected command tag %s", commandTag)
}
if cn.txnStatus != txnStatusIdleInTransaction {
cn.bad = true
cn.setBad()
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
}
return cn, nil
Expand All @@ -542,7 +559,7 @@ func (cn *conn) closeTxn() {

func (cn *conn) Commit() (err error) {
defer cn.closeTxn()
if cn.bad {
if cn.getBad() {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand All @@ -564,12 +581,12 @@ func (cn *conn) Commit() (err error) {
_, commandTag, err := cn.simpleExec("COMMIT")
if err != nil {
if cn.isInTransaction() {
cn.bad = true
cn.setBad()
}
return err
}
if commandTag != "COMMIT" {
cn.bad = true
cn.setBad()
return fmt.Errorf("unexpected command tag %s", commandTag)
}
cn.checkIsInTransaction(false)
Expand All @@ -578,7 +595,7 @@ func (cn *conn) Commit() (err error) {

func (cn *conn) Rollback() (err error) {
defer cn.closeTxn()
if cn.bad {
if cn.getBad() {
return driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand All @@ -590,7 +607,7 @@ func (cn *conn) rollback() (err error) {
_, commandTag, err := cn.simpleExec("ROLLBACK")
if err != nil {
if cn.isInTransaction() {
cn.bad = true
cn.setBad()
}
return err
}
Expand Down Expand Up @@ -630,7 +647,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
case 'T', 'D':
// ignore any results
default:
cn.bad = true
cn.setBad()
errorf("unknown response for simple query: %q", t)
}
}
Expand All @@ -652,7 +669,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if err != nil {
cn.bad = true
cn.setBad()
errorf("unexpected message %q in simple query execution", t)
}
if res == nil {
Expand All @@ -676,7 +693,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
err = parseError(r)
case 'D':
if res == nil {
cn.bad = true
cn.setBad()
errorf("unexpected DataRow in simple query execution")
}
// the query didn't fail; kick off to Next
Expand All @@ -691,7 +708,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
default:
cn.bad = true
cn.setBad()
errorf("unknown response for simple query: %q", t)
}
}
Expand Down Expand Up @@ -784,7 +801,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
}

func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand Down Expand Up @@ -823,7 +840,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
}

func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
if cn.inCopy {
Expand Down Expand Up @@ -857,7 +874,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {

// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
if cn.bad {
if cn.getBad() {
return nil, driver.ErrBadConn
}
defer cn.errRecover(&err)
Expand Down Expand Up @@ -918,7 +935,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) {
// the message yourself.
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
if cn.saveMessageType != 0 {
cn.bad = true
cn.setBad()
errorf("unexpected saveMessageType %d", cn.saveMessageType)
}
cn.saveMessageType = typ
Expand Down Expand Up @@ -1288,7 +1305,7 @@ func (st *stmt) Close() (err error) {
if st.closed {
return nil
}
if st.cn.bad {
if st.cn.getBad() {
return driver.ErrBadConn
}
defer st.cn.errRecover(&err)
Expand All @@ -1302,14 +1319,14 @@ func (st *stmt) Close() (err error) {

t, _ := st.cn.recv1()
if t != '3' {
st.cn.bad = true
st.cn.setBad()
errorf("unexpected close response: %q", t)
}
st.closed = true

t, r := st.cn.recv1()
if t != 'Z' {
st.cn.bad = true
st.cn.setBad()
errorf("expected ready for query, but got: %q", t)
}
st.cn.processReadyForQuery(r)
Expand All @@ -1318,7 +1335,7 @@ func (st *stmt) Close() (err error) {
}

func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
if st.cn.bad {
if st.cn.getBad() {
return nil, driver.ErrBadConn
}
defer st.cn.errRecover(&err)
Expand All @@ -1331,7 +1348,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
}

func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
if st.cn.bad {
if st.cn.getBad() {
return nil, driver.ErrBadConn
}
defer st.cn.errRecover(&err)
Expand Down Expand Up @@ -1418,7 +1435,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
parts := strings.Split(commandTag, " ")
if len(parts) != 3 {
cn.bad = true
cn.setBad()
errorf("unexpected INSERT command tag %s", commandTag)
}
affectedRows = &parts[len(parts)-1]
Expand All @@ -1430,7 +1447,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
}
n, err := strconv.ParseInt(*affectedRows, 10, 64)
if err != nil {
cn.bad = true
cn.setBad()
errorf("could not parse commandTag: %s", err)
}
return driver.RowsAffected(n), commandTag
Expand Down Expand Up @@ -1497,7 +1514,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}

conn := rs.cn
if conn.bad {
if conn.getBad() {
return driver.ErrBadConn
}
defer conn.errRecover(&err)
Expand All @@ -1522,7 +1539,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
case 'D':
n := rs.rb.int16()
if err != nil {
conn.bad = true
conn.setBad()
errorf("unexpected DataRow after error %s", err)
}
if n < len(dest) {
Expand Down Expand Up @@ -1717,7 +1734,7 @@ func (cn *conn) readReadyForQuery() {
cn.processReadyForQuery(r)
return
default:
cn.bad = true
cn.setBad()
errorf("unexpected message %q; expected ReadyForQuery", t)
}
}
Expand All @@ -1737,7 +1754,7 @@ func (cn *conn) readParseResponse() {
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Parse response %q", t)
}
}
Expand All @@ -1762,7 +1779,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Describe statement response %q", t)
}
}
Expand All @@ -1780,7 +1797,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader {
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Describe response %q", t)
}
panic("not reached")
Expand All @@ -1796,7 +1813,7 @@ func (cn *conn) readBindResponse() {
cn.readReadyForQuery()
panic(err)
default:
cn.bad = true
cn.setBad()
errorf("unexpected Bind response %q", t)
}
}
Expand All @@ -1823,7 +1840,7 @@ func (cn *conn) postExecuteWorkaround() {
cn.saveMessage(t, r)
return
default:
cn.bad = true
cn.setBad()
errorf("unexpected message during extended query execution: %q", t)
}
}
Expand All @@ -1836,7 +1853,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
switch t {
case 'C':
if err != nil {
cn.bad = true
cn.setBad()
errorf("unexpected CommandComplete after error %s", err)
}
res, commandTag = cn.parseComplete(r.string())
Expand All @@ -1850,15 +1867,15 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
err = parseError(r)
case 'T', 'D', 'I':
if err != nil {
cn.bad = true
cn.setBad()
errorf("unexpected %q after error %s", t, err)
}
if t == 'I' {
res = emptyRows
}
// ignore any results
default:
cn.bad = true
cn.setBad()
errorf("unknown %s response: %q", protocolState, t)
}
}
Expand Down
Loading