Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@
<artifactId>junit-interface</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.curator</groupId>
<artifactId>curator-test</artifactId>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this thing used for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

org.apache.curator.test.TestingServer is from this artifact. An embedded ZooKeeper server for testing.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in test scope?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Good catch. I was thinking it but forgot to add it here.

<scope>test</scope>
</dependency>
<dependency>
<groupId>net.razorvine</groupId>
<artifactId>pyrolite</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import java.io._

import scala.reflect.ClassTag

import akka.serialization.Serialization

import org.apache.spark.Logging
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer}
import org.apache.spark.util.Utils


Expand All @@ -32,11 +31,11 @@ import org.apache.spark.util.Utils
* Files are deleted when applications and workers are removed.
*
* @param dir Directory to store files. Created if non-existent (but not recursively).
* @param serialization Used to serialize our objects.
* @param serializer Used to serialize our objects.
*/
private[master] class FileSystemPersistenceEngine(
val dir: String,
val serialization: Serialization)
val serializer: Serializer)
extends PersistenceEngine with Logging {

new File(dir).mkdir()
Expand All @@ -57,27 +56,31 @@ private[master] class FileSystemPersistenceEngine(
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
val out = new FileOutputStream(file)
val fileOut = new FileOutputStream(file)
var out: SerializationStream = null
Utils.tryWithSafeFinally {
out.write(serialized)
out = serializer.newInstance().serializeStream(fileOut)
out.writeObject(value)
} {
out.close()
fileOut.close()
if (out != null) {
out.close()
}
}
}

private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
val fileIn = new FileInputStream(file)
var in: DeserializationStream = null
try {
dis.readFully(fileData)
in = serializer.newInstance().deserializeStream(fileIn)
in.readObject[T]()
} finally {
dis.close()
fileIn.close()
if (in != null) {
in.close()
}
}
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}

}
18 changes: 7 additions & 11 deletions core/src/main/scala/org/apache/spark/deploy/master/Master.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.language.postfixOps
import scala.util.Random

import akka.serialization.Serialization
import akka.serialization.SerializationExtension
import org.apache.hadoop.fs.Path

import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.rpc._
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
Expand All @@ -44,6 +41,7 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.serializer.{JavaSerializer, Serializer}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils}

Expand All @@ -58,9 +56,6 @@ private[master] class Master(
private val forwardMessageThread =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread")

// TODO Remove it once we don't use akka.serialization.Serialization
private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem

private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)

private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
Expand Down Expand Up @@ -161,20 +156,21 @@ private[master] class Master(
masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)

val serializer = new JavaSerializer(conf)
val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
val zkFactory =
new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem))
new ZooKeeperRecoveryModeFactory(conf, serializer)
(zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
val fsFactory =
new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem))
new FileSystemRecoveryModeFactory(conf, serializer)
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
case "CUSTOM" =>
val clazz = Utils.classForName(conf.get("spark.deploy.recoveryMode.factory"))
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
.newInstance(conf, SerializationExtension(actorSystem))
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serializer])
.newInstance(conf, serializer)
.asInstanceOf[StandaloneRecoveryModeFactory]
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
Expand Down Expand Up @@ -213,7 +209,7 @@ private[master] class Master(

override def receive: PartialFunction[Any, Unit] = {
case ElectedLeader => {
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData()
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
RecoveryState.ALIVE
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.deploy.master

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rpc.RpcEnv

import scala.reflect.ClassTag

Expand Down Expand Up @@ -80,8 +81,11 @@ abstract class PersistenceEngine {
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
(read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
final def readPersistedData(
rpcEnv: RpcEnv): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
rpcEnv.deserialize { () =>
(read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
}
}

def close() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

package org.apache.spark.deploy.master

import akka.serialization.Serialization

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.serializer.Serializer

/**
* ::DeveloperApi::
Expand All @@ -30,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
*
*/
@DeveloperApi
abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serializer) {

/**
* PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
Expand All @@ -49,7 +48,7 @@ abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serial
* LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
* recovery is made by restoring from filesystem.
*/
private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {

val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
Expand All @@ -64,7 +63,7 @@ private[master] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer:
}
}

private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
private[master] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serializer)
extends StandaloneRecoveryModeFactory(conf, serializer) {

def createPersistenceEngine(): PersistenceEngine = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.deploy.master

import akka.serialization.Serialization
import java.nio.ByteBuffer

import scala.collection.JavaConversions._
import scala.reflect.ClassTag
Expand All @@ -27,9 +27,10 @@ import org.apache.zookeeper.CreateMode

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.SparkCuratorUtil
import org.apache.spark.serializer.Serializer


private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer: Serializer)
extends PersistenceEngine
with Logging {

Expand Down Expand Up @@ -57,17 +58,16 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializat
}

private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
val serialized = serializer.newInstance().serialize(value)
val bytes = new Array[Byte](serialized.remaining())
serialized.get(bytes)
zk.create().withMode(CreateMode.PERSISTENT).forPath(path, bytes)
}

private def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
try {
Some(serializer.fromBinary(fileData).asInstanceOf[T])
Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
} catch {
case e: Exception => {
logWarning("Exception while reading persisted file, deleting", e)
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
* creating it manually because different [[RpcEnv]] may have different formats.
*/
def uriOf(systemName: String, address: RpcAddress, endpointName: String): String

/**
* [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object
* that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method.
*/
def deserialize[T](deserializationAction: () => T): T
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this new method to RpcEnv for RpcEndpointRef deserialization

}


Expand Down
14 changes: 13 additions & 1 deletion core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import com.google.common.util.concurrent.MoreExecutors
import akka.serialization.JavaSerializer

import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.rpc._
Expand Down Expand Up @@ -239,6 +239,12 @@ private[spark] class AkkaRpcEnv private[akka] (
}

override def toString: String = s"${getClass.getSimpleName}($actorSystem)"

override def deserialize[T](deserializationAction: () => T): T = {
JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain why this is necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, because now we no longer pass akka's Serialization, which has information about the actor system, into PersistenceEngine, so here we ensure that we're using the actor system's serializer.

But more generally, since we always serialize with JavaSerializer in the new code, why can't we always deserialize with the same thing? I just find it a little strange that we have to pass a closure into this method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark JavaSerializer is used to deserialize objects. However, it does not have an actor system in the current context. I need to use Akka JavaSerializer.currentSystem to put the current actor system into a thread-local variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but why do we need the actor system to deserialize it? Can't we just deserialize it with JavaSerializer? @rxin

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but why do we need the actor system to deserialize it? Can't we just deserialize it with JavaSerializer?

Oh, that's because WorkerInfo and ApplicationInfo contain a reference to RpcEndpointRef.

deserializationAction()
}
}
}

private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
Expand Down Expand Up @@ -315,6 +321,12 @@ private[akka] class AkkaRpcEndpointRef(

override def toString: String = s"${getClass.getSimpleName}($actorRef)"

final override def equals(that: Any): Boolean = that match {
case other: AkkaRpcEndpointRef => actorRef == other.actorRef
case _ => false
}

final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@
// when they are outside of org.apache.spark.
package other.supplier

import java.nio.ByteBuffer

import scala.collection.mutable
import scala.reflect.ClassTag

import akka.serialization.Serialization

import org.apache.spark.SparkConf
import org.apache.spark.deploy.master._
import org.apache.spark.serializer.Serializer

class CustomRecoveryModeFactory(
conf: SparkConf,
serialization: Serialization
) extends StandaloneRecoveryModeFactory(conf, serialization) {
serializer: Serializer
) extends StandaloneRecoveryModeFactory(conf, serializer) {

CustomRecoveryModeFactory.instantiationAttempts += 1

Expand All @@ -40,7 +41,7 @@ class CustomRecoveryModeFactory(
*
*/
override def createPersistenceEngine(): PersistenceEngine =
new CustomPersistenceEngine(serialization)
new CustomPersistenceEngine(serializer)

/**
* Create an instance of LeaderAgent that decides who gets elected as master.
Expand All @@ -53,7 +54,7 @@ object CustomRecoveryModeFactory {
@volatile var instantiationAttempts = 0
}

class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
class CustomPersistenceEngine(serializer: Serializer) extends PersistenceEngine {
val data = mutable.HashMap[String, Array[Byte]]()

CustomPersistenceEngine.lastInstance = Some(this)
Expand All @@ -64,10 +65,10 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def persist(name: String, obj: Object): Unit = {
CustomPersistenceEngine.persistAttempts += 1
serialization.serialize(obj) match {
case util.Success(bytes) => data += name -> bytes
case util.Failure(cause) => throw new RuntimeException(cause)
}
val serialized = serializer.newInstance().serialize(obj)
val bytes = new Array[Byte](serialized.remaining())
serialized.get(bytes)
data += name -> bytes
}

/**
Expand All @@ -84,15 +85,9 @@ class CustomPersistenceEngine(serialization: Serialization) extends PersistenceE
*/
override def read[T: ClassTag](prefix: String): Seq[T] = {
CustomPersistenceEngine.readAttempts += 1
val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
yield serialization.deserialize(bytes, clazz)

results.find(_.isFailure).foreach {
case util.Failure(cause) => throw new RuntimeException(cause)
}

results.flatMap(_.toOption).toSeq
yield serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
results.toSeq
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually {
persistenceEngine.addDriver(driverToPersist)
persistenceEngine.addWorker(workerToPersist)

val (apps, drivers, workers) = persistenceEngine.readPersistedData()
val (apps, drivers, workers) = persistenceEngine.readPersistedData(rpcEnv)

apps.map(_.id) should contain(appToPersist.id)
drivers.map(_.id) should contain(driverToPersist.id)
Expand Down
Loading