-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31937][SQL] Support processing ArrayType/MapType/StructType data using no-serde mode script transform #30957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
adc9ded
6a7438b
d3b9cec
fdd5225
aa16c8f
092c927
9761c0e
28ad7fa
9ac75fc
33d8b5b
63f07eb
b631b70
b7e7f92
8dec5a1
529d54d
4f0e78f
ed8c54c
520f4b8
97f9d58
b5a4268
76a746e
6aa05fc
9e3f808
3f51d27
adf8a66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution | |
|
|
||
| import java.io._ | ||
| import java.nio.charset.StandardCharsets | ||
| import java.util.Map.Entry | ||
| import java.util.concurrent.TimeUnit | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
@@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | |
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} | ||
| import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema | ||
| import org.apache.spark.sql.catalyst.plans.physical.Partitioning | ||
| import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
| import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} | ||
|
|
||
| trait BaseScriptTransformationExec extends UnaryExecNode { | ||
|
|
@@ -47,7 +46,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode { | |
| def ioschema: ScriptTransformationIOSchema | ||
|
|
||
| protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { | ||
| input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) | ||
| input.map { in: Expression => | ||
|
||
| in.dataType match { | ||
| case _: ArrayType | _: MapType | _: StructType => in | ||
| case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def producedAttributes: AttributeSet = outputSet -- inputSet | ||
|
|
@@ -186,58 +190,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { | |
| } | ||
|
|
||
| private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => | ||
| val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) | ||
| attr.dataType match { | ||
| case StringType => wrapperConvertException(data => data, converter) | ||
| case BooleanType => wrapperConvertException(data => data.toBoolean, converter) | ||
| case ByteType => wrapperConvertException(data => data.toByte, converter) | ||
| case BinaryType => | ||
| wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) | ||
| case IntegerType => wrapperConvertException(data => data.toInt, converter) | ||
| case ShortType => wrapperConvertException(data => data.toShort, converter) | ||
| case LongType => wrapperConvertException(data => data.toLong, converter) | ||
| case FloatType => wrapperConvertException(data => data.toFloat, converter) | ||
| case DoubleType => wrapperConvertException(data => data.toDouble, converter) | ||
| case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) | ||
| case DateType if conf.datetimeJava8ApiEnabled => | ||
| wrapperConvertException(data => DateTimeUtils.stringToDate( | ||
| UTF8String.fromString(data), | ||
| DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) | ||
| .map(DateTimeUtils.daysToLocalDate).orNull, converter) | ||
| case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( | ||
| UTF8String.fromString(data), | ||
| DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) | ||
| .map(DateTimeUtils.toJavaDate).orNull, converter) | ||
| case TimestampType if conf.datetimeJava8ApiEnabled => | ||
| wrapperConvertException(data => DateTimeUtils.stringToTimestamp( | ||
| UTF8String.fromString(data), | ||
| DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) | ||
| .map(DateTimeUtils.microsToInstant).orNull, converter) | ||
| case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( | ||
| UTF8String.fromString(data), | ||
| DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) | ||
| .map(DateTimeUtils.toJavaTimestamp).orNull, converter) | ||
| case CalendarIntervalType => wrapperConvertException( | ||
| data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), | ||
| converter) | ||
| case udt: UserDefinedType[_] => | ||
| wrapperConvertException(data => udt.deserialize(data), converter) | ||
| case dt => | ||
| throw new SparkException(s"${nodeName} without serde does not support " + | ||
| s"${dt.getClass.getSimpleName} as output data type") | ||
| } | ||
| SparkInspectors.unwrapper(attr.dataType, conf, ioschema, 1) | ||
| } | ||
|
|
||
| // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null | ||
| private val wrapperConvertException: (String => Any, Any => Any) => String => Any = | ||
| (f: String => Any, converter: Any => Any) => | ||
| (data: String) => converter { | ||
| try { | ||
| f(data) | ||
| } catch { | ||
| case NonFatal(_) => null | ||
| } | ||
| } | ||
| } | ||
|
|
||
| abstract class BaseScriptTransformationWriterThread extends Thread with Logging { | ||
|
|
@@ -260,18 +214,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging | |
|
|
||
| protected def processRows(): Unit | ||
|
|
||
| val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt)) | ||
|
|
||
| protected def processRowsWithoutSerde(): Unit = { | ||
| val len = inputSchema.length | ||
| iter.foreach { row => | ||
| val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map { | ||
| case (value, wrapper) => wrapper(value) | ||
| } | ||
| val data = if (len == 0) { | ||
| ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") | ||
| } else { | ||
| val sb = new StringBuilder | ||
| sb.append(row.get(0, inputSchema(0))) | ||
| buildString(sb, values(0), inputSchema(0), 1) | ||
| var i = 1 | ||
| while (i < len) { | ||
| sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) | ||
| sb.append(row.get(i, inputSchema(i))) | ||
| buildString(sb, values(i), inputSchema(i), 1) | ||
| i += 1 | ||
| } | ||
| sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) | ||
|
|
@@ -281,6 +240,50 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Convert data to string according to the data type. | ||
| * | ||
| * @param sb The StringBuilder to store the serialized data. | ||
| * @param obj The object for the current field. | ||
| * @param dataType The DataType for the current Object. | ||
| * @param level The current level of separator. | ||
| */ | ||
| private def buildString(sb: StringBuilder, obj: Any, dataType: DataType, level: Int): Unit = { | ||
| (obj, dataType) match { | ||
| case (list: java.util.List[_], ArrayType(typ, _)) => | ||
| val separator = ioSchema.getSeparator(level) | ||
| (0 until list.size).foreach { i => | ||
| if (i > 0) { | ||
| sb.append(separator) | ||
| } | ||
| buildString(sb, list.get(i), typ, level + 1) | ||
| } | ||
| case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) => | ||
| val separator = ioSchema.getSeparator(level) | ||
| val keyValueSeparator = ioSchema.getSeparator(level + 1) | ||
| val entries = map.entrySet().toArray() | ||
| (0 until entries.size).foreach { i => | ||
| if (i > 0) { | ||
| sb.append(separator) | ||
| } | ||
| val entry = entries(i).asInstanceOf[Entry[_, _]] | ||
| buildString(sb, entry.getKey, keyType, level + 2) | ||
| sb.append(keyValueSeparator) | ||
| buildString(sb, entry.getValue, valueType, level + 2) | ||
| } | ||
| case (arrayList: java.util.ArrayList[_], StructType(fields)) => | ||
| val separator = ioSchema.getSeparator(level) | ||
| (0 until arrayList.size).foreach { i => | ||
| if (i > 0) { | ||
| sb.append(separator) | ||
| } | ||
| buildString(sb, arrayList.get(i), fields(i).dataType, level + 1) | ||
| } | ||
| case (other, _) => | ||
| sb.append(other) | ||
| } | ||
| } | ||
|
|
||
| override def run(): Unit = Utils.logUncaughtExceptions { | ||
| TaskContext.setTaskContext(taskContext) | ||
|
|
||
|
|
@@ -333,14 +336,45 @@ case class ScriptTransformationIOSchema( | |
| schemaLess: Boolean) extends Serializable { | ||
| import ScriptTransformationIOSchema._ | ||
|
|
||
| val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) | ||
| val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) | ||
| val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k)) | ||
| val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k)) | ||
|
|
||
| val separators = (getByte(inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 0.toByte) :: | ||
| getByte(inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"), 1.toByte) :: | ||
| getByte(inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"), 2.toByte) :: Nil) ++ | ||
| (4 to 8).map(_.toByte) | ||
|
||
|
|
||
| def getByte(altValue: String, defaultVal: Byte): Byte = { | ||
| if (altValue != null && altValue.length > 0) { | ||
| try { | ||
| java.lang.Byte.parseByte(altValue) | ||
| } catch { | ||
| case _: NumberFormatException => | ||
| altValue.charAt(0).toByte | ||
| } | ||
| } else { | ||
| defaultVal | ||
| } | ||
| } | ||
|
|
||
| def getSeparator(level: Int): Char = { | ||
| try { | ||
| separators(level).toChar | ||
| } catch { | ||
| case _: IndexOutOfBoundsException => | ||
| val msg = "Number of levels of nesting supported for Spark SQL script transform" + | ||
| " is " + (separators.length - 1) + " Unable to work with level " + level | ||
| throw new RuntimeException(msg) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| object ScriptTransformationIOSchema { | ||
| val defaultFormat = Map( | ||
| ("TOK_TABLEROWFORMATLINES", "\n"), | ||
| ("TOK_TABLEROWFORMATFIELD", "\u0001"), | ||
| ("TOK_TABLEROWFORMATLINES", "\n") | ||
| ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), | ||
| ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") | ||
| ) | ||
|
|
||
| val defaultIOSchema = ScriptTransformationIOSchema( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
toCatalystImplconverts Scala data into Catalyst one butGenericArrayDatais Catalyst-internal, so this change looks weried.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For complex type, after use JsonToStruct, don't need this converter any more. So I remove this.
How about current change?