Skip to content
Closed
Prev Previous commit
Next Next commit
Add inline documentation to MetaSchema object
  • Loading branch information
Aerlinger committed Aug 4, 2015
commit 4ce2d5ca18b359d0c92b67e1d2cfb5e071b428f6
50 changes: 43 additions & 7 deletions src/main/scala/com/databricks/spark/redshift/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,28 @@ private object RedshiftBooleanParser extends JavaTokenParsers {
def parseRedshiftBoolean(s: String): Boolean = parse(TRUE | FALSE, s).get
}

object StringMetaSchema {
def mapStrLengths(df:DataFrame) : Map[String, Int] = {
/**
* Utility methods responsible for extracting information from data contained within dataframe in order to generate
* a schema compatible with Redshift.
*/
object MetaSchema {
/**
* Map-Reduce task to calculate the longest string length for each row, in each string column in the dataframe.
*
* Note: This is used to generate N for the VARCHAR(N) field in the table schema to be loaded into Redshift.
*
* TODO: This should only be called once per load into Redshift. A cache, TraversableOnce, or some similar
* structure should be used to enforce this function only being called once.
*
* @param df DataFrame to be processed
* @return A Map[String, Int] representing an assocition between the column name and the length of that column's
* longest string
*/
private[redshift] def mapStrLengths(df:DataFrame) : Map[String, Int] = {
val schema:StructType = df.schema

// Calculate maximum string lengths for each row in each respective row
// For each row, filter the string columns and calculate the string length
// TODO: Other optimization strategies may be possible
val stringLengths = df.flatMap(row =>
schema.collect {
case StructField(columnName, StringType, _, _) => (columnName, getStrLength(row, columnName))
Expand All @@ -49,24 +66,43 @@ object StringMetaSchema {
stringLengths.collect().toMap
}

def getStrLength(row:Row, columnName:String): Int = {
/**
* Calculate the string length in columnName for the provided Row. Defensively returns 0 if the provided
* columnName is not a string column.
*
* This is a collaborator method to make the mapStrLengths function more readable, and should not be used elsewhere.
*
* @param row Reference to a row of a dataframe
* @param columnName Name of the column
* @return Length of the string in cell, falling back to 0 if null or no string is present.
*/
private[redshift] def getStrLength(row:Row, columnName:String): Int = {
row.getAs[String](columnName) match {
case field:String => field.length()
case _ => 0
}
}

def setStrLength(metadata:Metadata, length:Int) : Metadata = {
/**
* Adds a "maxLength" -> Int field to column metadata.
*
* @param metadata metadata for a dataframe column
* @param length Length limit for content within that column
* @return new metadata object with added field
*/
private[redshift] def setStrLength(metadata:Metadata, length:Int) : Metadata = {
new MetadataBuilder().withMetadata(metadata).putLong("maxLength", length).build()
}

/**
* Iterate through each column in the schema that is a string, storing the longest string length in that columns'
* metadata.
* metadata for later usage.
*/
def computeEnhancedDf(df: DataFrame): DataFrame = {
// 1. Perform a full scan of each string column, storing it's maximum string length within a Map
val stringLengthsByColumn = mapStrLengths(df)

// 2. Generate an enhanced schema, with the metadata for each string column
val enhancedSchema = StructType(
df.schema map {
case StructField(name, StringType, nullable, meta) =>
Expand All @@ -75,7 +111,7 @@ object StringMetaSchema {
}
)

// Construct a new dataframe with a schema containing metadata with string lengths
// 3. Construct a new dataframe with a schema containing metadata with string lengths
df.sqlContext.createDataFrame(df.rdd, enhancedSchema)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {

def varcharStr(meta: Metadata): String = {
// TODO: Need fallback for max length
val maxLength: Long = meta.getLong("maxLength")

maxLength match {
case _: Long => s"VARCHAR($maxLength)"
case _ => "VARCHAR(255)"
}
}

Expand Down Expand Up @@ -79,7 +79,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
* Generate CREATE TABLE statement for Redshift
*/
def createTableSql(data: DataFrame, params: MergedParameters): String = {
var schemaSql = schemaString(StringMetaSchema.computeEnhancedDf(data))
var schemaSql = schemaString(MetaSchema.computeEnhancedDf(data))

val distStyleDef = params.distStyle match {
case Some(style) => s"DISTSTYLE $style"
Expand All @@ -91,7 +91,7 @@ class RedshiftWriter(jdbcWrapper: JDBCWrapper) extends Logging {
}
val sortKeyDef = params.sortKeySpec.getOrElse("")

s"CREATE TABLE IF NOT EXISTS ${params.table} ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef"
s"CREATE TABLE IF NOT EXISTS ${params.table} ($schemaSql) $distStyleDef $distKeyDef $sortKeyDef".trim
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ class RedshiftSourceSuite
val testSqlContext = new SQLContext(sc)
val df = testSqlContext.createDataFrame(rdd, TestUtils.testSchema)

val dfMetaSchema = StringMetaSchema.computeEnhancedDf(df)
val dfMetaSchema = MetaSchema.computeEnhancedDf(df)

assert(dfMetaSchema.schema("testString").metadata.getLong("maxLength") == 10)
}
Expand Down