diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c50835dd8f11..9929f318c1e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -271,11 +271,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - catalog.refreshTable(catalog.client.currentDatabase, tableName) + val dbAndTableName = tableName.split("\\.") + catalog.refreshTable(dbAndTableName.lift(dbAndTableName.size -2) + .getOrElse(catalog.client.currentDatabase), dbAndTableName.last) } protected[hive] def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(catalog.client.currentDatabase, tableName) + val dbAndTableName = tableName.split("\\.") + catalog.invalidateTable(dbAndTableName.lift(dbAndTableName.size -2) + .getOrElse(catalog.client.currentDatabase), dbAndTableName.last) } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f35ae96ee0b5..03d544d070d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -143,7 +143,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) + val dbAndTableName = tableName.split("\\.") + val (dbName, tblName) = processDatabaseAndTableName( + dbAndTableName + .lift(dbAndTableName.size -2) + .getOrElse(client.currentDatabase), dbAndTableName.last) val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) @@ -203,9 +207,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def hiveDefaultTableFilePath(tableName: String): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) + val dbAndTableName = tableName.split("\\.") new Path( - new Path(client.getDatabase(client.currentDatabase).location), - tableName.toLowerCase).toString + new Path(client.getDatabase(dbAndTableName.lift(dbAndTableName.size -2) + .getOrElse(client.currentDatabase)).location), + dbAndTableName.last.toLowerCase).toString } def tableExists(tableIdentifier: Seq[String]): Boolean = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 76469d7a3d6a..8648a91cbb99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { override lazy val sqlContext: SQLContext = TestHive @@ -609,4 +610,56 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + + test("SPARK-7943:DF created by hiveContext can create table to specific db by saveAstable") { + + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + // use dbname.tablename to specific db + sqlContext.sql("""create database if not exists testdb7943""") + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("testdb7943.tbl7943_1") + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("tbl7943_2") + + intercept[NoSuchDatabaseException] { + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("testdb7943-2.tbl1") + } + + sqlContext.sql("""use testdb7943""") + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("tbl7943_3") + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("default.tbl7943_4") + + checkAnswer( + sqlContext.sql("show TABLES in testdb7943"), + Seq(Row("tbl7943_1", false), Row("tbl7943_3", false))) + + val result = sqlContext.sql("show TABLES in default") + checkAnswer( + result.filter("tableName = 'tbl7943_2'"), + Row("tbl7943_2", false)) + + checkAnswer( + result.filter("tableName = 'tbl7943_4'"), + Row("tbl7943_4", false)) + + sqlContext.sql("""use default""") + sqlContext.sql("""drop table if exists tbl7943_2 """) + sqlContext.sql("""drop table if exists tbl7943_4 """) + sqlContext.sql("""drop database if exists testdb7943 CASCADE""") + } }