Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,11 @@ private[sql] object PartitioningUtils {

def validatePartitionColumnDataTypes(
schema: StructType,
partitionColumns: Array[String]): Unit = {
partitionColumns: Array[String],
caseSensitive: Boolean): Unit = {

ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field =>
field.dataType match {
ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach {
field => field.dataType match {
case _: AtomicType => // OK
case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ object ResolvedDataSource extends Logging {
val maybePartitionsSchema = if (partitionColumns.isEmpty) {
None
} else {
Some(partitionColumnsSchema(schema, partitionColumns))
Some(partitionColumnsSchema(
schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis))
}

val caseInsensitiveOptions = new CaseInsensitiveMap(options)
Expand Down Expand Up @@ -172,14 +173,24 @@ object ResolvedDataSource extends Logging {

def partitionColumnsSchema(
schema: StructType,
partitionColumns: Array[String]): StructType = {
partitionColumns: Array[String],
caseSensitive: Boolean): StructType = {
val equality = columnNameEquality(caseSensitive)
StructType(partitionColumns.map { col =>
schema.find(_.name == col).getOrElse {
schema.find(f => equality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema $schema")
}
}).asNullable
}

private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = {
if (caseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
}

/** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */
def apply(
sqlContext: SQLContext,
Expand Down Expand Up @@ -207,14 +218,18 @@ object ResolvedDataSource extends Logging {
path.makeQualified(fs.getUri, fs.getWorkingDirectory)
}

PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns)
val caseSensitive = sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumnDataTypes(
data.schema, partitionColumns, caseSensitive)

val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name)))
val equality = columnNameEquality(caseSensitive)
val dataSchema = StructType(
data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
val r = dataSource.createRelation(
sqlContext,
Array(outputPath.toString),
Some(dataSchema.asNullable),
Some(partitionColumnsSchema(data.schema, partitionColumns)),
Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)),
caseInsensitiveOptions)

// For partitioned relation r, r.schema's column ordering can be different from the column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}

PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray)
PartitioningUtils.validatePartitionColumnDataTypes(
r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis)

// Get all input data source relations of the query.
val srcRelations = query.collect {
Expand Down Expand Up @@ -190,7 +191,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}

PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns)
PartitioningUtils.validatePartitionColumnDataTypes(
query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis)

case _ => // OK
}
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1113,4 +1113,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
if (!allSequential) throw new SparkException("Partition should contain all sequential values")
})
}

test("fix case sensitivity of partition by") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
val p = path.getAbsolutePath
Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p)
checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012))
}
}
}
}