Skip to content
Closed
Prev Previous commit
Next Next commit
initial functionality and specs for Schema inference
  • Loading branch information
Aerlinger committed Aug 3, 2015
commit 1fec19044d8f0e801d6ff712fcd529eec65c597b
49 changes: 46 additions & 3 deletions src/main/scala/com/databricks/spark/redshift/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private object RedshiftBooleanParser extends JavaTokenParsers {
/**
* Data type conversions for Redshift unloaded data
*/
private [redshift] object Conversions {
private[redshift] object Conversions {

// Imports and exports with Redshift require that timestamps are represented
// as strings, using the following formats
Expand All @@ -58,7 +58,7 @@ private [redshift] object Conversions {
}

override def parse(source: String, pos: ParsePosition): Date = {
if(source.length < PATTERN_WITH_MILLIS.length) {
if (source.length < PATTERN_WITH_MILLIS.length) {
redshiftTimestampFormatWithoutMillis.parse(source, pos)
} else {
redshiftTimestampFormatWithMillis.parse(source, pos)
Expand Down Expand Up @@ -127,4 +127,47 @@ private [redshift] object Conversions {

sqlContext.createDataFrame(df.rdd, schema)
}
}

def mapStrLengths(df:DataFrame) : Map[String, Int] = {
// Calculate maximum string lengths for each row in each respective row
val stringLengths = df.flatMap(row =>
df.schema collect {
case StructField(columnName, StringType, _, _) => (columnName, getStrLength(row, columnName))
}
).reduceByKey(Math.max(_, _))

stringLengths.collect().toMap
}

def getStrLength(row:Row, columnName:String): Int = {
row.getAs[String](columnName) match {
case field:String => field.length()
case _ => 0
}
}

def setStrLength(metadata:Metadata, length:Int) : Metadata = {
new MetadataBuilder()
.withMetadata(metadata)
.putDouble("maxLength", length)
.build()
}

/**
* Iterate through each column in the schema that is a string, storing the longest string length in that columns'
* metadata.
*/
def injectMetaSchema(sqlContext: SQLContext, df: DataFrame): DataFrame = {
val stringLengthsByColumn = mapStrLengths(df)

val schema = StructType(
df.schema map {
case StructField(name, StringType, nullable, meta) =>
StructField(name, StringType, nullable, setStrLength(meta, stringLengthsByColumn(name)))
case other => other
}
)

sqlContext.createDataFrame(df.rdd, schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Serialize temporary data to S3, ready for Redshift COPY
*/
def unloadData(sqlContext: SQLContext, data: DataFrame, tempPath: String): Unit = {
Conversions.datesToTimestamps(sqlContext, data).write.format("com.databricks.spark.avro").save(tempPath)
val enrichedData = Conversions.datesToTimestamps(sqlContext, data) // TODO .extractStringColumnLengths

enrichedData.write.format("com.databricks.spark.avro").save(tempPath)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class RedshiftSourceSuite
}

test("DefaultSource supports simple column filtering") {

//TODO: DRY ME
val params = Map("url" -> "jdbc:postgresql://foo/bar",
"tempdir" -> "tmp",
"dbtable" -> "test_table",
Expand Down Expand Up @@ -201,6 +201,7 @@ class RedshiftSourceSuite

test("DefaultSource supports user schema, pruned and filtered scans") {

//TODO: DRY ME
val params = Map("url" -> "jdbc:postgresql://foo/bar",
"tempdir" -> "tmp",
"dbtable" -> "test_table",
Expand Down Expand Up @@ -478,9 +479,17 @@ class RedshiftSourceSuite
}
}

test("Basic string field extraction") {
val rdd = sc.parallelize(expectedData)
val testSqlContext = new SQLContext(sc)
val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema)

val dfMetaSchema = Conversions.injectMetaSchema(testSqlContext, df)

// assert(dfMetaSchema.schema("testString").metadata.getDouble("maxLength") == 1.0)
}

test("DefaultSource has default constructor, required by Data Source API") {
new DefaultSource()
}


}