Skip to content

Commit 378c5a9

Browse files
committed
support array type in postgresql
1 parent b8ff688 commit 378c5a9

File tree

11 files changed

+222
-108
lines changed

11 files changed

+222
-108
lines changed

docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
2020
import java.sql.Connection
2121
import java.util.Properties
2222

23+
import org.apache.spark.sql.Column
24+
import org.apache.spark.sql.catalyst.expressions.{Literal, If}
2325
import org.apache.spark.tags.DockerTest
2426

2527
@DockerTest
@@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
3739
override def dataPreparation(conn: Connection): Unit = {
3840
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
3941
conn.setCatalog("foo")
40-
conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
41-
+ "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
42+
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
43+
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
44+
+ "c10 integer[], c11 text[])").executeUpdate()
4245
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
43-
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
46+
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
47+
+ """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
4448
}
4549

4650
test("Type mapping for various types") {
4751
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
4852
val rows = df.collect()
4953
assert(rows.length == 1)
50-
val types = rows(0).toSeq.map(x => x.getClass.toString)
51-
assert(types.length == 10)
52-
assert(types(0).equals("class java.lang.String"))
53-
assert(types(1).equals("class java.lang.Integer"))
54-
assert(types(2).equals("class java.lang.Double"))
55-
assert(types(3).equals("class java.lang.Long"))
56-
assert(types(4).equals("class java.lang.Boolean"))
57-
assert(types(5).equals("class [B"))
58-
assert(types(6).equals("class [B"))
59-
assert(types(7).equals("class java.lang.Boolean"))
60-
assert(types(8).equals("class java.lang.String"))
61-
assert(types(9).equals("class java.lang.String"))
54+
val types = rows(0).toSeq.map(x => x.getClass)
55+
assert(types.length == 12)
56+
assert(classOf[String].isAssignableFrom(types(0)))
57+
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
58+
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
59+
assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
60+
assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
61+
assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
62+
assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
63+
assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
64+
assert(classOf[String].isAssignableFrom(types(8)))
65+
assert(classOf[String].isAssignableFrom(types(9)))
66+
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
67+
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
6268
assert(rows(0).getString(0).equals("hello"))
6369
assert(rows(0).getInt(1) == 42)
6470
assert(rows(0).getDouble(2) == 1.25)
@@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
7278
assert(rows(0).getBoolean(7) == true)
7379
assert(rows(0).getString(8) == "172.16.0.42")
7480
assert(rows(0).getString(9) == "192.168.0.0/16")
81+
assert(rows(0).getSeq(10) == Seq(1, 2))
82+
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
7583
}
7684

7785
test("Basic write test") {
7886
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
79-
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
8087
// Test only that it doesn't crash.
88+
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
89+
// Test write null values.
90+
df.select(df.queryExecution.analyzed.output.map { a =>
91+
Column(If(Literal(true), Literal(null), a)).as(a.name)
92+
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
8193
}
8294
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
28-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
28+
import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils}
2929
import org.apache.spark.sql.jdbc.JdbcDialects
3030
import org.apache.spark.sql.sources._
3131
import org.apache.spark.sql.types._
@@ -130,14 +130,14 @@ private[sql] object JDBCRDD extends Logging {
130130
val columnName = rsmd.getColumnLabel(i + 1)
131131
val dataType = rsmd.getColumnType(i + 1)
132132
val typeName = rsmd.getColumnTypeName(i + 1)
133-
val fieldSize = rsmd.getPrecision(i + 1)
134-
val fieldScale = rsmd.getScale(i + 1)
133+
val precision = rsmd.getPrecision(i + 1)
134+
val scale = rsmd.getScale(i + 1)
135135
val isSigned = rsmd.isSigned(i + 1)
136136
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
137137
val metadata = new MetadataBuilder().putString("name", columnName)
138138
val columnType =
139-
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
140-
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
139+
dialect.getCatalystType(dataType, typeName, precision, scale, metadata).getOrElse(
140+
getCatalystType(dataType, precision, scale, isSigned))
141141
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
142142
i = i + 1
143143
}
@@ -324,25 +324,27 @@ private[sql] class JDBCRDD(
324324
case object StringConversion extends JDBCConversion
325325
case object TimestampConversion extends JDBCConversion
326326
case object BinaryConversion extends JDBCConversion
327+
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
327328

328329
/**
329330
* Maps a StructType to a type tag list.
330331
*/
331-
def getConversions(schema: StructType): Array[JDBCConversion] = {
332-
schema.fields.map(sf => sf.dataType match {
333-
case BooleanType => BooleanConversion
334-
case DateType => DateConversion
335-
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
336-
case DoubleType => DoubleConversion
337-
case FloatType => FloatConversion
338-
case IntegerType => IntegerConversion
339-
case LongType =>
340-
if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
341-
case StringType => StringConversion
342-
case TimestampType => TimestampConversion
343-
case BinaryType => BinaryConversion
344-
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
345-
}).toArray
332+
def getConversions(schema: StructType): Array[JDBCConversion] =
333+
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
334+
335+
private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
336+
case BooleanType => BooleanConversion
337+
case DateType => DateConversion
338+
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
339+
case DoubleType => DoubleConversion
340+
case FloatType => FloatConversion
341+
case IntegerType => IntegerConversion
342+
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
343+
case StringType => StringConversion
344+
case TimestampType => TimestampConversion
345+
case BinaryType => BinaryConversion
346+
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
347+
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
346348
}
347349

348350
/**
@@ -420,16 +422,44 @@ private[sql] class JDBCRDD(
420422
mutableRow.update(i, null)
421423
}
422424
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
423-
case BinaryLongConversion => {
425+
case BinaryLongConversion =>
424426
val bytes = rs.getBytes(pos)
425427
var ans = 0L
426428
var j = 0
427429
while (j < bytes.size) {
428430
ans = 256 * ans + (255 & bytes(j))
429-
j = j + 1;
431+
j = j + 1
430432
}
431433
mutableRow.setLong(i, ans)
432-
}
434+
case ArrayConversion(elementConversion) =>
435+
val array = rs.getArray(pos).getArray
436+
if (array != null) {
437+
val data = elementConversion match {
438+
case TimestampConversion =>
439+
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
440+
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
441+
}
442+
case StringConversion =>
443+
array.asInstanceOf[Array[java.lang.String]]
444+
.map(UTF8String.fromString)
445+
case DateConversion =>
446+
array.asInstanceOf[Array[java.sql.Date]].map { date =>
447+
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
448+
}
449+
case DecimalConversion(p, s) =>
450+
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
451+
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
452+
}
453+
case BinaryLongConversion =>
454+
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
455+
case _: ArrayConversion =>
456+
throw new IllegalArgumentException("Nested arrays unsupported")
457+
case _ => array.asInstanceOf[Array[Any]]
458+
}
459+
mutableRow.update(i, new GenericArrayData(data))
460+
} else {
461+
mutableRow.update(i, null)
462+
}
433463
}
434464
if (rs.wasNull) mutableRow.setNullAt(i)
435465
i = i + 1
@@ -488,4 +518,12 @@ private[sql] class JDBCRDD(
488518
nextValue
489519
}
490520
}
521+
522+
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
523+
if (input == null) {
524+
null
525+
} else {
526+
f(input)
527+
}
528+
}
491529
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import java.util.Properties
2323
import scala.util.Try
2424

2525
import org.apache.spark.Logging
26-
import org.apache.spark.sql.jdbc.JdbcDialects
26+
import org.apache.spark.sql.jdbc.{JdbcType, JdbcDialects}
2727
import org.apache.spark.sql.types._
2828
import org.apache.spark.sql.{DataFrame, Row}
2929

@@ -72,6 +72,30 @@ object JdbcUtils extends Logging {
7272
conn.prepareStatement(sql.toString())
7373
}
7474

75+
/**
76+
* Retrieve standard jdbc types.
77+
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
78+
* @return The default JdbcType for this DataType
79+
*/
80+
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
81+
dt match {
82+
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
83+
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
84+
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
85+
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
86+
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
87+
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
88+
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
89+
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
90+
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
91+
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
92+
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
93+
case t: DecimalType => Option(
94+
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
95+
case _ => None
96+
}
97+
}
98+
7599
/**
76100
* Saves a partition of a DataFrame to the JDBC database. This is done in
77101
* a single database transaction in order to avoid repeatedly inserting
@@ -91,7 +115,7 @@ object JdbcUtils extends Logging {
91115
table: String,
92116
iterator: Iterator[Row],
93117
rddSchema: StructType,
94-
nullTypes: Array[Int],
118+
jdbcTypes: Array[JdbcType],
95119
batchSize: Int): Iterator[Byte] = {
96120
val conn = getConnection()
97121
var committed = false
@@ -106,7 +130,7 @@ object JdbcUtils extends Logging {
106130
var i = 0
107131
while (i < numFields) {
108132
if (row.isNullAt(i)) {
109-
stmt.setNull(i + 1, nullTypes(i))
133+
stmt.setNull(i + 1, jdbcTypes(i).jdbcNullType)
110134
} else {
111135
rddSchema.fields(i).dataType match {
112136
case IntegerType => stmt.setInt(i + 1, row.getInt(i))
@@ -121,6 +145,12 @@ object JdbcUtils extends Logging {
121145
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
122146
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
123147
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
148+
case ArrayType(et, _) =>
149+
assert(jdbcTypes(i).databaseTypeDefinition.endsWith("[]"))
150+
val array = conn.createArrayOf(
151+
jdbcTypes(i).databaseTypeDefinition.dropRight(2).toLowerCase,
152+
row.getSeq[AnyRef](i).toArray)
153+
stmt.setArray(i + 1, array)
124154
case _ => throw new IllegalArgumentException(
125155
s"Can't translate non-null value for field $i")
126156
}
@@ -170,22 +200,9 @@ object JdbcUtils extends Logging {
170200
df.schema.fields foreach { field => {
171201
val name = field.name
172202
val typ: String =
173-
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
174-
field.dataType match {
175-
case IntegerType => "INTEGER"
176-
case LongType => "BIGINT"
177-
case DoubleType => "DOUBLE PRECISION"
178-
case FloatType => "REAL"
179-
case ShortType => "INTEGER"
180-
case ByteType => "BYTE"
181-
case BooleanType => "BIT(1)"
182-
case StringType => "TEXT"
183-
case BinaryType => "BLOB"
184-
case TimestampType => "TIMESTAMP"
185-
case DateType => "DATE"
186-
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
187-
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
188-
})
203+
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition)
204+
.orElse(getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition))
205+
.getOrElse(throw new IllegalArgumentException(s"Don't know how to save $field to JDBC"))
189206
val nullable = if (field.nullable) "" else "NOT NULL"
190207
sb.append(s", $name $typ $nullable")
191208
}}
@@ -201,32 +218,19 @@ object JdbcUtils extends Logging {
201218
table: String,
202219
properties: Properties = new Properties()) {
203220
val dialect = JdbcDialects.get(url)
204-
val nullTypes: Array[Int] = df.schema.fields.map { field =>
205-
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
206-
field.dataType match {
207-
case IntegerType => java.sql.Types.INTEGER
208-
case LongType => java.sql.Types.BIGINT
209-
case DoubleType => java.sql.Types.DOUBLE
210-
case FloatType => java.sql.Types.REAL
211-
case ShortType => java.sql.Types.INTEGER
212-
case ByteType => java.sql.Types.INTEGER
213-
case BooleanType => java.sql.Types.BIT
214-
case StringType => java.sql.Types.CLOB
215-
case BinaryType => java.sql.Types.BLOB
216-
case TimestampType => java.sql.Types.TIMESTAMP
217-
case DateType => java.sql.Types.DATE
218-
case t: DecimalType => java.sql.Types.DECIMAL
219-
case _ => throw new IllegalArgumentException(
220-
s"Can't translate null value for field $field")
221-
})
221+
val jdbcTypes: Array[JdbcType] = df.schema.fields.map { field =>
222+
dialect.getJDBCType(field.dataType)
223+
.orElse(getCommonJDBCType(field.dataType))
224+
.getOrElse(
225+
throw new IllegalArgumentException(s"Can't get JDBC type for field $field"))
222226
}
223227

224228
val rddSchema = df.schema
225229
val driver: String = DriverRegistry.getDriverClassName(url)
226230
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
227231
val batchSize = properties.getProperty("batchsize", "1000").toInt
228232
df.foreachPartition { iterator =>
229-
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
233+
savePartition(getConnection, table, iterator, rddSchema, jdbcTypes, batchSize)
230234
}
231235
}
232236

sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect
3434
dialects.map(_.canHandle(url)).reduce(_ && _)
3535

3636
override def getCatalystType(
37-
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
38-
dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption
37+
sqlType: Int,
38+
typeName: String,
39+
precision: Int,
40+
scale: Int,
41+
md: MetadataBuilder): Option[DataType] = {
42+
dialects.flatMap(_.getCatalystType(sqlType, typeName, precision, scale, md)).headOption
3943
}
4044

4145
override def getJDBCType(dt: DataType): Option[JdbcType] = {

sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ private object DerbyDialect extends JdbcDialect {
2727
override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby")
2828

2929
override def getCatalystType(
30-
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
30+
sqlType: Int,
31+
typeName: String,
32+
precision: Int,
33+
scale: Int,
34+
md: MetadataBuilder): Option[DataType] = {
3135
if (sqlType == Types.REAL) Option(FloatType) else None
3236
}
3337

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,18 @@ abstract class JdbcDialect {
6464
* Get the custom datatype mapping for the given jdbc meta information.
6565
* @param sqlType The sql type (see java.sql.Types)
6666
* @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
67-
* @param size The size of the type.
67+
* @param precision The precision of the type.
68+
* @param scale The scale of the type.
6869
* @param md Result metadata associated with this type.
6970
* @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
7071
* or null if the default type mapping should be used.
7172
*/
7273
def getCatalystType(
73-
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
74+
sqlType: Int,
75+
typeName: String,
76+
precision: Int,
77+
scale: Int,
78+
md: MetadataBuilder): Option[DataType] = None
7479

7580
/**
7681
* Retrieve the jdbc / sql type for a given datatype.

0 commit comments

Comments
 (0)