Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,14 @@ the following case-insensitive options:
<td>
The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: <code>"name CHAR(64), comments VARCHAR(1024)")</code>. The specified types should be valid spark sql data types. This option applies only to writing.
</td>
</tr>
</tr>

<tr>
<td><code>customSchema</code></td>
<td>
The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING"). The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading.
</td>
</tr>
</table>

<div class="codetabs">
Expand Down
10 changes: 10 additions & 0 deletions examples/src/main/python/sql/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ def jdbc_dataset_example(spark):
.jdbc("jdbc:postgresql:dbserver", "schema.tablename",
properties={"user": "username", "password": "password"})

# Specifying dataframe column data types on read
jdbcDF3 = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql:dbserver") \
.option("dbtable", "schema.tablename") \
.option("user", "username") \
.option("password", "password") \
.option("customSchema", "id DECIMAL(38, 0), name STRING") \
.load()

# Saving data to a JDBC source
jdbcDF.write \
.format("jdbc") \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ object SQLDataSourceExample {
connectionProperties.put("password", "password")
val jdbcDF2 = spark.read
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
// Specifying the custom data types of the read schema
connectionProperties.put("customSchema", "id DECIMAL(38, 0), name STRING")
val jdbcDF3 = spark.read
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)

// Saving data to a JDBC source
jdbcDF.write
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp}
import java.util.Properties
import java.math.BigDecimal

import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -72,10 +72,17 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
conn.commit()

conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)")
.executeUpdate()
conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))")
.executeUpdate()
conn.prepareStatement(
"CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)").executeUpdate()
conn.prepareStatement(
"INSERT INTO ts_with_timezone VALUES " +
"(1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))").executeUpdate()
conn.commit()

conn.prepareStatement(
"CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate()
conn.prepareStatement(
"INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate()
conn.commit()

sql(
Expand Down Expand Up @@ -104,7 +111,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
}


test("SPARK-16625 : Importing Oracle numeric types") {
test("SPARK-16625 : Importing Oracle numeric types") {
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties);
val rows = df.collect()
assert(rows.size == 1)
Expand Down Expand Up @@ -272,4 +279,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
assert(row.getDate(0).equals(dateVal))
assert(row.getTimestamp(1).equals(timestampVal))
}

test("SPARK-20427/SPARK-20921: read table use custom schema by jdbc api") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move these newly added test cases to JDBCSuite.scala.

Only add a single test case for Oracle data type mappings requested in the JIRA tickets.

// default will throw IllegalArgumentException
val e = intercept[org.apache.spark.SparkException] {
spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", new Properties()).collect()
}
assert(e.getMessage.contains(
"requirement failed: Decimal precision 39 exceeds max precision 38"))

// custom schema can read data
val props = new Properties()
props.put("customSchema",
s"ID DECIMAL(${DecimalType.MAX_PRECISION}, 0), N1 INT, N2 BOOLEAN")
val dfRead = spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", props)

val rows = dfRead.collect()
// verify the data type
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types(0).equals("class java.math.BigDecimal"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Boolean"))

// verify the value
val values = rows(0)
assert(values.getDecimal(0).equals(new java.math.BigDecimal("12312321321321312312312312123")))
assert(values.getInt(1).equals(1))
assert(values.getBoolean(2).equals(false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.{Connection, DriverManager}
import java.util.{Locale, Properties}

import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types.StructType

/**
* Options for the JDBC data source.
Expand Down Expand Up @@ -123,6 +124,8 @@ class JDBCOptions(
// TODO: to reuse the existing partition parameters for those partition specific options
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
val customSchema = parameters.get(JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES)

val batchSize = {
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
require(size >= 1,
Expand Down Expand Up @@ -161,6 +164,7 @@ object JDBCOptions {
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object JDBCRDD extends Logging {
* @return A Catalyst schema corresponding to columns in the given order.
*/
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a related change. Could you revert it back?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CatalystSqlParser.parseTableSchema(columnTypes) constructed StructType without metadata, error message:

key not found: name
java.util.NoSuchElementException: key not found: name
	at scala.collection.MapLike$class.default(MapLike.scala:228)
	at scala.collection.AbstractMap.default(Map.scala:59)
	at scala.collection.MapLike$class.apply(MapLike.scala:141)
	at scala.collection.AbstractMap.apply(Map.scala:59)
	at org.apache.spark.sql.types.Metadata.get(Metadata.scala:111)
	at org.apache.spark.sql.types.Metadata.getString(Metadata.scala:60)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$$anonfun$1.apply(JDBCRDD.scala:83)
	at org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD$$anonfun$1.apply(JDBCRDD.scala:83)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I did not get your point. Could you show me an example? Is it a behavior breaking change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parseTableSchema("id int, name string").fields.map(x => x.metadata.getString("name") -> x)
java.util.NoSuchElementException: key not found: name
  at scala.collection.MapLike$class.default(MapLike.scala:228)
  at scala.collection.AbstractMap.default(Map.scala:59)
  at scala.collection.MapLike$class.apply(MapLike.scala:141)
  at scala.collection.AbstractMap.apply(Map.scala:59)
  at org.apache.spark.sql.types.Metadata.get(Metadata.scala:111)
  at org.apache.spark.sql.types.Metadata.getString(Metadata.scala:60)
  at $anonfun$1.apply(<console>:24)
  at $anonfun$1.apply(<console>:24)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
  at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
  at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
  at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
  at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
  at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
  ... 48 elided

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems safe to remove this line.

new StructType(columns.map(name => fieldMap(name)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,14 @@ private[sql] case class JDBCRelation(

override val needConversion: Boolean = false

override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
override val schema: StructType = {
val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
jdbcOptions.customSchema match {
case Some(customSchema) => JdbcUtils.getCustomSchema(
tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
case None => tableSchema
}
}

// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
Expand Down Expand Up @@ -301,7 +302,6 @@ object JdbcUtils extends Logging {
rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
}
val metadata = new MetadataBuilder()
.putString("name", columnName)
.putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
Expand Down Expand Up @@ -767,6 +767,34 @@ object JdbcUtils extends Logging {
if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
}

/**
* Parses the user specified customSchema option value to DataFrame schema,
* and returns it if it's all columns are equals to default schema's.
*/
def getCustomSchema(
tableSchema: StructType,
customSchema: String,
nameEquality: Resolver): StructType = {
val userSchema = CatalystSqlParser.parseTableSchema(customSchema)

SchemaUtils.checkColumnNameDuplication(
userSchema.map(_.name), "in the customSchema option value", nameEquality)

val colNames = tableSchema.fieldNames.mkString(",")
val errorMsg = s"Please provide all the columns, all columns are: $colNames"
if (userSchema.size != tableSchema.size) {
throw new AnalysisException(errorMsg)
}

// This is resolved by names, only check the column names.
userSchema.fieldNames.foreach { col =>
tableSchema.find(f => nameEquality(f.name, col)).getOrElse {
throw new AnalysisException(errorMsg)
}
}
userSchema
}

/**
* Saves the RDD to the database in a single transaction.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.types._

class JdbcUtilsSuite extends SparkFunSuite {

val tableSchema = StructType(Seq(
StructField("C1", StringType, false), StructField("C2", IntegerType, false)))
val caseSensitive = org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution

test("Parse user specified column types") {
assert(
JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseInsensitive) ===
StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true))))
assert(JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseSensitive) ===
StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true))))
assert(
JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) ===
StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true))))
assert(JdbcUtils.getCustomSchema(
tableSchema, "c1 DECIMAL(38, 0), C2 STRING", caseInsensitive) ===
StructType(Seq(StructField("c1", DecimalType(38, 0), true),
StructField("C2", StringType, true))))

// Throw AnalysisException
val duplicate = intercept[AnalysisException]{
JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) ===
StructType(Seq(StructField("c1", DateType, true), StructField("c1", StringType, true)))
}
assert(duplicate.getMessage.contains(
"Found duplicate column(s) in the customSchema option value"))

val allColumns = intercept[AnalysisException]{
JdbcUtils.getCustomSchema(tableSchema, "C1 STRING", caseSensitive) ===
StructType(Seq(StructField("C1", DateType, true)))
}
assert(allColumns.getMessage.contains("Please provide all the columns,"))

val caseSensitiveColumnNotFound = intercept[AnalysisException]{
JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) ===
StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true)))
}
assert(caseSensitiveColumnNotFound.getMessage.contains(
"Please provide all the columns, all columns are: C1,C2;"))

val caseInsensitiveColumnNotFound = intercept[AnalysisException]{
JdbcUtils.getCustomSchema(tableSchema, "c3 DATE, C2 STRING", caseInsensitive) ===
StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
}
assert(caseInsensitiveColumnNotFound.getMessage.contains(
"Please provide all the columns, all columns are: C1,C2;"))

// Throw ParseException
val dataTypeNotSupported = intercept[ParseException]{
JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) ===
StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
}
assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported"))

val mismatchedInput = intercept[ParseException]{
JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) ===
StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
}
assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting"))
}
}
30 changes: 30 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,36 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}

test("jdbc API support custom schema") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
val props = new Properties()
props.put("customSchema", "NAME STRING, THEID BIGINT")
val schema = StructType(Seq(
StructField("NAME", StringType, true), StructField("THEID", LongType, true)))
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props)
assert(df.schema.size === 2)
assert(df.schema === schema)
assert(df.count() === 3)
}

test("jdbc API custom schema DDL-like strings.") {
withTempView("people_view") {
sql(
s"""
|CREATE TEMPORARY VIEW people_view
|USING org.apache.spark.sql.jdbc
|OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass',
|customSchema 'NAME STRING, THEID INT')
""".stripMargin.replaceAll("\n", " "))
val schema = StructType(
Seq(StructField("NAME", StringType, true), StructField("THEID", IntegerType, true)))
val df = sql("select * from people_view")
assert(df.schema.size === 2)
assert(df.schema === schema)
assert(df.count() === 3)
}
}

test("SPARK-15648: teradataDialect StringType data mapping") {
val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
assert(teradataDialect.getJDBCType(StringType).
Expand Down