Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private[sql] object JDBCRDD extends Logging {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException => {
logWarning(s"Couldn't find class $driver", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider {
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)

if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
if (driver != null) DriverRegistry.register(driver)

if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
Expand Down Expand Up @@ -136,7 +136,7 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
val driver: String = DriverRegistry.getDriverClassName(url)
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
Expand Down
60 changes: 59 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.sql

import java.sql.{Connection, DriverManager, PreparedStatement}
import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement}
import java.util.Properties

import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

package object jdbc {
private[sql] object JDBCWriteDetails extends Logging {
Expand Down Expand Up @@ -179,4 +183,58 @@ package object jdbc {
}

}

private [sql] class DriverWrapper(val wrapped: Driver) extends Driver {
override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url)

override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant()

override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = {
wrapped.getPropertyInfo(url, info)
}

override def getMinorVersion: Int = wrapped.getMinorVersion

override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger

override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info)

override def getMajorVersion: Int = wrapped.getMajorVersion
}

/**
* java.sql.DriverManager is always loaded by bootstrap classloader,
* so it can't load JDBC drivers accessible by Spark ClassLoader.
*
* To solve the problem, drivers from user-supplied jars are wrapped
* into thin wrapper.
*/
private [sql] object DriverRegistry extends Logging {

private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty

def register(className: String): Unit = {
val cls = Utils.getContextOrSparkClassLoader.loadClass(className)
if (cls.getClassLoader == null) {
logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required")
} else if (wrapperMap.get(className).isDefined) {
logTrace(s"Wrapper for $className already exists")
} else {
synchronized {
if (wrapperMap.get(className).isEmpty) {
val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
DriverManager.registerDriver(wrapper)
wrapperMap(className) = wrapper
logTrace(s"Wrapper for $className registered")
}
}
}
}

def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
case driver => driver.getClass.getCanonicalName
}
}

} // package object jdbc