diff --git a/project/ScalaRedisProject.scala b/project/ScalaRedisProject.scala index e69c13f..0543374 100644 --- a/project/ScalaRedisProject.scala +++ b/project/ScalaRedisProject.scala @@ -24,7 +24,8 @@ object ScalaRedisProject extends Build "org.slf4j" % "slf4j-api" % "1.7.6" % "provided", "ch.qos.logback" % "logback-classic" % "1.1.1" % "provided", "junit" % "junit" % "4.11" % "test", - "org.scalatest" %% "scalatest" % "2.1.0" % "test", + "org.scalatest" %% "scalatest" % "2.1.0" % "test", + "com.typesafe.akka" %% "akka-testkit" % "2.3.0" % "test", // Third-party serialization libraries "net.liftweb" %% "lift-json" % "2.5.1" % "provided, test", diff --git a/src/main/scala/com/redis/RedisClientSettings.scala b/src/main/scala/com/redis/RedisClientSettings.scala index 782f1ec..2f39f55 100644 --- a/src/main/scala/com/redis/RedisClientSettings.scala +++ b/src/main/scala/com/redis/RedisClientSettings.scala @@ -1,11 +1,12 @@ package com.redis -import RedisClientSettings._ +import java.lang.{Long => JLong} +import RedisClientSettings._ case class RedisClientSettings( backpressureBufferSettings: Option[BackpressureBufferSettings] = None, - reconnectionSettings: Option[ReconnectionSettings] = None + reconnectionSettings: ReconnectionSettings = NoReconnectionSettings ) object RedisClientSettings { @@ -24,32 +25,62 @@ object RedisClientSettings { def newSchedule: ReconnectionSchedule trait ReconnectionSchedule { + val maxAttempts: Long + var attempts: Long = 0 + + /** + * Gets the number of milliseconds until the next reconnection attempt. + * + * This method is expected to increment attempts like an iterator + * + * @return milliseconds until the next attempt + */ def nextDelayMs: Long } } - case class ConstantReconnectionSettings(constantDelayMs: Long) extends ReconnectionSettings { + case object NoReconnectionSettings extends ReconnectionSettings{ + def newSchedule: ReconnectionSchedule = new ReconnectionSchedule { + val maxAttempts: Long = 0 + def nextDelayMs: Long = throw new NoSuchElementException("No delay available") + } + } + + case class ConstantReconnectionSettings(constantDelayMs: Long, maximumAttempts: Long = Long.MaxValue) extends ReconnectionSettings { require(constantDelayMs >= 0, s"Invalid negative reconnection delay (received $constantDelayMs)") + require(maximumAttempts >= 0, s"Invalid negative maximum attempts (received $maximumAttempts)") def newSchedule: ReconnectionSchedule = new ConstantSchedule class ConstantSchedule extends ReconnectionSchedule { - def nextDelayMs = constantDelayMs + val maxAttempts = maximumAttempts + def nextDelayMs = { + attempts += 1 + constantDelayMs + } } } - case class ExponentialReconnectionPolicy(baseDelayMs: Long, maxDelayMs: Long) extends ReconnectionSettings { + case class ExponentialReconnectionSettings(baseDelayMs: Long, maxDelayMs: Long, maximumAttempts: Long = Long.MaxValue) extends ReconnectionSettings { require(baseDelayMs > 0, s"Base reconnection delay must be greater than 0. Received $baseDelayMs") require(maxDelayMs > 0, s"Maximum reconnection delay must be greater than 0. Received $maxDelayMs") require(maxDelayMs >= baseDelayMs, "Maximum reconnection delay cannot be smaller than base reconnection delay") def newSchedule = new ExponentialSchedule + private val ceil = if ((baseDelayMs & (baseDelayMs - 1)) == 0) 0 else 1 + private val attemptCeiling = JLong.SIZE - JLong.numberOfLeadingZeros(Long.MaxValue / baseDelayMs) - ceil + class ExponentialSchedule extends ReconnectionSchedule { - var attempts = 0 + val maxAttempts = maximumAttempts def nextDelayMs = { attempts += 1 - Math.min(baseDelayMs * (1L << attempts), maxDelayMs) + if (attempts > attemptCeiling) { + maxDelayMs + } else { + val factor = 1L << (attempts - 1) + Math.min(baseDelayMs * factor, maxDelayMs) + } } } } diff --git a/src/main/scala/com/redis/RedisConnection.scala b/src/main/scala/com/redis/RedisConnection.scala index 783f83b..847236f 100644 --- a/src/main/scala/com/redis/RedisConnection.scala +++ b/src/main/scala/com/redis/RedisConnection.scala @@ -25,7 +25,7 @@ private [redis] class RedisConnection(remote: InetSocketAddress, settings: Redis private[this] var pendingRequests = Queue.empty[RedisRequest] private[this] var txnRequests = Queue.empty[RedisRequest] - private[this] var reconnectionSchedule: Option[_ <: ReconnectionSettings#ReconnectionSchedule] = None + private[this] lazy val reconnectionSchedule = settings.reconnectionSettings.newSchedule IO(Tcp) ! Connect(remote) @@ -46,18 +46,8 @@ private [redis] class RedisConnection(remote: InetSocketAddress, settings: Redis context watch pipe case CommandFailed(c: Connect) => - settings.reconnectionSettings match { - case Some(r) => - if (reconnectionSchedule.isEmpty) { - reconnectionSchedule = Some(settings.reconnectionSettings.get.newSchedule) - } - val delayMs = reconnectionSchedule.get.nextDelayMs - log.error("Connect failed for {} with {}. Reconnecting in {} ms... ", c.remoteAddress, c.failureMessage, delayMs) - context.system.scheduler.scheduleOnce(Duration(delayMs, TimeUnit.MILLISECONDS), IO(Tcp), Connect(remote))(context.dispatcher, self) - case None => - log.error("Connect failed for {} with {}. Stopping... ", c.remoteAddress, c.failureMessage) - context stop self - } + log.error("Connect failed for {} with {}. Stopping... ", c.remoteAddress, c.failureMessage) + context stop self } def transactional(pipe: ActorRef): Receive = withTerminationManagement { @@ -123,18 +113,14 @@ private [redis] class RedisConnection(remote: InetSocketAddress, settings: Redis def withTerminationManagement(handler: Receive): Receive = handler orElse { case Terminated(x) => { - settings.reconnectionSettings match { - case Some(r) => - if (reconnectionSchedule.isEmpty) { - reconnectionSchedule = Some(settings.reconnectionSettings.get.newSchedule) - } - val delayMs = reconnectionSchedule.get.nextDelayMs - log.error("Child termination detected: {}. Reconnecting in {} ms... ", x, delayMs) - context become unconnected - context.system.scheduler.scheduleOnce(Duration(delayMs, TimeUnit.MILLISECONDS), IO(Tcp), Connect(remote))(context.dispatcher, self) - case None => - log.error("Child termination detected: {}", x) - context stop self + if (reconnectionSchedule.attempts < reconnectionSchedule.maxAttempts) { + val delayMs = reconnectionSchedule.nextDelayMs + log.error("Child termination detected: {}. Reconnecting in {} ms... ", x, delayMs) + context become unconnected + context.system.scheduler.scheduleOnce(Duration(delayMs, TimeUnit.MILLISECONDS), IO(Tcp), Connect(remote))(context.dispatcher, self) + } else { + log.error("Child termination detected: {}", x) + context stop self } } } diff --git a/src/main/scala/com/redis/protocol/ServerCommands.scala b/src/main/scala/com/redis/protocol/ServerCommands.scala index 093d865..0c0158e 100644 --- a/src/main/scala/com/redis/protocol/ServerCommands.scala +++ b/src/main/scala/com/redis/protocol/ServerCommands.scala @@ -49,7 +49,7 @@ object ServerCommands { } case class SetName(name: String) extends RedisCommand[Boolean]("CLIENT") { - def params = "SETNAME" +: ANil + def params = "SETNAME" +: name +: ANil } case class Kill(ipPort: String) extends RedisCommand[Boolean]("CLIENT") { diff --git a/src/test/scala/com/redis/ClientSpec.scala b/src/test/scala/com/redis/ClientSpec.scala index 76c7819..b5e353a 100644 --- a/src/test/scala/com/redis/ClientSpec.scala +++ b/src/test/scala/com/redis/ClientSpec.scala @@ -2,12 +2,16 @@ package com.redis import scala.concurrent.Future +import akka.testkit.TestProbe +import org.junit.runner.RunWith import org.scalatest.exceptions.TestFailedException import org.scalatest.junit.JUnitRunner -import org.junit.runner.RunWith - import serialization._ - +import akka.io.Tcp.{Connected, CommandFailed} +import scala.reflect.ClassTag +import scala.concurrent.duration._ +import com.redis.RedisClientSettings.ConstantReconnectionSettings +import com.redis.protocol.ServerCommands.Client.Kill @RunWith(classOf[JUnitRunner]) class ClientSpec extends RedisSpecBase { @@ -75,21 +79,51 @@ class ClientSpec extends RedisSpecBase { } describe("reconnections based on policy") { - it("should reconnect") { - val key = "reconnect_test" - client.lpush(key, 0) + def killClientsNamed(rc: RedisClient, name: String): Future[List[Boolean]] = { + // Clients are a list of lines similar to + // addr=127.0.0.1:65227 fd=9 name= age=0 idle=0 flags=N db=0 sub=0 psub=0 multi=-1 qbuf=0 qbuf-free=32768 obl=0 oll=0 omem=0 events=r cmd=client + // We'll split them up and make a map + val clients = rc.client.list().futureValue.get.toString + .split('\n') + .map(_.trim) + .filterNot(_.isEmpty) + .map( + _.split(" ").map( + _.split("=").padTo(2, "") + ).map( + item => (item(0), item(1)) + ) + ).map(_.toMap) + Future.sequence(clients.filter(_("name") == name).map(_("addr")).map(rc.client.kill).toList) + } + + it("should not reconnect by default") { + val name = "test-client-1" + client.client.setname(name).futureValue should equal (true) - // Extract our address - // TODO Cleaner address extraction, perhaps in ServerOperations.client? - val address = client.client.list().futureValue.get.toString.split(" ").head.split("=").last - client.client.kill(address).futureValue should be (true) + val probe = TestProbe() + probe watch client.clientRef + killClientsNamed(client, name).futureValue.reduce(_ && _) should equal (true) + probe.expectTerminated(client.clientRef) + } - client.lpush(key, 1 to 100).futureValue should equal (101) - val list = client.lrange[Long](key, 0, -1).futureValue + it("should reconnect with settings") { + withReconnectingClient { client => + val name = "test-client-2" + client.client.setname(name).futureValue should equal (true) - list.size should equal (101) - list.reverse should equal (0 to 100) + val key = "reconnect_test" + client.lpush(key, 0) + + killClientsNamed(client, name).futureValue.reduce(_ && _) should equal (true) + + client.lpush(key, 1 to 100).futureValue should equal(101) + val list = client.lrange[Long](key, 0, -1).futureValue + + list.size should equal(101) + list.reverse should equal(0 to 100) + } } } } diff --git a/src/test/scala/com/redis/RedisSpecBase.scala b/src/test/scala/com/redis/RedisSpecBase.scala index 6271b18..7f38c97 100644 --- a/src/test/scala/com/redis/RedisSpecBase.scala +++ b/src/test/scala/com/redis/RedisSpecBase.scala @@ -2,23 +2,23 @@ package com.redis import scala.concurrent.duration._ -import akka.util.Timeout import akka.actor._ +import akka.testkit.TestKit +import akka.util.Timeout import com.redis.RedisClientSettings.ConstantReconnectionSettings import org.scalatest._ import org.scalatest.concurrent.{Futures, ScalaFutures} import org.scalatest.time._ -trait RedisSpecBase extends FunSpec +class RedisSpecBase(_system: ActorSystem) extends TestKit(_system) + with FunSpecLike with Matchers with Futures with ScalaFutures with BeforeAndAfterEach with BeforeAndAfterAll { - import RedisSpecBase._ - // Akka setup - implicit val system = ActorSystem("redis-test-"+ iter.next) + def this() = this(ActorSystem("redis-test-"+ RedisSpecBase.iter.next)) implicit val executionContext = system.dispatcher implicit val timeout = Timeout(2 seconds) @@ -26,7 +26,13 @@ trait RedisSpecBase extends FunSpec implicit val defaultPatience = PatienceConfig(timeout = Span(5, Seconds), interval = Span(5, Millis)) // Redis client setup - val client = RedisClient("localhost", 6379, settings = RedisClientSettings(reconnectionSettings = Some(ConstantReconnectionSettings(1000)))) + val client = RedisClient("localhost", 6379) + + def withReconnectingClient(testCode: RedisClient => Any) = { + val client = RedisClient("localhost", 6379, settings = RedisClientSettings(reconnectionSettings = ConstantReconnectionSettings(100))) + testCode(client) + client.quit().futureValue should equal (true) + } override def beforeEach = { client.flushdb()