From 299cebb2be933c689f33c0cd96e1a0f7e0f91623 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 13 Sep 2017 15:44:21 -0700 Subject: [PATCH] fix --- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 6 +++++- .../scala/org/apache/spark/sql/hive/HiveStrategies.scala | 3 +-- .../apache/spark/sql/hive/execution/SaveAsHiveFile.scala | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index b06f4ccaa3bb..162e1d5be293 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -801,7 +801,11 @@ object DDLUtils { val HIVE_PROVIDER = "hive" def isHiveTable(table: CatalogTable): Boolean = { - table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER + isHiveTable(table.provider) + } + + def isHiveTable(provider: Option[String]): Boolean = { + provider.isDefined && provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER } def isDatasourceTable(table: CatalogTable): Boolean = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index caf554d9ea51..805b3171cdaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -160,8 +160,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { CreateHiveTableAsSelectCommand(tableDesc, query, mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) - if provider.isDefined && provider.get.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER => - + if DDLUtils.isHiveTable(provider) => val outputPath = new Path(storage.locationUri.get) if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 7de9b421245f..da23c9930b59 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand @@ -36,7 +37,8 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, - partitionAttributes: Seq[Attribute] = Nil): Unit = { + customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, + partitionAttributes: Seq[Attribute] = Nil): Set[String] = { val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean if (isCompressed) { @@ -62,7 +64,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { plan = plan, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, - outputSpec = FileFormatWriter.OutputSpec(outputLocation, Map.empty), + outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, bucketSpec = None,