Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
rxins comments
  • Loading branch information
marmbrus committed May 2, 2015
commit c72f6acbca3f7766892ddb735dfcccdb494d1bfd
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
*/
class NoSuchTableException extends Exception

class NoSuchDatabaseException extends Exception

/**
* An interface for looking up relations by name. Used by an [[Analyzer]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hive.client

import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}

case class HiveDatabase(
name: String,
Expand Down Expand Up @@ -91,7 +91,7 @@ trait ClientInterface {

/** Returns the metadata for specified database, throwing an exception if it doesn't exist */
def getDatabase(name: String): HiveDatabase = {
getDatabaseOption(name).getOrElse(sys.error(s"No such database $name"))
getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException)
}

/** Returns the metadata for a given database, or None if it doesn't exist. */
Expand All @@ -112,7 +112,7 @@ trait ClientInterface {
def alterTable(table: HiveTable): Unit

/** Creates a new database with the given name. */
def createDatabase(databaseName: String): Unit
def createDatabase(database: HiveDatabase): Unit

/** Returns all partitions for the given table. */
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]
Expand All @@ -121,7 +121,7 @@ trait ClientInterface {
def loadPartition(
loadPath: String,
tableName: String,
partSpec: java.util.LinkedHashMap[String, String],
partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering
replace: Boolean,
holdDDLTime: Boolean,
inheritTableSpecs: Boolean,
Expand All @@ -138,7 +138,7 @@ trait ClientInterface {
def loadDynamicPartitions(
loadPath: String,
tableName: String,
partSpec: java.util.LinkedHashMap[String, String],
partSpec: java.util.LinkedHashMap[String, String], // Hive relies on LinkedHashMap ordering
replace: Boolean,
numDP: Int,
holdDDLTime: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import scala.language.reflectiveCalls
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata._
import org.apache.hadoop.hive.ql.metadata
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.Driver
Expand Down Expand Up @@ -65,13 +67,6 @@ class ClientWrapper(
conf.set(k, v)
}

private def properties = Seq(
"javax.jdo.option.ConnectionURL",
"javax.jdo.option.ConnectionDriverName",
"javax.jdo.option.ConnectionUserName")

properties.foreach(p => logInfo(s"Hive Configuration: $p = ${conf.get(p)}"))

// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
private val outputBuffer = new java.io.OutputStream {
var pos: Int = 0
Expand Down Expand Up @@ -117,7 +112,10 @@ class ClientWrapper(

private val client = Hive.get(conf)

private def withClassLoader[A](f: => A): A = synchronized {
/**
* Runs `f` with ThreadLocal session state and classloaders configured for this version of hive.
*/
private def withHiveState[A](f: => A): A = synchronized {
val original = Thread.currentThread().getContextClassLoader
Thread.currentThread().setContextClassLoader(getClass.getClassLoader)
Hive.set(client)
Expand All @@ -135,25 +133,32 @@ class ClientWrapper(
ret
}

def currentDatabase: String = withClassLoader {
override def currentDatabase: String = withHiveState {
state.getCurrentDatabase
}

def createDatabase(tableName: String): Unit = withClassLoader {
val table = new Table("default", tableName)
override def createDatabase(database: HiveDatabase): Unit = withHiveState {
client.createDatabase(
new Database("default", "", new File("").toURI.toString, new java.util.HashMap), true)
new Database(
database.name,
"",
new File(database.location).toURI.toString,
new java.util.HashMap),
true)
}

def getDatabaseOption(name: String): Option[HiveDatabase] = withClassLoader {
override def getDatabaseOption(name: String): Option[HiveDatabase] = withHiveState {
Option(client.getDatabase(name)).map { d =>
HiveDatabase(
name = d.getName,
location = d.getLocationUri)
}
}

def getTableOption(dbName: String, tableName: String): Option[HiveTable] = withClassLoader {
override def getTableOption(
dbName: String,
tableName: String): Option[HiveTable] = withHiveState {

logDebug(s"Looking up $dbName.$tableName")

val hiveTable = Option(client.getTable(dbName, tableName, false))
Expand Down Expand Up @@ -185,8 +190,8 @@ class ClientWrapper(
Class.forName(name)
.asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]]

private def toQlTable(table: HiveTable): Table = {
val qlTable = new Table(table.database, table.name)
private def toQlTable(table: HiveTable): metadata.Table = {
val qlTable = new metadata.Table(table.database, table.name)

qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))
qlTable.setPartCols(
Expand All @@ -208,25 +213,23 @@ class ClientWrapper(
qlTable
}

def createTable(table: HiveTable): Unit = withClassLoader {
override def createTable(table: HiveTable): Unit = withHiveState {
val qlTable = toQlTable(table)
client.createTable(qlTable)
}

def alterTable(table: HiveTable): Unit = withClassLoader {
override def alterTable(table: HiveTable): Unit = withHiveState {
val qlTable = toQlTable(table)
client.alterTable(table.qualifiedName, qlTable)
}

def getTables(dbName: String): Seq[String] = withClassLoader {
client.getAllTables(dbName).toSeq
}

def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withClassLoader {
override def getAllPartitions(hTable: HiveTable): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
val qlPartitions = version match {
case hive.v12 => client.call[Table, Set[Partition]]("getAllPartitionsForPruner", qlTable)
case hive.v13 => client.call[Table, Set[Partition]]("getAllPartitionsOf", qlTable)
case hive.v12 =>
client.call[metadata.Table, Set[metadata.Partition]]("getAllPartitionsForPruner", qlTable)
case hive.v13 =>
client.call[metadata.Table, Set[metadata.Partition]]("getAllPartitionsOf", qlTable)
}
qlPartitions.map(_.getTPartition).map { p =>
HivePartition(
Expand All @@ -239,14 +242,14 @@ class ClientWrapper(
}.toSeq
}

def listTables(dbName: String): Seq[String] = withClassLoader {
override def listTables(dbName: String): Seq[String] = withHiveState {
client.getAllTables
}

/**
* Runs the specified SQL query using Hive.
*/
def runSqlHive(sql: String): Seq[String] = {
override def runSqlHive(sql: String): Seq[String] = {
val maxResults = 100000
val results = runHive(sql, maxResults)
// It is very confusing when you only get back some of the results...
Expand All @@ -258,7 +261,7 @@ class ClientWrapper(
* Execute the command using Hive and return the results as a sequence. Each element
* in the sequence is one row.
*/
protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withClassLoader {
protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = withHiveState {
logDebug(s"Running hiveql '$cmd'")
if (cmd.toLowerCase.startsWith("set")) { logDebug(s"Changing config: $cmd") }
try {
Expand Down Expand Up @@ -331,7 +334,7 @@ class ClientWrapper(
replace: Boolean,
holdDDLTime: Boolean,
inheritTableSpecs: Boolean,
isSkewedStoreAsSubdir: Boolean): Unit = withClassLoader {
isSkewedStoreAsSubdir: Boolean): Unit = withHiveState {

client.loadPartition(
new Path(loadPath), // TODO: Use URI
Expand All @@ -347,7 +350,7 @@ class ClientWrapper(
loadPath: String, // TODO URI
tableName: String,
replace: Boolean,
holdDDLTime: Boolean): Unit = withClassLoader {
holdDDLTime: Boolean): Unit = withHiveState {
client.loadTable(
new Path(loadPath),
tableName,
Expand All @@ -362,7 +365,7 @@ class ClientWrapper(
replace: Boolean,
numDP: Int,
holdDDLTime: Boolean,
listBucketingEnabled: Boolean): Unit = withClassLoader {
listBucketingEnabled: Boolean): Unit = withHiveState {
client.loadDynamicPartitions(
new Path(loadPath),
tableName,
Expand All @@ -373,7 +376,7 @@ class ClientWrapper(
listBucketingEnabled)
}

def reset(): Unit = withClassLoader {
def reset(): Unit = withHiveState {
client.getAllTables("default").foreach { t =>
logDebug(s"Deleting table $t")
val table = client.getTable("default", t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,23 @@ object IsolatedClientLoader {
* Creates isolated Hive client loaders by downloading the requested version from maven.
*/
def forVersion(
version: Int,
version: String,
config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized {
val files = resolvedVersions.getOrElseUpdate(version, downloadVersion(version))
val resolvedVersion = hiveVersion(version)
val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion))
new IsolatedClientLoader(hiveVersion(version), files, config)
}

def hiveVersion(version: Int): HiveVersion = version match {
case 12 => hive.v12
case 13 => hive.v13
def hiveVersion(version: String): HiveVersion = version match {
case "12" | "0.12" | "0.12.0" => hive.v12
case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13
}

private def downloadVersion(version: Int): Seq[File] = {
val v = hiveVersion(version).fullVersion
private def downloadVersion(version: HiveVersion): Seq[File] = {
val hiveArtifacts =
(Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") ++
(if (version <= 10) "hive-builtins" :: Nil else Nil))
.map(a => s"org.apache.hive:$a:$v") :+
(if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil))
.map(a => s"org.apache.hive:$a:${version.fullVersion}") :+
"com.google.guava:guava:14.0.1" :+
"org.apache.hadoop:hadoop-client:2.4.0" :+
"mysql:mysql-connector-java:5.1.12"
Expand All @@ -75,7 +75,7 @@ object IsolatedClientLoader {
tempDir.listFiles()
}

private def resolvedVersions = new scala.collection.mutable.HashMap[Int, Seq[File]]
private def resolvedVersions = new scala.collection.mutable.HashMap[HiveVersion, Seq[File]]
}

/**
Expand Down Expand Up @@ -136,7 +136,7 @@ class IsolatedClientLoader(
protected val classLoader: ClassLoader = new URLClassLoader(allJars, rootClassLoader) {
override def loadClass(name: String, resolve: Boolean): Class[_] = {
val loaded = findLoadedClass(name)
if(loaded == null) doLoadClass(name, resolve) else loaded
if (loaded == null) doLoadClass(name, resolve) else loaded
}

def doLoadClass(name: String, resolve: Boolean): Class[_] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is the correct way to implement a new classloader. In my understanding, we isolate the classes by using leaf classloader(custom classloader) in a delegate way. e.g.

                    Bootstrap
                            |
                      System
                            |
                      Shared (non Hive jars)
          /                                      \
  HiveV1(barrier jars)      HiveV2(barrier jars) ...

Here is the sample of Tomcat classloader.
https://tomcat.apache.org/tomcat-6.0-doc/class-loader-howto.html

But we need to handle the jar files visibility for different classloaders very carefully, otherwise it may causes memory leak or weird ClassCastException.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think its incorrect this way (I do already have nearly all the tests passing using this isolated classloader), but it does seem that we could eliminate the findLoadedClass logic by overriding findClass instead. I'll simplify in my next PR.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ package org.apache.spark.sql.hive

/** Support for interacting with different versions of the HiveMetastoreClient */
package object client {
private[client] abstract class HiveVersion(val fullVersion: String)
private[client] abstract class HiveVersion(val fullVersion: String, val hasBuiltinsJar: Boolean)

// scalastyle:off
private[client] object hive {
case object v12 extends HiveVersion("0.12.0")
case object v13 extends HiveVersion("0.13.1")
case object v10 extends HiveVersion("0.10.0", true)
case object v11 extends HiveVersion("0.11.0", false)
case object v12 extends HiveVersion("0.12.0", false)
case object v13 extends HiveVersion("0.13.1", false)
}
// scalastyle:on

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: new line

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.FunSuite
class VersionsSuite extends FunSuite with Logging {
val testType = "derby"

private def buildConf(version: Int) = {
private def buildConf() = {
lazy val warehousePath = Utils.createTempDir()
lazy val metastorePath = Utils.createTempDir()
metastorePath.delete()
Expand All @@ -35,8 +35,9 @@ class VersionsSuite extends FunSuite with Logging {
}

test("success sanity check") {
val badClient = IsolatedClientLoader.forVersion(13, buildConf(13)).client
badClient.createDatabase("default")
val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client
val db = new HiveDatabase("default", "")
badClient.createDatabase(db)
}

private def getNestedMessages(e: Throwable): String = {
Expand All @@ -55,24 +56,25 @@ class VersionsSuite extends FunSuite with Logging {
// TODO: currently only works on mysql where we manually create the schema...
ignore("failure sanity check") {
val e = intercept[Throwable] {
val badClient = quietly { IsolatedClientLoader.forVersion(13, buildConf(12)).client }
val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client }
}
assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'")
}

private val versions = Seq(12, 13)
private val versions = Seq("12", "13")

private var client: ClientInterface = null

versions.foreach { version =>
test(s"$version: listTables") {
client = null
client = IsolatedClientLoader.forVersion(version, buildConf(version)).client
client = IsolatedClientLoader.forVersion(version, buildConf()).client
client.listTables("default")
}

test(s"$version: createDatabase") {
client.createDatabase("default")
val db = HiveDatabase("default", "")
client.createDatabase(db)
}

test(s"$version: createTable") {
Expand All @@ -85,7 +87,7 @@ class VersionsSuite extends FunSuite with Logging {
properties = Map.empty,
serdeProperties = Map.empty,
tableType = ManagedTable,
location = Some("/user/hive/src"),
location = None,
inputFormat =
Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName),
outputFormat =
Expand Down