Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,7 @@ object FromUnsafeProjection {
*/
def apply(fields: Seq[DataType]): Projection = {
create(fields.zipWithIndex.map(x => {
val b = new BoundReference(x._2, x._1, true)
// todo: this is quite slow, maybe remove this whole projection after remove generic getter of
// InternalRow?
b.dataType match {
case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
case _ => b
}
new BoundReference(x._2, x._1, true)
}))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.types.{StringType, StructType, DataType}
import org.apache.spark.sql.types._


/**
Expand All @@ -36,34 +36,94 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))

private def genUpdater(
private def createCodeForStruct(
ctx: CodeGenContext,
setter: String,
dataType: DataType,
ordinal: Int,
value: String): String = {
dataType match {
case struct: StructType =>
val rowTerm = ctx.freshName("row")
val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val colTerm = ctx.freshName("col")
s"""
if ($value.isNullAt($i)) {
$rowTerm.setNullAt($i);
} else {
${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
${genUpdater(ctx, rowTerm, dt, i, colTerm)};
}
"""
}.mkString("\n")
s"""
$genericMutableRowType $rowTerm = new $genericMutableRowType(${struct.fields.length});
$updates
$setter.update($ordinal, $rowTerm.copy());
"""
case _ =>
ctx.setColumn(setter, dataType, ordinal, value)
}
input: String,
schema: StructType): GeneratedExpressionCode = {
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeRow")
val values = ctx.freshName("values")
val rowClass = classOf[GenericInternalRow].getName

val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
s"""
if (!$tmp.isNullAt($i)) {
${converter.code}
$values[$i] = ${converter.primitive};
}
"""
}.mkString("\n")

val code = s"""
final InternalRow $tmp = $input;
final Object[] $values = new Object[${schema.length}];
$fieldWriters
final InternalRow $output = new $rowClass($values);
"""

GeneratedExpressionCode(code, "false", output)
}

private def createCodeForArray(
ctx: CodeGenContext,
input: String,
elementType: DataType): GeneratedExpressionCode = {
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeArray")
val values = ctx.freshName("values")
val numElements = ctx.freshName("numElements")
val index = ctx.freshName("index")
val arrayClass = classOf[GenericArrayData].getName

val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType)
val code = s"""
final ArrayData $tmp = $input;
final int $numElements = $tmp.numElements();
final Object[] $values = new Object[$numElements];
for (int $index = 0; $index < $numElements; $index++) {
if (!$tmp.isNullAt($index)) {
${elementConverter.code}
$values[$index] = ${elementConverter.primitive};
}
}
final ArrayData $output = new $arrayClass($values);
"""

GeneratedExpressionCode(code, "false", output)
}

private def createCodeForMap(
ctx: CodeGenContext,
input: String,
keyType: DataType,
valueType: DataType): GeneratedExpressionCode = {
val tmp = ctx.freshName("tmp")
val output = ctx.freshName("safeMap")
val mapClass = classOf[ArrayBasedMapData].getName

val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType)
val code = s"""
final MapData $tmp = $input;
${keyConverter.code}
${valueConverter.code}
final MapData $output = new $mapClass(${keyConverter.primitive}, ${valueConverter.primitive});
"""

GeneratedExpressionCode(code, "false", output)
}

private def convertToSafe(
ctx: CodeGenContext,
input: String,
dataType: DataType): GeneratedExpressionCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
case _ => GeneratedExpressionCode("", "false", input)
}

protected def create(expressions: Seq[Expression]): Projection = {
Expand All @@ -72,12 +132,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
val converter = convertToSafe(ctx, evaluationCode.primitive, e.dataType)
evaluationCode.code +
s"""
if (${evaluationCode.isNull}) {
mutableRow.setNullAt($i);
} else {
${genUpdater(ctx, "mutableRow", e.dataType, i, evaluationCode.primitive)};
${converter.code}
${ctx.setColumn("mutableRow", e.dataType, i, converter.primitive)};
}
"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode {

override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
// cache all strings to make sure we have deep copied UTF8String inside incoming
// This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some
// values gotten from the incoming rows.
// we cache all strings here to make sure we have deep copied UTF8String inside incoming
// safe InternalRow.
val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
iter.foreach { row =>
Expand Down