Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Incorporate review comments
  • Loading branch information
Cheolsoo Park committed Jul 13, 2015
commit c212c4d975d570604eee0e0bfdd738fb5e9cd213
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.HiveTableScan
import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -302,9 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
// We're converting the entire table into a ParquetRelation, so the filter to Hive metastore
// is None.
val partitions = metastoreRelation.getHiveQlPartitions(None).map { p =>
// We're converting the entire table into ParquetRelation, so predicates to Hive metastore
// are empty.
val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
Expand Down Expand Up @@ -667,8 +666,8 @@ private[hive] case class MetastoreRelation
}
)

def getHiveQlPartitions(filter: Option[String]): Seq[Partition] = {
table.getPartitions(filter).map { p =>
def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
table.getPartitions(predicates).map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
Expand Down
56 changes: 1 addition & 55 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive

import java.io.{InputStream, OutputStream}
import java.rmi.server.UID
import java.util.List

/* Implicit conversions */
import scala.collection.JavaConversions._
Expand All @@ -31,18 +30,15 @@ import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
import org.apache.hadoop.io.Writable

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, Expression}
import org.apache.spark.sql.types.{StringType, IntegralType, Decimal}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.Utils

private[hive] object HiveShim {
Expand Down Expand Up @@ -104,56 +100,6 @@ private[hive] object HiveShim {
}
}

def toMetastoreFilter(
predicates: Seq[Expression],
partitionKeys: List[FieldSchema],
hiveMetastoreVersion: String): Option[String] = {

// Binary comparison has been supported in getPartitionsByFilter() since Hive 0.13.
// So if Hive matastore version is older than 0.13, predicates cannot be pushed down.
// See HIVE-4888.
val versionPattern = "([\\d]+\\.[\\d]+).*".r
hiveMetastoreVersion match {
case versionPattern(version) if (version.toDouble < 0.13) => return None
case _ => // continue
}

// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = partitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
Option(predicates.foldLeft("") {
(prevStr, expr) => {
expr match {
case op @ BinaryComparison(lhs, rhs) => {
val curr: Option[String] =
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
curr match {
case Some(currStr) if (prevStr.nonEmpty) => s"$prevStr and $currStr"
case Some(currStr) if (prevStr.isEmpty) => currStr
case None => prevStr
}
}
case _ => prevStr
}
}
}).filter(_.nonEmpty)
}

/**
* This class provides the UDF creation and also the UDF instance serialization and
* de-serialization cross process boundary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ private[hive] trait HiveStrategies {

try {
if (relation.hiveQlTable.isPartitioned) {
val metastoreFilter =
HiveShim.toMetastoreFilter(
pruningPredicates,
relation.hiveQlTable.getPartitionKeys,
hiveContext.hiveMetastoreVersion)

val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true))
// Translate the predicate so that it automatically casts the input values to the
// correct data types during evaluation.
Expand All @@ -131,9 +125,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}

logDebug(s"Hive metastore filter is $metastoreFilter")

val partitions = relation.getHiveQlPartitions(metastoreFilter).filter { part =>
val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.PrintStream
import java.util.{Map => JMap}

import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression

private[hive] case class HiveDatabase(
name: String,
Expand Down Expand Up @@ -71,10 +72,10 @@ private[hive] case class HiveTable(

def isPartitioned: Boolean = partitionColumns.nonEmpty

def getPartitions(filter: Option[String]): Seq[HivePartition] = {
filter match {
case None => client.getAllPartitions(this)
case Some(expr) => client.getPartitionsByFilter(this, expr)
def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
predicates match {
case Nil => client.getAllPartitions(this)
case _ => client.getPartitionsByFilter(this, predicates)
}
}

Expand Down Expand Up @@ -138,7 +139,7 @@ private[hive] trait ClientInterface {
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]

/** Returns partitions filtered by predicates for the given table. */
def getPartitionsByFilter(hTable: HiveTable, filter: String): Seq[HivePartition]
def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]

/** Loads a static partition into an existing table. */
def loadPartition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.hive.client

import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.net.URI
import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
import java.io.{File, PrintStream}
import java.util.{Map => JMap}
import javax.annotation.concurrent.GuardedBy

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.util.CircularBuffer

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -315,9 +315,9 @@ private[hive] class ClientWrapper(

override def getPartitionsByFilter(
hTable: HiveTable,
filter: String): Seq[HivePartition] = withHiveState {
predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
shim.getPartitionsByFilter(client, qlTable, filter).map(toHivePartition)
shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
}

override def listTables(dbName: String): Seq[String] = withHiveState {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
import org.apache.spark.sql.types.{StringType, IntegralType}

/**
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
Expand Down Expand Up @@ -61,7 +66,7 @@ private[client] sealed abstract class Shim {

def getAllPartitions(hive: Hive, table: Table): Seq[Partition]

def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition]
def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor

Expand Down Expand Up @@ -111,7 +116,7 @@ private[client] sealed abstract class Shim {

}

private[client] class Shim_v0_12 extends Shim {
private[client] class Shim_v0_12 extends Shim with Logging {

private lazy val startMethod =
findStaticMethod(
Expand All @@ -129,12 +134,6 @@ private[client] class Shim_v0_12 extends Shim {
classOf[Hive],
"getAllPartitionsForPruner",
classOf[Table])
private lazy val getPartitionsByFilterMethod =
findMethod(
classOf[Hive],
"getPartitionsByFilter",
classOf[Table],
classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
Expand Down Expand Up @@ -204,9 +203,16 @@ private[client] class Shim_v0_12 extends Shim {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition] =
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
.toSeq
override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
// See HIVE-4888.
logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
"Please use Hive 0.13 or higher.")
getAllPartitions(hive, table)
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
Expand Down Expand Up @@ -306,9 +312,47 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition] =
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
.toSeq
override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = predicates.flatMap { expr =>
expr match {
case op @ BinaryComparison(lhs, rhs) => {
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
}
case _ => None
}
}.mkString(" and ")

val partitions =
if (filter.isEmpty) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
} else {
logDebug(s"Hive metastore filter is '$filter'.")
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
}

partitions.toSeq
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ case class HiveTableScan(
// Retrieve the original attributes based on expression ID so that capitalization matches.
val attributes = requestedAttributes.map(relation.attributeMap)

val metastoreFilter: Option[String] =
HiveShim.toMetastoreFilter(
partitionPruningPred,
relation.hiveQlTable.getPartitionKeys,
context.hiveMetastoreVersion)

// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
Expand Down Expand Up @@ -139,9 +133,8 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
logDebug(s"Hive metastore filter is $metastoreFilter")
hadoopReader.makeRDDForPartitionedTable(
prunePartitions(relation.getHiveQlPartitions(metastoreFilter)))
prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
}

override def output: Seq[Attribute] = attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.client

import java.io.File

import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -152,7 +154,9 @@ class VersionsSuite extends SparkFunSuite with Logging {
}

test(s"$version: getPartitionsByFilter") {
client.getPartitionsByFilter(client.getTable("default", "src_part"), "key = 1")
client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
Literal(1))))
}

test(s"$version: loadPartition") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
p.prunePartitions(relation.getHiveQlPartitions(None)).map(_.getValues)
p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
} else {
Seq.empty
}
Expand Down