diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala new file mode 100644 index 000000000000..7a6a8c59b721 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTypeInfoOperation.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.UUID + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hive.service.cli.{HiveSQLException, OperationState} +import org.apache.hive.service.cli.operation.GetTypeInfoOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.{Utils => SparkUtils} + +/** + * Spark's own GetTypeInfoOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + */ +private[hive] class SparkGetTypeInfoOperation( + sqlContext: SQLContext, + parentSession: HiveSession) + extends GetTypeInfoOperation(parentSession) with Logging { + + private var statementId: String = _ + + override def close(): Unit = { + super.close() + HiveThriftServer2.listener.onOperationClosed(statementId) + } + + override def runInternal(): Unit = { + statementId = UUID.randomUUID().toString + val logMsg = "Listing type info" + logInfo(s"$logMsg with $statementId") + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + if (isAuthV2Enabled) { + authorizeMetaGets(HiveOperationType.GET_TYPEINFO, null) + } + + HiveThriftServer2.listener.onStatementStart( + statementId, + parentSession.getSessionHandle.getSessionId.toString, + logMsg, + statementId, + parentSession.getUsername) + + try { + ThriftserverShimUtils.supportedType().foreach(typeInfo => { + val rowData = Array[AnyRef]( + typeInfo.getName, // TYPE_NAME + typeInfo.toJavaSQLType.asInstanceOf[AnyRef], // DATA_TYPE + typeInfo.getMaxPrecision.asInstanceOf[AnyRef], // PRECISION + typeInfo.getLiteralPrefix, // LITERAL_PREFIX + typeInfo.getLiteralSuffix, // LITERAL_SUFFIX + typeInfo.getCreateParams, // CREATE_PARAMS + typeInfo.getNullable.asInstanceOf[AnyRef], // NULLABLE + typeInfo.isCaseSensitive.asInstanceOf[AnyRef], // CASE_SENSITIVE + typeInfo.getSearchable.asInstanceOf[AnyRef], // SEARCHABLE + typeInfo.isUnsignedAttribute.asInstanceOf[AnyRef], // UNSIGNED_ATTRIBUTE + typeInfo.isFixedPrecScale.asInstanceOf[AnyRef], // FIXED_PREC_SCALE + typeInfo.isAutoIncrement.asInstanceOf[AnyRef], // AUTO_INCREMENT + typeInfo.getLocalizedName, // LOCAL_TYPE_NAME + typeInfo.getMinimumScale.asInstanceOf[AnyRef], // MINIMUM_SCALE + typeInfo.getMaximumScale.asInstanceOf[AnyRef], // MAXIMUM_SCALE + null, // SQL_DATA_TYPE, unused + null, // SQL_DATETIME_SUB, unused + typeInfo.getNumPrecRadix // NUM_PREC_RADIX + ) + rowSet.addRow(rowData) + }) + setState(OperationState.FINISHED) + } catch { + case e: HiveSQLException => + setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) + throw e + } + HiveThriftServer2.listener.onStatementFinish(statementId) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a37894fe908e..3396560f4350 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -140,4 +140,15 @@ private[thriftserver] class SparkSQLOperationManager() logDebug(s"Created GetFunctionsOperation with session=$parentSession.") operation } + + override def newGetTypeInfoOperation( + parentSession: HiveSession): GetTypeInfoOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + " initialized or had already closed.") + val operation = new SparkGetTypeInfoOperation(sqlContext, parentSession) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetTypeInfoOperation with session=$parentSession.") + operation + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index 21870ffd463e..f7ee3e0a46cd 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -231,4 +231,20 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest { assert(!rs.next()) } } + + test("GetTypeInfo Thrift API") { + def checkResult(rs: ResultSet, typeNames: Seq[String]): Unit = { + for (i <- typeNames.indices) { + assert(rs.next()) + assert(rs.getString("TYPE_NAME") === typeNames(i)) + } + // Make sure there are no more elements + assert(!rs.next()) + } + + withJdbcStatement() { statement => + val metaData = statement.getConnection.getMetaData + checkResult(metaData.getTypeInfo, ThriftserverShimUtils.supportedType().map(_.getName)) + } + } } diff --git a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java index 0f72071d7e7d..3e81f8afbd85 100644 --- a/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java +++ b/sql/hive-thriftserver/v1.2.1/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java @@ -73,7 +73,7 @@ public class GetTypeInfoOperation extends MetadataOperation { .addPrimitiveColumn("NUM_PREC_RADIX", Type.INT_TYPE, "Usually 2 or 10"); - private final RowSet rowSet; + protected final RowSet rowSet; protected GetTypeInfoOperation(HiveSession parentSession) { super(parentSession, OperationType.GET_TYPE_INFO); diff --git a/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala b/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala index 87c0f8f6a571..837861a77bf5 100644 --- a/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala +++ b/sql/hive-thriftserver/v1.2.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala @@ -21,6 +21,7 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema, Type} +import org.apache.hive.service.cli.Type._ import org.apache.hive.service.cli.thrift.TProtocolVersion._ /** @@ -51,6 +52,14 @@ private[thriftserver] object ThriftserverShimUtils { private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType + private[thriftserver] def supportedType(): Seq[Type] = { + Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE, + TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE, + FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE, + DATE_TYPE, TIMESTAMP_TYPE, + ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE) + } + private[thriftserver] def addToClassPath( loader: ClassLoader, auxJars: Array[String]): ClassLoader = { diff --git a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java index 9612eb145638..0f57a72e2a1c 100644 --- a/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java +++ b/sql/hive-thriftserver/v2.3.5/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java @@ -73,7 +73,7 @@ public class GetTypeInfoOperation extends MetadataOperation { .addPrimitiveColumn("NUM_PREC_RADIX", Type.INT_TYPE, "Usually 2 or 10"); - private final RowSet rowSet; + protected final RowSet rowSet; protected GetTypeInfoOperation(HiveSession parentSession) { super(parentSession, OperationType.GET_TYPE_INFO); diff --git a/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala b/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala index 124c9937c0fc..cb32ebed0ac1 100644 --- a/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala +++ b/sql/hive-thriftserver/v2.3.5/src/main/scala/org/apache/spark/sql/hive/thriftserver/ThriftserverShimUtils.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.exec.AddToClassPathAction import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.thrift.Type +import org.apache.hadoop.hive.serde2.thrift.Type._ import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema} import org.apache.hive.service.rpc.thrift.TProtocolVersion._ import org.slf4j.LoggerFactory @@ -56,6 +57,14 @@ private[thriftserver] object ThriftserverShimUtils { private[thriftserver] def toJavaSQLType(s: String): Int = Type.getType(s).toJavaSQLType + private[thriftserver] def supportedType(): Seq[Type] = { + Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE, + TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE, + FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE, + DATE_TYPE, TIMESTAMP_TYPE, + ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE) + } + private[thriftserver] def addToClassPath( loader: ClassLoader, auxJars: Array[String]): ClassLoader = {