Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private[spark] object Utils extends Logging {
*/
val DEFAULT_MAX_TO_STRING_FIELDS = 25

private def maxNumToStringFields = {
private[spark] def maxNumToStringFields = {
if (SparkEnv.get != null) {
SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Partition
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* Instructions on how to partition the table among workers.
Expand All @@ -48,10 +50,17 @@ private[sql] object JDBCRelation extends Logging {
* Null value predicate is added to the first partition where clause to include
* the rows with null value for the partitions column.
*
* @param schema resolved schema of a JDBC table
* @param partitioning partition information to generate the where clause for each partition
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the other two @param for the new parameters?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

* @param resolver function used to determine if two identifiers are equal
* @param jdbcOptions JDBC options that contains url
* @return an array of partitions with where clause for each partition
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
def columnPartition(
schema: StructType,
partitioning: JDBCPartitioningInfo,
resolver: Resolver,
jdbcOptions: JDBCOptions): Array[Partition] = {
if (partitioning == null || partitioning.numPartitions <= 1 ||
partitioning.lowerBound == partitioning.upperBound) {
return Array[Partition](JDBCPartition(null, 0))
Expand All @@ -78,7 +87,10 @@ private[sql] object JDBCRelation extends Logging {
// Overflow and silliness can happen if you subtract then divide.
// Here we get a little roundoff, but that's (hopefully) OK.
val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
val column = partitioning.column

val column = verifyAndGetNormalizedColumnName(
schema, partitioning.column, resolver, jdbcOptions)

var i: Int = 0
var currentValue: Long = lowerBound
val ans = new ArrayBuffer[Partition]()
Expand All @@ -99,10 +111,57 @@ private[sql] object JDBCRelation extends Logging {
}
ans.toArray
}

// Verify column name based on the JDBC resolved schema
private def verifyAndGetNormalizedColumnName(
schema: StructType,
columnName: String,
resolver: Resolver,
jdbcOptions: JDBCOptions): String = {
val dialect = JdbcDialects.get(jdbcOptions.url)
schema.map(_.name).find { fieldName =>
resolver(fieldName, columnName) ||
resolver(dialect.quoteIdentifier(fieldName), columnName)
}.map(dialect.quoteIdentifier).getOrElse {
throw new AnalysisException(s"User-defined partition column $columnName not " +
s"found in the JDBC relation: ${schema.simpleString(Utils.maxNumToStringFields)}")
}
}

/**
* Takes a (schema, table) specification and returns the table's Catalyst schema.
* If `customSchema` defined in the JDBC options, replaces the schema's dataType with the
* custom schema's type.
*
* @param resolver function used to determine if two identifiers are equal
* @param jdbcOptions JDBC options that contains url, table and other information.
* @return resolved Catalyst schema of a JDBC table
*/
def getSchema(resolver: Resolver, jdbcOptions: JDBCOptions): StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
jdbcOptions.customSchema match {
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, resolver)
case None => tableSchema
}
}

/**
* Resolves a Catalyst schema of a JDBC table and returns [[JDBCRelation]] with the schema.
*/
def apply(
parts: Array[Partition],
jdbcOptions: JDBCOptions)(
sparkSession: SparkSession): JDBCRelation = {
val schema = JDBCRelation.getSchema(sparkSession.sessionState.conf.resolver, jdbcOptions)
JDBCRelation(schema, parts, jdbcOptions)(sparkSession)
}
}

private[sql] case class JDBCRelation(
parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
override val schema: StructType,
parts: Array[Partition],
jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
with InsertableRelation {
Expand All @@ -111,15 +170,6 @@ private[sql] case class JDBCRelation(

override val needConversion: Boolean = false

override val schema: StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
jdbcOptions.customSchema match {
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
case None => tableSchema
}
}

// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class JdbcRelationProvider extends CreatableRelationProvider
JDBCPartitioningInfo(
partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
val resolver = sqlContext.conf.resolver
val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
val parts = JDBCRelation.columnPartition(schema, partitionInfo, resolver, jdbcOptions)
JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
}

override def createRelation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -238,6 +239,11 @@ class JDBCSuite extends SparkFunSuite
|OPTIONS (url '$url', dbtable 'TEST."mixedCaseCols"', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))

conn.prepareStatement("CREATE TABLE test.partition (THEID INTEGER, `THE ID` INTEGER) " +
"AS SELECT 1, 1")
.executeUpdate()
conn.commit()

// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
}

Expand Down Expand Up @@ -1206,4 +1212,47 @@ class JDBCSuite extends SparkFunSuite
}.getMessage
assert(errMsg.contains("Statement was canceled or the session timed out"))
}

test("SPARK-24327 verify and normalize a partition column based on a JDBC resolved schema") {
def testJdbcParitionColumn(partColName: String, expectedColumnName: String): Unit = {
val df = spark.read.format("jdbc")
.option("url", urlWithUserAndPass)
.option("dbtable", "TEST.PARTITION")
.option("partitionColumn", partColName)
.option("lowerBound", 1)
.option("upperBound", 4)
.option("numPartitions", 3)
.load()

val quotedPrtColName = testH2Dialect.quoteIdentifier(expectedColumnName)
df.logicalPlan match {
case LogicalRelation(JDBCRelation(_, parts, _), _, _, _) =>
val whereClauses = parts.map(_.asInstanceOf[JDBCPartition].whereClause).toSet
assert(whereClauses === Set(
s"$quotedPrtColName < 2 or $quotedPrtColName is null",
s"$quotedPrtColName >= 2 AND $quotedPrtColName < 3",
s"$quotedPrtColName >= 3"))
}
}

testJdbcParitionColumn("THEID", "THEID")
testJdbcParitionColumn("\"THEID\"", "THEID")
withSQLConf("spark.sql.caseSensitive" -> "false") {
testJdbcParitionColumn("ThEiD", "THEID")
}
testJdbcParitionColumn("THE ID", "THE ID")

def testIncorrectJdbcPartitionColumn(partColName: String): Unit = {
val errMsg = intercept[AnalysisException] {
testJdbcParitionColumn(partColName, "THEID")
}.getMessage
assert(errMsg.contains(s"User-defined partition column $partColName not found " +
"in the JDBC relation:"))
}

testIncorrectJdbcPartitionColumn("NoExistingColumn")
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
testIncorrectJdbcPartitionColumn(testH2Dialect.quoteIdentifier("ThEiD"))
}
}
}