Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix mima and address comments
  • Loading branch information
cloud-fan committed Nov 12, 2015
commit fbaa5438625548a9d2e0eb84af0501dce497288d
8 changes: 7 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.setSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.createSession")
"org.apache.spark.sql.SQLContext.createSession"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.jdbc.MySQLDialect.getCatalystType"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.jdbc.AggregatedDialect.getCatalystType"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.jdbc.PostgresDialect.getCatalystType")
) ++ Seq(
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.SparkContext.preferredNodeLocationData_="),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.{JdbcType, JdbcDialects}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -96,6 +96,11 @@ object JdbcUtils extends Logging {
}
}

private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}

/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction in order to avoid repeatedly inserting
Expand All @@ -115,8 +120,9 @@ object JdbcUtils extends Logging {
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
jdbcTypes: Array[JdbcType],
batchSize: Int): Iterator[Byte] = {
nullTypes: Array[Int],
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
Expand All @@ -130,7 +136,7 @@ object JdbcUtils extends Logging {
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, jdbcTypes(i).jdbcNullType)
stmt.setNull(i + 1, nullTypes(i))
} else {
rddSchema.fields(i).dataType match {
case IntegerType => stmt.setInt(i + 1, row.getInt(i))
Expand All @@ -146,9 +152,8 @@ object JdbcUtils extends Logging {
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
assert(jdbcTypes(i).databaseTypeDefinition.endsWith("[]"))
val array = conn.createArrayOf(
jdbcTypes(i).databaseTypeDefinition.dropRight(2).toLowerCase,
getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
Expand Down Expand Up @@ -199,10 +204,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition)
.orElse(getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition))
.getOrElse(throw new IllegalArgumentException(s"Don't know how to save $field to JDBC"))
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
Expand All @@ -218,19 +220,16 @@ object JdbcUtils extends Logging {
table: String,
properties: Properties = new Properties()) {
val dialect = JdbcDialects.get(url)
val jdbcTypes: Array[JdbcType] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType)
.orElse(getCommonJDBCType(field.dataType))
.getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for field $field"))
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, jdbcTypes, batchSize)
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
* for the given Catalyst type.
*/
@DeveloperApi
abstract class JdbcDialect {
abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
Expand Down