diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 8f508219c2de..2c68b73095c4 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -21,7 +21,7 @@ avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
-bcprov-jdk15on-1.51.jar
+bcprov-jdk15on-1.58.jar
bonecp-0.8.0.RELEASE.jar
breeze-macros_2.11-0.13.2.jar
breeze_2.11-0.13.2.jar
@@ -97,7 +97,7 @@ jackson-module-paranamer-2.7.9.jar
jackson-module-scala_2.11-2.6.7.1.jar
jackson-xc-1.9.13.jar
janino-3.0.7.jar
-java-xmlbuilder-1.0.jar
+java-xmlbuilder-1.1.jar
javassist-3.18.1-GA.jar
javax.annotation-api-1.2.jar
javax.inject-1.jar
@@ -115,7 +115,7 @@ jersey-container-servlet-core-2.22.2.jar
jersey-guava-2.22.2.jar
jersey-media-jaxb-2.22.2.jar
jersey-server-2.22.2.jar
-jets3t-0.9.3.jar
+jets3t-0.9.4.jar
jetty-6.1.26.jar
jetty-util-6.1.26.jar
jline-2.12.1.jar
@@ -137,14 +137,12 @@ log4j-1.2.17.jar
lz4-java-1.4.0.jar
machinist_2.11-0.6.1.jar
macro-compat_2.11-1.1.1.jar
-mail-1.4.7.jar
mesos-1.4.0-shaded-protobuf.jar
metrics-core-3.1.5.jar
metrics-graphite-3.1.5.jar
metrics-json-3.1.5.jar
metrics-jvm-3.1.5.jar
minlog-1.3.0.jar
-mx4j-3.0.2.jar
netty-3.9.9.Final.jar
netty-all-4.0.47.Final.jar
objenesis-2.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index 68e937f50b39..2aaac600b3ec 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -21,7 +21,7 @@ avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
-bcprov-jdk15on-1.51.jar
+bcprov-jdk15on-1.58.jar
bonecp-0.8.0.RELEASE.jar
breeze-macros_2.11-0.13.2.jar
breeze_2.11-0.13.2.jar
@@ -97,7 +97,7 @@ jackson-module-paranamer-2.7.9.jar
jackson-module-scala_2.11-2.6.7.1.jar
jackson-xc-1.9.13.jar
janino-3.0.7.jar
-java-xmlbuilder-1.0.jar
+java-xmlbuilder-1.1.jar
javassist-3.18.1-GA.jar
javax.annotation-api-1.2.jar
javax.inject-1.jar
@@ -115,7 +115,7 @@ jersey-container-servlet-core-2.22.2.jar
jersey-guava-2.22.2.jar
jersey-media-jaxb-2.22.2.jar
jersey-server-2.22.2.jar
-jets3t-0.9.3.jar
+jets3t-0.9.4.jar
jetty-6.1.26.jar
jetty-util-6.1.26.jar
jline-2.12.1.jar
@@ -138,14 +138,12 @@ log4j-1.2.17.jar
lz4-java-1.4.0.jar
machinist_2.11-0.6.1.jar
macro-compat_2.11-1.1.1.jar
-mail-1.4.7.jar
mesos-1.4.0-shaded-protobuf.jar
metrics-core-3.1.5.jar
metrics-graphite-3.1.5.jar
metrics-json-3.1.5.jar
metrics-jvm-3.1.5.jar
minlog-1.3.0.jar
-mx4j-3.0.2.jar
netty-3.9.9.Final.jar
netty-all-4.0.47.Final.jar
objenesis-2.1.jar
diff --git a/pom.xml b/pom.xml
index 731ee86439ef..07bca9d267da 100644
--- a/pom.xml
+++ b/pom.xml
@@ -141,7 +141,7 @@
3.1.5
1.7.7
hadoop2
- 0.9.3
+ 0.9.4
1.7.3
1.11.76
@@ -985,6 +985,12 @@
+
+ org.bouncycastle
+ bcprov-jdk15on
+
+ 1.58
+
org.apache.hadoop
hadoop-yarn-api
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 173e171910b6..3b52a0efd404 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -75,23 +75,51 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ctx.addMutableState(ctx.javaType(dataType), ev.value)
+ // all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
val eval = e.genCode(ctx)
s"""
- if (${ev.isNull}) {
- ${eval.code}
- if (!${eval.isNull}) {
- ${ev.isNull} = false;
- ${ev.value} = ${eval.value};
- }
- }
- """
+ |${eval.code}
+ |if (!${eval.isNull}) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = ${eval.value};
+ | continue;
+ |}
+ """.stripMargin
}
+ val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+ evals.mkString("\n")
+ } else {
+ ctx.splitExpressions(evals, "coalesce",
+ ("InternalRow", ctx.INPUT_ROW) :: Nil,
+ makeSplitFunction = {
+ func =>
+ s"""
+ |do {
+ | $func
+ |} while (false);
+ """.stripMargin
+ },
+ foldFunctions = { funcCalls =>
+ funcCalls.map { funcCall =>
+ s"""
+ |$funcCall;
+ |if (!${ev.isNull}) {
+ | continue;
+ |}
+ """.stripMargin
+ }.mkString
+ })
+ }
- ev.copy(code = s"""
- ${ev.isNull} = true;
- ${ev.value} = ${ctx.defaultValue(dataType)};
- ${ctx.splitExpressions(evals)}""")
+ ev.copy(code =
+ s"""
+ |${ev.isNull} = true;
+ |${ev.value} = ${ctx.defaultValue(dataType)};
+ |do {
+ | $code
+ |} while (false);
+ """.stripMargin)
}
}
@@ -358,53 +386,70 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val nonnull = ctx.freshName("nonnull")
+ // all evals are meant to be inside a do { ... } while (false); loop
val evals = children.map { e =>
val eval = e.genCode(ctx)
e.dataType match {
case DoubleType | FloatType =>
s"""
- if ($nonnull < $n) {
- ${eval.code}
- if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
- $nonnull += 1;
- }
- }
- """
+ |if ($nonnull < $n) {
+ | ${eval.code}
+ | if (!${eval.isNull} && !Double.isNaN(${eval.value})) {
+ | $nonnull += 1;
+ | }
+ |} else {
+ | continue;
+ |}
+ """.stripMargin
case _ =>
s"""
- if ($nonnull < $n) {
- ${eval.code}
- if (!${eval.isNull}) {
- $nonnull += 1;
- }
- }
- """
+ |if ($nonnull < $n) {
+ | ${eval.code}
+ | if (!${eval.isNull}) {
+ | $nonnull += 1;
+ | }
+ |} else {
+ | continue;
+ |}
+ """.stripMargin
}
}
val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
- evals.mkString("\n")
- } else {
- ctx.splitExpressions(
- expressions = evals,
- funcName = "atLeastNNonNulls",
- arguments = ("InternalRow", ctx.INPUT_ROW) :: ("int", nonnull) :: Nil,
- returnType = "int",
- makeSplitFunction = { body =>
- s"""
- $body
- return $nonnull;
- """
- },
- foldFunctions = { funcCalls =>
- funcCalls.map(funcCall => s"$nonnull = $funcCall;").mkString("\n")
- }
- )
- }
+ evals.mkString("\n")
+ } else {
+ ctx.splitExpressions(
+ expressions = evals,
+ funcName = "atLeastNNonNulls",
+ arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil,
+ returnType = ctx.JAVA_INT,
+ makeSplitFunction = { body =>
+ s"""
+ |do {
+ | $body
+ |} while (false);
+ |return $nonnull;
+ """.stripMargin
+ },
+ foldFunctions = { funcCalls =>
+ funcCalls.map(funcCall =>
+ s"""
+ |$nonnull = $funcCall;
+ |if ($nonnull >= $n) {
+ | continue;
+ |}
+ """.stripMargin).mkString("\n")
+ }
+ )
+ }
- ev.copy(code = s"""
- int $nonnull = 0;
- $code
- boolean ${ev.value} = $nonnull >= $n;""", isNull = "false")
+ ev.copy(code =
+ s"""
+ |${ctx.JAVA_INT} $nonnull = 0;
+ |do {
+ | $code
+ |} while (false);
+ |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
+ """.stripMargin, isNull = "false")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 1aaaaf1db48d..75cc9b3bd804 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -234,36 +234,62 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val javaDataType = ctx.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
val valueArg = ctx.freshName("valueArg")
+ // All the blocks are meant to be inside a do { ... } while (false); loop.
+ // The evaluation of variables can be stopped when we find a matching value.
val listCode = listGen.map(x =>
s"""
- if (!${ev.value}) {
- ${x.code}
- if (${x.isNull}) {
- ${ev.isNull} = true;
- } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
- ${ev.isNull} = false;
- ${ev.value} = true;
+ |${x.code}
+ |if (${x.isNull}) {
+ | ${ev.isNull} = true;
+ |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = true;
+ | continue;
+ |}
+ """.stripMargin)
+ val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+ listCode.mkString("\n")
+ } else {
+ ctx.splitExpressions(
+ expressions = listCode,
+ funcName = "valueIn",
+ arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil,
+ makeSplitFunction = { body =>
+ s"""
+ |do {
+ | $body
+ |} while (false);
+ """.stripMargin
+ },
+ foldFunctions = { funcCalls =>
+ funcCalls.map(funcCall =>
+ s"""
+ |$funcCall;
+ |if (${ev.value}) {
+ | continue;
+ |}
+ """.stripMargin).mkString("\n")
}
- }
- """)
- val listCodes = ctx.splitExpressions(
- expressions = listCode,
- funcName = "valueIn",
- extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil)
- ev.copy(code = s"""
- ${valueGen.code}
- ${ev.value} = false;
- ${ev.isNull} = ${valueGen.isNull};
- if (!${ev.isNull}) {
- ${ctx.javaType(value.dataType)} $valueArg = ${valueGen.value};
- $listCodes
+ )
}
- """)
+ ev.copy(code =
+ s"""
+ |${valueGen.code}
+ |${ev.value} = false;
+ |${ev.isNull} = ${valueGen.isNull};
+ |if (!${ev.isNull}) {
+ | $javaDataType $valueArg = ${valueGen.value};
+ | do {
+ | $code
+ | } while (false);
+ |}
+ """.stripMargin)
}
override def sql: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
new file mode 100644
index 000000000000..4ecc54bd2fd9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.io._
+import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp}
+import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A deserializer to deserialize ORC structs to Spark rows.
+ */
+class OrcDeserializer(
+ dataSchema: StructType,
+ requiredSchema: StructType,
+ requestedColIds: Array[Int]) {
+
+ private val resultRow = new SpecificInternalRow(requiredSchema.map(_.dataType))
+
+ private val fieldWriters: Array[WritableComparable[_] => Unit] = {
+ requiredSchema.zipWithIndex
+ // The value of missing columns are always null, do not need writers.
+ .filterNot { case (_, index) => requestedColIds(index) == -1 }
+ .map { case (f, index) =>
+ val writer = newWriter(f.dataType, new RowUpdater(resultRow))
+ (value: WritableComparable[_]) => writer(index, value)
+ }.toArray
+ }
+
+ private val validColIds = requestedColIds.filterNot(_ == -1)
+
+ def deserialize(orcStruct: OrcStruct): InternalRow = {
+ var i = 0
+ while (i < validColIds.length) {
+ val value = orcStruct.getFieldValue(validColIds(i))
+ if (value == null) {
+ resultRow.setNullAt(i)
+ } else {
+ fieldWriters(i)(value)
+ }
+ i += 1
+ }
+ resultRow
+ }
+
+ /**
+ * Creates a writer to write ORC values to Catalyst data structure at the given ordinal.
+ */
+ private def newWriter(
+ dataType: DataType, updater: CatalystDataUpdater): (Int, WritableComparable[_]) => Unit =
+ dataType match {
+ case NullType => (ordinal, _) =>
+ updater.setNullAt(ordinal)
+
+ case BooleanType => (ordinal, value) =>
+ updater.setBoolean(ordinal, value.asInstanceOf[BooleanWritable].get)
+
+ case ByteType => (ordinal, value) =>
+ updater.setByte(ordinal, value.asInstanceOf[ByteWritable].get)
+
+ case ShortType => (ordinal, value) =>
+ updater.setShort(ordinal, value.asInstanceOf[ShortWritable].get)
+
+ case IntegerType => (ordinal, value) =>
+ updater.setInt(ordinal, value.asInstanceOf[IntWritable].get)
+
+ case LongType => (ordinal, value) =>
+ updater.setLong(ordinal, value.asInstanceOf[LongWritable].get)
+
+ case FloatType => (ordinal, value) =>
+ updater.setFloat(ordinal, value.asInstanceOf[FloatWritable].get)
+
+ case DoubleType => (ordinal, value) =>
+ updater.setDouble(ordinal, value.asInstanceOf[DoubleWritable].get)
+
+ case StringType => (ordinal, value) =>
+ updater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Text].copyBytes))
+
+ case BinaryType => (ordinal, value) =>
+ val binary = value.asInstanceOf[BytesWritable]
+ val bytes = new Array[Byte](binary.getLength)
+ System.arraycopy(binary.getBytes, 0, bytes, 0, binary.getLength)
+ updater.set(ordinal, bytes)
+
+ case DateType => (ordinal, value) =>
+ updater.setInt(ordinal, DateTimeUtils.fromJavaDate(value.asInstanceOf[DateWritable].get))
+
+ case TimestampType => (ordinal, value) =>
+ updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp]))
+
+ case DecimalType.Fixed(precision, scale) => (ordinal, value) =>
+ val decimal = value.asInstanceOf[HiveDecimalWritable].getHiveDecimal()
+ val v = Decimal(decimal.bigDecimalValue, decimal.precision(), decimal.scale())
+ v.changePrecision(precision, scale)
+ updater.set(ordinal, v)
+
+ case st: StructType => (ordinal, value) =>
+ val result = new SpecificInternalRow(st)
+ val fieldUpdater = new RowUpdater(result)
+ val fieldConverters = st.map(_.dataType).map { dt =>
+ newWriter(dt, fieldUpdater)
+ }.toArray
+ val orcStruct = value.asInstanceOf[OrcStruct]
+
+ var i = 0
+ while (i < st.length) {
+ val value = orcStruct.getFieldValue(i)
+ if (value == null) {
+ result.setNullAt(i)
+ } else {
+ fieldConverters(i)(i, value)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case ArrayType(elementType, _) => (ordinal, value) =>
+ val orcArray = value.asInstanceOf[OrcList[WritableComparable[_]]]
+ val length = orcArray.size()
+ val result = createArrayData(elementType, length)
+ val elementUpdater = new ArrayDataUpdater(result)
+ val elementConverter = newWriter(elementType, elementUpdater)
+
+ var i = 0
+ while (i < length) {
+ val value = orcArray.get(i)
+ if (value == null) {
+ result.setNullAt(i)
+ } else {
+ elementConverter(i, value)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, result)
+
+ case MapType(keyType, valueType, _) => (ordinal, value) =>
+ val orcMap = value.asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]]
+ val length = orcMap.size()
+ val keyArray = createArrayData(keyType, length)
+ val keyUpdater = new ArrayDataUpdater(keyArray)
+ val keyConverter = newWriter(keyType, keyUpdater)
+ val valueArray = createArrayData(valueType, length)
+ val valueUpdater = new ArrayDataUpdater(valueArray)
+ val valueConverter = newWriter(valueType, valueUpdater)
+
+ var i = 0
+ val it = orcMap.entrySet().iterator()
+ while (it.hasNext) {
+ val entry = it.next()
+ keyConverter(i, entry.getKey)
+ val value = entry.getValue
+ if (value == null) {
+ valueArray.setNullAt(i)
+ } else {
+ valueConverter(i, value)
+ }
+ i += 1
+ }
+
+ updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+
+ case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater)
+
+ case _ =>
+ throw new UnsupportedOperationException(s"$dataType is not supported yet.")
+ }
+
+ private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match {
+ case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length))
+ case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length))
+ case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length))
+ case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length))
+ case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length))
+ case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length))
+ case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length))
+ case _ => new GenericArrayData(new Array[Any](length))
+ }
+
+ /**
+ * A base interface for updating values inside catalyst data structure like `InternalRow` and
+ * `ArrayData`.
+ */
+ sealed trait CatalystDataUpdater {
+ def set(ordinal: Int, value: Any): Unit
+
+ def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+ def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+ def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+ def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+ def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+ def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+ def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+ def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+ }
+
+ final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+ override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+ override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
+ }
+
+ final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+ override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+ override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value)
+
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value)
+ override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value)
+ override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value)
+ override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value)
+ override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
+ override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
+ override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 215740e90fe8..75c42213db3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -17,10 +17,29 @@
package org.apache.spark.sql.execution.datasources.orc
-import org.apache.orc.TypeDescription
+import java.io._
+import java.net.URI
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.orc._
+import org.apache.orc.OrcConf.{COMPRESS, MAPRED_OUTPUT_SCHEMA}
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce._
+
+import org.apache.spark.TaskContext
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.SerializableConfiguration
private[sql] object OrcFileFormat {
private def checkFieldName(name: String): Unit = {
@@ -39,3 +58,119 @@ private[sql] object OrcFileFormat {
names.foreach(checkFieldName)
}
}
+
+/**
+ * New ORC File Format based on Apache ORC.
+ */
+class OrcFileFormat
+ extends FileFormat
+ with DataSourceRegister
+ with Serializable {
+
+ override def shortName(): String = "orc"
+
+ override def toString: String = "ORC"
+
+ override def hashCode(): Int = getClass.hashCode()
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[OrcFileFormat]
+
+ override def inferSchema(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ OrcUtils.readSchema(sparkSession, files)
+ }
+
+ override def prepareWrite(
+ sparkSession: SparkSession,
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
+
+ val conf = job.getConfiguration
+
+ conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString)
+
+ conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
+
+ conf.asInstanceOf[JobConf]
+ .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]])
+
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new OrcOutputWriter(path, dataSchema, context)
+ }
+
+ override def getFileExtension(context: TaskAttemptContext): String = {
+ val compressionExtension: String = {
+ val name = context.getConfiguration.get(COMPRESS.getAttribute)
+ OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "")
+ }
+
+ compressionExtension + ".orc"
+ }
+ }
+ }
+
+ override def isSplitable(
+ sparkSession: SparkSession,
+ options: Map[String, String],
+ path: Path): Boolean = {
+ true
+ }
+
+ override def buildReader(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+ if (sparkSession.sessionState.conf.orcFilterPushDown) {
+ OrcFilters.createFilter(dataSchema, filters).foreach { f =>
+ OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
+ }
+ }
+
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+
+ (file: PartitionedFile) => {
+ val conf = broadcastedConf.value.value
+
+ val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
+ isCaseSensitive, dataSchema, requiredSchema, new Path(new URI(file.filePath)), conf)
+
+ if (requestedColIdsOrEmptyFile.isEmpty) {
+ Iterator.empty
+ } else {
+ val requestedColIds = requestedColIdsOrEmptyFile.get
+ assert(requestedColIds.length == requiredSchema.length,
+ "[BUG] requested column IDs do not match required schema")
+ conf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
+ requestedColIds.filter(_ != -1).sorted.mkString(","))
+
+ val fileSplit =
+ new FileSplit(new Path(new URI(file.filePath)), file.start, file.length, Array.empty)
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val taskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+
+ val orcRecordReader = new OrcInputFormat[OrcStruct]
+ .createRecordReader(fileSplit, taskAttemptContext)
+ val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
+ Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
+
+ val unsafeProjection = UnsafeProjection.create(requiredSchema)
+ val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds)
+ iter.map(value => unsafeProjection(deserializer.deserialize(value)))
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
new file mode 100644
index 000000000000..cec256cc1b49
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory}
+import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder
+import org.apache.orc.storage.serde2.io.HiveDecimalWritable
+
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types._
+
+/**
+ * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down.
+ *
+ * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double-
+ * checking pattern when converting `And`/`Or`/`Not` filters.
+ *
+ * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't
+ * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite
+ * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using
+ * existing simpler ones.
+ *
+ * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and
+ * `startNot()` mutate internal state of the builder instance. This forces us to translate all
+ * convertible filters with a single builder instance. However, before actually converting a filter,
+ * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is
+ * found, we may already end up with a builder whose internal state is inconsistent.
+ *
+ * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then
+ * try to convert its children. Say we convert `left` child successfully, but find that `right`
+ * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent
+ * now.
+ *
+ * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their
+ * children with brand new builders, and only do the actual conversion with the right builder
+ * instance when the children are proven to be convertible.
+ *
+ * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of
+ * builder methods mentioned above can only be found in test code, where all tested filters are
+ * known to be convertible.
+ */
+private[orc] object OrcFilters {
+
+ /**
+ * Create ORC filter as a SearchArgument instance.
+ */
+ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
+ val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
+
+ // First, tries to convert each filter individually to see whether it's convertible, and then
+ // collect all convertible ones to build the final `SearchArgument`.
+ val convertibleFilters = for {
+ filter <- filters
+ _ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder())
+ } yield filter
+
+ for {
+ // Combines all convertible filters using `And` to produce a single conjunction
+ conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And)
+ // Then tries to build a single ORC `SearchArgument` for the conjunction predicate
+ builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder())
+ } yield builder.build()
+ }
+
+ /**
+ * Return true if this is a searchable type in ORC.
+ * Both CharType and VarcharType are cleaned at AstBuilder.
+ */
+ private def isSearchableType(dataType: DataType) = dataType match {
+ // TODO: SPARK-21787 Support for pushing down filters for DateType in ORC
+ case BinaryType | DateType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+
+ /**
+ * Get PredicateLeafType which is corresponding to the given DataType.
+ */
+ private def getPredicateLeafType(dataType: DataType) = dataType match {
+ case BooleanType => PredicateLeaf.Type.BOOLEAN
+ case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
+ case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
+ case StringType => PredicateLeaf.Type.STRING
+ case DateType => PredicateLeaf.Type.DATE
+ case TimestampType => PredicateLeaf.Type.TIMESTAMP
+ case _: DecimalType => PredicateLeaf.Type.DECIMAL
+ case _ => throw new UnsupportedOperationException(s"DataType: $dataType")
+ }
+
+ /**
+ * Cast literal values for filters.
+ *
+ * We need to cast to long because ORC raises exceptions
+ * at 'checkLiteralType' of SearchArgumentImpl.java.
+ */
+ private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match {
+ case ByteType | ShortType | IntegerType | LongType =>
+ value.asInstanceOf[Number].longValue
+ case FloatType | DoubleType =>
+ value.asInstanceOf[Number].doubleValue()
+ case _: DecimalType =>
+ val decimal = value.asInstanceOf[java.math.BigDecimal]
+ val decimalWritable = new HiveDecimalWritable(decimal.longValue)
+ decimalWritable.mutateEnforcePrecisionScale(decimal.precision, decimal.scale)
+ decimalWritable
+ case _ => value
+ }
+
+ /**
+ * Build a SearchArgument and return the builder so far.
+ */
+ private def buildSearchArgument(
+ dataTypeMap: Map[String, DataType],
+ expression: Filter,
+ builder: Builder): Option[Builder] = {
+ def newBuilder = SearchArgumentFactory.newBuilder()
+
+ def getType(attribute: String): PredicateLeaf.Type =
+ getPredicateLeafType(dataTypeMap(attribute))
+
+ import org.apache.spark.sql.sources._
+
+ expression match {
+ case And(left, right) =>
+ // At here, it is not safe to just convert one side if we do not understand the
+ // other side. Here is an example used to explain the reason.
+ // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to
+ // convert b in ('1'). If we only convert a = 2, we will end up with a filter
+ // NOT(a = 2), which will generate wrong results.
+ // Pushing one side of AND down is only safe to do at the top level.
+ // You can see ParquetRelation's initializeLocalJobFunc method as an example.
+ for {
+ _ <- buildSearchArgument(dataTypeMap, left, newBuilder)
+ _ <- buildSearchArgument(dataTypeMap, right, newBuilder)
+ lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd())
+ rhs <- buildSearchArgument(dataTypeMap, right, lhs)
+ } yield rhs.end()
+
+ case Or(left, right) =>
+ for {
+ _ <- buildSearchArgument(dataTypeMap, left, newBuilder)
+ _ <- buildSearchArgument(dataTypeMap, right, newBuilder)
+ lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr())
+ rhs <- buildSearchArgument(dataTypeMap, right, lhs)
+ } yield rhs.end()
+
+ case Not(child) =>
+ for {
+ _ <- buildSearchArgument(dataTypeMap, child, newBuilder)
+ negate <- buildSearchArgument(dataTypeMap, child, builder.startNot())
+ } yield negate.end()
+
+ // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
+ // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
+ // wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
+
+ case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end())
+
+ case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end())
+
+ case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end())
+
+ case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end())
+
+ case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end())
+
+ case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(attribute))
+ Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end())
+
+ case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ Some(builder.startAnd().isNull(attribute, getType(attribute)).end())
+
+ case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
+ Some(builder.startNot().isNull(attribute, getType(attribute)).end())
+
+ case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
+ val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute)))
+ Some(builder.startAnd().in(attribute, getType(attribute),
+ castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
+
+ case _ => None
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
new file mode 100644
index 000000000000..84755bfa301f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.orc.mapred.OrcStruct
+import org.apache.orc.mapreduce.OrcOutputFormat
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.OutputWriter
+import org.apache.spark.sql.types._
+
+private[orc] class OrcOutputWriter(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext)
+ extends OutputWriter {
+
+ private[this] val serializer = new OrcSerializer(dataSchema)
+
+ private val recordWriter = {
+ new OrcOutputFormat[OrcStruct]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ new Path(path)
+ }
+ }.getRecordWriter(context)
+ }
+
+ override def write(row: InternalRow): Unit = {
+ recordWriter.write(NullWritable.get(), serializer.serialize(row))
+ }
+
+ override def close(): Unit = {
+ recordWriter.close(context)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
new file mode 100644
index 000000000000..899af0750cad
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.io._
+import org.apache.orc.TypeDescription
+import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp}
+import org.apache.orc.storage.common.`type`.HiveDecimal
+import org.apache.orc.storage.serde2.io.{DateWritable, HiveDecimalWritable}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize Spark rows to ORC structs.
+ */
+class OrcSerializer(dataSchema: StructType) {
+
+ private val result = createOrcValue(dataSchema).asInstanceOf[OrcStruct]
+ private val converters = dataSchema.map(_.dataType).map(newConverter(_)).toArray
+
+ def serialize(row: InternalRow): OrcStruct = {
+ var i = 0
+ while (i < converters.length) {
+ if (row.isNullAt(i)) {
+ result.setFieldValue(i, null)
+ } else {
+ result.setFieldValue(i, converters(i)(row, i))
+ }
+ i += 1
+ }
+ result
+ }
+
+ private type Converter = (SpecializedGetters, Int) => WritableComparable[_]
+
+ /**
+ * Creates a converter to convert Catalyst data at the given ordinal to ORC values.
+ */
+ private def newConverter(
+ dataType: DataType,
+ reuseObj: Boolean = true): Converter = dataType match {
+ case NullType => (getter, ordinal) => null
+
+ case BooleanType =>
+ if (reuseObj) {
+ val result = new BooleanWritable()
+ (getter, ordinal) =>
+ result.set(getter.getBoolean(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new BooleanWritable(getter.getBoolean(ordinal))
+ }
+
+ case ByteType =>
+ if (reuseObj) {
+ val result = new ByteWritable()
+ (getter, ordinal) =>
+ result.set(getter.getByte(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new ByteWritable(getter.getByte(ordinal))
+ }
+
+ case ShortType =>
+ if (reuseObj) {
+ val result = new ShortWritable()
+ (getter, ordinal) =>
+ result.set(getter.getShort(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new ShortWritable(getter.getShort(ordinal))
+ }
+
+ case IntegerType =>
+ if (reuseObj) {
+ val result = new IntWritable()
+ (getter, ordinal) =>
+ result.set(getter.getInt(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new IntWritable(getter.getInt(ordinal))
+ }
+
+
+ case LongType =>
+ if (reuseObj) {
+ val result = new LongWritable()
+ (getter, ordinal) =>
+ result.set(getter.getLong(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new LongWritable(getter.getLong(ordinal))
+ }
+
+ case FloatType =>
+ if (reuseObj) {
+ val result = new FloatWritable()
+ (getter, ordinal) =>
+ result.set(getter.getFloat(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new FloatWritable(getter.getFloat(ordinal))
+ }
+
+ case DoubleType =>
+ if (reuseObj) {
+ val result = new DoubleWritable()
+ (getter, ordinal) =>
+ result.set(getter.getDouble(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new DoubleWritable(getter.getDouble(ordinal))
+ }
+
+
+ // Don't reuse the result object for string and binary as it would cause extra data copy.
+ case StringType => (getter, ordinal) =>
+ new Text(getter.getUTF8String(ordinal).getBytes)
+
+ case BinaryType => (getter, ordinal) =>
+ new BytesWritable(getter.getBinary(ordinal))
+
+ case DateType =>
+ if (reuseObj) {
+ val result = new DateWritable()
+ (getter, ordinal) =>
+ result.set(getter.getInt(ordinal))
+ result
+ } else {
+ (getter, ordinal) => new DateWritable(getter.getInt(ordinal))
+ }
+
+ // The following cases are already expensive, reusing object or not doesn't matter.
+
+ case TimestampType => (getter, ordinal) =>
+ val ts = DateTimeUtils.toJavaTimestamp(getter.getLong(ordinal))
+ val result = new OrcTimestamp(ts.getTime)
+ result.setNanos(ts.getNanos)
+ result
+
+ case DecimalType.Fixed(precision, scale) => (getter, ordinal) =>
+ val d = getter.getDecimal(ordinal, precision, scale)
+ new HiveDecimalWritable(HiveDecimal.create(d.toJavaBigDecimal))
+
+ case st: StructType => (getter, ordinal) =>
+ val result = createOrcValue(st).asInstanceOf[OrcStruct]
+ val fieldConverters = st.map(_.dataType).map(newConverter(_))
+ val numFields = st.length
+ val struct = getter.getStruct(ordinal, numFields)
+ var i = 0
+ while (i < numFields) {
+ if (struct.isNullAt(i)) {
+ result.setFieldValue(i, null)
+ } else {
+ result.setFieldValue(i, fieldConverters(i)(struct, i))
+ }
+ i += 1
+ }
+ result
+
+ case ArrayType(elementType, _) => (getter, ordinal) =>
+ val result = createOrcValue(dataType).asInstanceOf[OrcList[WritableComparable[_]]]
+ // Need to put all converted values to a list, can't reuse object.
+ val elementConverter = newConverter(elementType, reuseObj = false)
+ val array = getter.getArray(ordinal)
+ var i = 0
+ while (i < array.numElements()) {
+ if (array.isNullAt(i)) {
+ result.add(null)
+ } else {
+ result.add(elementConverter(array, i))
+ }
+ i += 1
+ }
+ result
+
+ case MapType(keyType, valueType, _) => (getter, ordinal) =>
+ val result = createOrcValue(dataType)
+ .asInstanceOf[OrcMap[WritableComparable[_], WritableComparable[_]]]
+ // Need to put all converted values to a list, can't reuse object.
+ val keyConverter = newConverter(keyType, reuseObj = false)
+ val valueConverter = newConverter(valueType, reuseObj = false)
+ val map = getter.getMap(ordinal)
+ val keyArray = map.keyArray()
+ val valueArray = map.valueArray()
+ var i = 0
+ while (i < map.numElements()) {
+ val key = keyConverter(keyArray, i)
+ if (valueArray.isNullAt(i)) {
+ result.put(key, null)
+ } else {
+ result.put(key, valueConverter(valueArray, i))
+ }
+ i += 1
+ }
+ result
+
+ case udt: UserDefinedType[_] => newConverter(udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(s"$dataType is not supported yet.")
+ }
+
+ /**
+ * Return a Orc value object for the given Spark schema.
+ */
+ private def createOrcValue(dataType: DataType) = {
+ OrcStruct.createValue(TypeDescription.fromString(dataType.catalogString))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
new file mode 100644
index 000000000000..b03ee06d04a1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.orc.{OrcFile, TypeDescription}
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.types._
+
+object OrcUtils extends Logging {
+
+ // The extensions for ORC compression codecs
+ val extensionsForCompressionCodecNames = Map(
+ "NONE" -> "",
+ "SNAPPY" -> ".snappy",
+ "ZLIB" -> ".zlib",
+ "LZO" -> ".lzo")
+
+ def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = {
+ val origPath = new Path(pathStr)
+ val fs = origPath.getFileSystem(conf)
+ val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath)
+ .filterNot(_.isDirectory)
+ .map(_.getPath)
+ .filterNot(_.getName.startsWith("_"))
+ .filterNot(_.getName.startsWith("."))
+ paths
+ }
+
+ def readSchema(file: Path, conf: Configuration): Option[TypeDescription] = {
+ val fs = file.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ val reader = OrcFile.createReader(file, readerOptions)
+ val schema = reader.getSchema
+ if (schema.getFieldNames.size == 0) {
+ None
+ } else {
+ Some(schema)
+ }
+ }
+
+ def readSchema(sparkSession: SparkSession, files: Seq[FileStatus])
+ : Option[StructType] = {
+ val conf = sparkSession.sessionState.newHadoopConf()
+ // TODO: We need to support merge schema. Please see SPARK-11412.
+ files.map(_.getPath).flatMap(readSchema(_, conf)).headOption.map { schema =>
+ logDebug(s"Reading schema from file $files, got Hive schema string: $schema")
+ CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]
+ }
+ }
+
+ /**
+ * Returns the requested column ids from the given ORC file. Column id can be -1, which means the
+ * requested column doesn't exist in the ORC file. Returns None if the given ORC file is empty.
+ */
+ def requestedColumnIds(
+ isCaseSensitive: Boolean,
+ dataSchema: StructType,
+ requiredSchema: StructType,
+ file: Path,
+ conf: Configuration): Option[Array[Int]] = {
+ val fs = file.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ val reader = OrcFile.createReader(file, readerOptions)
+ val orcFieldNames = reader.getSchema.getFieldNames.asScala
+ if (orcFieldNames.isEmpty) {
+ // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer.
+ None
+ } else {
+ if (orcFieldNames.forall(_.startsWith("_col"))) {
+ // This is a ORC file written by Hive, no field names in the physical schema, assume the
+ // physical schema maps to the data scheme by index.
+ assert(orcFieldNames.length <= dataSchema.length, "The given data schema " +
+ s"${dataSchema.simpleString} has less fields than the actual ORC physical schema, " +
+ "no idea which columns were dropped, fail to read.")
+ Some(requiredSchema.fieldNames.map { name =>
+ val index = dataSchema.fieldIndex(name)
+ if (index < orcFieldNames.length) {
+ index
+ } else {
+ -1
+ }
+ })
+ } else {
+ val resolver = if (isCaseSensitive) caseSensitiveResolution else caseInsensitiveResolution
+ Some(requiredSchema.fieldNames.map { name => orcFieldNames.indexWhere(resolver(_, name)) })
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5d0bba69daca..31d9b909ad46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2757,4 +2757,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ // Only New OrcFileFormat supports this
+ Seq(classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName,
+ "parquet").foreach { format =>
+ test(s"SPARK-15474 Write and read back non-emtpy schema with empty dataframe - $format") {
+ withTempPath { file =>
+ val path = file.getCanonicalPath
+ val emptyDf = Seq((true, 1, "str")).toDF.limit(0)
+ emptyDf.write.format(format).save(path)
+
+ val df = spark.read.format(format).load(path)
+ assert(df.schema.sameType(emptyDf.schema))
+ checkAnswer(df, emptyDf)
+ }
+ }
+ }
+
+ test("SPARK-21791 ORC should support column names with dot") {
+ val orc = classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat].getCanonicalName
+ withTempDir { dir =>
+ val path = new File(dir, "orc").getCanonicalPath
+ Seq(Some(1), None).toDF("col.dots").write.format(orc).save(path)
+ assert(spark.read.format(orc).load(path).collect().length == 2)
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 47ce6ba83866..77e836003b39 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -418,7 +418,7 @@ private[hive] class HiveClientImpl(
// Note that this statistics could be overridden by Spark's statistics if that's available.
val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_))
val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_))
- val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)).filter(_ >= 0)
+ val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_))
// TODO: check if this estimate is valid for tables after partition pruning.
// NOTE: getting `totalSize` directly from params is kind of hacky, but this should be
// relatively cheap if parameters for the table are populated into the metastore.
@@ -430,9 +430,9 @@ private[hive] class HiveClientImpl(
// so when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero,
// return None. Later, we will use the other ways to estimate the statistics.
if (totalSize.isDefined && totalSize.get > 0L) {
- Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount))
+ Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount.filter(_ > 0)))
} else if (rawDataSize.isDefined && rawDataSize.get > 0) {
- Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount))
+ Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount.filter(_ > 0)))
} else {
// TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything?
None
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 0cdd9305c6b6..ee027e530826 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -1360,4 +1360,23 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}
}
+
+ test("Deals with wrong Hive's statistics (zero rowCount)") {
+ withTable("maybe_big") {
+ sql("CREATE TABLE maybe_big (c1 bigint)" +
+ "TBLPROPERTIES ('numRows'='0', 'rawDataSize'='60000000000', 'totalSize'='8000000000000')")
+
+ val relation = spark.table("maybe_big").queryExecution.analyzed.children.head
+ .asInstanceOf[HiveTableRelation]
+
+ val properties = relation.tableMeta.ignoredProperties
+ assert(properties("totalSize").toLong > 0)
+ assert(properties("rawDataSize").toLong > 0)
+ assert(properties("numRows").toLong == 0)
+
+ assert(relation.stats.sizeInBytes > 0)
+ // May be cause OOM if rowCount == 0 when enables CBO, see SPARK-22626 for details.
+ assert(relation.stats.rowCount.isEmpty)
+ }
+ }
}