Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion error.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ func isRedisError(err error) bool {
}

func isBadConn(err error, allowTimeout bool) bool {
if err == nil {
switch err {
case nil:
return false
case context.Canceled, context.DeadlineExceeded:
return true
}

if isRedisError(err) {
Expand Down
2 changes: 1 addition & 1 deletion main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ func redisRingOptions() *redis.RingOptions {
func performAsync(n int, cbs ...func(int)) *sync.WaitGroup {
var wg sync.WaitGroup
for _, cb := range cbs {
wg.Add(n)
for i := 0; i < n; i++ {
wg.Add(1)
go func(cb func(int), i int) {
defer GinkgoRecover()
defer wg.Done()
Expand Down
20 changes: 20 additions & 0 deletions race_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package redis_test

import (
"bytes"
"context"
"fmt"
"net"
"strconv"
Expand Down Expand Up @@ -295,6 +296,25 @@ var _ = Describe("races", func() {
Expect(err).NotTo(HaveOccurred())
})
})

It("should abort on context timeout", func() {
opt := redisClusterOptions()
client := cluster.newClusterClient(ctx, opt)

ctx, cancel := context.WithCancel(context.Background())

wg := performAsync(C, func(_ int) {
_, err := client.XRead(ctx, &redis.XReadArgs{
Streams: []string{"test", "$"},
Block: 1 * time.Second,
}).Result()
Expect(err).To(Equal(context.Canceled))
})

time.Sleep(10 * time.Millisecond)
cancel()
wg.Wait()
})
})

var _ = Describe("cluster races", func() {
Expand Down
44 changes: 25 additions & 19 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"

"github.com/go-redis/redis/v8/internal"
Expand Down Expand Up @@ -130,20 +131,7 @@ func (hs hooks) processTxPipeline(
}

func (hs hooks) withContext(ctx context.Context, fn func() error) error {
done := ctx.Done()
if done == nil {
return fn()
}

errc := make(chan error, 1)
go func() { errc <- fn() }()

select {
case <-done:
return ctx.Err()
case err := <-errc:
return err
}
return fn()
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -316,8 +304,24 @@ func (c *baseClient) withConn(
c.releaseConn(ctx, cn, err)
}()

err = fn(ctx, cn)
return err
done := ctx.Done()
if done == nil {
err = fn(ctx, cn)
return err
}

errc := make(chan error, 1)
go func() { errc <- fn(ctx, cn) }()

select {
case <-done:
_ = cn.Close()

err = ctx.Err()
return err
case err = <-errc:
return err
}
})
}

Expand All @@ -334,7 +338,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
}
}

retryTimeout := true
retryTimeout := uint32(1)
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
Expand All @@ -345,7 +349,9 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {

err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply)
if err != nil {
retryTimeout = cmd.readTimeout() == nil
if cmd.readTimeout() == nil {
atomic.StoreUint32(&retryTimeout, 1)
}
return err
}

Expand All @@ -354,7 +360,7 @@ func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
if err == nil {
return nil
}
retry = shouldRetry(err, retryTimeout)
retry = shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1)
return err
})
if err == nil || !retry {
Expand Down