Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ public Object get(int ordinal, DataType dataType) {
return getBinary(ordinal);
} else if (dataType instanceof StringType) {
return getUTF8String(ordinal);
} else if (dataType instanceof IntervalType) {
return getInterval(ordinal);
} else if (dataType instanceof StructType) {
return getStruct(ordinal, ((StructType) dataType).size());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
Expand Down Expand Up @@ -81,6 +82,52 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input)
}
}

/**
* Writer for struct type where the struct field is backed by an {@link UnsafeRow}.
*
* We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}.
* Non-UnsafeRow struct fields are handled directly in
* {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}
* by generating the Java code needed to convert them into UnsafeRow.
*/
public static class StructWriter {
public static int getSize(InternalRow input) {
int numBytes = 0;
if (input instanceof UnsafeRow) {
numBytes = ((UnsafeRow) input).getSizeInBytes();
} else {
// This is handled directly in GenerateUnsafeProjection.
throw new UnsupportedOperationException();
}
return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
}

public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) {
int numBytes = 0;
final long offset = target.getBaseOffset() + cursor;
if (input instanceof UnsafeRow) {
final UnsafeRow row = (UnsafeRow) input;
numBytes = row.getSizeInBytes();

// zero-out the padding bytes
if ((numBytes & 0x07) > 0) {
PlatformDependent.UNSAFE.putLong(
target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
}

// Write the string to the variable length portion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: string -> struct

row.writeToMemory(target.getBaseObject(), offset);

// Set the fixed length portion.
target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
} else {
// This is handled directly in GenerateUnsafeProjection.
throw new UnsupportedOperationException();
}
return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
}
}

/** Writer for interval type. */
public static class IntervalWriter {

Expand All @@ -96,5 +143,4 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Interval inpu
return 16;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
case BinaryType => input.getBinary(ordinal)
case IntervalType => input.getInterval(ordinal)
case t: StructType => input.getStruct(ordinal, t.size)
case dataType => input.get(ordinal, dataType)
case _ => input.get(ordinal, dataType)
}
}
}
Expand All @@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def exprId: ExprId = throw new UnsupportedOperationException

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getColumn("i", dataType, ordinal)
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)});
boolean ${ev.isNull} = i.isNullAt($ordinal);
$javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
"""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName

/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
case t: AtomicType if !t.isInstanceOf[DecimalType] => true
case _: IntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case NullType => true
case _ => false
}
Expand All @@ -55,15 +57,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

val ret = ev.primitive
ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();")
val bufferTerm = ctx.freshName("buffer")
ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];")
val cursorTerm = ctx.freshName("cursor")
val numBytesTerm = ctx.freshName("numBytes")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val cursor = ctx.freshName("cursor")
val numBytes = ctx.freshName("numBytes")

val exprs = expressions.map(_.gen(ctx))
val exprs = expressions.zipWithIndex.map { case (e, i) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the index i here is unused?

e.dataType match {
case st: StructType =>
createCodeForStruct(ctx, e.gen(ctx), st)
case _ =>
e.gen(ctx)
}
}
val allExprs = exprs.map(_.code).mkString("\n")
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)

val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
e.dataType match {
case StringType =>
Expand All @@ -72,6 +81,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
case IntervalType =>
s" + (${exprs(i).isNull} ? 0 : 16)"
case _: StructType =>
s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))"
case _ => ""
}
}.mkString("")
Expand All @@ -81,11 +92,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case dt if ctx.isPrimitiveType(dt) =>
s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
case StringType =>
s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case BinaryType =>
s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case IntervalType =>
s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case t: StructType =>
s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
Expand All @@ -99,24 +112,139 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

s"""
$allExprs
int $numBytesTerm = $fixedSize $additionalSize;
if ($numBytesTerm > $bufferTerm.length) {
$bufferTerm = new byte[$numBytesTerm];
int $numBytes = $fixedSize $additionalSize;
if ($numBytes > $buffer.length) {
$buffer = new byte[$numBytes];
}

$ret.pointTo(
$bufferTerm,
$buffer,
org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
${expressions.size},
$numBytesTerm);
int $cursorTerm = $fixedSize;

$numBytes);
int $cursor = $fixedSize;

$writers
boolean ${ev.isNull} = false;
"""
}

/**
* Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
*
* This function also handles nested structs by recursively generating the code to do conversion.
*
* @param ctx code generation context
* @param input the input struct, identified by a [[GeneratedExpressionCode]]
* @param schema schema of the struct field
*/
// TODO: refactor createCode and this function to reduce code duplication.
private def createCodeForStruct(
ctx: CodeGenContext,
input: GeneratedExpressionCode,
schema: StructType): GeneratedExpressionCode = {

val isNull = input.isNull
val primitive = ctx.freshName("structConvert")
ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val cursor = ctx.freshName("cursor")

val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map {
case (dt, i) => dt match {
case st: StructType =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks underindented?

val nestedStructEv = GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
)
createCodeForStruct(ctx, nestedStructEv, st)
case _ =>
GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
)
}
}
val allExprs = exprs.map(_.code).mkString("\n")

val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) =>
dt match {
case StringType =>
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
case BinaryType =>
s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
case IntervalType =>
s" + (${ev.isNull} ? 0 : 16)"
case _: StructType =>
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
case _ => ""
}
}.mkString("")

val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) =>
val update = dt match {
case _ if ctx.isPrimitiveType(dt) =>
s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}"
case StringType =>
s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case BinaryType =>
s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case IntervalType =>
s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case t: StructType =>
s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: $dt")
}
s"""
if (${exprs(i).isNull}) {
$primitive.setNullAt($i);
} else {
$update;
}
"""
}.mkString("\n ")

// Note that we add a shortcut here for performance: if the input is already an UnsafeRow,
// just copy the bytes directly into our buffer space without running any conversion.
// We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from
// complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow.
val tmp = ctx.freshName("tmp")
val numBytes = ctx.freshName("numBytes")
val code = s"""
|${input.code}
|if (!${input.isNull}) {
| Object $tmp = (Object) ${input.primitive};
| if ($tmp instanceof UnsafeRow) {
| $primitive = (UnsafeRow) $tmp;
| } else {
| $allExprs
|
| int $numBytes = $fixedSize $additionalSize;
| if ($numBytes > $buffer.length) {
| $buffer = new byte[$numBytes];
| }
|
| $primitive.pointTo(
| $buffer,
| org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET,
| ${exprs.size},
| $numBytes);
| int $cursor = $fixedSize;
|
| $writers
| }
|}
""".stripMargin

GeneratedExpressionCode(code, isNull, primitive)
}

protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
in.map(ExpressionCanonicalizer.execute)

Expand Down Expand Up @@ -159,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}
"""

logDebug(s"code for ${expressions.mkString(",")}:\n$code")
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")

val c = compile(code)
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
Expand Down
Loading