Skip to content

Commit 88dadd9

Browse files
cloud-fankiszk
authored andcommitted
[SPARK-10186][SQL] support postgre array type in JDBCRDD
Add ARRAY support to `PostgresDialect`. Nested ARRAY is not allowed for now because it's hard to get the array dimension info. See http://stackoverflow.com/questions/16619113/how-to-get-array-base-type-in-postgres-via-jdbc Thanks for the initial work from mariusvniekerk ! Close apache/spark#9137 Author: Wenchen Fan <[email protected]> Closes #9662 from cloud-fan/postgre.
1 parent 2480086 commit 88dadd9

File tree

5 files changed

+157
-85
lines changed

5 files changed

+157
-85
lines changed

docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
2020
import java.sql.Connection
2121
import java.util.Properties
2222

23+
import org.apache.spark.sql.Column
24+
import org.apache.spark.sql.catalyst.expressions.{Literal, If}
2325
import org.apache.spark.tags.DockerTest
2426

2527
@DockerTest
@@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
3739
override def dataPreparation(conn: Connection): Unit = {
3840
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
3941
conn.setCatalog("foo")
40-
conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
41-
+ "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
42+
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
43+
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
44+
+ "c10 integer[], c11 text[])").executeUpdate()
4245
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
43-
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
46+
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
47+
+ """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
4448
}
4549

4650
test("Type mapping for various types") {
4751
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
4852
val rows = df.collect()
4953
assert(rows.length == 1)
50-
val types = rows(0).toSeq.map(x => x.getClass.toString)
51-
assert(types.length == 10)
52-
assert(types(0).equals("class java.lang.String"))
53-
assert(types(1).equals("class java.lang.Integer"))
54-
assert(types(2).equals("class java.lang.Double"))
55-
assert(types(3).equals("class java.lang.Long"))
56-
assert(types(4).equals("class java.lang.Boolean"))
57-
assert(types(5).equals("class [B"))
58-
assert(types(6).equals("class [B"))
59-
assert(types(7).equals("class java.lang.Boolean"))
60-
assert(types(8).equals("class java.lang.String"))
61-
assert(types(9).equals("class java.lang.String"))
54+
val types = rows(0).toSeq.map(x => x.getClass)
55+
assert(types.length == 12)
56+
assert(classOf[String].isAssignableFrom(types(0)))
57+
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
58+
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
59+
assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
60+
assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
61+
assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
62+
assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
63+
assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
64+
assert(classOf[String].isAssignableFrom(types(8)))
65+
assert(classOf[String].isAssignableFrom(types(9)))
66+
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
67+
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
6268
assert(rows(0).getString(0).equals("hello"))
6369
assert(rows(0).getInt(1) == 42)
6470
assert(rows(0).getDouble(2) == 1.25)
@@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
7278
assert(rows(0).getBoolean(7) == true)
7379
assert(rows(0).getString(8) == "172.16.0.42")
7480
assert(rows(0).getString(9) == "192.168.0.0/16")
81+
assert(rows(0).getSeq(10) == Seq(1, 2))
82+
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
7583
}
7684

7785
test("Basic write test") {
7886
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
79-
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
8087
// Test only that it doesn't crash.
88+
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
89+
// Test write null values.
90+
df.select(df.queryExecution.analyzed.output.map { a =>
91+
Column(If(Literal(true), Literal(null), a)).as(a.name)
92+
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
8193
}
8294
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
28-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
28+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils}
2929
import org.apache.spark.sql.jdbc.JdbcDialects
3030
import org.apache.spark.sql.sources._
3131
import org.apache.spark.sql.types._
@@ -324,25 +324,27 @@ private[sql] class JDBCRDD(
324324
case object StringConversion extends JDBCConversion
325325
case object TimestampConversion extends JDBCConversion
326326
case object BinaryConversion extends JDBCConversion
327+
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
327328

328329
/**
329330
* Maps a StructType to a type tag list.
330331
*/
331-
def getConversions(schema: StructType): Array[JDBCConversion] = {
332-
schema.fields.map(sf => sf.dataType match {
333-
case BooleanType => BooleanConversion
334-
case DateType => DateConversion
335-
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
336-
case DoubleType => DoubleConversion
337-
case FloatType => FloatConversion
338-
case IntegerType => IntegerConversion
339-
case LongType =>
340-
if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
341-
case StringType => StringConversion
342-
case TimestampType => TimestampConversion
343-
case BinaryType => BinaryConversion
344-
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
345-
}).toArray
332+
def getConversions(schema: StructType): Array[JDBCConversion] =
333+
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
334+
335+
private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
336+
case BooleanType => BooleanConversion
337+
case DateType => DateConversion
338+
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
339+
case DoubleType => DoubleConversion
340+
case FloatType => FloatConversion
341+
case IntegerType => IntegerConversion
342+
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
343+
case StringType => StringConversion
344+
case TimestampType => TimestampConversion
345+
case BinaryType => BinaryConversion
346+
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
347+
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
346348
}
347349

348350
/**
@@ -420,16 +422,44 @@ private[sql] class JDBCRDD(
420422
mutableRow.update(i, null)
421423
}
422424
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
423-
case BinaryLongConversion => {
425+
case BinaryLongConversion =>
424426
val bytes = rs.getBytes(pos)
425427
var ans = 0L
426428
var j = 0
427429
while (j < bytes.size) {
428430
ans = 256 * ans + (255 & bytes(j))
429-
j = j + 1;
431+
j = j + 1
430432
}
431433
mutableRow.setLong(i, ans)
432-
}
434+
case ArrayConversion(elementConversion) =>
435+
val array = rs.getArray(pos).getArray
436+
if (array != null) {
437+
val data = elementConversion match {
438+
case TimestampConversion =>
439+
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
440+
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
441+
}
442+
case StringConversion =>
443+
array.asInstanceOf[Array[java.lang.String]]
444+
.map(UTF8String.fromString)
445+
case DateConversion =>
446+
array.asInstanceOf[Array[java.sql.Date]].map { date =>
447+
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
448+
}
449+
case DecimalConversion(p, s) =>
450+
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
451+
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
452+
}
453+
case BinaryLongConversion =>
454+
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
455+
case _: ArrayConversion =>
456+
throw new IllegalArgumentException("Nested arrays unsupported")
457+
case _ => array.asInstanceOf[Array[Any]]
458+
}
459+
mutableRow.update(i, new GenericArrayData(data))
460+
} else {
461+
mutableRow.update(i, null)
462+
}
433463
}
434464
if (rs.wasNull) mutableRow.setNullAt(i)
435465
i = i + 1
@@ -488,4 +518,12 @@ private[sql] class JDBCRDD(
488518
nextValue
489519
}
490520
}
521+
522+
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
523+
if (input == null) {
524+
null
525+
} else {
526+
f(input)
527+
}
528+
}
491529
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.util.Properties
2323
import scala.util.Try
2424

2525
import org.apache.spark.Logging
26-
import org.apache.spark.sql.jdbc.JdbcDialects
26+
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects}
2727
import org.apache.spark.sql.types._
2828
import org.apache.spark.sql.{DataFrame, Row}
2929

@@ -72,6 +72,35 @@ object JdbcUtils extends Logging {
7272
conn.prepareStatement(sql.toString())
7373
}
7474

75+
/**
76+
* Retrieve standard jdbc types.
77+
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
78+
* @return The default JdbcType for this DataType
79+
*/
80+
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
81+
dt match {
82+
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
83+
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
84+
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
85+
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
86+
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
87+
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
88+
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
89+
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
90+
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
91+
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
92+
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
93+
case t: DecimalType => Option(
94+
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
95+
case _ => None
96+
}
97+
}
98+
99+
private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
100+
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
101+
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
102+
}
103+
75104
/**
76105
* Saves a partition of a DataFrame to the JDBC database. This is done in
77106
* a single database transaction in order to avoid repeatedly inserting
@@ -92,7 +121,8 @@ object JdbcUtils extends Logging {
92121
iterator: Iterator[Row],
93122
rddSchema: StructType,
94123
nullTypes: Array[Int],
95-
batchSize: Int): Iterator[Byte] = {
124+
batchSize: Int,
125+
dialect: JdbcDialect): Iterator[Byte] = {
96126
val conn = getConnection()
97127
var committed = false
98128
try {
@@ -121,6 +151,11 @@ object JdbcUtils extends Logging {
121151
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
122152
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
123153
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
154+
case ArrayType(et, _) =>
155+
val array = conn.createArrayOf(
156+
getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase,
157+
row.getSeq[AnyRef](i).toArray)
158+
stmt.setArray(i + 1, array)
124159
case _ => throw new IllegalArgumentException(
125160
s"Can't translate non-null value for field $i")
126161
}
@@ -169,23 +204,7 @@ object JdbcUtils extends Logging {
169204
val dialect = JdbcDialects.get(url)
170205
df.schema.fields foreach { field => {
171206
val name = field.name
172-
val typ: String =
173-
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
174-
field.dataType match {
175-
case IntegerType => "INTEGER"
176-
case LongType => "BIGINT"
177-
case DoubleType => "DOUBLE PRECISION"
178-
case FloatType => "REAL"
179-
case ShortType => "INTEGER"
180-
case ByteType => "BYTE"
181-
case BooleanType => "BIT(1)"
182-
case StringType => "TEXT"
183-
case BinaryType => "BLOB"
184-
case TimestampType => "TIMESTAMP"
185-
case DateType => "DATE"
186-
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
187-
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
188-
})
207+
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
189208
val nullable = if (field.nullable) "" else "NOT NULL"
190209
sb.append(s", $name $typ $nullable")
191210
}}
@@ -202,31 +221,15 @@ object JdbcUtils extends Logging {
202221
properties: Properties = new Properties()) {
203222
val dialect = JdbcDialects.get(url)
204223
val nullTypes: Array[Int] = df.schema.fields.map { field =>
205-
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
206-
field.dataType match {
207-
case IntegerType => java.sql.Types.INTEGER
208-
case LongType => java.sql.Types.BIGINT
209-
case DoubleType => java.sql.Types.DOUBLE
210-
case FloatType => java.sql.Types.REAL
211-
case ShortType => java.sql.Types.INTEGER
212-
case ByteType => java.sql.Types.INTEGER
213-
case BooleanType => java.sql.Types.BIT
214-
case StringType => java.sql.Types.CLOB
215-
case BinaryType => java.sql.Types.BLOB
216-
case TimestampType => java.sql.Types.TIMESTAMP
217-
case DateType => java.sql.Types.DATE
218-
case t: DecimalType => java.sql.Types.DECIMAL
219-
case _ => throw new IllegalArgumentException(
220-
s"Can't translate null value for field $field")
221-
})
224+
getJdbcType(field.dataType, dialect).jdbcNullType
222225
}
223226

224227
val rddSchema = df.schema
225228
val driver: String = DriverRegistry.getDriverClassName(url)
226229
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
227230
val batchSize = properties.getProperty("batchsize", "1000").toInt
228231
df.foreachPartition { iterator =>
229-
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
232+
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
230233
}
231234
}
232235

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
5151
* for the given Catalyst type.
5252
*/
5353
@DeveloperApi
54-
abstract class JdbcDialect {
54+
abstract class JdbcDialect extends Serializable {
5555
/**
5656
* Check if this dialect instance can handle a certain jdbc url.
5757
* @param url the jdbc url.

sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc
1919

2020
import java.sql.Types
2121

22+
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
2223
import org.apache.spark.sql.types._
2324

2425

@@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect {
2930
override def getCatalystType(
3031
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
3132
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
32-
Option(BinaryType)
33-
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
34-
Option(StringType)
35-
} else if (sqlType == Types.OTHER && typeName.equals("inet")) {
36-
Option(StringType)
37-
} else if (sqlType == Types.OTHER && typeName.equals("json")) {
38-
Option(StringType)
39-
} else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
40-
Option(StringType)
33+
Some(BinaryType)
34+
} else if (sqlType == Types.OTHER) {
35+
toCatalystType(typeName).filter(_ == StringType)
36+
} else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
37+
toCatalystType(typeName.drop(1)).map(ArrayType(_))
4138
} else None
4239
}
4340

41+
// TODO: support more type names.
42+
private def toCatalystType(typeName: String): Option[DataType] = typeName match {
43+
case "bool" => Some(BooleanType)
44+
case "bit" => Some(BinaryType)
45+
case "int2" => Some(ShortType)
46+
case "int4" => Some(IntegerType)
47+
case "int8" | "oid" => Some(LongType)
48+
case "float4" => Some(FloatType)
49+
case "money" | "float8" => Some(DoubleType)
50+
case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
51+
Some(StringType)
52+
case "bytea" => Some(BinaryType)
53+
case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
54+
case "date" => Some(DateType)
55+
case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
56+
case _ => None
57+
}
58+
4459
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
45-
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
46-
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
47-
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
60+
case StringType => Some(JdbcType("TEXT", Types.CHAR))
61+
case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
62+
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
63+
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
64+
getJDBCType(et).map(_.databaseTypeDefinition)
65+
.orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
66+
.map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
4867
case _ => None
4968
}
5069

0 commit comments

Comments
 (0)