Skip to content
Closed
Prev Previous commit
Next Next commit
WIP: First pass at function to generate shema for "VARCHAR(N)"
  • Loading branch information
Aerlinger committed Aug 3, 2015
commit 1f75db89d5aee8d3451f46ad532a86f8e2faa3e2
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private[redshift] object Conversions {
def setStrLength(metadata:Metadata, length:Int) : Metadata = {
new MetadataBuilder()
.withMetadata(metadata)
.putDouble("maxLength", length)
.putLong("maxLength", length)
.build()
}

Expand Down
20 changes: 11 additions & 9 deletions src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.databricks.spark.redshift
import java.sql.{Connection, SQLException}
import java.util.Properties

import org.apache.spark.sql.types._

import scala.util.Random

import com.databricks.spark.redshift.Parameters.MergedParameters
Expand All @@ -35,8 +37,8 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
/**
* Generate CREATE TABLE statement for Redshift
*/
def createTableSql(data: DataFrame, params: MergedParameters) : String = {
val schemaSql = jdbcWrapper.schemaString(data, params.jdbcUrl)
def createTableSql(data: DataFrame, params: MergedParameters): String = {
val schemaSql = jdbcWrapper.schemaString(data, params.jdbcUrl) // TODO: Replace
val distStyleDef = params.distStyle match {
case Some(style) => s"DISTSTYLE $style"
case None => ""
Expand All @@ -63,7 +65,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Sets up a staging table then runs the given action, passing the temporary table name
* as a parameter.
*/
def withStagingTable(conn:Connection, params: MergedParameters, action: (String) => Unit) {
def withStagingTable(conn: Connection, params: MergedParameters, action: (String) => Unit) {
val randomSuffix = Math.abs(Random.nextInt()).toString
val tempTable = s"${params.table}_staging_$randomSuffix"
val backupTable = s"${params.table}_backup_$randomSuffix"
Expand Down Expand Up @@ -93,10 +95,10 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Perform the Redshift load, including deletion of existing data in the case of an overwrite,
* and creating the table if it doesn't already exist.
*/
def doRedshiftLoad(conn: Connection, data: DataFrame, params: MergedParameters) : Unit = {
def doRedshiftLoad(conn: Connection, data: DataFrame, params: MergedParameters): Unit = {

// Overwrites must drop the table, in case there has been a schema update
if(params.overwrite) {
if (params.overwrite) {
val deleteExisting = conn.prepareStatement(s"DROP TABLE IF EXISTS ${params.table}")
deleteExisting.execute()
}
Expand All @@ -114,7 +116,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {

// Execute postActions
params.postActions.foreach(action => {
val actionSql = if(action.contains("%s")) action.format(params.table) else action
val actionSql = if (action.contains("%s")) action.format(params.table) else action
log.info("Executing postAction: " + actionSql)
conn.prepareStatement(actionSql).execute()
})
Expand All @@ -132,13 +134,13 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
/**
* Write a DataFrame to a Redshift table, using S3 and Avro serialization
*/
def saveToRedshift(sqlContext: SQLContext, data: DataFrame, params: MergedParameters) : Unit = {
def saveToRedshift(sqlContext: SQLContext, data: DataFrame, params: MergedParameters): Unit = {
val conn = jdbcWrapper.getConnector(params.jdbcDriver, params.jdbcUrl, new Properties()).apply()

try {
if(params.overwrite && params.useStagingTable) {
if (params.overwrite && params.useStagingTable) {
withStagingTable(conn, params, table => {
val updatedParams = MergedParameters(params.parameters updated ("dbtable", table))
val updatedParams = MergedParameters(params.parameters updated("dbtable", table))
unloadData(sqlContext, data, updatedParams.tempPath)
doRedshiftLoad(conn, data, updatedParams)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ class RedshiftSourceSuite

val dfMetaSchema = Conversions.injectMetaSchema(testSqlContext, df)

// assert(dfMetaSchema.schema("testString").metadata.getDouble("maxLength") == 1.0)
assert(dfMetaSchema.schema("testString").metadata.getLong("maxLength") == 10)
}

test("DefaultSource has default constructor, required by Data Source API") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package com.databricks.spark.redshift

import java.io.File

import org.apache.spark.SparkContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SQLContext, Row}
import org.scalamock.scalatest.MockFactory
import org.scalatest.{BeforeAndAfterAll, Matchers, FunSuite}

class SchemaGenerationSuite extends FunSuite with Matchers with MockFactory with BeforeAndAfterAll {
/**
* Expected parsed output corresponding to the output of testData.
*/
val expectedData =
Array(
Row(1.toByte, true, TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0), 1234152.123124981,
1.0f, 42, 1239012341823719L, 23, "Unicode是樂趣", TestUtils.toTimestamp(2015, 6, 1, 0, 0, 0, 1)),
Row(1.toByte, false, TestUtils.toTimestamp(2015, 6, 2, 0, 0, 0), 0.0, 0.0f, 42, 1239012341823719L, -13, "asdf",
TestUtils.toTimestamp(2015, 6, 2, 0, 0, 0, 0)),
Row(0.toByte, null, TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0), 0.0, -1.0f, 4141214, 1239012341823719L, null, "f",
TestUtils.toTimestamp(2015, 6, 3, 0, 0, 0)),
Row(0.toByte, false, null, -1234152.123124981, 100000.0f, null, 1239012341823719L, 24, "___|_123", null),
Row(List.fill(10)(null): _*))

var sc: SparkContext = _
var testSqlContext: SQLContext = _
var df: DataFrame = _

def varcharCol(meta:Metadata): String = {
val maxLength:Long = meta.getLong("maxLength")

maxLength match {
case _:Long => s"VARCHAR($maxLength)"
case _ => "VARCHAR(255)"
}
}

/**
* Compute the schema string for this RDD.
*/
def schemaString(df: DataFrame): String = {
val sb = new StringBuilder()

df.schema.fields foreach { field => {
val name = field.name
val typ: String =
field match {
case StructField(_, IntegerType, _, _) => "INTEGER"
case StructField(_, LongType, _, _) => "BIGINT"
case StructField(_, DoubleType, _, _) => "DOUBLE PRECISION"
case StructField(_, FloatType, _, _) => "REAL"
case StructField(_, ShortType, _, _) => "INTEGER"
case StructField(_, ByteType, _, _) => "BYTE"
case StructField(_, BooleanType, _, _) => "BOOLEAN"
case StructField(_, StringType, _, metadata) => varcharCol(metadata)
case StructField(_, BinaryType, _, _) => "BLOB"
case StructField(_, TimestampType, _, _) => "TIMESTAMP"
case StructField(_, DateType, _, _) => "DATE"
case StructField(_, t:DecimalType, _, _) => s"DECIMAL(${t.precision}},${t.scale}})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
}
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}
}
if (sb.length < 2) "" else sb.substring(2)
}

override def beforeAll(): Unit = {
super.beforeAll()

sc = new TestContext
testSqlContext = new SQLContext(sc)

df = testSqlContext.createDataFrame(sc.parallelize(expectedData), TestUtils.testSchema)
}

override def afterAll(): Unit = {
sc.stop()
super.afterAll()
}

test("Schema inference") {
val enhancedDf:DataFrame = Conversions.injectMetaSchema(testSqlContext, df)

schemaString(enhancedDf) should equal("testByte BYTE , testBool BOOLEAN , testDate DATE , testDouble DOUBLE PRECISION , testFloat REAL , testInt INTEGER , testLong BIGINT , testShort INTEGER , testString VARCHAR(10) , testTimestamp TIMESTAMP ")
}
}