Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging {
val props = new Properties()
props.put("metadata.broker.list", brokerAddress)
props.put("serializer.class", classOf[StringEncoder].getName)
// wait for all in-sync replicas to ack sends
props.put("request.required.acks", "-1")
props
}

Expand Down Expand Up @@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging {
tryAgain(1)
}

/** Wait until the leader offset for the given topic/partition equals the specified offset */
def waitUntilLeaderOffset(
topic: String,
partition: Int,
offset: Long): Unit = {
eventually(Time(10000), Time(100)) {
val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress))
val tp = TopicAndPartition(topic, partition)
val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset
assert(
llo == offset,
s"$topic $partition $offset not reached after timeout")
}
}

private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match {
case Some(partitionState) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ public void testKafkaRDD() throws InterruptedException {
HashMap<String, String> kafkaParams = new HashMap<String, String>();
kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());

kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length);
kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length);

OffsetRange[] offsetRanges = {
OffsetRange.create(topic1, 0, 0, 1),
OffsetRange.create(topic2, 0, 0, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
"group.id" -> s"test-consumer-${Random.nextInt}")

kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size)

val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))

val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
Expand All @@ -86,7 +84,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
// this is the "lots of messages" case
kafkaTestUtils.sendMessages(topic, sent)
val sentCount = sent.values.sum
kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount)

// rdd defined from leaders after sending messages, should get the number sent
val rdd = getRdd(kc, Set(topic))
Expand All @@ -113,7 +110,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
val sentOnlyOne = Map("d" -> 1)

kafkaTestUtils.sendMessages(topic, sentOnlyOne)
kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1)

assert(rdd2.isDefined)
assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ object MimaExcludes {
// Mima false positive (was a private[spark] class)
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.util.collection.PairIterator"),
// Removing a testing method from a private class
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution")
)
Expand Down
5 changes: 0 additions & 5 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def test_kafka_stream(self):

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))

stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
"test-streaming-consumer", {topic: 1},
Expand All @@ -631,7 +630,6 @@ def test_kafka_direct_stream(self):

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))

stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
self._validateStreamResult(sendData, stream)
Expand All @@ -646,7 +644,6 @@ def test_kafka_direct_stream_from_offset(self):

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))

stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
self._validateStreamResult(sendData, stream)
Expand All @@ -661,7 +658,6 @@ def test_kafka_rdd(self):

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
self._validateRddResult(sendData, rdd)

Expand All @@ -677,7 +673,6 @@ def test_kafka_rdd_with_leaders(self):

self._kafkaTestUtils.createTopic(topic)
self._kafkaTestUtils.sendMessages(topic, sendData)
self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
self._validateRddResult(sendData, rdd)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees

Expand All @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def qualifiers: Seq[String] = throw new UnsupportedOperationException

override def exprId: ExprId = throw new UnsupportedOperationException

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
}

object BindReferences extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -160,7 +161,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null }
})
case BooleanType =>
buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0)))
buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0))
case LongType =>
buildCast[Long](_, l => new Timestamp(l))
case IntegerType =>
Expand Down Expand Up @@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {

case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)

case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")

case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")

case other =>
super.genCode(ctx, ev)
}
}
}

object Cast {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] {
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): Any

/**
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
*
* @param ctx a [[CodeGenContext]]
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve
}

/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodeGenContext]]
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm;
}
"""
}

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
Expand Down Expand Up @@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable

override def toString: String = s"($left $symbol $right)"

/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (Term, Term) => Code): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
}

val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitive, eval2.primitive)

s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.isNull} = true;
}
}
"""
}
}

private[sql] object BinaryExpression {
Expand All @@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]

abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>

/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
* defineCodeGen(ctx, ev, c => s"!($c)")
* }}}
*
* @param f function that accepts a variable name and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
val eval = child.gen(ctx)
// reuse the previous isNull
ev.isNull = eval.isNull
eval.code + s"""
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = ${f(eval.primitive)};
}
"""
}
}

// TODO Semantically we probably not need GroupExpression
Expand Down
Loading