From 378c5a9b414ff5e82f2c1bb0f3963279c2b4f929 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Nov 2015 21:57:07 +0800 Subject: [PATCH 1/3] support array type in postgresql --- .../sql/jdbc/PostgresIntegrationSuite.scala | 44 ++++++---- .../execution/datasources/jdbc/JDBCRDD.scala | 84 ++++++++++++++----- .../datasources/jdbc/JdbcUtils.scala | 80 +++++++++--------- .../spark/sql/jdbc/AggregatedDialect.scala | 8 +- .../apache/spark/sql/jdbc/DerbyDialect.scala | 6 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 9 +- .../spark/sql/jdbc/MsSqlServerDialect.scala | 6 +- .../apache/spark/sql/jdbc/MySQLDialect.scala | 8 +- .../apache/spark/sql/jdbc/OracleDialect.scala | 10 ++- .../spark/sql/jdbc/PostgresDialect.scala | 55 ++++++++---- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 20 +++-- 11 files changed, 222 insertions(+), 108 deletions(-) diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 164a7f396280c..2e18d0a2baa1c 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.Connection import java.util.Properties +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Literal, If} import org.apache.spark.tags.DockerTest @DockerTest @@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override def dataPreparation(conn: Connection): Unit = { conn.prepareStatement("CREATE DATABASE foo").executeUpdate() conn.setCatalog("foo") - conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " - + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " + + "c10 integer[], c11 text[])").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " - + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " + + """'{1, 2}', '{"a", null, "b"}')""").executeUpdate() } test("Type mapping for various types") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) val rows = df.collect() assert(rows.length == 1) - val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 10) - assert(types(0).equals("class java.lang.String")) - assert(types(1).equals("class java.lang.Integer")) - assert(types(2).equals("class java.lang.Double")) - assert(types(3).equals("class java.lang.Long")) - assert(types(4).equals("class java.lang.Boolean")) - assert(types(5).equals("class [B")) - assert(types(6).equals("class [B")) - assert(types(7).equals("class java.lang.Boolean")) - assert(types(8).equals("class java.lang.String")) - assert(types(9).equals("class java.lang.String")) + val types = rows(0).toSeq.map(x => x.getClass) + assert(types.length == 12) + assert(classOf[String].isAssignableFrom(types(0))) + assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) + assert(classOf[java.lang.Double].isAssignableFrom(types(2))) + assert(classOf[java.lang.Long].isAssignableFrom(types(3))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(4))) + assert(classOf[Array[Byte]].isAssignableFrom(types(5))) + assert(classOf[Array[Byte]].isAssignableFrom(types(6))) + assert(classOf[java.lang.Boolean].isAssignableFrom(types(7))) + assert(classOf[String].isAssignableFrom(types(8))) + assert(classOf[String].isAssignableFrom(types(9))) + assert(classOf[Seq[Int]].isAssignableFrom(types(10))) + assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) @@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getBoolean(7) == true) assert(rows(0).getString(8) == "172.16.0.42") assert(rows(0).getString(9) == "192.168.0.0/16") + assert(rows(0).getSeq(10) == Seq(1, 2)) + assert(rows(0).getSeq(11) == Seq("a", null, "b")) } test("Basic write test") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) - df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) // Test only that it doesn't crash. + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test write null values. + df.select(df.queryExecution.analyzed.output.map { a => + Column(If(Literal(true), Literal(null), a)).as(a.name) + }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 018a009fbda6d..d944a5e1033dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -130,14 +130,14 @@ private[sql] object JDBCRDD extends Logging { val columnName = rsmd.getColumnLabel(i + 1) val dataType = rsmd.getColumnType(i + 1) val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val fieldScale = rsmd.getScale(i + 1) + val precision = rsmd.getPrecision(i + 1) + val scale = rsmd.getScale(i + 1) val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) + dialect.getCatalystType(dataType, typeName, precision, scale, metadata).getOrElse( + getCatalystType(dataType, precision, scale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } @@ -324,25 +324,27 @@ private[sql] class JDBCRDD( case object StringConversion extends JDBCConversion case object TimestampConversion extends JDBCConversion case object BinaryConversion extends JDBCConversion + case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion /** * Maps a StructType to a type tag list. */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray + def getConversions(schema: StructType): Array[JDBCConversion] = + schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) + + private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } /** @@ -420,16 +422,44 @@ private[sql] class JDBCRDD( mutableRow.update(i, null) } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { + case BinaryLongConversion => val bytes = rs.getBytes(pos) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) - j = j + 1; + j = j + 1 } mutableRow.setLong(i, ans) - } + case ArrayConversion(elementConversion) => + val array = rs.getArray(pos).getArray + if (array != null) { + val data = elementConversion match { + case TimestampConversion => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + case StringConversion => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + case DateConversion => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + case DecimalConversion(p, s) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s)) + } + case BinaryLongConversion => + throw new IllegalArgumentException(s"Unsupported array element conversion $i") + case _: ArrayConversion => + throw new IllegalArgumentException("Nested arrays unsupported") + case _ => array.asInstanceOf[Array[Any]] + } + mutableRow.update(i, new GenericArrayData(data)) + } else { + mutableRow.update(i, null) + } } if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 @@ -488,4 +518,12 @@ private[sql] class JDBCRDD( nextValue } } + + private def nullSafeConvert[T](input: T, f: T => Any): Any = { + if (input == null) { + null + } else { + f(input) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e212..a14457a8d2a60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcType, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -72,6 +72,30 @@ object JdbcUtils extends Logging { conn.prepareStatement(sql.toString()) } + /** + * Retrieve standard jdbc types. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The default JdbcType for this DataType + */ + def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None + } + } + /** * Saves a partition of a DataFrame to the JDBC database. This is done in * a single database transaction in order to avoid repeatedly inserting @@ -91,7 +115,7 @@ object JdbcUtils extends Logging { table: String, iterator: Iterator[Row], rddSchema: StructType, - nullTypes: Array[Int], + jdbcTypes: Array[JdbcType], batchSize: Int): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -106,7 +130,7 @@ object JdbcUtils extends Logging { var i = 0 while (i < numFields) { if (row.isNullAt(i)) { - stmt.setNull(i + 1, nullTypes(i)) + stmt.setNull(i + 1, jdbcTypes(i).jdbcNullType) } else { rddSchema.fields(i).dataType match { case IntegerType => stmt.setInt(i + 1, row.getInt(i)) @@ -121,6 +145,12 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case ArrayType(et, _) => + assert(jdbcTypes(i).databaseTypeDefinition.endsWith("[]")) + val array = conn.createArrayOf( + jdbcTypes(i).databaseTypeDefinition.dropRight(2).toLowerCase, + row.getSeq[AnyRef](i).toArray) + stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -170,22 +200,9 @@ object JdbcUtils extends Logging { df.schema.fields foreach { field => { val name = field.name val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition) + .orElse(getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition)) + .getOrElse(throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -201,24 +218,11 @@ object JdbcUtils extends Logging { table: String, properties: Properties = new Properties()) { val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) + val jdbcTypes: Array[JdbcType] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType) + .orElse(getCommonJDBCType(field.dataType)) + .getOrElse( + throw new IllegalArgumentException(s"Can't get JDBC type for field $field")) } val rddSchema = df.schema @@ -226,7 +230,7 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, table, iterator, rddSchema, jdbcTypes, batchSize) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 467d8d62d1b7f..1b6f90e3a1119 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -34,8 +34,12 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, precision, scale, md)).headOption } override def getJDBCType(dt: DataType): Option[JdbcType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c38c..27b61f1582399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -27,7 +27,11 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 14bfea4e3e287..ba0f4dff4aa6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -64,13 +64,18 @@ abstract class JdbcDialect { * Get the custom datatype mapping for the given jdbc meta information. * @param sqlType The sql type (see java.sql.Types) * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") - * @param size The size of the type. + * @param precision The precision of the type. + * @param scale The scale of the type. * @param md Result metadata associated with this type. * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) * or null if the default type mapping should be used. */ def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = None /** * Retrieve the jdbc / sql type for a given datatype. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 3eb722b070d5d..c3751f9a5459d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -25,7 +25,11 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients Option(StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index da413ed1f08b5..50de79fbae1c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -27,8 +27,12 @@ private case object MySQLDialect extends JdbcDialect { override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && precision != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. md.putLong("binarylong", 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 4165c382689f9..da77749acf5ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -27,16 +27,20 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = { // Handle NUMBER fields that have no precision/scale in special way // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale // For more details, please see // https://github.com/apache/spark/pull/8780#issuecomment-145598968 // and // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - if (sqlType == Types.NUMERIC && size == 0) { + if (sqlType == Types.NUMERIC && precision == 0) { // This is sub-optimal as we have to pick a precision/scale in advance whereas the data - // in Oracle is allowed to have different precision/scale for each value. + // in Oracle is allowed to have different precision/scale for each value. Option(DecimalType(DecimalType.MAX_PRECISION, 10)) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e701a7fcd9e16..0851917b7fcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -27,24 +28,50 @@ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Option(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("json")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { - Option(StringType) + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && precision != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER) { + toCatalystType(typeName, precision, scale).filter(_ == StringType) + } else if (sqlType == Types.ARRAY && typeName(0) == '_') { + toCatalystType(typeName.drop(1), precision, scale).map(ArrayType(_)) } else None } + // TODO: support more type names. + private def toCatalystType( + typeName: String, + precision: Int, + scale: Int): Option[DataType] = typeName match { + case "bool" => Some(BooleanType) + case "bit" => Some(BinaryType) + case "int2" => Some(ShortType) + case "int4" => Some(IntegerType) + case "int8" | "oid" => Some(LongType) + case "float4" => Some(FloatType) + case "money" | "float8" => Some(DoubleType) + case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => + Some(StringType) + case "bytea" => Some(BinaryType) + case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) + case "date" => Some(DateType) + case "numeric" if precision != 0 || scale != 0 => Some(DecimalType(precision, scale)) + case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case StringType => Some(JdbcType("TEXT", Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + getJDBCType(et).map(_.databaseTypeDefinition) + .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) + .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce2..99c5055df0152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -41,7 +41,11 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val testH2Dialect = new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = Some(StringType) } @@ -437,7 +441,11 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val agg = new AggregatedDialect(List(new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + sqlType: Int, + typeName: String, + precision: Int, + scale: Int, + md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { Some(LongType) } else { @@ -446,8 +454,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) - assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.getCatalystType(0, "", 1, 0, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, 0, null) === Some(StringType)) } test("DB2Dialect type mapping") { @@ -458,8 +466,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("PostgresDialect type mapping") { val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, 0, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, 0, null) === Some(StringType)) } test("DerbyDialect jdbc type mapping") { From fbaa5438625548a9d2e0eb84af0501dce497288d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Nov 2015 23:31:18 +0800 Subject: [PATCH 2/3] fix mima and address comments --- project/MimaExcludes.scala | 8 ++++- .../datasources/jdbc/JdbcUtils.scala | 31 +++++++++---------- .../apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 50220790d1f84..5d84d24351c6d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -106,7 +106,13 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.sql.SQLContext.setSession"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.createSession") + "org.apache.spark.sql.SQLContext.createSession"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.jdbc.MySQLDialect.getCatalystType"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.jdbc.AggregatedDialect.getCatalystType"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.jdbc.PostgresDialect.getCatalystType") ) ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.preferredNodeLocationData_="), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index a14457a8d2a60..32d28e59377a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.{JdbcType, JdbcDialects} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -96,6 +96,11 @@ object JdbcUtils extends Logging { } } + private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + } + /** * Saves a partition of a DataFrame to the JDBC database. This is done in * a single database transaction in order to avoid repeatedly inserting @@ -115,8 +120,9 @@ object JdbcUtils extends Logging { table: String, iterator: Iterator[Row], rddSchema: StructType, - jdbcTypes: Array[JdbcType], - batchSize: Int): Iterator[Byte] = { + nullTypes: Array[Int], + batchSize: Int, + dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false try { @@ -130,7 +136,7 @@ object JdbcUtils extends Logging { var i = 0 while (i < numFields) { if (row.isNullAt(i)) { - stmt.setNull(i + 1, jdbcTypes(i).jdbcNullType) + stmt.setNull(i + 1, nullTypes(i)) } else { rddSchema.fields(i).dataType match { case IntegerType => stmt.setInt(i + 1, row.getInt(i)) @@ -146,9 +152,8 @@ object JdbcUtils extends Logging { case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case ArrayType(et, _) => - assert(jdbcTypes(i).databaseTypeDefinition.endsWith("[]")) val array = conn.createArrayOf( - jdbcTypes(i).databaseTypeDefinition.dropRight(2).toLowerCase, + getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, row.getSeq[AnyRef](i).toArray) stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( @@ -199,10 +204,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition) - .orElse(getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition)) - .getOrElse(throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")) + val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -218,11 +220,8 @@ object JdbcUtils extends Logging { table: String, properties: Properties = new Properties()) { val dialect = JdbcDialects.get(url) - val jdbcTypes: Array[JdbcType] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType) - .orElse(getCommonJDBCType(field.dataType)) - .getOrElse( - throw new IllegalArgumentException(s"Can't get JDBC type for field $field")) + val nullTypes: Array[Int] = df.schema.fields.map { field => + getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema @@ -230,7 +229,7 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, jdbcTypes, batchSize) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index ba0f4dff4aa6c..9e7fad9dc4498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -abstract class JdbcDialect { +abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. * @param url the jdbc url. From ad52183193aee624952f961f831a0c569818b0e2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 17 Nov 2015 11:36:28 +0800 Subject: [PATCH 3/3] keep the public API unchanged --- project/MimaExcludes.scala | 8 +------- .../execution/datasources/jdbc/JDBCRDD.scala | 8 ++++---- .../spark/sql/jdbc/AggregatedDialect.scala | 8 ++------ .../apache/spark/sql/jdbc/DerbyDialect.scala | 6 +----- .../apache/spark/sql/jdbc/JdbcDialects.scala | 9 ++------- .../spark/sql/jdbc/MsSqlServerDialect.scala | 6 +----- .../apache/spark/sql/jdbc/MySQLDialect.scala | 8 ++------ .../apache/spark/sql/jdbc/OracleDialect.scala | 10 +++------- .../spark/sql/jdbc/PostgresDialect.scala | 20 ++++++------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 20 ++++++------------- 10 files changed, 28 insertions(+), 75 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 5d84d24351c6d..50220790d1f84 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -106,13 +106,7 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.sql.SQLContext.setSession"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.createSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.jdbc.MySQLDialect.getCatalystType"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.jdbc.AggregatedDialect.getCatalystType"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.jdbc.PostgresDialect.getCatalystType") + "org.apache.spark.sql.SQLContext.createSession") ) ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.preferredNodeLocationData_="), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index d944a5e1033dc..89c850ce238d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -130,14 +130,14 @@ private[sql] object JDBCRDD extends Logging { val columnName = rsmd.getColumnLabel(i + 1) val dataType = rsmd.getColumnType(i + 1) val typeName = rsmd.getColumnTypeName(i + 1) - val precision = rsmd.getPrecision(i + 1) - val scale = rsmd.getScale(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) val columnType = - dialect.getCatalystType(dataType, typeName, precision, scale, metadata).getOrElse( - getCatalystType(dataType, precision, scale, isSigned)) + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 1b6f90e3a1119..467d8d62d1b7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -34,12 +34,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = { - dialects.flatMap(_.getCatalystType(sqlType, typeName, precision, scale, md)).headOption + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption } override def getJDBCType(dt: DataType): Option[JdbcType] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 27b61f1582399..84f68e779c38c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -27,11 +27,7 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = { + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9e7fad9dc4498..b3b2cb6178c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -64,18 +64,13 @@ abstract class JdbcDialect extends Serializable { * Get the custom datatype mapping for the given jdbc meta information. * @param sqlType The sql type (see java.sql.Types) * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") - * @param precision The precision of the type. - * @param scale The scale of the type. + * @param size The size of the type. * @param md Result metadata associated with this type. * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) * or null if the default type mapping should be used. */ def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = None + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None /** * Retrieve the jdbc / sql type for a given datatype. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index c3751f9a5459d..3eb722b070d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -25,11 +25,7 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = { + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients Option(StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index 50de79fbae1c8..da413ed1f08b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -27,12 +27,8 @@ private case object MySQLDialect extends JdbcDialect { override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && precision != 1) { + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as // byte arrays instead of longs. md.putLong("binarylong", 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index da77749acf5ba..4165c382689f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -27,20 +27,16 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = { + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { // Handle NUMBER fields that have no precision/scale in special way // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale // For more details, please see // https://github.com/apache/spark/pull/8780#issuecomment-145598968 // and // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - if (sqlType == Types.NUMERIC && precision == 0) { + if (sqlType == Types.NUMERIC && size == 0) { // This is sub-optimal as we have to pick a precision/scale in advance whereas the data - // in Oracle is allowed to have different precision/scale for each value. + // in Oracle is allowed to have different precision/scale for each value. Option(DecimalType(DecimalType.MAX_PRECISION, 10)) } else { None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 0851917b7fcce..ed3faa1268635 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -28,25 +28,18 @@ private object PostgresDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && precision != 1) { + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { - toCatalystType(typeName, precision, scale).filter(_ == StringType) - } else if (sqlType == Types.ARRAY && typeName(0) == '_') { - toCatalystType(typeName.drop(1), precision, scale).map(ArrayType(_)) + toCatalystType(typeName).filter(_ == StringType) + } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { + toCatalystType(typeName.drop(1)).map(ArrayType(_)) } else None } // TODO: support more type names. - private def toCatalystType( - typeName: String, - precision: Int, - scale: Int): Option[DataType] = typeName match { + private def toCatalystType(typeName: String): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) @@ -59,7 +52,6 @@ private object PostgresDialect extends JdbcDialect { case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) - case "numeric" if precision != 0 || scale != 0 => Some(DecimalType(precision, scale)) case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 99c5055df0152..d530b1a469ce2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -41,11 +41,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val testH2Dialect = new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) } @@ -441,11 +437,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val agg = new AggregatedDialect(List(new JdbcDialect { override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( - sqlType: Int, - typeName: String, - precision: Int, - scale: Int, - md: MetadataBuilder): Option[DataType] = + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { Some(LongType) } else { @@ -454,8 +446,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0, "", 1, 0, null) === Some(LongType)) - assert(agg.getCatalystType(1, "", 1, 0, null) === Some(StringType)) + assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } test("DB2Dialect type mapping") { @@ -466,8 +458,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("PostgresDialect type mapping") { val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, 0, null) === Some(StringType)) - assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, 0, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) + assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) } test("DerbyDialect jdbc type mapping") {