Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implement predicate pushdown for hive metastore catalog
  • Loading branch information
Cheolsoo Park committed Jun 29, 2015
commit ca460429692191d75fe077a598418e59ba4aca8e
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
val partitions = metastoreRelation.hiveQlPartitions.map { p =>
val partitions = metastoreRelation.getHiveQlPartitions(None).map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
Expand Down Expand Up @@ -643,32 +643,6 @@ private[hive] case class MetastoreRelation
new Table(tTable)
}

@transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
tPartition.setValues(p.values)

val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tPartition.setSd(sd)
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))

sd.setLocation(p.storage.location)
sd.setInputFormat(p.storage.inputFormat)
sd.setOutputFormat(p.storage.outputFormat)

val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
sd.setSerdeInfo(serdeInfo)
serdeInfo.setSerializationLib(p.storage.serde)

val serdeParameters = new java.util.HashMap[String, String]()
serdeInfo.setParameters(serdeParameters)
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }

new Partition(hiveQlTable, tPartition)
}

@transient override lazy val statistics: Statistics = Statistics(
sizeInBytes = {
val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
Expand All @@ -689,6 +663,34 @@ private[hive] case class MetastoreRelation
}
)

def getHiveQlPartitions(filter: Option[String]): Seq[Partition] = {
table.getPartitions(filter).map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
tPartition.setValues(p.values)

val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tPartition.setSd(sd)
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))

sd.setLocation(p.storage.location)
sd.setInputFormat(p.storage.inputFormat)
sd.setOutputFormat(p.storage.outputFormat)

val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
sd.setSerdeInfo(serdeInfo)
serdeInfo.setSerializationLib(p.storage.serde)

val serdeParameters = new java.util.HashMap[String, String]()
serdeInfo.setParameters(serdeParameters)
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }

new Partition(hiveQlTable, tPartition)
}
}

/** Only compare database and tablename, not alias. */
override def sameResult(plan: LogicalPlan): Boolean = {
plan match {
Expand Down
29 changes: 29 additions & 0 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.reflect.ClassTag

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
Expand All @@ -37,6 +38,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObject
import org.apache.hadoop.io.Writable

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -99,6 +101,33 @@ private[hive] object HiveShim {
}
}

def toMetastoreFilter(predicates: Seq[Expression]): Option[String] = {
if (predicates.nonEmpty) {
// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key_1=\"value_1\" and int_key_2=value_2 ..."
Some(predicates.foldLeft("") {
(str, expr) => {
expr match {
case op @ BinaryComparison(lhs, rhs) => {
val hiveFriendlyExpr =
lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\""
if (str.isEmpty) {
s"$hiveFriendlyExpr"
} else {
s"$str and $hiveFriendlyExpr"
}
}
case _ => {
str
}
}
}
})
} else {
None
}
}

/**
* This class provides the UDF creation and also the UDF instance serialization and
* de-serialization cross process boundary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}

val partitions = relation.hiveQlPartitions.filter { part =>
val partitions = relation.getHiveQlPartitions(None).filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
Expand Down Expand Up @@ -212,7 +212,7 @@ private[hive] trait HiveStrategies {
projectList,
otherPredicates,
identity[Seq[Expression]],
HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil
HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ private[hive] case class HiveTable(

def isPartitioned: Boolean = partitionColumns.nonEmpty

def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this)
def getPartitions(filter: Option[String]): Seq[HivePartition] = {
filter match {
case None => client.getAllPartitions(this)
case Some(expr) => client.getPartitionsByFilter(this, expr)
}
}

// Hive does not support backticks when passing names to the client.
def qualifiedName: String = s"$database.$name"
Expand Down Expand Up @@ -132,6 +137,9 @@ private[hive] trait ClientInterface {
/** Returns all partitions for the given table. */
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]

/** Returns partitions filtered by predicates for the given table. */
def getPartitionsByFilter(hTable: HiveTable, filter: String): Seq[HivePartition]

/** Loads a static partition into an existing table. */
def loadPartition(
loadPath: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hive.client

import java.io.{BufferedReader, InputStreamReader, File, PrintStream}
import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.net.URI
import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
import javax.annotation.concurrent.GuardedBy
Expand All @@ -28,16 +28,13 @@ import scala.collection.JavaConversions._
import scala.language.reflectiveCalls

import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
import org.apache.hadoop.hive.metastore.{TableType => HTableType}
import org.apache.hadoop.hive.metastore.api
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.{Driver, metadata}

import org.apache.spark.Logging
import org.apache.spark.sql.execution.QueryExecutionException
Expand Down Expand Up @@ -316,6 +313,13 @@ private[hive] class ClientWrapper(
shim.getAllPartitions(client, qlTable).map(toHivePartition)
}

override def getPartitionsByFilter(
hTable: HiveTable,
filter: String): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
shim.getPartitionsByFilter(client, qlTable, filter).map(toHivePartition)
}

override def listTables(dbName: String): Seq[String] = withHiveState {
client.getAllTables(dbName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ private[client] sealed abstract class Shim {

def getAllPartitions(hive: Hive, table: Table): Seq[Partition]

def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor

def getDriverResults(driver: Driver): Seq[String]
Expand Down Expand Up @@ -127,6 +129,12 @@ private[client] class Shim_v0_12 extends Shim {
classOf[Hive],
"getAllPartitionsForPruner",
classOf[Table])
private lazy val getPartitionsByFilterMethod =
findMethod(
classOf[Hive],
"getPartitionsByFilter",
classOf[Table],
classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
Expand Down Expand Up @@ -196,6 +204,10 @@ private[client] class Shim_v0_12 extends Shim {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition] =
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
.toSeq

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]

Expand Down Expand Up @@ -267,6 +279,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
classOf[Hive],
"getAllPartitionsOf",
classOf[Table])
private lazy val getPartitionsByFilterMethod =
findMethod(
classOf[Hive],
"getPartitionsByFilter",
classOf[Table],
classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
Expand All @@ -288,6 +306,10 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition] =
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
.toSeq

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[hive]
case class HiveTableScan(
requestedAttributes: Seq[Attribute],
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
partitionPruningPred: Seq[Expression])(
@transient val context: HiveContext)
extends LeafNode {

Expand All @@ -53,9 +53,11 @@ case class HiveTableScan(
// Retrieve the original attributes based on expression ID so that capitalization matches.
val attributes = requestedAttributes.map(relation.attributeMap)

private[this] val metastoreFilter = HiveShim.toMetastoreFilter(partitionPruningPred)

// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
private[this] val boundPruningPred = partitionPruningPred.map { pred =>
private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
require(
pred.dataType == BooleanType,
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
Expand Down Expand Up @@ -132,7 +134,9 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
logDebug(s"Hive metastore filter is $metastoreFilter")
hadoopReader.makeRDDForPartitionedTable(
prunePartitions(relation.getHiveQlPartitions(metastoreFilter)))
}

override def output: Seq[Attribute] = attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ class VersionsSuite extends SparkFunSuite with Logging {
client.getAllPartitions(client.getTable("default", "src_part"))
}

test(s"$version: getPartitionsByFilter") {
client.getPartitionsByFilter(client.getTable("default", "src_part"), "key = 1")
}

test(s"$version: loadPartition") {
client.loadPartition(
emptyDir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
p.prunePartitions(relation.getHiveQlPartitions(None)).map(_.getValues)
} else {
Seq.empty
}
Expand Down