Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-6913] Fixed "No suitable driver found" when using using JDBC d…
…river added with SparkContext.addJar
  • Loading branch information
SlavikBaranov committed Apr 29, 2015
commit c8294aea77daa51090f3403a5ac5056435b7ecd4
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private[sql] object JDBCRDD extends Logging {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException => {
logWarning(s"Couldn't find class $driver", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider {
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)

if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
if (driver != null) DriverRegistry.register(driver)

if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
Expand Down Expand Up @@ -136,7 +136,7 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
val driver: String = DriverRegistry.getDriverClassName(url)
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
Expand Down
57 changes: 56 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.sql

import java.sql.{Connection, DriverManager, PreparedStatement}
import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement}
import java.util.Properties

import scala.collection.concurrent.TrieMap

import org.apache.spark.Logging
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

package object jdbc {
private[sql] object JDBCWriteDetails extends Logging {
Expand Down Expand Up @@ -179,4 +183,55 @@ package object jdbc {
}

}

private [sql] case class DriverWrapper(wrapped: Driver) extends Driver {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't need to be a case class, does it? Looks like you can just slightly change the pattern matching down below.

override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url)

override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant()

override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = {
wrapped.getPropertyInfo(url, info)
}

override def getMinorVersion: Int = wrapped.getMinorVersion

override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger

override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info)

override def getMajorVersion: Int = wrapped.getMajorVersion
}

/**
* java.sql.DriverManager is always loaded by bootstrap classloader,
* so it can't load JDBC drivers accessible by Spark ClassLoader.
*
* To solve the problem, drivers from user-supplied jars are wrapped
* into thin wrapper.
*/
private [sql] object DriverRegistry extends Logging {

val wrapperMap: TrieMap[String, DriverWrapper] = TrieMap.empty

def register(className: String): Unit = {
val cls = Utils.getContextOrSparkClassLoader.loadClass(className)
if (cls.getClassLoader == null) {
logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required")
} else if (wrapperMap.get(className).isDefined) {
logTrace(s"Wrapper for $className already exists")
} else {
val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
if (wrapperMap.putIfAbsent(className, wrapper).isEmpty) {
DriverManager.registerDriver(wrapper)
logTrace(s"Wrapper for $className registered")
}
}
}

def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
case DriverWrapper(wrapped) => wrapped.getClass.getCanonicalName
case driver => driver.getClass.getCanonicalName
}
}

} // package object jdbc