Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
cogengen FromUnsafe
  • Loading branch information
cloud-fan committed Aug 7, 2015
commit a93fd4b41a1f80ef77df08fc0c7d99d2fb63f05e

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,7 @@ object FromUnsafeProjection {
*/
def apply(fields: Seq[DataType]): Projection = {
create(fields.zipWithIndex.map(x => {
val b = new BoundReference(x._2, x._1, true)
// todo: this is quite slow, maybe remove this whole projection after remove generic getter of
// InternalRow?
b.dataType match {
case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
case _ => b
}
new BoundReference(x._2, x._1, true)
}))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ class CodeGenContext {
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
case StringType => s"$row.update($ordinal, $value.clone())"
case StringType => s"$row.update($ordinal, $value)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be used in generated MutableProjection, so we need the copy() here, or we should make sure that UnsafeRow will not be used with MutableProjection (the fallback version of new aggregation)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the copy here, just want to tighten up the scope that we only need to do copy when convert a unsafe row to safe row(according to the fact that we only copy nested struct in GenerateSafe not here before).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SafeProjection may be not the only one projection that turn an UnsafeRow into InternalRow. For some operators that accept both safe row and unsafe row, that will use generated mutable projection, it will got crupted UTF8String if we remove the copy here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, thanks for your explanation!

case _ => s"$row.update($ordinal, $value)"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types.{StringType, StructType, DataType}
import org.apache.spark.sql.types._


/**
Expand All @@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))

private def genUpdater(
private def createCodeForStruct(
ctx: CodeGenContext,
setter: String,
dataType: DataType,
ordinal: Int,
value: String): String = {
dataType match {
case struct: StructType =>
val rowTerm = ctx.freshName("row")
val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val colTerm = ctx.freshName("col")
s"""
if ($value.isNullAt($i)) {
$rowTerm.setNullAt($i);
} else {
${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
${genUpdater(ctx, rowTerm, dt, i, colTerm)};
}
"""
}.mkString("\n")
s"""
$genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
$updates
$setter.update($ordinal, $rowTerm.copy());
"""
case _ =>
ctx.setColumn(setter, dataType, ordinal, value)
}
input: String,
schema: StructType): GeneratedExpressionCode = {
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
s"""
if (!$tmp.isNullAt($i)) {
${converter.code}
$values[$i] = ${converter.primitive};
}
"""
}.mkString("\n")

val code = s"""
final InternalRow $tmp = $input;
final Object[] $values = new Object[${schema.length}];
$fieldWriters
final InternalRow $output = new $rowClass($values);
"""

GeneratedExpressionCode(code, "false", output)
}

private def createCodeForArray(
ctx: CodeGenContext,
input: String,
elementType: DataType): GeneratedExpressionCode = {
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeArray")
val values = ctx.freshName("values")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
val arrayClass = classOf[GenericArrayData].getName

val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
val code = s"""
final ArrayData $tmp = $input;
final int $numElements = $tmp.numElements();
final Object[] $values = new Object[$numElements];
for (int $index = 0; $index < $numElements; $index++) {
if (!$tmp.isNullAt($index)) {
${elementConverter.code}
$values[$index] = ${elementConverter.primitive};
}
}
final ArrayData $output = new $arrayClass($values);
"""

GeneratedExpressionCode(code, "false", output)
}

private def createCodeForMap(
ctx: CodeGenContext,
input: String,
keyType: DataType,
valueType: DataType): GeneratedExpressionCode = {
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeMap")
val mapClass = classOf[ArrayBasedMapData].getName

val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType)
val code = s"""
final MapData $tmp = $input;
${keyConverter.code}
${valueConverter.code}
final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive});
"""

GeneratedExpressionCode(code, "false", output)
}

private def convertToSafe(
ctx: CodeGenContext,
input: String,
dataType: DataType): GeneratedExpressionCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
case _ => GeneratedExpressionCode("", "false", input)
}

protected def create(expressions: Seq[Expression]): Projection = {
Expand All @@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType)
evaluationCode.code +
s"""
if (${evaluationCode.isNull}) {
mutableRow.setNullAt($i);
} else {
${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
${converter.code}
${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)};
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode {

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
// cache all strings to make sure we have deep copied UTF8String inside incoming
// This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some
// values gotten from the incoming rows.
// we cache all strings here to make sure we have deep copied UTF8String inside incoming
// safe InternalRow.
val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
iter.foreach { row =>
Expand Down