diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e21..f2b30f025222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -169,8 +169,10 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + // Added getJDBCType with added parameter metadata + val typ: String = dialect.getJDBCType(field.dataType, field.metadata) + .map(_.databaseTypeDefinition).orElse(dialect.getJDBCType(field.dataType) + .map(_.databaseTypeDefinition)).getOrElse( field.dataType match { case IntegerType => "INTEGER" case LongType => "BIGINT" @@ -202,23 +204,27 @@ object JdbcUtils extends Logging { properties: Properties = new Properties()) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( + // Added getJDBCType with added parameter metadata + dialect.getJDBCType(field.dataType, field.metadata) + .map(_.jdbcNullType).orElse(dialect.getJDBCType(field.dataType) + .map(_.jdbcNullType)).getOrElse( + field.dataType match { + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case DoubleType => java.sql.Types.DOUBLE + case FloatType => java.sql.Types.REAL + case ShortType => java.sql.Types.INTEGER + case ByteType => java.sql.Types.INTEGER + case BooleanType => java.sql.Types.BIT + case StringType => java.sql.Types.CLOB + case BinaryType => java.sql.Types.BLOB + case TimestampType => java.sql.Types.TIMESTAMP + case DateType => java.sql.Types.DATE + case t: DecimalType => java.sql.Types.DECIMAL + case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") - }) + } + ) } val rddSchema = df.schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c70fea1c3f50..2a822b997471 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -81,6 +81,14 @@ abstract class JdbcDialect { */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Retrieve the jdbc / sql type for a given datatype. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @param md The metadata + * @return The new JdbcType if there is an override for this DataType + */ + def getJDBCType(dt: DataType, md: Metadata): Option[JdbcType] = None + /** * Quotes the identifier. This is used to put quotes around the identifier in case the column * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). @@ -138,7 +146,8 @@ object JdbcDialects { registerDialect(PostgresDialect) registerDialect(DB2Dialect) registerDialect(MsSqlServerDialect) - + registerDialect(OracleDialect) + registerDialect(NetezzaDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -173,8 +182,8 @@ class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption } - override def getJDBCType(dt: DataType): Option[JdbcType] = { - dialects.flatMap(_.getJDBCType(dt)).headOption + override def getJDBCType(dt: DataType, md: Metadata): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt, md)).headOption } } @@ -205,7 +214,7 @@ case object PostgresDialect extends JdbcDialect { } else None } - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + override def getJDBCType(dt: DataType, md: Metadata): Option[JdbcType] = dt match { case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) @@ -253,10 +262,8 @@ case object MySQLDialect extends JdbcDialect { */ @DeveloperApi case object DB2Dialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + override def getJDBCType(dt: DataType, md: Metadata): Option[JdbcType] = dt match { case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None @@ -278,3 +285,59 @@ case object MsSqlServerDialect extends JdbcDialect { } else None } } + +/** + * :: DeveloperApi :: +* Default Oracle dialect, mapping string/boolean on write to valid Oracle types. +*/ +@DeveloperApi +case object OracleDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARCHAR && typeName.equals("VARCHAR")) { + // Save varchar size to metadata + md.putLong("maxlength", size) + Some(LongType) + } else None + } + + override def getJDBCType(dt: DataType, md: Metadata): Option[JdbcType] = { + if (dt == StringType && md.contains("maxlength")) { + Some(JdbcType(s"VARCHAR(${md.getLong("maxlength")})", java.sql.Types.VARCHAR)) + } else if (dt == StringType ) { + Some(JdbcType("CLOB", java.sql.Types.CLOB)) + } else if (dt == BooleanType ) { + Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + } else None + } +} + +/** + * :: DeveloperApi :: + * Default Netezza dialect, mapping string/boolean on write to valid Netezza types. + */ +@DeveloperApi +case object NetezzaDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:netezza") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARCHAR && typeName.equals("VARCHAR")) { + // Save varchar size to metadata + md.putLong("maxlength", size) + Some(LongType) + } else None + } + + override def getJDBCType(dt: DataType, md: Metadata): Option[JdbcType] = { + if (dt == StringType && md.contains("maxlength")) { + Some(JdbcType(s"VARCHAR(${md.getLong("maxlength")})", java.sql.Types.VARCHAR)) + } else if (dt == StringType ) { + Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) + } else if (dt == BinaryType ) { + Some(JdbcType("BYTEINT", java.sql.Types.BINARY)) + } else if (dt == BooleanType ) { + Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + } else None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index c4b039a9c535..1bf303e75dbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{MetadataCleanerType, Utils} class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { import testImplicits._ @@ -409,6 +409,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) + assert(JdbcDialects.get("jdbc:oracle://127.0.0.1/db") == OracleDialect) + assert(JdbcDialects.get("jdbc:netezza://127.0.0.1/db") == NetezzaDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } @@ -448,8 +450,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("DB2Dialect type mapping") { val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") - assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") - assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + assert(db2Dialect.getJDBCType(StringType, null).map(_.databaseTypeDefinition).get == "CLOB") + assert(db2Dialect.getJDBCType(BooleanType, null).map(_.databaseTypeDefinition).get == "CHAR(1)") } test("table exists query by jdbc dialect") {