diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java index 1a7ca79163d7..2af17a662a29 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java @@ -46,7 +46,7 @@ public class GetTablesOperation extends MetadataOperation { private final String schemaName; private final String tableName; private final List tableTypes = new ArrayList(); - private final RowSet rowSet; + protected final RowSet rowSet; private final TableTypeMapping tableTypeMapping; diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala new file mode 100644 index 000000000000..369650047b10 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.{List => JList} + +import scala.collection.JavaConverters.seqAsJavaListConverter + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObjectUtils +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.GetTablesOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ + +/** + * Spark's own GetTablesOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + * @param catalogName catalog name. null if not applicable + * @param schemaName database name, null or a concrete database name + * @param tableName table name pattern + * @param tableTypes list of allowed table types, e.g. "TABLE", "VIEW" + */ +private[hive] class SparkGetTablesOperation( + sqlContext: SQLContext, + parentSession: HiveSession, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: JList[String]) + extends GetTablesOperation(parentSession, catalogName, schemaName, tableName, tableTypes) { + + if (tableTypes != null) { + this.tableTypes.addAll(tableTypes) + } + + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + val catalog = sqlContext.sessionState.catalog + val schemaPattern = convertSchemaPattern(schemaName) + val matchingDbs = catalog.listDatabases(schemaPattern) + + if (isAuthV2Enabled) { + val privObjs = + HivePrivilegeObjectUtils.getHivePrivDbObjects(seqAsJavaListConverter(matchingDbs).asJava) + val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" + authorizeMetaGets(HiveOperationType.GET_TABLES, privObjs, cmdStr) + } + + val tablePattern = convertIdentifierPattern(tableName, true) + matchingDbs.foreach { dbName => + catalog.listTables(dbName, tablePattern).foreach { tableIdentifier => + val catalogTable = catalog.getTableMetadata(tableIdentifier) + val tableType = tableTypeString(catalogTable.tableType) + if (tableTypes == null || tableTypes.isEmpty || tableTypes.contains(tableType)) { + val rowData = Array[AnyRef]( + "", + catalogTable.database, + catalogTable.identifier.table, + tableType, + catalogTable.comment.getOrElse("")) + rowSet.addRow(rowData) + } + } + } + setState(OperationState.FINISHED) + } + + private def tableTypeString(tableType: CatalogTableType): String = tableType match { + case EXTERNAL | MANAGED => "TABLE" + case VIEW => "VIEW" + case t => + throw new IllegalArgumentException(s"Unknown table type is found at showCreateHiveTable: $t") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 85b6c7134755..7947d1785a8f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hive.thriftserver.server -import java.util.{Map => JMap} +import java.util.{List => JList, Map => JMap} import java.util.concurrent.ConcurrentHashMap import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, Operation, OperationManager} +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, MetadataOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation, SparkGetTablesOperation} import org.apache.spark.sql.internal.SQLConf /** @@ -76,6 +76,22 @@ private[thriftserver] class SparkSQLOperationManager() operation } + override def newGetTablesOperation( + parentSession: HiveSession, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: JList[String]): MetadataOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + " initialized or had already closed.") + val operation = new SparkGetTablesOperation(sqlContext, parentSession, + catalogName, schemaName, tableName, tableTypes) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetTablesOperation with session=$parentSession.") + operation + } + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { val iterator = confMap.entrySet().iterator() while (iterator.hasNext) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index f9509aed4aaa..0f53fcd327f1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -280,7 +280,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { var defaultV2: String = null var data: ArrayBuffer[Int] = null - withMultipleConnectionJdbcStatement("test_map")( + withMultipleConnectionJdbcStatement("test_map", "db1.test_map2")( // create table { statement => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index 9a997ae01df9..bf9982388d6b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.Properties +import java.util.{Arrays => JArrays, List => JList, Properties} import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils} import org.apache.hive.service.auth.PlainSaslHelper @@ -100,4 +100,89 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest { } } } + + test("Spark's own GetTablesOperation(SparkGetTablesOperation)") { + def testGetTablesOperation( + schema: String, + tableNamePattern: String, + tableTypes: JList[String])(f: HiveQueryResultSet => Unit): Unit = { + val rawTransport = new TSocket("localhost", serverPort) + val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val client = new TCLIService.Client(new TBinaryProtocol(transport)) + transport.open() + + var rs: HiveQueryResultSet = null + + try { + val openResp = client.OpenSession(new TOpenSessionReq) + val sessHandle = openResp.getSessionHandle + + val getTableReq = new TGetTablesReq(sessHandle) + getTableReq.setSchemaName(schema) + getTableReq.setTableName(tableNamePattern) + getTableReq.setTableTypes(tableTypes) + + val getTableResp = client.GetTables(getTableReq) + + JdbcUtils.verifySuccess(getTableResp.getStatus) + + rs = new HiveQueryResultSet.Builder(connection) + .setClient(client) + .setSessionHandle(sessHandle) + .setStmtHandle(getTableResp.getOperationHandle) + .build() + + f(rs) + } finally { + rs.close() + connection.close() + transport.close() + rawTransport.close() + } + } + + def checkResult(tableNames: Seq[String], rs: HiveQueryResultSet): Unit = { + if (tableNames.nonEmpty) { + for (i <- tableNames.indices) { + assert(rs.next()) + assert(rs.getString("TABLE_NAME") === tableNames(i)) + } + } else { + assert(!rs.next()) + } + } + + withJdbcStatement("table1", "table2") { statement => + Seq( + "CREATE TABLE table1(key INT, val STRING)", + "CREATE TABLE table2(key INT, val STRING)", + "CREATE VIEW view1 AS SELECT * FROM table2").foreach(statement.execute) + + testGetTablesOperation("%", "%", null) { rs => + checkResult(Seq("table1", "table2", "view1"), rs) + } + + testGetTablesOperation("%", "table1", null) { rs => + checkResult(Seq("table1"), rs) + } + + testGetTablesOperation("%", "table_not_exist", null) { rs => + checkResult(Seq.empty, rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("TABLE")) { rs => + checkResult(Seq("table1", "table2"), rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("VIEW")) { rs => + checkResult(Seq("view1"), rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("TABLE", "VIEW")) { rs => + checkResult(Seq("table1", "table2", "view1"), rs) + } + } + } }