Skip to content
Next Next commit
Add MessageHandler for Kafka Python API
  • Loading branch information
jerryshao committed Nov 6, 2015
commit 5e7a093663f1689a019111c56aa6c146e118ccad
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.streaming.kafka

import java.io.OutputStream
import java.lang.{Integer => JInt, Long => JLong}
import java.util.{List => JList, Map => JMap, Set => JSet}

Expand All @@ -25,17 +26,19 @@ import scala.reflect.ClassTag

import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder}
import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler}

import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.api.java._
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.util.WriteAheadLogUtils
import org.apache.spark.{SparkContext, SparkException}

object KafkaUtils {
/**
Expand Down Expand Up @@ -184,6 +187,27 @@ object KafkaUtils {
}
}

private[kafka] def getFromOffsets(
kc: KafkaCluster,
kafkaParams: Map[String, String],
topics: Set[String]
): Map[TopicAndPartition, Long] = {
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
} else {
kc.getLatestLeaderOffsets(topicPartitions)
}).right
} yield {
leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
}
KafkaCluster.checkErrors(result)
}

/**
* Create a RDD from Kafka using offset ranges for each topic and partition.
*
Expand Down Expand Up @@ -246,7 +270,7 @@ object KafkaUtils {
// This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
leaders.map {
case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
}.toMap
}
}
val cleanedHandler = sc.clean(messageHandler)
checkOffsets(kc, offsetRanges)
Expand Down Expand Up @@ -406,23 +430,9 @@ object KafkaUtils {
): InputDStream[(K, V)] = {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
val kc = new KafkaCluster(kafkaParams)
val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)

val result = for {
topicPartitions <- kc.getPartitions(topics).right
leaderOffsets <- (if (reset == Some("smallest")) {
kc.getEarliestLeaderOffsets(topicPartitions)
} else {
kc.getLatestLeaderOffsets(topicPartitions)
}).right
} yield {
val fromOffsets = leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}
KafkaCluster.checkErrors(result)
val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}

/**
Expand Down Expand Up @@ -550,6 +560,8 @@ object KafkaUtils {
* takes care of known parameters instead of passing them from Python
*/
private[kafka] class KafkaUtilsPythonHelper {
import KafkaUtilsPythonHelper._

def createStream(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
Expand All @@ -570,82 +582,67 @@ private[kafka] class KafkaUtilsPythonHelper {
jsc: JavaSparkContext,
kafkaParams: JMap[String, String],
offsetRanges: JList[OffsetRange],
leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = {
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
(Array[Byte], Array[Byte])] {
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
(t1.key(), t1.message())
}
leaders: JMap[TopicAndPartition, Broker]
): JavaRDD[Array[Byte]] = {
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())

val jrdd = KafkaUtils.createRDD[
KafkaUtils.createRDD[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
(Array[Byte], Array[Byte])](
jsc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
classOf[(Array[Byte], Array[Byte])],
kafkaParams,
PythonMessageAndMetadata](
jsc.sc,
Map(kafkaParams.asScala.toSeq: _*),
offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())),
leaders,
messageHandler
)
new JavaPairRDD(jrdd.rdd)
Map(leaders.asScala.toSeq: _*),
messageHandler).mapPartitions { iter =>
initialize()
new SerDeUtil.AutoBatchedPickler(iter)
}
}

def createDirectStream(
jssc: JavaStreamingContext,
kafkaParams: JMap[String, String],
topics: JSet[String],
fromOffsets: JMap[TopicAndPartition, JLong]
): JavaPairInputDStream[Array[Byte], Array[Byte]] = {
): JavaDStream[Array[Byte]] = {

if (!fromOffsets.isEmpty) {
val currentFromOffsets = if (!fromOffsets.isEmpty) {
val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic)
if (topicsFromOffsets != topics.asScala.toSet) {
throw new IllegalStateException(
s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " +
s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}")
}
Map(fromOffsets.asScala.mapValues { _.longValue() }.toSeq: _*)
} else {
val kc = new KafkaCluster(Map(kafkaParams.asScala.toSeq: _*))
KafkaUtils.getFromOffsets(
kc, Map(kafkaParams.asScala.toSeq: _*), Set(topics.asScala.toSeq: _*))
}

if (fromOffsets.isEmpty) {
KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
jssc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
kafkaParams,
topics)
} else {
val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]],
(Array[Byte], Array[Byte])] {
def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) =
(t1.key(), t1.message())
}
val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) =>
new PythonMessageAndMetadata(
mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message())

val jstream = KafkaUtils.createDirectStream[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
(Array[Byte], Array[Byte])](
jssc,
classOf[Array[Byte]],
classOf[Array[Byte]],
classOf[DefaultDecoder],
classOf[DefaultDecoder],
classOf[(Array[Byte], Array[Byte])],
kafkaParams,
fromOffsets,
messageHandler)
new JavaPairInputDStream(jstream.inputDStream)
val stream = KafkaUtils.createDirectStream[
Array[Byte],
Array[Byte],
DefaultDecoder,
DefaultDecoder,
PythonMessageAndMetadata](
jssc.ssc,
Map(kafkaParams.asScala.toSeq: _*),
Map(currentFromOffsets.toSeq: _*),
messageHandler).mapPartitions { iter =>
initialize()
new SerDeUtil.AutoBatchedPickler(iter)
}
new JavaDStream(stream)
}

def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong
Expand All @@ -669,3 +666,53 @@ private[kafka] class KafkaUtilsPythonHelper {
kafkaRDD.offsetRanges.toSeq.asJava
}
}

private object KafkaUtilsPythonHelper {
private var initialized = false

def initialize(): Unit = {
SerDeUtil.initialize()
synchronized {
if (!initialized) {
new PythonMessageAndMetadataPickler().register()
initialized = true
}
}
}
// will not called in Executor automatically
initialize()

case class PythonMessageAndMetadata(
topic: String,
partition: JInt,
offset: JLong,
key: Array[Byte],
message: Array[Byte])

class PythonMessageAndMetadataPickler extends IObjectPickler {
private val module = "pyspark.streaming.kafka"

def register(): Unit = {
Pickler.registerCustomPickler(classOf[PythonMessageAndMetadata], this)
Pickler.registerCustomPickler(this.getClass, this)
}

def pickle(obj: Object, out: OutputStream, pickler: Pickler) {
if (obj == this) {
out.write(Opcodes.GLOBAL)
out.write(s"$module\nKafkaMessageAndMetadata\n".getBytes("utf-8"))
} else {
pickler.save(this)
val msgAndMetaData = obj.asInstanceOf[PythonMessageAndMetadata]
out.write(Opcodes.MARK)
pickler.save(msgAndMetaData.topic)
pickler.save(msgAndMetaData.partition)
pickler.save(msgAndMetaData.offset)
pickler.save(msgAndMetaData.key)
pickler.save(msgAndMetaData.message)
out.write(Opcodes.TUPLE)
out.write(Opcodes.REDUCE)
}
}
}
}
Loading