Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}

Expand All @@ -65,7 +67,7 @@ class HadoopTableReader(
@transient private val tableDesc: TableDesc,
@transient private val sparkSession: SparkSession,
hadoopConf: Configuration)
extends TableReader with Logging {
extends TableReader with CastSupport with Logging {

// Hadoop honors "mapreduce.job.maps" as hint,
// but will ignore when mapreduce.jobtracker.address is "local".
Expand All @@ -86,6 +88,8 @@ class HadoopTableReader(
private val _broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

override def conf: SQLConf = sparkSession.sessionState.conf

override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
makeRDDForTable(
hiveTable,
Expand Down Expand Up @@ -227,7 +231,7 @@ class HadoopTableReader(
def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = {
partitionKeyAttrs.foreach { case (attr, ordinal) =>
val partOrdinal = partitionKeys.indexOf(attr)
row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.util.Utils

Expand All @@ -53,11 +55,13 @@ case class HiveTableScanExec(
relation: HiveTableRelation,
partitionPruningPred: Seq[Expression])(
@transient private val sparkSession: SparkSession)
extends LeafExecNode {
extends LeafExecNode with CastSupport {

require(partitionPruningPred.isEmpty || relation.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")

override def conf: SQLConf = sparkSession.sessionState.conf

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

Expand Down Expand Up @@ -104,7 +108,7 @@ case class HiveTableScanExec(
hadoopConf)

private def castFromString(value: String, dataType: DataType) = {
Cast(Literal(value), dataType).eval(null)
cast(Literal(value), dataType).eval(null)
}

private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.hive

import java.io.File
import java.sql.Timestamp

import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem
Expand Down Expand Up @@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl
sql("DROP TABLE IF EXISTS createAndInsertTest")
}
}

test("SPARK-21739: Cast expression should initialize timezoneId") {
withTable("table_with_timestamp_partition") {
sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
"PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")

// test for Cast expression in TableReader
checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))

// test for Cast expression in HiveTableScanExec
checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
"WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
}
}
}