-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-10186][SQL] support postgre array type in JDBCRDD #9662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import java.util.Properties | |
| import scala.util.Try | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.sql.jdbc.JdbcDialects | ||
| import org.apache.spark.sql.jdbc.{JdbcType, JdbcDialects} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.sql.{DataFrame, Row} | ||
|
|
||
|
|
@@ -72,6 +72,30 @@ object JdbcUtils extends Logging { | |
| conn.prepareStatement(sql.toString()) | ||
| } | ||
|
|
||
| /** | ||
| * Retrieve standard jdbc types. | ||
| * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) | ||
| * @return The default JdbcType for this DataType | ||
| */ | ||
| def getCommonJDBCType(dt: DataType): Option[JdbcType] = { | ||
| dt match { | ||
| case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) | ||
| case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) | ||
| case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) | ||
| case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) | ||
| case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) | ||
| case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) | ||
| case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) | ||
| case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) | ||
| case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) | ||
| case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) | ||
| case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) | ||
| case t: DecimalType => Option( | ||
| JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) | ||
| case _ => None | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Saves a partition of a DataFrame to the JDBC database. This is done in | ||
| * a single database transaction in order to avoid repeatedly inserting | ||
|
|
@@ -91,7 +115,7 @@ object JdbcUtils extends Logging { | |
| table: String, | ||
| iterator: Iterator[Row], | ||
| rddSchema: StructType, | ||
| nullTypes: Array[Int], | ||
| jdbcTypes: Array[JdbcType], | ||
| batchSize: Int): Iterator[Byte] = { | ||
| val conn = getConnection() | ||
| var committed = false | ||
|
|
@@ -106,7 +130,7 @@ object JdbcUtils extends Logging { | |
| var i = 0 | ||
| while (i < numFields) { | ||
| if (row.isNullAt(i)) { | ||
| stmt.setNull(i + 1, nullTypes(i)) | ||
| stmt.setNull(i + 1, jdbcTypes(i).jdbcNullType) | ||
| } else { | ||
| rddSchema.fields(i).dataType match { | ||
| case IntegerType => stmt.setInt(i + 1, row.getInt(i)) | ||
|
|
@@ -121,6 +145,12 @@ object JdbcUtils extends Logging { | |
| case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) | ||
| case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) | ||
| case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) | ||
| case ArrayType(et, _) => | ||
| assert(jdbcTypes(i).databaseTypeDefinition.endsWith("[]")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that the same in all backends that support arrays (Oracle etc)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are right, it's not always working |
||
| val array = conn.createArrayOf( | ||
| jdbcTypes(i).databaseTypeDefinition.dropRight(2).toLowerCase, | ||
| row.getSeq[AnyRef](i).toArray) | ||
| stmt.setArray(i + 1, array) | ||
| case _ => throw new IllegalArgumentException( | ||
| s"Can't translate non-null value for field $i") | ||
| } | ||
|
|
@@ -170,22 +200,9 @@ object JdbcUtils extends Logging { | |
| df.schema.fields foreach { field => { | ||
| val name = field.name | ||
| val typ: String = | ||
| dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( | ||
| field.dataType match { | ||
| case IntegerType => "INTEGER" | ||
| case LongType => "BIGINT" | ||
| case DoubleType => "DOUBLE PRECISION" | ||
| case FloatType => "REAL" | ||
| case ShortType => "INTEGER" | ||
| case ByteType => "BYTE" | ||
| case BooleanType => "BIT(1)" | ||
| case StringType => "TEXT" | ||
| case BinaryType => "BLOB" | ||
| case TimestampType => "TIMESTAMP" | ||
| case DateType => "DATE" | ||
| case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" | ||
| case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") | ||
| }) | ||
| dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition) | ||
| .orElse(getCommonJDBCType(field.dataType).map(_.databaseTypeDefinition)) | ||
| .getOrElse(throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")) | ||
| val nullable = if (field.nullable) "" else "NOT NULL" | ||
| sb.append(s", $name $typ $nullable") | ||
| }} | ||
|
|
@@ -201,32 +218,19 @@ object JdbcUtils extends Logging { | |
| table: String, | ||
| properties: Properties = new Properties()) { | ||
| val dialect = JdbcDialects.get(url) | ||
| val nullTypes: Array[Int] = df.schema.fields.map { field => | ||
| dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( | ||
| field.dataType match { | ||
| case IntegerType => java.sql.Types.INTEGER | ||
| case LongType => java.sql.Types.BIGINT | ||
| case DoubleType => java.sql.Types.DOUBLE | ||
| case FloatType => java.sql.Types.REAL | ||
| case ShortType => java.sql.Types.INTEGER | ||
| case ByteType => java.sql.Types.INTEGER | ||
| case BooleanType => java.sql.Types.BIT | ||
| case StringType => java.sql.Types.CLOB | ||
| case BinaryType => java.sql.Types.BLOB | ||
| case TimestampType => java.sql.Types.TIMESTAMP | ||
| case DateType => java.sql.Types.DATE | ||
| case t: DecimalType => java.sql.Types.DECIMAL | ||
| case _ => throw new IllegalArgumentException( | ||
| s"Can't translate null value for field $field") | ||
| }) | ||
| val jdbcTypes: Array[JdbcType] = df.schema.fields.map { field => | ||
| dialect.getJDBCType(field.dataType) | ||
| .orElse(getCommonJDBCType(field.dataType)) | ||
| .getOrElse( | ||
| throw new IllegalArgumentException(s"Can't get JDBC type for field $field")) | ||
| } | ||
|
|
||
| val rddSchema = df.schema | ||
| val driver: String = DriverRegistry.getDriverClassName(url) | ||
| val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) | ||
| val batchSize = properties.getProperty("batchsize", "1000").toInt | ||
| df.foreachPartition { iterator => | ||
| savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) | ||
| savePartition(getConnection, table, iterator, rddSchema, jdbcTypes, batchSize) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,13 +64,18 @@ abstract class JdbcDialect { | |
| * Get the custom datatype mapping for the given jdbc meta information. | ||
| * @param sqlType The sql type (see java.sql.Types) | ||
| * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") | ||
| * @param size The size of the type. | ||
| * @param precision The precision of the type. | ||
| * @param scale The scale of the type. | ||
| * @param md Result metadata associated with this type. | ||
| * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) | ||
| * or null if the default type mapping should be used. | ||
| */ | ||
| def getCatalystType( | ||
| sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None | ||
| sqlType: Int, | ||
| typeName: String, | ||
| precision: Int, | ||
| scale: Int, | ||
| md: MetadataBuilder): Option[DataType] = None | ||
|
||
|
|
||
| /** | ||
| * Retrieve the jdbc / sql type for a given datatype. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just
Literal.create(null, a.dataType)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yea, we can simply this.