diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 3f20c6142e59..f00c6dbb5bf2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -38,7 +38,7 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim) + case statement => HiveQlConverter.createPlan(statement.trim) } protected lazy val dfs: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveASTNodeUtil.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveASTNodeUtil.scala new file mode 100644 index 000000000000..aa1b8abd52cd --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveASTNodeUtil.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.ql.parse.{ParseDriver, ParseUtils, ASTNode} +import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.lib.Node + +import org.apache.spark.sql.catalyst.trees.CurrentOrigin + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[hive] object HiveASTNodeUtil { + val nativeCommands = Seq( + "TOK_ALTERDATABASE_OWNER", + "TOK_ALTERDATABASE_PROPERTIES", + "TOK_ALTERINDEX_PROPERTIES", + "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE_ADDCOLS", + "TOK_ALTERTABLE_ADDPARTS", + "TOK_ALTERTABLE_ALTERPARTS", + "TOK_ALTERTABLE_ARCHIVE", + "TOK_ALTERTABLE_CLUSTER_SORT", + "TOK_ALTERTABLE_DROPPARTS", + "TOK_ALTERTABLE_PARTITION", + "TOK_ALTERTABLE_PROPERTIES", + "TOK_ALTERTABLE_RENAME", + "TOK_ALTERTABLE_RENAMECOL", + "TOK_ALTERTABLE_REPLACECOLS", + "TOK_ALTERTABLE_SKEWED", + "TOK_ALTERTABLE_TOUCH", + "TOK_ALTERTABLE_UNARCHIVE", + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_AS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME", + + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_CREATEROLE", + "TOK_CREATEVIEW", + + "TOK_DESCDATABASE", + "TOK_DESCFUNCTION", + + "TOK_DROPDATABASE", + "TOK_DROPFUNCTION", + "TOK_DROPINDEX", + "TOK_DROPROLE", + "TOK_DROPTABLE_PROPERTIES", + "TOK_DROPVIEW", + "TOK_DROPVIEW_PROPERTIES", + + "TOK_EXPORT", + + "TOK_GRANT", + "TOK_GRANT_ROLE", + + "TOK_IMPORT", + + "TOK_LOAD", + + "TOK_LOCKTABLE", + + "TOK_MSCK", + + "TOK_REVOKE", + + "TOK_SHOW_COMPACTIONS", + "TOK_SHOW_CREATETABLE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + "TOK_SHOW_ROLE_PRINCIPALS", + "TOK_SHOW_ROLES", + "TOK_SHOW_SET_ROLE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOW_TBLPROPERTIES", + "TOK_SHOW_TRANSACTIONS", + "TOK_SHOWCOLUMNS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWLOCKS", + "TOK_SHOWPARTITIONS", + + "TOK_SWITCHDATABASE", + + "TOK_UNLOCKTABLE" + ) + + // Commands that we do not need to explain. + val noExplainCommands = Seq( + "TOK_DESCTABLE", + "TOK_SHOWTABLES", + "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. + ) ++ nativeCommands + + /** + * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations + * similar to [[catalyst.trees.TreeNode]]. + * + * Note that this should be considered very experimental and is not indented as a replacement + * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to + * have clean copy semantics. Therefore, users of this class should take care when + * copying/modifying trees that might be used elsewhere. + */ + implicit class TransformableNode(n: ASTNode) { + /** + * Returns a copy of this node where `rule` has been recursively applied to it and all of its + * children. When `rule` does not apply to a given node it is left unchanged. + * @param rule the function use to transform this nodes children + */ + def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { + try { + val afterRule = rule.applyOrElse(n, identity[ASTNode]) + afterRule.withChildren( + nilIfEmpty(afterRule.getChildren) + .asInstanceOf[Seq[ASTNode]] + .map(ast => Option(ast).map(_.transform(rule)).orNull)) + } catch { + case e: Exception => + println(dumpTree(n)) + throw e + } + } + + /** + * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. + */ + private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = + Option(s).map(_.toSeq).getOrElse(Nil) + + /** + * Returns this ASTNode with the text changed to `newText`. + */ + def withText(newText: String): ASTNode = { + n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) + n + } + + /** + * Returns this ASTNode with the children changed to `newChildren`. + */ + def withChildren(newChildren: Seq[ASTNode]): ASTNode = { + (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) + n.addChildren(newChildren) + n + } + + /** + * Throws an error if this is not equal to other. + * + * Right now this function only checks the name, type, text and children of the node + * for equality. + */ + def checkEquals(other: ASTNode): Unit = { + def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { + sys.error(s"$field does not match for trees. " + + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") + } + check("name", _.getName) + check("type", _.getType) + check("text", _.getText) + check("numChildren", n => nilIfEmpty(n.getChildren).size) + + val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] + val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] + leftChildren zip rightChildren foreach { + case (l, r) => l checkEquals r + } + } + } + + /** Extractor for matching Hive's AST Tokens. */ + object Token { + /** @return matches of the form (tokenName, children). */ + def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { + case t: ASTNode => + CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) + Some((t.getText, + Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case _ => None + } + } + + val escapedIdentifier = "`([^`]+)`".r + /** Strips backticks from ident if present */ + def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + + /** + * Returns the AST for the given SQL string. + */ + def getAst(sql: String): ASTNode = { + /* + * Context has to be passed in hive0.13.1. + * Otherwise, there will be Null pointer exception, + * when retrieving properties form HiveConf. + */ + val hContext = new Context(new HiveConf()) + val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) + hContext.clear() + node + } + + def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { + var remainingNodes = nodeList + val clauses = clauseNames.map { clauseName => + val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) + remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) + matches.headOption + } + + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } + clauses + } + + def getClause(clauseName: String, nodeList: Seq[Node]): Node = + getClauseOption(clauseName, nodeList).getOrElse(sys.error( + s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) + + def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { + nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { + case Seq(oneMatch) => Some(oneMatch) + case Seq() => None + case _ => sys.error(s"Found multiple instances of clause $clauseName") + } + } + + + def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { + val (db, tableName) = + tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + + (db, tableName) + } + + def extractTableIdent(tableNameParts: Node): Seq[String] = { + tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + case Seq(tableOnly) => Seq(tableOnly) + case Seq(databaseName, table) => Seq(databaseName, table) + case other => sys.error("Hive only supports tables names like 'tableName' " + + s"or 'databaseName.tableName', found '$other'") + } + } + + def dumpTree( + node: Node, + builder: StringBuilder = new StringBuilder, + indent: Int = 0): StringBuilder = { + node match { + case a: ASTNode => builder.append( + (" " * indent) + a.getText + " " + + a.getLine + ", " + + a.getTokenStartIndex + "," + + a.getTokenStopIndex + ", " + + a.getCharPositionInLine + "\n") + case other => sys.error(s"Non ASTNode encountered: $other") + } + + Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + builder + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1d8d0b5c322a..059e89618d3a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -20,9 +20,6 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, InputStreamReader, PrintStream} import java.sql.Timestamp -import org.apache.hadoop.hive.ql.parse.VariableSubstitution -import org.apache.spark.sql.catalyst.Dialect - import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -39,6 +36,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.Dialect import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} @@ -49,8 +47,14 @@ import org.apache.spark.sql.types._ * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext */ private[hive] class HiveQLDialect extends Dialect { + @transient + protected val sqlParser = { + val hiveParser = new ExtendedHiveQlParser + new SparkSQLParser(hiveParser.parse) + } + override def parse(sqlText: String): LogicalPlan = { - HiveQl.parseSql(sqlText) + sqlParser.parse(sqlText) } } @@ -93,6 +97,33 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] def convertCTAS: Boolean = getConf("spark.sql.hive.convertCTAS", "false").toBoolean + /* A catalyst metadata catalog that points to the Hive Metastore. */ + @transient + override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog + + // Note that HiveUDFs will be overridden by functions registered in this context. + @transient + override protected[sql] lazy val functionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry { + def caseSensitive: Boolean = false + } + + /* An analyzer that uses the Hive metastore. */ + @transient + override protected[sql] lazy val analyzer = + new Analyzer(catalog, functionRegistry, caseSensitive = false) { + override val extendedResolutionRules = + catalog.ParquetConversions :: + catalog.CreateTables :: + catalog.PreInsertionCasts :: + ExtractPythonUdfs :: + sources.PreInsertCastAndRename :: + Nil + } + + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution(plan) + @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -100,9 +131,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { super.parseSql(substitutor.substitute(hiveconf, sql)) } - override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - /** * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, * Spark SQL or the external data source library it uses might cache certain metadata about a @@ -232,30 +260,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { runSqlHive(s"SET $key=$value") } - /* A catalyst metadata catalog that points to the Hive Metastore. */ - @transient - override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog - - // Note that HiveUDFs will be overridden by functions registered in this context. - @transient - override protected[sql] lazy val functionRegistry = - new HiveFunctionRegistry with OverrideFunctionRegistry { - def caseSensitive: Boolean = false - } - - /* An analyzer that uses the Hive metastore. */ - @transient - override protected[sql] lazy val analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) { - override val extendedResolutionRules = - catalog.ParquetConversions :: - catalog.CreateTables :: - catalog.PreInsertionCasts :: - ExtractPythonUdfs :: - sources.PreInsertCastAndRename :: - Nil - } - override protected[sql] def createSession(): SQLSession = { new this.SQLSession() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4d222cf88e5e..50fd2690c3d0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec} -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, DDLParser, LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -218,7 +218,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } else if (table.isView) { // if the unresolved relation is from hive view // parse the text into logic node. - HiveQl.createPlanForView(table, alias) + HiveQlConverter.createPlanForView(table, alias) } else { val partitions: Seq[Partition] = if (table.isPartitioned) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQlConverter.scala similarity index 81% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQlConverter.scala index 0a86519e1412..f74225379dd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQlConverter.scala @@ -19,30 +19,26 @@ package org.apache.spark.sql.hive import java.sql.Date - -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} - import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils -import org.apache.spark.sql.{AnalysisException, SparkSQLParser} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.execution.ExplainCommand import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema} +import org.apache.spark.sql.hive.HiveASTNodeUtil._ import org.apache.spark.sql.types._ import org.apache.spark.util.random.RandomSampler +import org.apache.hadoop.hive.ql.metadata.Table /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -55,192 +51,7 @@ import scala.collection.JavaConversions._ private[hive] case object NativePlaceholder extends Command /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl { - protected val nativeCommands = Seq( - "TOK_ALTERDATABASE_OWNER", - "TOK_ALTERDATABASE_PROPERTIES", - "TOK_ALTERINDEX_PROPERTIES", - "TOK_ALTERINDEX_REBUILD", - "TOK_ALTERTABLE_ADDCOLS", - "TOK_ALTERTABLE_ADDPARTS", - "TOK_ALTERTABLE_ALTERPARTS", - "TOK_ALTERTABLE_ARCHIVE", - "TOK_ALTERTABLE_CLUSTER_SORT", - "TOK_ALTERTABLE_DROPPARTS", - "TOK_ALTERTABLE_PARTITION", - "TOK_ALTERTABLE_PROPERTIES", - "TOK_ALTERTABLE_RENAME", - "TOK_ALTERTABLE_RENAMECOL", - "TOK_ALTERTABLE_REPLACECOLS", - "TOK_ALTERTABLE_SKEWED", - "TOK_ALTERTABLE_TOUCH", - "TOK_ALTERTABLE_UNARCHIVE", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_AS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME", - - "TOK_CREATEDATABASE", - "TOK_CREATEFUNCTION", - "TOK_CREATEINDEX", - "TOK_CREATEROLE", - "TOK_CREATEVIEW", - - "TOK_DESCDATABASE", - "TOK_DESCFUNCTION", - - "TOK_DROPDATABASE", - "TOK_DROPFUNCTION", - "TOK_DROPINDEX", - "TOK_DROPROLE", - "TOK_DROPTABLE_PROPERTIES", - "TOK_DROPVIEW", - "TOK_DROPVIEW_PROPERTIES", - - "TOK_EXPORT", - - "TOK_GRANT", - "TOK_GRANT_ROLE", - - "TOK_IMPORT", - - "TOK_LOAD", - - "TOK_LOCKTABLE", - - "TOK_MSCK", - - "TOK_REVOKE", - - "TOK_SHOW_COMPACTIONS", - "TOK_SHOW_CREATETABLE", - "TOK_SHOW_GRANT", - "TOK_SHOW_ROLE_GRANT", - "TOK_SHOW_ROLE_PRINCIPALS", - "TOK_SHOW_ROLES", - "TOK_SHOW_SET_ROLE", - "TOK_SHOW_TABLESTATUS", - "TOK_SHOW_TBLPROPERTIES", - "TOK_SHOW_TRANSACTIONS", - "TOK_SHOWCOLUMNS", - "TOK_SHOWDATABASES", - "TOK_SHOWFUNCTIONS", - "TOK_SHOWINDEXES", - "TOK_SHOWLOCKS", - "TOK_SHOWPARTITIONS", - - "TOK_SWITCHDATABASE", - - "TOK_UNLOCKTABLE" - ) - - // Commands that we do not need to explain. - protected val noExplainCommands = Seq( - "TOK_DESCTABLE", - "TOK_SHOWTABLES", - "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. - ) ++ nativeCommands - - protected val hqlParser = { - val fallback = new ExtendedHiveQlParser - new SparkSQLParser(fallback.parse(_)) - } - - /** - * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations - * similar to [[catalyst.trees.TreeNode]]. - * - * Note that this should be considered very experimental and is not indented as a replacement - * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to - * have clean copy semantics. Therefore, users of this class should take care when - * copying/modifying trees that might be used elsewhere. - */ - implicit class TransformableNode(n: ASTNode) { - /** - * Returns a copy of this node where `rule` has been recursively applied to it and all of its - * children. When `rule` does not apply to a given node it is left unchanged. - * @param rule the function use to transform this nodes children - */ - def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { - try { - val afterRule = rule.applyOrElse(n, identity[ASTNode]) - afterRule.withChildren( - nilIfEmpty(afterRule.getChildren) - .asInstanceOf[Seq[ASTNode]] - .map(ast => Option(ast).map(_.transform(rule)).orNull)) - } catch { - case e: Exception => - println(dumpTree(n)) - throw e - } - } - - /** - * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. - */ - private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.toSeq).getOrElse(Nil) - - /** - * Returns this ASTNode with the text changed to `newText`. - */ - def withText(newText: String): ASTNode = { - n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) - n - } - - /** - * Returns this ASTNode with the children changed to `newChildren`. - */ - def withChildren(newChildren: Seq[ASTNode]): ASTNode = { - (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren) - n - } - - /** - * Throws an error if this is not equal to other. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def checkEquals(other: ASTNode): Unit = { - def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { - sys.error(s"$field does not match for trees. " + - s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") - } - check("name", _.getName) - check("type", _.getType) - check("text", _.getText) - check("numChildren", n => nilIfEmpty(n.getChildren).size) - - val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] - val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] - leftChildren zip rightChildren foreach { - case (l, r) => l checkEquals r - } - } - } - - /** - * Returns the AST for the given SQL string. - */ - def getAst(sql: String): ASTNode = { - /* - * Context has to be passed in hive0.13.1. - * Otherwise, there will be Null pointer exception, - * when retrieving properties form HiveConf. - */ - val hContext = new Context(new HiveConf()) - val node = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, hContext)) - hContext.clear() - node - } - - - /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) +private[hive] object HiveQlConverter { val errorRegEx = "line (\\d+):(\\d+) (.*)".r @@ -285,179 +96,6 @@ private[hive] object HiveQl { case Some(aliasText) => Subquery(aliasText, createPlan(view.getViewExpandedText)) } - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = - try { - ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) - } - assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.getChildren - val colList = - tableOps - .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")).getChildren - - colList.map(nodeToAttribute) - } - - /** Extractor for matching Hive's AST Tokens. */ - object Token { - /** @return matches of the form (tokenName, children). */ - def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { - case t: ASTNode => - CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) - Some((t.getText, - Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) - case _ => None - } - } - - protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } - - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses - } - - def getClause(clauseName: String, nodeList: Seq[Node]): Node = - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) - - def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { - nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } - - protected def nodeToAttribute(node: Node): Attribute = node match { - case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => - AttributeReference(colName, nodeToDataType(dataType), true)() - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", - Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", - keyType :: - valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") - } - - protected def nodeToStructField(node: Node): StructField = node match { - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: - _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") - } - - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { - val (db, tableName) = - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { - case Seq(tableOnly) => (None, tableOnly) - case Seq(databaseName, table) => (Some(databaseName), table) - } - - (db, tableName) - } - - protected def extractTableIdent(tableNameParts: Node): Seq[String] = { - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { - case Seq(tableOnly) => Seq(tableOnly) - case Seq(databaseName, table) => Seq(databaseName, table) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } - } - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition( n => n match { - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets - case _ => true // grouping keys - }) - - val keys = keyASTs.map(nodeToExpr).toSeq - val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap - - val bitmasks: Seq[Int] = setASTs.map(set => set match { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => - children.foldLeft(0)((bitmap, col) => { - val colString = col.asInstanceOf[ASTNode].toStringTree() - require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") - bitmap | 1 << keyMap(colString) - }) - case _ => sys.error("Expect GROUPING SETS clause") - }) - - (keys, bitmasks) - } - protected def nodeToPlan(node: Node): LogicalPlan = node match { // Special drop table that also uncaches. case Token("TOK_DROPTABLE", @@ -980,19 +618,71 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") } + protected def nodeToAttribute(node: Node): Attribute = node match { + case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => + AttributeReference(colName, nodeToDataType(dataType), true)() + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + protected def nodeToDataType(node: Node): DataType = node match { + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited + case Token("TOK_BIGINT", Nil) => LongType + case Token("TOK_INT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => ByteType + case Token("TOK_SMALLINT", Nil) => ShortType + case Token("TOK_BOOLEAN", Nil) => BooleanType + case Token("TOK_STRING", Nil) => StringType + case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_FLOAT", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType + case Token("TOK_TIMESTAMP", Nil) => TimestampType + case Token("TOK_BINARY", Nil) => BinaryType + case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) + case Token("TOK_STRUCT", + Token("TOK_TABCOLLIST", fields) :: Nil) => + StructType(fields.map(nodeToStructField)) + case Token("TOK_MAP", + keyType :: + valueType :: Nil) => + MapType(nodeToDataType(keyType), nodeToDataType(valueType)) + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") + } + + protected def nodeToStructField(node: Node): StructField = node match { + case Token("TOK_TABCOL", + Token(fieldName, Nil) :: + dataType :: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", + Token(fieldName, Nil) :: + dataType :: + _ /* comment */:: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") + } + val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r protected def nodeToDest( node: Node, query: LogicalPlan, overwrite: Boolean): LogicalPlan = node match { case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => + Token("TOK_DIR", + Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => query case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => + Token("TOK_TAB", + tableArgs) :: Nil) => val Some(tableNameParts) :: partitionClause :: Nil = getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) @@ -1009,10 +699,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => val Some(tableNameParts) :: partitionClause :: Nil = getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) @@ -1055,12 +745,35 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") } + val explode = "(?i)explode".r + def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { + val function = nodes.head + + val attributes = nodes.flatMap { + case Token(a, Nil) => a.toLowerCase :: Nil + case _ => Nil + } + + function match { + case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => + (Explode(nodeToExpr(child)), attributes) + + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + + (HiveGenericUdtf( + new HiveFunctionWrapper(functionClassName), + children.map(nodeToExpr)), attributes) - protected val escapedIdentifier = "`([^`]+)`".r - /** Strips backticks from ident if present */ - protected def cleanIdentifier(ident: String): String = ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: + |${dumpTree(a).toString} + """.stripMargin) + } } val numericAstTypes = Seq( @@ -1313,51 +1026,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } - - val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { - val function = nodes.head - - val attributes = nodes.flatMap { - case Token(a, Nil) => a.toLowerCase :: Nil - case _ => Nil + protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { + exprs.zipWithIndex.map { + case (ne: NamedExpression, _) => ne + case (e, i) => Alias(e, s"_c$i")() } + } - function match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - (Explode(nodeToExpr(child)), attributes) - - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - - (HiveGenericUdtf( - new HiveFunctionWrapper(functionClassName), - children.map(nodeToExpr)), attributes) + /** + * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) + * is equivalent to + * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 + * Check the following link for details. + * +https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup + * + * The bitmask denotes the grouping expressions validity for a grouping set, + * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) + * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of + * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. + */ + protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { + val (keyASTs, setASTs) = children.partition( n => n match { + case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets + case _ => true // grouping keys + }) - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: - |${dumpTree(a).toString} - """.stripMargin) - } - } + val keys = keyASTs.map(nodeToExpr).toSeq + val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap - def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) - : StringBuilder = { - node match { - case a: ASTNode => builder.append( - (" " * indent) + a.getText + " " + - a.getLine + ", " + - a.getTokenStartIndex + "," + - a.getTokenStopIndex + ", " + - a.getCharPositionInLine + "\n") - case other => sys.error(s"Non ASTNode encountered: $other") - } + val bitmasks: Seq[Int] = setASTs.map(set => set match { + case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 + case Token("TOK_GROUPING_SETS_EXPRESSION", children) => + children.foldLeft(0)((bitmap, col) => { + val colString = col.asInstanceOf[ASTNode].toStringTree() + require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") + bitmap | 1 << keyMap(colString) + }) + case _ => sys.error("Expect GROUPING SETS clause") + }) - Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) - builder + (keys, bitmasks) } } + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index edeab5158df6..c959c349e660 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -163,7 +163,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { // Hive parser need substituted text. HiveContext.sql() does this but return a DataFrame, // while we need a logicalPlan so we cannot reuse that. protected[hive] class HiveQLQueryExecution(hql: String) - extends this.QueryExecution(HiveQl.parseSql(vs.substitute(hiveconf, hql))) { + extends this.QueryExecution(getSQLDialect().parse(vs.substitute(hiveconf, hql))) { def hiveExec(): Seq[String] = runSqlHive(hql) override def toString: String = hql + "\n" + super.toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index d960a30e0073..bbf10196c511 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -138,7 +138,8 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { */ def positionTest(name: String, query: String, token: String): Unit = { def parseTree = - Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + Try(quietly(HiveASTNodeUtil.dumpTree(HiveASTNodeUtil.getAst(query)))) + .getOrElse("") test(name) { val error = intercept[AnalysisException] { @@ -162,7 +163,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { val actualStart = error.startPosition.getOrElse { fail( s"start not returned for error on token $token\n" + - HiveQl.dumpTree(HiveQl.getAst(query)) + HiveASTNodeUtil.dumpTree(HiveASTNodeUtil.getAst(query)) ) } assert(expectedStart === actualStart, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 00a69de9e426..7a8308edee57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,7 +33,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = HiveQl.parseSql(analyzeCommand) + val parsed = getSQLDialect().parse(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o