@@ -18,6 +18,7 @@ import (
1818 "path/filepath"
1919 "strconv"
2020 "strings"
21+ "sync/atomic"
2122 "time"
2223 "unicode"
2324
@@ -141,7 +142,7 @@ type conn struct {
141142
142143 // If true, this connection is bad and all public-facing functions should
143144 // return ErrBadConn.
144- bad bool
145+ bad * atomic. Value
145146
146147 // If set, this connection should never use the binary format when
147148 // receiving query results from prepared statements. Only provided for
@@ -299,9 +300,12 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
299300
300301 o := c .opts
301302
303+ bad := & atomic.Value {}
304+ bad .Store (false )
302305 cn = & conn {
303306 opts : o ,
304307 dialer : c .dialer ,
308+ bad : bad ,
305309 }
306310 err = cn .handleDriverSettings (o )
307311 if err != nil {
@@ -506,9 +510,22 @@ func (cn *conn) isInTransaction() bool {
506510 cn .txnStatus == txnStatusInFailedTransaction
507511}
508512
513+ func (cn * conn ) setBad () {
514+ if cn .bad != nil {
515+ cn .bad .Store (true )
516+ }
517+ }
518+
519+ func (cn * conn ) getBad () bool {
520+ if cn .bad != nil {
521+ return cn .bad .Load ().(bool )
522+ }
523+ return false
524+ }
525+
509526func (cn * conn ) checkIsInTransaction (intxn bool ) {
510527 if cn .isInTransaction () != intxn {
511- cn .bad = true
528+ cn .setBad ()
512529 errorf ("unexpected transaction status %v" , cn .txnStatus )
513530 }
514531}
@@ -518,7 +535,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) {
518535}
519536
520537func (cn * conn ) begin (mode string ) (_ driver.Tx , err error ) {
521- if cn .bad {
538+ if cn .getBad () {
522539 return nil , driver .ErrBadConn
523540 }
524541 defer cn .errRecover (& err )
@@ -529,11 +546,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
529546 return nil , err
530547 }
531548 if commandTag != "BEGIN" {
532- cn .bad = true
549+ cn .setBad ()
533550 return nil , fmt .Errorf ("unexpected command tag %s" , commandTag )
534551 }
535552 if cn .txnStatus != txnStatusIdleInTransaction {
536- cn .bad = true
553+ cn .setBad ()
537554 return nil , fmt .Errorf ("unexpected transaction status %v" , cn .txnStatus )
538555 }
539556 return cn , nil
@@ -547,7 +564,7 @@ func (cn *conn) closeTxn() {
547564
548565func (cn * conn ) Commit () (err error ) {
549566 defer cn .closeTxn ()
550- if cn .bad {
567+ if cn .getBad () {
551568 return driver .ErrBadConn
552569 }
553570 defer cn .errRecover (& err )
@@ -569,12 +586,12 @@ func (cn *conn) Commit() (err error) {
569586 _ , commandTag , err := cn .simpleExec ("COMMIT" )
570587 if err != nil {
571588 if cn .isInTransaction () {
572- cn .bad = true
589+ cn .setBad ()
573590 }
574591 return err
575592 }
576593 if commandTag != "COMMIT" {
577- cn .bad = true
594+ cn .setBad ()
578595 return fmt .Errorf ("unexpected command tag %s" , commandTag )
579596 }
580597 cn .checkIsInTransaction (false )
@@ -583,7 +600,7 @@ func (cn *conn) Commit() (err error) {
583600
584601func (cn * conn ) Rollback () (err error ) {
585602 defer cn .closeTxn ()
586- if cn .bad {
603+ if cn .getBad () {
587604 return driver .ErrBadConn
588605 }
589606 defer cn .errRecover (& err )
@@ -595,7 +612,7 @@ func (cn *conn) rollback() (err error) {
595612 _ , commandTag , err := cn .simpleExec ("ROLLBACK" )
596613 if err != nil {
597614 if cn .isInTransaction () {
598- cn .bad = true
615+ cn .setBad ()
599616 }
600617 return err
601618 }
@@ -635,7 +652,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err
635652 case 'T' , 'D' :
636653 // ignore any results
637654 default :
638- cn .bad = true
655+ cn .setBad ()
639656 errorf ("unknown response for simple query: %q" , t )
640657 }
641658 }
@@ -657,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
657674 // the user can close, though, to avoid connections from being
658675 // leaked. A "rows" with done=true works fine for that purpose.
659676 if err != nil {
660- cn .bad = true
677+ cn .setBad ()
661678 errorf ("unexpected message %q in simple query execution" , t )
662679 }
663680 if res == nil {
@@ -684,7 +701,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
684701 err = parseError (r )
685702 case 'D' :
686703 if res == nil {
687- cn .bad = true
704+ cn .setBad ()
688705 errorf ("unexpected DataRow in simple query execution" )
689706 }
690707 // the query didn't fail; kick off to Next
@@ -699,7 +716,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
699716 // To work around a bug in QueryRow in Go 1.2 and earlier, wait
700717 // until the first DataRow has been received.
701718 default :
702- cn .bad = true
719+ cn .setBad ()
703720 errorf ("unknown response for simple query: %q" , t )
704721 }
705722 }
@@ -792,7 +809,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
792809}
793810
794811func (cn * conn ) Prepare (q string ) (_ driver.Stmt , err error ) {
795- if cn .bad {
812+ if cn .getBad () {
796813 return nil , driver .ErrBadConn
797814 }
798815 defer cn .errRecover (& err )
@@ -831,7 +848,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
831848}
832849
833850func (cn * conn ) query (query string , args []driver.Value ) (_ * rows , err error ) {
834- if cn .bad {
851+ if cn .getBad () {
835852 return nil , driver .ErrBadConn
836853 }
837854 if cn .inCopy {
@@ -865,7 +882,7 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
865882
866883// Implement the optional "Execer" interface for one-shot queries
867884func (cn * conn ) Exec (query string , args []driver.Value ) (res driver.Result , err error ) {
868- if cn .bad {
885+ if cn .getBad () {
869886 return nil , driver .ErrBadConn
870887 }
871888 defer cn .errRecover (& err )
@@ -937,7 +954,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) {
937954// the message yourself.
938955func (cn * conn ) saveMessage (typ byte , buf * readBuf ) {
939956 if cn .saveMessageType != 0 {
940- cn .bad = true
957+ cn .setBad ()
941958 errorf ("unexpected saveMessageType %d" , cn .saveMessageType )
942959 }
943960 cn .saveMessageType = typ
@@ -1307,7 +1324,7 @@ func (st *stmt) Close() (err error) {
13071324 if st .closed {
13081325 return nil
13091326 }
1310- if st .cn .bad {
1327+ if st .cn .getBad () {
13111328 return driver .ErrBadConn
13121329 }
13131330 defer st .cn .errRecover (& err )
@@ -1321,14 +1338,14 @@ func (st *stmt) Close() (err error) {
13211338
13221339 t , _ := st .cn .recv1 ()
13231340 if t != '3' {
1324- st .cn .bad = true
1341+ st .cn .setBad ()
13251342 errorf ("unexpected close response: %q" , t )
13261343 }
13271344 st .closed = true
13281345
13291346 t , r := st .cn .recv1 ()
13301347 if t != 'Z' {
1331- st .cn .bad = true
1348+ st .cn .setBad ()
13321349 errorf ("expected ready for query, but got: %q" , t )
13331350 }
13341351 st .cn .processReadyForQuery (r )
@@ -1337,7 +1354,7 @@ func (st *stmt) Close() (err error) {
13371354}
13381355
13391356func (st * stmt ) Query (v []driver.Value ) (r driver.Rows , err error ) {
1340- if st .cn .bad {
1357+ if st .cn .getBad () {
13411358 return nil , driver .ErrBadConn
13421359 }
13431360 defer st .cn .errRecover (& err )
@@ -1350,7 +1367,7 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
13501367}
13511368
13521369func (st * stmt ) Exec (v []driver.Value ) (res driver.Result , err error ) {
1353- if st .cn .bad {
1370+ if st .cn .getBad () {
13541371 return nil , driver .ErrBadConn
13551372 }
13561373 defer st .cn .errRecover (& err )
@@ -1437,7 +1454,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
14371454 if affectedRows == nil && strings .HasPrefix (commandTag , "INSERT " ) {
14381455 parts := strings .Split (commandTag , " " )
14391456 if len (parts ) != 3 {
1440- cn .bad = true
1457+ cn .setBad ()
14411458 errorf ("unexpected INSERT command tag %s" , commandTag )
14421459 }
14431460 affectedRows = & parts [len (parts )- 1 ]
@@ -1449,7 +1466,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
14491466 }
14501467 n , err := strconv .ParseInt (* affectedRows , 10 , 64 )
14511468 if err != nil {
1452- cn .bad = true
1469+ cn .setBad ()
14531470 errorf ("could not parse commandTag: %s" , err )
14541471 }
14551472 return driver .RowsAffected (n ), commandTag
@@ -1516,7 +1533,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
15161533 }
15171534
15181535 conn := rs .cn
1519- if conn .bad {
1536+ if conn .getBad () {
15201537 return driver .ErrBadConn
15211538 }
15221539 defer conn .errRecover (& err )
@@ -1541,7 +1558,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
15411558 case 'D' :
15421559 n := rs .rb .int16 ()
15431560 if err != nil {
1544- conn .bad = true
1561+ conn .setBad ()
15451562 errorf ("unexpected DataRow after error %s" , err )
15461563 }
15471564 if n < len (dest ) {
@@ -1736,7 +1753,7 @@ func (cn *conn) readReadyForQuery() {
17361753 cn .processReadyForQuery (r )
17371754 return
17381755 default :
1739- cn .bad = true
1756+ cn .setBad ()
17401757 errorf ("unexpected message %q; expected ReadyForQuery" , t )
17411758 }
17421759}
@@ -1756,7 +1773,7 @@ func (cn *conn) readParseResponse() {
17561773 cn .readReadyForQuery ()
17571774 panic (err )
17581775 default :
1759- cn .bad = true
1776+ cn .setBad ()
17601777 errorf ("unexpected Parse response %q" , t )
17611778 }
17621779}
@@ -1781,7 +1798,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
17811798 cn .readReadyForQuery ()
17821799 panic (err )
17831800 default :
1784- cn .bad = true
1801+ cn .setBad ()
17851802 errorf ("unexpected Describe statement response %q" , t )
17861803 }
17871804 }
@@ -1799,7 +1816,7 @@ func (cn *conn) readPortalDescribeResponse() rowsHeader {
17991816 cn .readReadyForQuery ()
18001817 panic (err )
18011818 default :
1802- cn .bad = true
1819+ cn .setBad ()
18031820 errorf ("unexpected Describe response %q" , t )
18041821 }
18051822 panic ("not reached" )
@@ -1815,7 +1832,7 @@ func (cn *conn) readBindResponse() {
18151832 cn .readReadyForQuery ()
18161833 panic (err )
18171834 default :
1818- cn .bad = true
1835+ cn .setBad ()
18191836 errorf ("unexpected Bind response %q" , t )
18201837 }
18211838}
@@ -1842,7 +1859,7 @@ func (cn *conn) postExecuteWorkaround() {
18421859 cn .saveMessage (t , r )
18431860 return
18441861 default :
1845- cn .bad = true
1862+ cn .setBad ()
18461863 errorf ("unexpected message during extended query execution: %q" , t )
18471864 }
18481865 }
@@ -1855,7 +1872,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
18551872 switch t {
18561873 case 'C' :
18571874 if err != nil {
1858- cn .bad = true
1875+ cn .setBad ()
18591876 errorf ("unexpected CommandComplete after error %s" , err )
18601877 }
18611878 res , commandTag = cn .parseComplete (r .string ())
@@ -1869,15 +1886,15 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co
18691886 err = parseError (r )
18701887 case 'T' , 'D' , 'I' :
18711888 if err != nil {
1872- cn .bad = true
1889+ cn .setBad ()
18731890 errorf ("unexpected %q after error %s" , t , err )
18741891 }
18751892 if t == 'I' {
18761893 res = emptyRows
18771894 }
18781895 // ignore any results
18791896 default :
1880- cn .bad = true
1897+ cn .setBad ()
18811898 errorf ("unknown %s response: %q" , protocolState , t )
18821899 }
18831900 }
0 commit comments