diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ff2c19034e0b..0d29e1e65837 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3691,6 +3691,13 @@ ], "sqlState" : "42K0G" }, + "PROTOBUF_NOT_LOADED_SQL_FUNCTIONS_UNUSABLE" : { + "message" : [ + "Cannot call the SQL function because the Protobuf data source is not loaded.", + "Please restart your job or session with the 'spark-protobuf' package loaded, such as by using the --packages argument on the command line, and then retry your query or command again." + ], + "sqlState" : "22KD3" + }, "PROTOBUF_TYPE_NOT_SUPPORT" : { "message" : [ "Protobuf type not yet supported: ." diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 5233e0688349..ddc9381dc8df 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -2042,6 +2042,166 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } } + test("SPARK-49121: from_protobuf and to_protobuf SQL functions") { + withTable("protobuf_test_table") { + sql( + """ + |CREATE TABLE protobuf_test_table AS + | SELECT named_struct( + | 'id', 1L, + | 'string_value', 'test_string', + | 'int32_value', 32, + | 'int64_value', 64L, + | 'double_value', CAST(123.456 AS DOUBLE), + | 'float_value', CAST(789.01 AS FLOAT), + | 'bool_value', true, + | 'bytes_value', CAST('sample_bytes' AS BINARY) + | ) AS complex_struct + |""".stripMargin) + + val toProtobufSql = + s""" + |SELECT + | to_protobuf( + | complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', map() + | ) AS protobuf_data + |FROM protobuf_test_table + |""".stripMargin + + val protobufResult = spark.sql(toProtobufSql).collect() + assert(protobufResult != null) + + val fromProtobufSql = + s""" + |SELECT + | from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', map()) + |FROM + | ($toProtobufSql) + |""".stripMargin + + checkAnswer( + spark.sql(fromProtobufSql), + Seq(Row(Row(1L, "test_string", 32, 64L, 123.456, 789.01F, true, "sample_bytes".getBytes))) + ) + + // Negative tests for to_protobuf. + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | to_protobuf(complex_struct, 42, '$testFileDescFile', map()) + |FROM protobuf_test_table + |""".stripMargin)), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> s"""\"toprotobuf(complex_struct, 42, $testFileDescFile, map())\"""", + "msg" -> ("The second argument of the TO_PROTOBUF SQL function must be a constant " + + "string representing the Protobuf message name"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = s"to_protobuf(complex_struct, 42, '$testFileDescFile', map())", + start = 10, + stop = 153)) + ) + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map()) + |FROM protobuf_test_table + |""".stripMargin)), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> "\"toprotobuf(complex_struct, SimpleMessageJavaTypes, 42, map())\"", + "msg" -> ("The third argument of the TO_PROTOBUF SQL function must be a constant " + + "string representing the Protobuf descriptor file path"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = "to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map())", + start = 10, + stop = 73)) + ) + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) + |FROM protobuf_test_table + |""".stripMargin)), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> + s"""\"toprotobuf(complex_struct, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", + "msg" -> ("The fourth argument of the TO_PROTOBUF SQL function must be a constant " + + "map of strings to strings containing the options to use for converting the value " + + "to Protobuf format"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = + s"to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42)", + start = 10, + stop = 172)) + ) + + // Negative tests for from_protobuf. + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT from_protobuf(protobuf_data, 42, '$testFileDescFile', map()) + |FROM ($toProtobufSql) + |""".stripMargin)), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> s"""\"fromprotobuf(protobuf_data, 42, $testFileDescFile, map())\"""", + "msg" -> ("The second argument of the FROM_PROTOBUF SQL function must be a constant " + + "string representing the Protobuf message name"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = s"from_protobuf(protobuf_data, 42, '$testFileDescFile', map())", + start = 8, + stop = 152)) + ) + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map()) + |FROM ($toProtobufSql) + |""".stripMargin)), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> "\"fromprotobuf(protobuf_data, SimpleMessageJavaTypes, 42, map())\"", + "msg" -> ("The third argument of the FROM_PROTOBUF SQL function must be a constant " + + "string representing the Protobuf descriptor file path"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = "from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map())", + start = 8, + stop = 72)) + ) + checkError( + exception = intercept[AnalysisException](sql( + s""" + |SELECT + | from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) + |FROM ($toProtobufSql) + |""".stripMargin)), + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = Map( + "sqlExpr" -> + s"""\"fromprotobuf(protobuf_data, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", + "msg" -> ("The fourth argument of the FROM_PROTOBUF SQL function must be a constant " + + "map of strings to strings containing the options to use for converting the value " + + "from Protobuf format"), + "hint" -> ""), + queryContext = Array(ExpectedContext( + fragment = + s"from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42)", + start = 10, + stop = 173)) + ) + } + } + def testFromProtobufWithOptions( df: DataFrame, expectedDf: DataFrame, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7352d2bf94a0..75dffcf58eae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -869,7 +869,11 @@ object FunctionRegistry { // Avro expression[FromAvro]("from_avro"), - expression[ToAvro]("to_avro") + expression[ToAvro]("to_avro"), + + // Protobuf + expression[FromProtobuf]("from_protobuf"), + expression[ToProtobuf]("to_protobuf") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala new file mode 100644 index 000000000000..29351107a098 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.io.File +import java.io.FileNotFoundException +import java.nio.file.NoSuchFileException + +import scala.util.control.NonFatal + +import org.apache.commons.io.FileUtils + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{MapType, NullType, StringType} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +object ProtobufHelper { + def readDescriptorFileContent(filePath: String): Array[Byte] = { + try { + FileUtils.readFileToByteArray(new File(filePath)) + } catch { + case ex: FileNotFoundException => + throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex) + case ex: NoSuchFileException => + throw new RuntimeException(s"Cannot find descriptor file at path: $filePath", ex) + case NonFatal(ex) => + throw new RuntimeException(s"Failed to read the descriptor file: $filePath", ex) + } + } +} + +/** + * Converts a binary column of Protobuf format into its corresponding catalyst value. + * The Protobuf definition is provided through Protobuf descriptor file. + * + * @param data + * The Catalyst binary input column. + * @param messageName + * The protobuf message name to look for in descriptor file. + * @param descFilePath + * The Protobuf descriptor file. This file is usually created using `protoc` with + * `--descriptor_set_out` and `--include_imports` options. + * @param options + * the options to use when performing the conversion. + * @since 4.0.0 + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(data, messageName, descFilePath, options) - Converts a binary Protobuf value into a Catalyst value. + """, + examples = """ + Examples: + > SELECT _FUNC_(s, 'Person', '/path/to/descriptor.desc', map()) IS NULL AS result FROM (SELECT NAMED_STRUCT('name', name, 'id', id) AS s FROM VALUES ('John Doe', 1), (NULL, 2) tab(name, id)); + [false] + """, + note = """ + The specified Protobuf schema must match actual schema of the read data, otherwise the behavior + is undefined: it may fail or return arbitrary result. + To deserialize the data with a compatible and evolved schema, the expected Protobuf schema can be + set via the corresponding option. + """, + group = "misc_funcs", + since = "4.0.0" +) +// scalastyle:on line.size.limit +case class FromProtobuf( + data: Expression, + messageName: Expression, + descFilePath: Expression, + options: Expression) extends QuaternaryExpression with RuntimeReplaceable { + override def first: Expression = data + override def second: Expression = messageName + override def third: Expression = descFilePath + override def fourth: Expression = options + + override def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression, + newFourth: Expression): Expression = { + copy(data = newFirst, messageName = newSecond, descFilePath = newThird, options = newFourth) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val messageNameCheck = messageName.dataType match { + case _: StringType if messageName.foldable => None + case _ => + Some(TypeCheckResult.TypeCheckFailure( + "The second argument of the FROM_PROTOBUF SQL function must be a constant string " + + "representing the Protobuf message name")) + } + val descFilePathCheck = descFilePath.dataType match { + case _: StringType if descFilePath.foldable => None + case _ => + Some(TypeCheckResult.TypeCheckFailure( + "The third argument of the FROM_PROTOBUF SQL function must be a constant string " + + "representing the Protobuf descriptor file path")) + } + val optionsCheck = options.dataType match { + case MapType(StringType, StringType, _) | + MapType(NullType, NullType, _) | + _: NullType if options.foldable => None + case _ => + Some(TypeCheckResult.TypeCheckFailure( + "The fourth argument of the FROM_PROTOBUF SQL function must be a constant map of " + + "strings to strings containing the options to use for converting the value from " + + "Protobuf format")) + } + messageNameCheck.getOrElse( + descFilePathCheck.getOrElse( + optionsCheck.getOrElse(TypeCheckResult.TypeCheckSuccess) + ) + ) + } + + override lazy val replacement: Expression = { + val messageNameValue: String = messageName.eval() match { + case null => + throw new IllegalArgumentException("Message name cannot be null") + case s: UTF8String => + s.toString + } + val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match { + case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString)) + case null => None + } + val optionsValue: Map[String, String] = options.eval() match { + case a: ArrayBasedMapData if a.keyArray.array.nonEmpty => + val keys: Array[String] = a.keyArray.array.map(_.toString) + val values: Array[String] = a.valueArray.array.map(_.toString) + keys.zip(values).toMap + case _ => Map.empty + } + val constructor = try { + Utils.classForName( + "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst").getConstructors().head + } catch { + case _: java.lang.ClassNotFoundException => + throw QueryCompilationErrors.protobufNotLoadedSqlFunctionsUnusable( + functionName = "FROM_PROTOBUF") + } + val expr = constructor.newInstance(data, messageNameValue, descFilePathValue, optionsValue) + expr.asInstanceOf[Expression] + } +} + +/** + * Converts a Catalyst binary input value into its corresponding Protobuf format result. + * This is a thin wrapper over the [[CatalystDataToProtobuf]] class to create a SQL function. + * + * @param data + * The Catalyst binary input column. + * @param messageName + * The Protobuf message name. + * @param descFilePath + * The Protobuf descriptor file path. + * @param options + * The options to use when performing the conversion. + * @since 4.0.0 + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(child, messageName, descFilePath, options) - Converts a Catalyst binary input value into its corresponding + Protobuf format result. + """, + examples = """ + Examples: + > SELECT _FUNC_(s, 'Person', '/path/to/descriptor.desc', map('emitDefaultValues', 'true')) IS NULL FROM (SELECT NULL AS s); + [true] + """, + group = "misc_funcs", + since = "4.0.0" +) +// scalastyle:on line.size.limit +case class ToProtobuf( + data: Expression, + messageName: Expression, + descFilePath: Expression, + options: Expression) extends QuaternaryExpression with RuntimeReplaceable { + override def first: Expression = data + override def second: Expression = messageName + override def third: Expression = descFilePath + override def fourth: Expression = options + + override def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression, + newFourth: Expression): Expression = { + copy(data = newFirst, messageName = newSecond, descFilePath = newThird, options = newFourth) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val messageNameCheck = messageName.dataType match { + case _: StringType if messageName.foldable => None + case _ => + Some(TypeCheckResult.TypeCheckFailure( + "The second argument of the TO_PROTOBUF SQL function must be a constant string " + + "representing the Protobuf message name")) + } + val descFilePathCheck = descFilePath.dataType match { + case _: StringType if descFilePath.foldable => None + case _ => + Some(TypeCheckResult.TypeCheckFailure( + "The third argument of the TO_PROTOBUF SQL function must be a constant string " + + "representing the Protobuf descriptor file path")) + } + val optionsCheck = options.dataType match { + case MapType(StringType, StringType, _) | + MapType(NullType, NullType, _) | + _: NullType if options.foldable => None + case _ => + Some(TypeCheckResult.TypeCheckFailure( + "The fourth argument of the TO_PROTOBUF SQL function must be a constant map of " + + "strings to strings containing the options to use for converting the value to " + + "Protobuf format")) + } + + messageNameCheck.getOrElse( + descFilePathCheck.getOrElse( + optionsCheck.getOrElse(TypeCheckResult.TypeCheckSuccess) + ) + ) + } + + override lazy val replacement: Expression = { + val messageNameValue: String = messageName.eval() match { + case null => + throw new IllegalArgumentException("Message name cannot be null") + case s: UTF8String => + s.toString + } + val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match { + case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString)) + case null => None + } + val optionsValue: Map[String, String] = options.eval() match { + case a: ArrayBasedMapData if a.keyArray.array.nonEmpty => + val keys: Array[String] = a.keyArray.array.map(_.toString) + val values: Array[String] = a.valueArray.array.map(_.toString) + keys.zip(values).toMap + case _ => Map.empty + } + val constructor = try { + Utils.classForName( + "org.apache.spark.sql.protobuf.CatalystDataToProtobuf").getConstructors().head + } catch { + case _: java.lang.ClassNotFoundException => + throw QueryCompilationErrors.protobufNotLoadedSqlFunctionsUnusable( + functionName = "TO_PROTOBUF") + } + val expr = constructor.newInstance(data, messageNameValue, descFilePathValue, optionsValue) + expr.asInstanceOf[Expression] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e52ae63c4110..9db65a57016f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -4143,6 +4143,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def protobufNotLoadedSqlFunctionsUnusable(functionName: String): Throwable = { + new AnalysisException( + errorClass = "PROTOBUF_NOT_LOADED_SQL_FUNCTIONS_UNUSABLE", + messageParameters = Map("functionName" -> functionName) + ) + } + def operationNotSupportClusteringError(operation: String): Throwable = { new AnalysisException( errorClass = "CLUSTERING_NOT_SUPPORTED", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 4bc6ab0f6e0f..2342722c0bb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -628,7 +628,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "shuffle", // other functions which are not yet supported "to_avro", - "from_avro" + "from_avro", + "to_protobuf", + "from_protobuf" ) for (funInfo <- funInfos.filter(f => !toSkip.contains(f.getName))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala index 443597f10056..8c0231fddf39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala @@ -117,10 +117,12 @@ class ExpressionsSchemaSuite extends QueryTest with SharedSparkSession { // Note: We need to filter out the commands that set the parameters, such as: // SET spark.sql.parser.escapedStringLiterals=true example.split(" > ").tail.filterNot(_.trim.startsWith("SET")).take(1).foreach { - case _ if funcName == "from_avro" || funcName == "to_avro" => - // Skip running the example queries for the from_avro and to_avro functions because - // these functions dynamically load the AvroDataToCatalyst or CatalystDataToAvro classes - // which are not available in this test. + case _ if funcName == "from_avro" || funcName == "to_avro" || + funcName == "from_protobuf" || funcName == "to_protobuf" => + // Skip running the example queries for the from_avro, to_avro, from_protobuf and + // to_protobuf functions because these functions dynamically load the + // AvroDataToCatalyst or CatalystDataToAvro classes which are not available in this + // test. case exampleRe(sql, _) => val df = spark.sql(sql) val escapedSql = sql.replaceAll("\\|", "|") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index bf5d1b24af21..20030908f445 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -228,6 +228,8 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // Requires dynamic class loading not available in this test suite. "org.apache.spark.sql.catalyst.expressions.FromAvro", "org.apache.spark.sql.catalyst.expressions.ToAvro", + "org.apache.spark.sql.catalyst.expressions.FromProtobuf", + "org.apache.spark.sql.catalyst.expressions.ToProtobuf", classOf[CurrentUser].getName, // The encrypt expression includes a random initialization vector to its encrypted result classOf[AesEncrypt].getName) diff --git a/sql/gen-sql-functions-docs.py b/sql/gen-sql-functions-docs.py index dc48a5a6155e..bb813cffb012 100644 --- a/sql/gen-sql-functions-docs.py +++ b/sql/gen-sql-functions-docs.py @@ -163,8 +163,8 @@ def _make_pretty_examples(jspark, infos): pretty_output = "" for info in infos: - if (info.examples.startswith("\n Examples:") - and info.name.lower() not in ("from_avro", "to_avro")): + if (info.examples.startswith("\n Examples:") and info.name.lower() not in + ("from_avro", "to_avro", "from_protobuf", "to_protobuf")): output = [] output.append("-- %s" % info.name) query_examples = filter(lambda x: x.startswith(" > "), info.examples.split("\n"))