Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
support array type in postgresql
  • Loading branch information
cloud-fan committed Nov 12, 2015
commit 378c5a9b414ff5e82f2c1bb0f3963279c2b4f929
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.Connection
import java.util.Properties

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Literal, If}
import org.apache.spark.tags.DockerTest

@DockerTest
Expand All @@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
+ "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ "c10 integer[], c11 text[])").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
}

test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 10)
assert(types(0).equals("class java.lang.String"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Double"))
assert(types(3).equals("class java.lang.Long"))
assert(types(4).equals("class java.lang.Boolean"))
assert(types(5).equals("class [B"))
assert(types(6).equals("class [B"))
assert(types(7).equals("class java.lang.Boolean"))
assert(types(8).equals("class java.lang.String"))
assert(types(9).equals("class java.lang.String"))
val types = rows(0).toSeq.map(x => x.getClass)
assert(types.length == 12)
assert(classOf[String].isAssignableFrom(types(0)))
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
assert(classOf[String].isAssignableFrom(types(8)))
assert(classOf[String].isAssignableFrom(types(9)))
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
Expand All @@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getBoolean(7) == true)
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
assert(rows(0).getSeq(10) == Seq(1, 2))
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
}

test("Basic write test") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test only that it doesn't crash.
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test write null values.
df.select(df.queryExecution.analyzed.output.map { a =>
Column(If(Literal(true), Literal(null), a)).as(a.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just Literal.create(null, a.dataType)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yea, we can simply this.

}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -130,14 +130,14 @@ private[sql] object JDBCRDD extends Logging {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val precision = rsmd.getPrecision(i + 1)
val scale = rsmd.getScale(i + 1)
val isSigned = rsmd.isSigned(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
dialect.getCatalystType(dataType, typeName, precision, scale, metadata).getOrElse(
getCatalystType(dataType, precision, scale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
Expand Down Expand Up @@ -324,25 +324,27 @@ private[sql] class JDBCRDD(
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion

/**
* Maps a StructType to a type tag list.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => sf.dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType =>
if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
}).toArray
def getConversions(schema: StructType): Array[JDBCConversion] =
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))

private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}

/**
Expand Down Expand Up @@ -420,16 +422,44 @@ private[sql] class JDBCRDD(
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
j = j + 1
}
mutableRow.setLong(i, ans)
}
case ArrayConversion(elementConversion) =>
val array = rs.getArray(pos).getArray
if (array != null) {
val data = elementConversion match {
case TimestampConversion =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringConversion =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateConversion =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case DecimalConversion(p, s) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
}
case BinaryLongConversion =>
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case _: ArrayConversion =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => array.asInstanceOf[Array[Any]]
}
mutableRow.update(i, new GenericArrayData(data))
} else {
mutableRow.update(i, null)
}
}
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
Expand Down Expand Up @@ -488,4 +518,12 @@ private[sql] class JDBCRDD(
nextValue
}
}

private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcType, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -72,6 +72,30 @@ object JdbcUtils extends Logging {
conn.prepareStatement(sql.toString())
}

/**
* Retrieve standard jdbc types.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
dt match {
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Option(
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}

/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction in order to avoid repeatedly inserting
Expand All @@ -91,7 +115,7 @@ object JdbcUtils extends Logging {
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
jdbcTypes: Array[JdbcType],
batchSize: Int): Iterator[Byte] = {
val conn = getConnection()
var committed = false
Expand All @@ -106,7 +130,7 @@ object JdbcUtils extends Logging {
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
stmt.setNull(i + 1, jdbcTypes(i).jdbcNullType)
} else {
rddSchema.fields(i).dataType match {
case IntegerType => stmt.setInt(i + 1, row.getInt(i))
Expand All @@ -121,6 +145,12 @@ object JdbcUtils extends Logging {
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
assert(jdbcTypes(i).databaseTypeDefinition.endsWith("[]"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that the same in all backends that support arrays (Oracle etc)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, it's not always working

val array = conn.createArrayOf(
jdbcTypes(i).databaseTypeDefinition.dropRight(2).toLowerCase,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
Expand Down Expand Up @@ -170,22 +200,9 @@ object JdbcUtils extends Logging {
df.schema.fields foreach { field => {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition)
.orElse(getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition))
.getOrElse(throw new IllegalArgumentException(s"Don't know how to save $field to JDBC"))
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
Expand All @@ -201,32 +218,19 @@ object JdbcUtils extends Logging {
table: String,
properties: Properties = new Properties()) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case DoubleType => java.sql.Types.DOUBLE
case FloatType => java.sql.Types.REAL
case ShortType => java.sql.Types.INTEGER
case ByteType => java.sql.Types.INTEGER
case BooleanType => java.sql.Types.BIT
case StringType => java.sql.Types.CLOB
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
val jdbcTypes: Array[JdbcType] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType)
.orElse(getCommonJDBCType(field.dataType))
.getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for field $field"))
}

val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
savePartition(getConnection, table, iterator, rddSchema, jdbcTypes, batchSize)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect
dialects.map(_.canHandle(url)).reduce(_ && _)

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption
sqlType: Int,
typeName: String,
precision: Int,
scale: Int,
md: MetadataBuilder): Option[DataType] = {
dialects.flatMap(_.getCatalystType(sqlType, typeName, precision, scale, md)).headOption
}

override def getJDBCType(dt: DataType): Option[JdbcType] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ private object DerbyDialect extends JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby")

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
sqlType: Int,
typeName: String,
precision: Int,
scale: Int,
md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.REAL) Option(FloatType) else None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ abstract class JdbcDialect {
* Get the custom datatype mapping for the given jdbc meta information.
* @param sqlType The sql type (see java.sql.Types)
* @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
* @param size The size of the type.
* @param precision The precision of the type.
* @param scale The scale of the type.
* @param md Result metadata associated with this type.
* @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
* or null if the default type mapping should be used.
*/
def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
sqlType: Int,
typeName: String,
precision: Int,
scale: Int,
md: MetadataBuilder): Option[DataType] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a public API, if we add fields here we are going to break any existing implementations. Do we have to do this?


/**
* Retrieve the jdbc / sql type for a given datatype.
Expand Down
Loading