Skip to content

Commit aecc811

Browse files
authored
Merge pull request #1000 from quillchat/master
Fix query cancellation collateralizing future queries using the same connection
2 parents f1ff737 + f0b8e15 commit aecc811

File tree

6 files changed

+108
-56
lines changed

6 files changed

+108
-56
lines changed

conn.go

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"path/filepath"
1919
"strconv"
2020
"strings"
21+
"sync/atomic"
2122
"time"
2223
"unicode"
2324

@@ -141,7 +142,7 @@ type conn struct {
141142

142143
// If true, this connection is bad and all public-facing functions should
143144
// return ErrBadConn.
144-
bad bool
145+
bad *atomic.Value
145146

146147
// If set, this connection should never use the binary format when
147148
// receiving query results from prepared statements. Only provided for
@@ -299,9 +300,12 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
299300

300301
o := c.opts
301302

303+
bad := &atomic.Value{}
304+
bad.Store(false)
302305
cn = &conn{
303306
opts: o,
304307
dialer: c.dialer,
308+
bad: bad,
305309
}
306310
err = cn.handleDriverSettings(o)
307311
if err != nil {
@@ -506,9 +510,22 @@ func (cn *conn) isInTransaction() bool {
506510
cn.txnStatus == txnStatusInFailedTransaction
507511
}
508512

513+
func (cn *conn) setBad() {
514+
if cn.bad != nil {
515+
cn.bad.Store(true)
516+
}
517+
}
518+
519+
func (cn *conn) getBad() bool {
520+
if cn.bad != nil {
521+
return cn.bad.Load().(bool)
522+
}
523+
return false
524+
}
525+
509526
func (cn *conn) checkIsInTransaction(intxn bool) {
510527
if cn.isInTransaction() != intxn {
511-
cn.bad = true
528+
cn.setBad()
512529
errorf("unexpected transaction status %v", cn.txnStatus)
513530
}
514531
}
@@ -518,7 +535,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) {
518535
}
519536

520537
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
521-
if cn.bad {
538+
if cn.getBad() {
522539
return nil, driver.ErrBadConn
523540
}
524541
defer cn.errRecover(&err)
@@ -529,11 +546,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
529546
return nil, err
530547
}
531548
if commandTag != "BEGIN" {
532-
cn.bad = true
549+
cn.setBad()
533550
return nil, fmt.Errorf("unexpected command tag %s", commandTag)
534551
}
535552
if cn.txnStatus != txnStatusIdleInTransaction {
536-
cn.bad = true
553+
cn.setBad()
537554
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
538555
}
539556
return cn, nil
@@ -547,7 +564,7 @@ func (cn *conn) closeTxn() {
547564

548565
func (cn *conn) Commit() (err error) {
549566
defer cn.closeTxn()
550-
if cn.bad {
567+
if cn.getBad() {
551568
return driver.ErrBadConn
552569
}
553570
defer cn.errRecover(&err)
@@ -569,12 +586,12 @@ func (cn *conn) Commit() (err error) {
569586
_, commandTag, err := cn.simpleExec("COMMIT")
570587
if err != nil {
571588
if cn.isInTransaction() {
572-
cn.bad = true
589+
cn.setBad()
573590
}
574591
return err
575592
}
576593
if commandTag != "COMMIT" {
577-
cn.bad = true
594+
cn.setBad()
578595
return fmt.Errorf("unexpected command tag %s", commandTag)
579596
}
580597
cn.checkIsInTransaction(false)
@@ -583,7 +600,7 @@ func (cn *conn) Commit() (err error) {
583600

584601
func (cn *conn) Rollback() (err error) {
585602
defer cn.closeTxn()
586-
if cn.bad {
603+
if cn.getBad() {
587604
return driver.ErrBadConn
588605
}
589606
defer cn.errRecover(&err)
@@ -595,7 +612,7 @@ func (cn *conn) rollback() (err error) {
595612
_, commandTag, err := cn.simpleExec("ROLLBACK")
596613
if err != nil {
597614
if cn.isInTransaction() {
598-
cn.bad = true
615+
cn.setBad()
599616
}
600617
return err
601618
}
@@ -635,7 +652,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
635652
case 'T', 'D':
636653
// ignore any results
637654
default:
638-
cn.bad = true
655+
cn.setBad()
639656
errorf("unknown response for simple query: %q", t)
640657
}
641658
}
@@ -657,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
657674
// the user can close, though, to avoid connections from being
658675
// leaked. A "rows" with done=true works fine for that purpose.
659676
if err != nil {
660-
cn.bad = true
677+
cn.setBad()
661678
errorf("unexpected message %q in simple query execution", t)
662679
}
663680
if res == nil {
@@ -684,7 +701,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
684701
err = parseError(r)
685702
case 'D':
686703
if res == nil {
687-
cn.bad = true
704+
cn.setBad()
688705
errorf("unexpected DataRow in simple query execution")
689706
}
690707
// the query didn't fail; kick off to Next
@@ -699,7 +716,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
699716
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
700717
// until the first DataRow has been received.
701718
default:
702-
cn.bad = true
719+
cn.setBad()
703720
errorf("unknown response for simple query: %q", t)
704721
}
705722
}
@@ -792,7 +809,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
792809
}
793810

794811
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
795-
if cn.bad {
812+
if cn.getBad() {
796813
return nil, driver.ErrBadConn
797814
}
798815
defer cn.errRecover(&err)
@@ -831,7 +848,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
831848
}
832849

833850
func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
834-
if cn.bad {
851+
if cn.getBad() {
835852
return nil, driver.ErrBadConn
836853
}
837854
if cn.inCopy {
@@ -865,7 +882,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
865882

866883
// Implement the optional "Execer" interface for one-shot queries
867884
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
868-
if cn.bad {
885+
if cn.getBad() {
869886
return nil, driver.ErrBadConn
870887
}
871888
defer cn.errRecover(&err)
@@ -937,7 +954,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) {
937954
// the message yourself.
938955
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
939956
if cn.saveMessageType != 0 {
940-
cn.bad = true
957+
cn.setBad()
941958
errorf("unexpected saveMessageType %d", cn.saveMessageType)
942959
}
943960
cn.saveMessageType = typ
@@ -1307,7 +1324,7 @@ func (st *stmt) Close() (err error) {
13071324
if st.closed {
13081325
return nil
13091326
}
1310-
if st.cn.bad {
1327+
if st.cn.getBad() {
13111328
return driver.ErrBadConn
13121329
}
13131330
defer st.cn.errRecover(&err)
@@ -1321,14 +1338,14 @@ func (st *stmt) Close() (err error) {
13211338

13221339
t, _ := st.cn.recv1()
13231340
if t != '3' {
1324-
st.cn.bad = true
1341+
st.cn.setBad()
13251342
errorf("unexpected close response: %q", t)
13261343
}
13271344
st.closed = true
13281345

13291346
t, r := st.cn.recv1()
13301347
if t != 'Z' {
1331-
st.cn.bad = true
1348+
st.cn.setBad()
13321349
errorf("expected ready for query, but got: %q", t)
13331350
}
13341351
st.cn.processReadyForQuery(r)
@@ -1337,7 +1354,7 @@ func (st *stmt) Close() (err error) {
13371354
}
13381355

13391356
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1340-
if st.cn.bad {
1357+
if st.cn.getBad() {
13411358
return nil, driver.ErrBadConn
13421359
}
13431360
defer st.cn.errRecover(&err)
@@ -1350,7 +1367,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
13501367
}
13511368

13521369
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
1353-
if st.cn.bad {
1370+
if st.cn.getBad() {
13541371
return nil, driver.ErrBadConn
13551372
}
13561373
defer st.cn.errRecover(&err)
@@ -1437,7 +1454,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
14371454
if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
14381455
parts := strings.Split(commandTag, " ")
14391456
if len(parts) != 3 {
1440-
cn.bad = true
1457+
cn.setBad()
14411458
errorf("unexpected INSERT command tag %s", commandTag)
14421459
}
14431460
affectedRows = &parts[len(parts)-1]
@@ -1449,7 +1466,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
14491466
}
14501467
n, err := strconv.ParseInt(*affectedRows, 10, 64)
14511468
if err != nil {
1452-
cn.bad = true
1469+
cn.setBad()
14531470
errorf("could not parse commandTag: %s", err)
14541471
}
14551472
return driver.RowsAffected(n), commandTag
@@ -1516,7 +1533,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
15161533
}
15171534

15181535
conn := rs.cn
1519-
if conn.bad {
1536+
if conn.getBad() {
15201537
return driver.ErrBadConn
15211538
}
15221539
defer conn.errRecover(&err)
@@ -1541,7 +1558,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
15411558
case 'D':
15421559
n := rs.rb.int16()
15431560
if err != nil {
1544-
conn.bad = true
1561+
conn.setBad()
15451562
errorf("unexpected DataRow after error %s", err)
15461563
}
15471564
if n < len(dest) {
@@ -1736,7 +1753,7 @@ func (cn *conn) readReadyForQuery() {
17361753
cn.processReadyForQuery(r)
17371754
return
17381755
default:
1739-
cn.bad = true
1756+
cn.setBad()
17401757
errorf("unexpected message %q; expected ReadyForQuery", t)
17411758
}
17421759
}
@@ -1756,7 +1773,7 @@ func (cn *conn) readParseResponse() {
17561773
cn.readReadyForQuery()
17571774
panic(err)
17581775
default:
1759-
cn.bad = true
1776+
cn.setBad()
17601777
errorf("unexpected Parse response %q", t)
17611778
}
17621779
}
@@ -1781,7 +1798,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
17811798
cn.readReadyForQuery()
17821799
panic(err)
17831800
default:
1784-
cn.bad = true
1801+
cn.setBad()
17851802
errorf("unexpected Describe statement response %q", t)
17861803
}
17871804
}
@@ -1799,7 +1816,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader {
17991816
cn.readReadyForQuery()
18001817
panic(err)
18011818
default:
1802-
cn.bad = true
1819+
cn.setBad()
18031820
errorf("unexpected Describe response %q", t)
18041821
}
18051822
panic("not reached")
@@ -1815,7 +1832,7 @@ func (cn *conn) readBindResponse() {
18151832
cn.readReadyForQuery()
18161833
panic(err)
18171834
default:
1818-
cn.bad = true
1835+
cn.setBad()
18191836
errorf("unexpected Bind response %q", t)
18201837
}
18211838
}
@@ -1842,7 +1859,7 @@ func (cn *conn) postExecuteWorkaround() {
18421859
cn.saveMessage(t, r)
18431860
return
18441861
default:
1845-
cn.bad = true
1862+
cn.setBad()
18461863
errorf("unexpected message during extended query execution: %q", t)
18471864
}
18481865
}
@@ -1855,7 +1872,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
18551872
switch t {
18561873
case 'C':
18571874
if err != nil {
1858-
cn.bad = true
1875+
cn.setBad()
18591876
errorf("unexpected CommandComplete after error %s", err)
18601877
}
18611878
res, commandTag = cn.parseComplete(r.string())
@@ -1869,15 +1886,15 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
18691886
err = parseError(r)
18701887
case 'T', 'D', 'I':
18711888
if err != nil {
1872-
cn.bad = true
1889+
cn.setBad()
18731890
errorf("unexpected %q after error %s", t, err)
18741891
}
18751892
if t == 'I' {
18761893
res = emptyRows
18771894
}
18781895
// ignore any results
18791896
default:
1880-
cn.bad = true
1897+
cn.setBad()
18811898
errorf("unknown %s response: %q", protocolState, t)
18821899
}
18831900
}

0 commit comments

Comments
 (0)