From 321859a3e7a5e618c629fab701634f43eb23769f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 23:53:34 -0700 Subject: [PATCH 01/10] [SPARK-9373][SQL] Support StructType in Tungsten projection [WIP] --- .../expressions/UnsafeRowWriters.java | 39 +++++- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../codegen/GenerateUnsafeProjection.scala | 117 +++++++++++++++++- .../expressions/complexTypeCreator.scala | 36 ++++++ .../spark/sql/execution/SparkStrategies.scala | 6 +- .../spark/sql/execution/basicOperators.scala | 24 ++++ .../spark/sql/DataFrameTungstenSuite.scala | 58 +++++++++ 7 files changed, 276 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 0ba31d3b9b74..1d97f66bef13 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -81,6 +82,43 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) } } + /** Writer for struct type. */ + public static class StructWriter { + public static int getSize(InternalRow input) { + int numBytes = 0; + if (input instanceof UnsafeRow) { + numBytes = ((UnsafeRow) input).getSizeInBytes(); + } else { + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { + int numBytes = 0; + final long offset = target.getBaseOffset() + cursor; + if (input instanceof UnsafeRow) { + final UnsafeRow row = (UnsafeRow) input; + numBytes = row.getSizeInBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + row.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + } else { + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + /** Writer for interval type. */ public static class IntervalWriter { @@ -96,5 +134,4 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Interval inpu return 16; } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 41a877f214e5..950c85670a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -50,7 +50,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case BinaryType => input.getBinary(ordinal) case IntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) - case dataType => input.get(ordinal, dataType) + case _ => input.get(ordinal, dataType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9d2161947b35..e9f043771b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -34,15 +34,116 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName + private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true case _: IntervalType => true + case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true case _ => false } + private def createCodeForStruct( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + dataType: StructType): GeneratedExpressionCode = { + + val isNull = input.isNull + val primitive = ctx.freshName("structConvert") + ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") + val bufferTerm = ctx.freshName("buffer") + ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") + val cursorTerm = ctx.freshName("cursor") + + val exprs: Seq[GeneratedExpressionCode] = dataType.map(_.dataType).zipWithIndex.map { + case (dt, i) => dt match { + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + } + } + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = dataType.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + dt match { + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case IntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _ => "" + } + }.mkString("") + + val writers = dataType.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val update = dt match { + case _ if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" + case StringType => + s"$cursorTerm += $StringWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + case BinaryType => + s"$cursorTerm += $BinaryWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + case t: StructType => + s"$cursorTerm += $StructWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + case t: StructType => + s"$cursorTerm += $StructWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $dt") + } + s""" + if (${exprs(i).isNull}) { + $primitive.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n ") + + val code = s""" + |${input.code} + |if (!${input.isNull}) { + | if (${input.primitive} instanceof UnsafeRow) { + | $primitive = (UnsafeRow) ${input.primitive}; + | } else { + | $allExprs + | + | int numBytes = $fixedSize $additionalSize; + | if (numBytes > $bufferTerm.length) { + | $bufferTerm = new byte[numBytes]; + | } + | + | $primitive.pointTo( + | $bufferTerm, + | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | ${exprs.size}, + | numBytes); + | int $cursorTerm = $fixedSize; + | + | $writers + | } + |} + """.stripMargin + + GeneratedExpressionCode(code, isNull, primitive) + } + /** * Generates the code to create an [[UnsafeRow]] object based on the input expressions. * @param ctx context for code generation @@ -60,10 +161,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val cursorTerm = ctx.freshName("cursor") val numBytesTerm = ctx.freshName("numBytes") - val exprs = expressions.map(_.gen(ctx)) + val exprs = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case st: StructType => + createCodeForStruct(ctx, e.gen(ctx), st) + case _ => + e.gen(ctx) + } + } val allExprs = exprs.map(_.code).mkString("\n") - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) val additionalSize = expressions.zipWithIndex.map { case (e, i) => e.dataType match { case StringType => @@ -72,6 +180,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" case IntervalType => s" + (${exprs(i).isNull} ? 0 : 16)" + case _: StructType => + s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" case _ => "" } }.mkString("") @@ -86,6 +196,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case IntervalType => s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + case t: StructType => + s"$cursorTerm += $StructWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") @@ -111,7 +223,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $numBytesTerm); int $cursorTerm = $fixedSize; - $writers boolean ${ev.isNull} = false; """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 119168fa59f1..ee7516dde91f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -116,6 +116,42 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def prettyName: String = "struct" } + +/** + * Returns a Row containing the evaluation of all children expressions. + */ +case class UnsafeCreateStruct(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } + + override def prettyName: String = "struct_unsafe" +} + + /** * Creates a struct with the given field names and values * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 306bbfec624c..ecda743ca755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - execution.Project(projectList, planLater(child)) :: Nil + if (UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + execution.TungstenProject(projectList, planLater(child)) :: Nil + } else { + execution.Project(projectList, planLater(child)) :: Nil + } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fe429d862a0a..ffd4a22a6616 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,6 +49,30 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override def outputOrdering: Seq[SortOrder] = child.outputOrdering } + +/** + * A variant of [[Project]] that returns [[UnsafeRow]]s. + */ +case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + val exprs = this.transformAllExpressions { case CreateStruct(children) => + UnsafeCreateStruct(children) + }.projectList + val project = UnsafeProjection.create(exprs, child.output) + iter.map(project) + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 000000000000..7deda0023771 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,58 @@ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + + +class DataFrameTungstenSuite extends SparkFunSuite { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("test simple types") { + ctx.setConf("spark.sql.unsafe.enabled", "true") + val df = ctx.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + + test("test struct type") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = ctx.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = ctx.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + + test("test nested struct type") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = ctx.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = ctx.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } +} From 525b95b9fd4ece99d59f54a7a7f3fbf37804c226 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 12:31:57 -0700 Subject: [PATCH 02/10] Merged with master, more documentation & test cases. --- .../expressions/UnsafeRowWriters.java | 11 +- .../codegen/GenerateUnsafeProjection.scala | 214 +++++++++--------- .../expressions/complexTypeCreator.scala | 99 +++----- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../spark/sql/execution/basicOperators.scala | 5 +- .../org/apache/spark/sql/DataFrameSuite.scala | 2 + .../spark/sql/DataFrameTungstenSuite.scala | 104 +++++---- 7 files changed, 230 insertions(+), 210 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 1d97f66bef13..8fdd7399602d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -82,13 +82,21 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) } } - /** Writer for struct type. */ + /** + * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. + * + * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. + * Non-UnsafeRow struct fields are handled directly in + * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} + * by generating the Java code needed to convert them into UnsafeRow. + */ public static class StructWriter { public static int getSize(InternalRow input) { int numBytes = 0; if (input instanceof UnsafeRow) { numBytes = ((UnsafeRow) input).getSizeInBytes(); } else { + // This is handled directly in GenerateUnsafeProjection. throw new UnsupportedOperationException(); } return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); @@ -113,6 +121,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i // Set the fixed length portion. target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); } else { + // This is handled directly in GenerateUnsafeProjection. throw new UnsupportedOperationException(); } return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index e9f043771b08..c482f7601810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -45,19 +45,113 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } + /** + * Generates the code to create an [[UnsafeRow]] object based on the input expressions. + * @param ctx context for code generation + * @param ev specifies the name of the variable for the output [[UnsafeRow]] object + * @param expressions input expressions + * @return generated code to put the expression output into an [[UnsafeRow]] + */ + def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) + : String = { + + val ret = ev.primitive + ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + val numBytes = ctx.freshName("numBytes") + + val exprs = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case st: StructType => + createCodeForStruct(ctx, e.gen(ctx), st) + case _ => + e.gen(ctx) + } + } + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case StringType => + s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" + case BinaryType => + s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" + case IntervalType => + s" + (${exprs(i).isNull} ? 0 : 16)" + case _: StructType => + s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" + case _ => "" + } + }.mkString("") + + val writers = expressions.zipWithIndex.map { case (e, i) => + val update = e.dataType match { + case dt if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" + case StringType => + s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" + case IntervalType => + s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") + } + s"""if (${exprs(i).isNull}) { + $ret.setNullAt($i); + } else { + $update; + }""" + }.mkString("\n ") + + s""" + $allExprs + int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + $ret.pointTo( + $buffer, + org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, + $numBytes); + int $cursor = $fixedSize; + + $writers + boolean ${ev.isNull} = false; + """ + } + + /** + * Generates the Java code to convert a struct (backed by InternalRow) into UnsafeRow. + * + * This function also handles nested structs by recursively generating the code to + * + * @param ctx code generation context + * @param input the input struct, identified by a [[GeneratedExpressionCode]] + * @param schema schema of the struct field + */ + // TODO: refactor createCode and this function to reduce code duplication. private def createCodeForStruct( ctx: CodeGenContext, input: GeneratedExpressionCode, - dataType: StructType): GeneratedExpressionCode = { + schema: StructType): GeneratedExpressionCode = { val isNull = input.isNull val primitive = ctx.freshName("structConvert") ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") - val bufferTerm = ctx.freshName("buffer") - ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") - val cursorTerm = ctx.freshName("cursor") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") - val exprs: Seq[GeneratedExpressionCode] = dataType.map(_.dataType).zipWithIndex.map { + val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => dt match { case st: StructType => val nestedStructEv = GeneratedExpressionCode( @@ -77,32 +171,32 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = dataType.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => dt match { case StringType => s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" case BinaryType => s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" - case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" case IntervalType => s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" case _ => "" } }.mkString("") - val writers = dataType.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => val update = dt match { case _ if ctx.isPrimitiveType(dt) => s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" case StringType => - s"$cursorTerm += $StringWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" case BinaryType => - s"$cursorTerm += $BinaryWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" - case t: StructType => - s"$cursorTerm += $StructWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case IntervalType => + s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" case t: StructType => - s"$cursorTerm += $StructWriter.write($primitive, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: $dt") @@ -125,16 +219,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | $allExprs | | int numBytes = $fixedSize $additionalSize; - | if (numBytes > $bufferTerm.length) { - | $bufferTerm = new byte[numBytes]; + | if (numBytes > $buffer.length) { + | $buffer = new byte[numBytes]; | } | | $primitive.pointTo( - | $bufferTerm, + | $buffer, | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, | ${exprs.size}, | numBytes); - | int $cursorTerm = $fixedSize; + | int $cursor = $fixedSize; | | $writers | } @@ -144,90 +238,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro GeneratedExpressionCode(code, isNull, primitive) } - /** - * Generates the code to create an [[UnsafeRow]] object based on the input expressions. - * @param ctx context for code generation - * @param ev specifies the name of the variable for the output [[UnsafeRow]] object - * @param expressions input expressions - * @return generated code to put the expression output into an [[UnsafeRow]] - */ - def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) - : String = { - - val ret = ev.primitive - ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val bufferTerm = ctx.freshName("buffer") - ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") - val cursorTerm = ctx.freshName("cursor") - val numBytesTerm = ctx.freshName("numBytes") - - val exprs = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case st: StructType => - createCodeForStruct(ctx, e.gen(ctx), st) - case _ => - e.gen(ctx) - } - } - val allExprs = exprs.map(_.code).mkString("\n") - - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case StringType => - s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" - case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" - case IntervalType => - s" + (${exprs(i).isNull} ? 0 : 16)" - case _: StructType => - s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" - case _ => "" - } - }.mkString("") - - val writers = expressions.zipWithIndex.map { case (e, i) => - val update = e.dataType match { - case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case BinaryType => - s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case IntervalType => - s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case t: StructType => - s"$cursorTerm += $StructWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") - } - s"""if (${exprs(i).isNull}) { - $ret.setNullAt($i); - } else { - $update; - }""" - }.mkString("\n ") - - s""" - $allExprs - int $numBytesTerm = $fixedSize $additionalSize; - if ($numBytesTerm > $bufferTerm.length) { - $bufferTerm = new byte[$numBytesTerm]; - } - - $ret.pointTo( - $bufferTerm, - org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, - $numBytesTerm); - int $cursorTerm = $fixedSize; - - $writers - boolean ${ev.isNull} = false; - """ - } - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ee7516dde91f..878b9a1675d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -96,59 +96,28 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName - s""" - boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${children.size}); - """ + - children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ - }.mkString("\n") - } - - override def prettyName: String = "struct" -} - - -/** - * Returns a Row containing the evaluation of all children expressions. - */ -case class UnsafeCreateStruct(children: Seq[Expression]) extends Expression { - - override def foldable: Boolean = children.forall(_.foldable) - - override lazy val resolved: Boolean = childrenResolved - - override lazy val dataType: StructType = { - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } + if (GenerateUnsafeProjection.canSupport(dataType)) { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } else { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") } - StructType(fields) - } - - override def nullable: Boolean = false - - override def eval(input: InternalRow): Any = { - throw new UnsupportedOperationException } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - GenerateUnsafeProjection.createCode(ctx, ev, children) - } - - override def prettyName: String = "struct_unsafe" + override def prettyName: String = "struct" } @@ -196,21 +165,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val rowClass = classOf[GenericMutableRow].getName - s""" + if (GenerateUnsafeProjection.canSupport(dataType)) { + GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + } else { + val rowClass = classOf[GenericMutableRow].getName + s""" boolean ${ev.isNull} = false; final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); - """ + - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ - }.mkString("\n") + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } } override def prettyName: String = "named_struct" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ecda743ca755..995eb8f8c017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - if (UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + // If unsafe mode is enabled and we support these data types in Unsafe, use the + // Tungsten project. Otherwise, use the normal project. + if (sqlContext.conf.unsafeEnabled && + UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { execution.TungstenProject(projectList, planLater(child)) :: Nil } else { execution.Project(projectList, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ffd4a22a6616..3131fda127bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -62,10 +62,7 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val exprs = this.transformAllExpressions { case CreateStruct(children) => - UnsafeCreateStruct(children) - }.projectList - val project = UnsafeProjection.create(exprs, child.output) + val project = UnsafeProjection.create(projectList, child.output) iter.map(project) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f67f2c60c0e1..a79b74317362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,6 +33,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + ctx.setConf(SQLConf.UNSAFE_ENABLED, true) + def sqlContext: SQLContext = ctx test("analysis error should be eagerly reported") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 7deda0023771..bf8ef9a97bc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -1,58 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { -class DataFrameTungstenSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ test("test simple types") { - ctx.setConf("spark.sql.unsafe.enabled", "true") - val df = ctx.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") - assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } } test("test struct type") { - val struct = Row(1, 2L, 3.0F, 3.0) - val data = ctx.sparkContext.parallelize(Seq(Row(1, struct))) - - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType)) - - val df = ctx.createDataFrame(data, schema) - assert(df.select("b").first() === Row(struct)) + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } } test("test nested struct type") { - val innerStruct = Row(1, "abcd") - val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = ctx.sparkContext.parallelize(Seq(Row(1, outerStruct))) - - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType) - .add("b5", new StructType() + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() .add("b5a", IntegerType) .add("b5b", StringType)) - .add("b6", StringType)) + .add("b6", StringType)) - val df = ctx.createDataFrame(data, schema) - assert(df.select("b").first() === Row(outerStruct)) + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } } } From 6b781feb37c3c2e153b16cd0afb8d7cc772b6eb8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 12:33:10 -0700 Subject: [PATCH 03/10] Reset the change in DataFrameSuite. --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a79b74317362..f67f2c60c0e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,8 +33,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ - ctx.setConf(SQLConf.UNSAFE_ENABLED, true) - def sqlContext: SQLContext = ctx test("analysis error should be eagerly reported") { From 9f36216056f7edb35ee8daa2355a988bec3e00ce Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 12:35:41 -0700 Subject: [PATCH 04/10] Updated comment. --- .../expressions/codegen/GenerateUnsafeProjection.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index c482f7601810..43bc62d7ad68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -130,9 +130,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } /** - * Generates the Java code to convert a struct (backed by InternalRow) into UnsafeRow. + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. * - * This function also handles nested structs by recursively generating the code to + * This function also handles nested structs by recursively generating the code to do conversion. * * @param ctx code generation context * @param input the input struct, identified by a [[GeneratedExpressionCode]] From ac203bf5705264c2df4d5a3dbc28975b7c7d8b94 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 12:36:39 -0700 Subject: [PATCH 05/10] More comments. --- .../catalyst/expressions/codegen/GenerateUnsafeProjection.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 43bc62d7ad68..63e6f468d29b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -210,6 +210,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ }.mkString("\n ") + // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, + // just copy the bytes directly into our buffer space without running any conversion. val code = s""" |${input.code} |if (!${input.isNull}) { From ac4951d30c143211c44c364da3a40e1fec4944a3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 14:42:33 -0700 Subject: [PATCH 06/10] Yay. --- .../codegen/GenerateUnsafeProjection.scala | 17 ++- .../expressions/complexTypeCreator.scala | 133 +++++++++++++----- .../expressions/ExpressionEvalHelper.scala | 26 ++-- .../spark/sql/execution/basicOperators.scala | 4 + 4 files changed, 130 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 63e6f468d29b..96fc99384d86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -212,24 +212,29 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, // just copy the bytes directly into our buffer space without running any conversion. + // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from + // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. + val tmp = ctx.freshName("tmp") + val numBytes = ctx.freshName("numBytes") val code = s""" |${input.code} |if (!${input.isNull}) { - | if (${input.primitive} instanceof UnsafeRow) { - | $primitive = (UnsafeRow) ${input.primitive}; + | Object $tmp = (Object) ${input.primitive}; + | if ($tmp instanceof UnsafeRow) { + | $primitive = (UnsafeRow) $tmp; | } else { | $allExprs | - | int numBytes = $fixedSize $additionalSize; - | if (numBytes > $buffer.length) { - | $buffer = new byte[numBytes]; + | int $numBytes = $fixedSize $additionalSize; + | if ($numBytes > $buffer.length) { + | $buffer = new byte[$numBytes]; | } | | $primitive.pointTo( | $buffer, | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, | ${exprs.size}, - | numBytes); + | $numBytes); | int $cursor = $fixedSize; | | $writers diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 878b9a1675d8..d8c9087ff538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -96,25 +96,21 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (GenerateUnsafeProjection.canSupport(dataType)) { - GenerateUnsafeProjection.createCode(ctx, ev, children) - } else { - val rowClass = classOf[GenericMutableRow].getName - s""" - boolean ${ev.isNull} = false; - final $rowClass ${ev.primitive} = new $rowClass(${children.size}); - """ + - children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ - }.mkString("\n") - } + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") } override def prettyName: String = "struct" @@ -165,26 +161,91 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (GenerateUnsafeProjection.canSupport(dataType)) { - GenerateUnsafeProjection.createCode(ctx, ev, valExprs) - } else { - val rowClass = classOf[GenericMutableRow].getName - s""" + val rowClass = classOf[GenericMutableRow].getName + s""" boolean ${ev.isNull} = false; final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); - """ + - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ - }.mkString("\n") - } + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") } override def prettyName: String = "named_struct" } + +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } + + override def prettyName: String = "struct_unsafe" +} + + +/** + * Creates a struct with the given field names and values. This is a variant that returns + * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + } + + override def prettyName: String = "named_struct_unsafe" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ab0cdc857c80..136368bf5b36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -114,7 +114,7 @@ trait ExpressionEvalHelper { val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -146,7 +146,8 @@ trait ExpressionEvalHelper { if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + fail("Incorrect Evaluation in codegen mode: " + + s"$expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") @@ -163,12 +164,21 @@ trait ExpressionEvalHelper { expression) val unsafeRow = plan(inputRow) - // UnsafeRow cannot be compared with GenericInternalRow directly - val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) - val expectedRow = InternalRow(expected) - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected) + val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3131fda127bc..b02e60dc85cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -62,6 +62,10 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } val project = UnsafeProjection.create(projectList, child.output) iter.map(project) } From 77e8d0e34814c030545f45b4e27b89b0c1d76ac2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 15:16:50 -0700 Subject: [PATCH 07/10] Fixed NondeterministicSuite. --- .../spark/sql/execution/expression/NondeterministicSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala index 99e11fd64b2b..1c5a2ed2c0a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.expressions.{SparkPartitionID, Monotonical class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { - checkEvaluation(MonotonicallyIncreasingID(), 0) + checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") { From 10c4b7ccf1493ec4edaa757f50ae1c22472a379c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 15:59:05 -0700 Subject: [PATCH 08/10] Format generated code. --- .../catalyst/expressions/codegen/GenerateUnsafeProjection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 96fc99384d86..3e87f7285847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] From be9f377a84cf42defe57d21e4f9a6b636c17e85f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 17:59:07 -0700 Subject: [PATCH 09/10] Fixed tests. --- .../catalyst/expressions/BoundAttribute.scala | 7 ++++--- .../ArithmeticExpressionSuite.scala | 2 +- .../expressions/BitwiseFunctionsSuite.scala | 20 +++++++++++-------- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 950c85670a60..8304d4ccd47f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val value = ctx.getColumn("i", dataType, ordinal) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); - ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); + boolean ${ev.isNull} = i.isNullAt($ordinal); + $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e7e5231d32c9..7773e098e0ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -170,6 +170,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(-7, 3), 2) checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) - checkEvaluation(Pmod(2L, Long.MaxValue), 2) + checkEvaluation(Pmod(2L, Long.MaxValue), 2L) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 648fbf5a4c30..fa30fbe52847 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, ~1.toByte) - check(1000.toShort, ~1000.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, (~1.toByte).toByte) + check(1000.toShort, (~1000.toShort).toShort) check(1000000, ~1000000) check(123456789123L, ~123456789123L) @@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte & 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) check(1000000, 4, 1000000 & 4) check(123456789123L, 5L, 123456789123L & 5L) @@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte | 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) check(1000000, 4, 1000000 | 4) check(123456789123L, 5L, 123456789123L | 5L) @@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) check(1000000, 4, 1000000 ^ 4) check(123456789123L, 5L, 123456789123L ^ 5L) From 9162f4204a5a38b84b9c758a20764c9d32307168 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 21:40:17 -0700 Subject: [PATCH 10/10] Support IntervalType in UnsafeRow's getter. --- .../org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fb084dd13b62..955fb4226fc0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -265,6 +265,8 @@ public Object get(int ordinal, DataType dataType) { return getBinary(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); + } else if (dataType instanceof IntervalType) { + return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); } else {