diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index ee231a934a3a..8713ff4df618 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1334,7 +1334,14 @@ the following case-insensitive options:
The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. This option applies only to writing.
|
-
+
+
+
+ customSchema |
+
+ The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING"). The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading.
+ |
+
diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py
index 8777cca66bfe..f86012ea382e 100644
--- a/examples/src/main/python/sql/datasource.py
+++ b/examples/src/main/python/sql/datasource.py
@@ -177,6 +177,16 @@ def jdbc_dataset_example(spark):
.jdbc("jdbc:postgresql:dbserver", "schema.tablename",
properties={"user": "username", "password": "password"})
+ # Specifying dataframe column data types on read
+ jdbcDF3 = spark.read \
+ .format("jdbc") \
+ .option("url", "jdbc:postgresql:dbserver") \
+ .option("dbtable", "schema.tablename") \
+ .option("user", "username") \
+ .option("password", "password") \
+ .option("customSchema", "id DECIMAL(38, 0), name STRING") \
+ .load()
+
# Saving data to a JDBC source
jdbcDF.write \
.format("jdbc") \
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
index 6ff03bdb2212..86b3dc4a84f5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
@@ -185,6 +185,10 @@ object SQLDataSourceExample {
connectionProperties.put("password", "password")
val jdbcDF2 = spark.read
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
+ // Specifying the custom data types of the read schema
+ connectionProperties.put("customSchema", "id DECIMAL(38, 0), name STRING")
+ val jdbcDF3 = spark.read
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
// Saving data to a JDBC source
jdbcDF.write
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
index 1b2c1b9e800a..7680ae383513 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
@@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp}
import java.util.Properties
import java.math.BigDecimal
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -72,10 +72,17 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
conn.commit()
- conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))")
- .executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO ts_with_timezone VALUES " +
+ "(1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))").executeUpdate()
+ conn.commit()
+
+ conn.prepareStatement(
+ "CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate()
conn.commit()
sql(
@@ -104,7 +111,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
}
- test("SPARK-16625 : Importing Oracle numeric types") {
+ test("SPARK-16625 : Importing Oracle numeric types") {
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties);
val rows = df.collect()
assert(rows.size == 1)
@@ -272,4 +279,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
assert(row.getDate(0).equals(dateVal))
assert(row.getTimestamp(1).equals(timestampVal))
}
+
+ test("SPARK-20427/SPARK-20921: read table use custom schema by jdbc api") {
+ // default will throw IllegalArgumentException
+ val e = intercept[org.apache.spark.SparkException] {
+ spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", new Properties()).collect()
+ }
+ assert(e.getMessage.contains(
+ "requirement failed: Decimal precision 39 exceeds max precision 38"))
+
+ // custom schema can read data
+ val props = new Properties()
+ props.put("customSchema",
+ s"ID DECIMAL(${DecimalType.MAX_PRECISION}, 0), N1 INT, N2 BOOLEAN")
+ val dfRead = spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", props)
+
+ val rows = dfRead.collect()
+ // verify the data type
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types(0).equals("class java.math.BigDecimal"))
+ assert(types(1).equals("class java.lang.Integer"))
+ assert(types(2).equals("class java.lang.Boolean"))
+
+ // verify the value
+ val values = rows(0)
+ assert(values.getDecimal(0).equals(new java.math.BigDecimal("12312321321321312312312312123")))
+ assert(values.getInt(1).equals(1))
+ assert(values.getBoolean(2).equals(false))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 05b00058618a..b4e5d169066d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -21,6 +21,7 @@ import java.sql.{Connection, DriverManager}
import java.util.{Locale, Properties}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.types.StructType
/**
* Options for the JDBC data source.
@@ -123,6 +124,8 @@ class JDBCOptions(
// TODO: to reuse the existing partition parameters for those partition specific options
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
+ val customSchema = parameters.get(JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES)
+
val batchSize = {
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
require(size >= 1,
@@ -161,6 +164,7 @@ object JDBCOptions {
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
+ val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 3274be91d481..05326210f324 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -80,7 +80,7 @@ object JDBCRDD extends Logging {
* @return A Catalyst schema corresponding to columns in the given order.
*/
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
- val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
+ val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
new StructType(columns.map(name => fieldMap(name)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 17405f550d25..b23e5a772200 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -111,7 +111,14 @@ private[sql] case class JDBCRelation(
override val needConversion: Boolean = false
- override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
+ override val schema: StructType = {
+ val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
+ jdbcOptions.customSchema match {
+ case Some(customSchema) => JdbcUtils.getCustomSchema(
+ tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
+ case None => tableSchema
+ }
+ }
// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index bbe9024f13a4..75327f0d38c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -301,7 +302,6 @@ object JdbcUtils extends Logging {
rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
}
val metadata = new MetadataBuilder()
- .putString("name", columnName)
.putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
@@ -767,6 +767,34 @@ object JdbcUtils extends Logging {
if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
}
+ /**
+ * Parses the user specified customSchema option value to DataFrame schema,
+ * and returns it if it's all columns are equals to default schema's.
+ */
+ def getCustomSchema(
+ tableSchema: StructType,
+ customSchema: String,
+ nameEquality: Resolver): StructType = {
+ val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
+
+ SchemaUtils.checkColumnNameDuplication(
+ userSchema.map(_.name), "in the customSchema option value", nameEquality)
+
+ val colNames = tableSchema.fieldNames.mkString(",")
+ val errorMsg = s"Please provide all the columns, all columns are: $colNames"
+ if (userSchema.size != tableSchema.size) {
+ throw new AnalysisException(errorMsg)
+ }
+
+ // This is resolved by names, only check the column names.
+ userSchema.fieldNames.foreach { col =>
+ tableSchema.find(f => nameEquality(f.name, col)).getOrElse {
+ throw new AnalysisException(errorMsg)
+ }
+ }
+ userSchema
+ }
+
/**
* Saves the RDD to the database in a single transaction.
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
new file mode 100644
index 000000000000..1255f262bce9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.types._
+
+class JdbcUtilsSuite extends SparkFunSuite {
+
+ val tableSchema = StructType(Seq(
+ StructField("C1", StringType, false), StructField("C2", IntegerType, false)))
+ val caseSensitive = org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+ val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+
+ test("Parse user specified column types") {
+ assert(
+ JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true))))
+ assert(JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseSensitive) ===
+ StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true))))
+ assert(
+ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true))))
+ assert(JdbcUtils.getCustomSchema(
+ tableSchema, "c1 DECIMAL(38, 0), C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c1", DecimalType(38, 0), true),
+ StructField("C2", StringType, true))))
+
+ // Throw AnalysisException
+ val duplicate = intercept[AnalysisException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c1", DateType, true), StructField("c1", StringType, true)))
+ }
+ assert(duplicate.getMessage.contains(
+ "Found duplicate column(s) in the customSchema option value"))
+
+ val allColumns = intercept[AnalysisException]{
+ JdbcUtils.getCustomSchema(tableSchema, "C1 STRING", caseSensitive) ===
+ StructType(Seq(StructField("C1", DateType, true)))
+ }
+ assert(allColumns.getMessage.contains("Please provide all the columns,"))
+
+ val caseSensitiveColumnNotFound = intercept[AnalysisException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) ===
+ StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true)))
+ }
+ assert(caseSensitiveColumnNotFound.getMessage.contains(
+ "Please provide all the columns, all columns are: C1,C2;"))
+
+ val caseInsensitiveColumnNotFound = intercept[AnalysisException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c3 DATE, C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
+ }
+ assert(caseInsensitiveColumnNotFound.getMessage.contains(
+ "Please provide all the columns, all columns are: C1,C2;"))
+
+ // Throw ParseException
+ val dataTypeNotSupported = intercept[ParseException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
+ }
+ assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported"))
+
+ val mismatchedInput = intercept[ParseException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true)))
+ }
+ assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 8dc11d80c306..5f3148d74339 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -968,6 +968,36 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
+ test("jdbc API support custom schema") {
+ val parts = Array[String]("THEID < 2", "THEID >= 2")
+ val props = new Properties()
+ props.put("customSchema", "NAME STRING, THEID BIGINT")
+ val schema = StructType(Seq(
+ StructField("NAME", StringType, true), StructField("THEID", LongType, true)))
+ val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props)
+ assert(df.schema.size === 2)
+ assert(df.schema === schema)
+ assert(df.count() === 3)
+ }
+
+ test("jdbc API custom schema DDL-like strings.") {
+ withTempView("people_view") {
+ sql(
+ s"""
+ |CREATE TEMPORARY VIEW people_view
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass',
+ |customSchema 'NAME STRING, THEID INT')
+ """.stripMargin.replaceAll("\n", " "))
+ val schema = StructType(
+ Seq(StructField("NAME", StringType, true), StructField("THEID", IntegerType, true)))
+ val df = sql("select * from people_view")
+ assert(df.schema.size === 2)
+ assert(df.schema === schema)
+ assert(df.count() === 3)
+ }
+ }
+
test("SPARK-15648: teradataDialect StringType data mapping") {
val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
assert(teradataDialect.getJDBCType(StringType).