diff --git a/.github/labeler.yml b/.github/labeler.yml index cf1d2a7117203..84dfa35f2627e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -155,3 +155,6 @@ CONNECT: - "connector/connect/**/*" - "**/sql/sparkconnect/**/*" - "python/pyspark/sql/**/connect/**/*" +PROTOBUF: + - "connector/protobuf/**/*" + - "python/pyspark/sql/protobuf/**/*" \ No newline at end of file diff --git a/R/check-cran.sh b/R/check-cran.sh index 22c8f423cfd12..4123361f5e285 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/create-docs.sh b/R/create-docs.sh index 4867fd99e647c..3deaefd0659dc 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/create-rd.sh b/R/create-rd.sh index 72a932c175c95..1f0527458f2f0 100755 --- a/R/create-rd.sh +++ b/R/create-rd.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/find-r.sh b/R/find-r.sh index 690acc083af91..f1a5026911a7f 100755 --- a/R/find-r.sh +++ b/R/find-r.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/install-dev.sh b/R/install-dev.sh index 9fbc999f2e805..7df21c6c5ec9a 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/install-source-package.sh b/R/install-source-package.sh index 8de3569d1d482..0a2a5fe00f31f 100755 --- a/R/install-source-package.sh +++ b/R/install-source-package.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/R/run-tests.sh b/R/run-tests.sh index ca5b661127b53..90a60eda03871 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 04adaeed7ac61..fc5e881dd0dfd 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -63,3 +63,8 @@ if [ -z "$SPARK_SCALA_VERSION" ]; then export SPARK_SCALA_VERSION=${SCALA_VERSION_2} fi fi + +# Append jline option to enable the Beeline process to run in background. +if [[ ( ! $(ps -o stat= -p $$) =~ "+" ) && ! ( -p /dev/stdin ) ]]; then + export SPARK_BEELINE_OPTS="$SPARK_BEELINE_OPTS -Djline.terminal=jline.UnsupportedTerminal" +fi diff --git a/bin/sparkR b/bin/sparkR index 29ab10df8ab6d..8ecc755839fe3 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/binder/postBuild b/binder/postBuild index 733eafe175ef0..34ead09f692f9 100644 --- a/binder/postBuild +++ b/binder/postBuild @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java b/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java index df5d6d73f2f14..b2a57060fc2d9 100644 --- a/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java +++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/SparkAvroKeyOutputFormat.java @@ -25,6 +25,7 @@ import org.apache.avro.file.CodecFactory; import org.apache.avro.file.DataFileWriter; import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroKeyOutputFormat; @@ -53,7 +54,7 @@ protected RecordWriter, NullWritable> create( CodecFactory compressionCodec, OutputStream outputStream, int syncInterval) throws IOException { - return new SparkAvroKeyRecordWriter( + return new SparkAvroKeyRecordWriter<>( writerSchema, dataModel, compressionCodec, outputStream, syncInterval, metadata); } } @@ -72,7 +73,7 @@ class SparkAvroKeyRecordWriter extends RecordWriter, NullWritable> OutputStream outputStream, int syncInterval, Map metadata) throws IOException { - this.mAvroFileWriter = new DataFileWriter(dataModel.createDatumWriter(writerSchema)); + this.mAvroFileWriter = new DataFileWriter<>(new GenericDatumWriter<>(writerSchema, dataModel)); for (Map.Entry entry : metadata.entrySet()) { this.mAvroFileWriter.setMeta(entry.getKey(), entry.getValue()); } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index a02bb067dcc4b..f8d0ac08d0073 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1075,7 +1075,7 @@ abstract class AvroSuite .save(s"$tempDir/${UUID.randomUUID()}") }.getMessage assert(message.contains("Caused by: java.lang.NullPointerException: ")) - assert(message.contains("null in string in field Name")) + assert(message.contains("null value for (non-nullable) string at test_schema.Name")) } } diff --git a/connector/connect/dev/generate_protos.sh b/connector/connect/dev/generate_protos.sh index 204beda6aa971..9457e7b33edd5 100755 --- a/connector/connect/dev/generate_protos.sh +++ b/connector/connect/dev/generate_protos.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash + # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index 791b1b5887b74..4b5a81d2a568c 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -35,6 +35,7 @@ message Expression { UnresolvedFunction unresolved_function = 3; ExpressionString expression_string = 4; UnresolvedStar unresolved_star = 5; + Alias alias = 6; } message Literal { @@ -166,4 +167,9 @@ message Expression { string name = 1; DataType type = 2; } + + message Alias { + Expression expr = 1; + string name = 2; + } } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 3ccf71c26b744..80d6e77c9fc45 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -40,6 +40,11 @@ package object dsl { .build()) .build() } + + implicit class DslExpression(val expr: proto.Expression) { + def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias( + proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build() + } } object plans { // scalastyle:ignore diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 66560f5e62f6f..5ad95a6b516ab 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -24,7 +24,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.types._ @@ -132,6 +132,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { transformUnresolvedExpression(exp) case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => transformScalarFunction(exp.getUnresolvedFunction) + case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case _ => throw InvalidPlanInput() } } @@ -208,6 +209,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } + private def transformAlias(alias: proto.Expression.Alias): Expression = { + Alias(transformExpression(alias.getExpr), alias.getName)() + } + private def transformUnion(u: proto.Union): LogicalPlan = { assert(u.getInputsCount == 2, "Union must have 2 inputs") val plan = logical.Union(transformRelation(u.getInputs(0)), transformRelation(u.getInputs(1))) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 441a3a9f1e41f..510b54cd25084 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -81,6 +81,15 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } } + test("column alias") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.expressions._ + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.select("id".protoAttr.as("id2"))) + } + val sparkPlan = sparkTestRelation.select($"id".as("id2")) + } + test("Aggregate with more than 1 grouping expressions") { val connectPlan = { import org.apache.spark.sql.connect.dsl.expressions._ diff --git a/connector/docker/build b/connector/docker/build index 253a2fc8dd8e7..de83c7d7611dc 100755 --- a/connector/docker/build +++ b/connector/docker/build @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/docker/spark-test/build b/connector/docker/spark-test/build index 6f9e19743370b..55dff4754b000 100755 --- a/connector/docker/spark-test/build +++ b/connector/docker/spark-test/build @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/docker/spark-test/master/default_cmd b/connector/docker/spark-test/master/default_cmd index 96a36cd0bb682..6865ca41b894f 100755 --- a/connector/docker/spark-test/master/default_cmd +++ b/connector/docker/spark-test/master/default_cmd @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/docker/spark-test/worker/default_cmd b/connector/docker/spark-test/worker/default_cmd index 2401f5565aa0b..1f2aac95ed699 100755 --- a/connector/docker/spark-test/worker/default_cmd +++ b/connector/docker/spark-test/worker/default_cmd @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml new file mode 100644 index 0000000000000..0515f128b8d63 --- /dev/null +++ b/connector/protobuf/pom.xml @@ -0,0 +1,115 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.12 + 3.4.0-SNAPSHOT + ../../pom.xml + + + spark-protobuf_2.12 + + protobuf + 3.21.1 + + jar + Spark Protobuf + https://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + compile + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + com.google.protobuf:* + + + + + com.google.protobuf + ${spark.shade.packageName}.spark-protobuf.protobuf + + com.google.protobuf.** + + + + + + + + diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala new file mode 100644 index 0000000000000..145100268c232 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import com.google.protobuf.DynamicMessage + +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.types.{BinaryType, DataType} + +private[protobuf] case class CatalystDataToProtobuf( + child: Expression, + descFilePath: String, + messageName: String) + extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val protoType = + ProtobufUtils.buildDescriptor(descFilePath, messageName) + + @transient private lazy val serializer = + new ProtobufSerializer(child.dataType, protoType, child.nullable) + + override def nullSafeEval(input: Any): Any = { + val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage] + dynamicMessage.toByteArray + } + + override def prettyName: String = "to_protobuf" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => s"(byte[]) $expr.nullSafeEval($input)") + } + + override protected def withNewChildInternal(newChild: Expression): CatalystDataToProtobuf = + copy(child = newChild) +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala new file mode 100644 index 0000000000000..f08f876799723 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.google.protobuf.DynamicMessage + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.protobuf.utils.{ProtobufOptions, ProtobufUtils, SchemaConverters} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType} + +private[protobuf] case class ProtobufDataToCatalyst( + child: Expression, + descFilePath: String, + messageName: String, + options: Map[String, String]) + extends UnaryExpression + with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = { + val dt = SchemaConverters.toSqlType(messageDescriptor).dataType + parseMode match { + // With PermissiveMode, the output Catalyst row might contain columns of null values for + // corrupt records, even if some of the columns are not nullable in the user-provided schema. + // Therefore we force the schema to be all nullable here. + case PermissiveMode => dt.asNullable + case _ => dt + } + } + + override def nullable: Boolean = true + + private lazy val protobufOptions = ProtobufOptions(options) + + @transient private lazy val messageDescriptor = + ProtobufUtils.buildDescriptor(descFilePath, messageName) + + @transient private lazy val fieldsNumbers = + messageDescriptor.getFields.asScala.map(f => f.getNumber) + + @transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType) + + @transient private var result: DynamicMessage = _ + + @transient private lazy val parseMode: ParseMode = { + val mode = protobufOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(unacceptableModeMessage(mode.name)) + } + mode + } + + private def unacceptableModeMessage(name: String): String = { + s"from_protobuf() doesn't support the $name mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}." + } + + @transient private lazy val nullResultRow: Any = dataType match { + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + for (i <- 0 until st.length) { + resultRow.setNullAt(i) + } + resultRow + + case _ => + null + } + + private def handleException(e: Throwable): Any = { + parseMode match { + case PermissiveMode => + nullResultRow + case FailFastMode => + throw new SparkException( + "Malformed records are detected in record parsing. " + + s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + + "result, try setting the option 'mode' as 'PERMISSIVE'.", + e) + case _ => + throw new AnalysisException(unacceptableModeMessage(parseMode.name)) + } + } + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + try { + result = DynamicMessage.parseFrom(messageDescriptor, binary) + val unknownFields = result.getUnknownFields + if (!unknownFields.asMap().isEmpty) { + unknownFields.asMap().keySet().asScala.map { number => + { + if (fieldsNumbers.contains(number)) { + return handleException( + new Throwable(s"Type mismatch encountered for field:" + + s" ${messageDescriptor.getFields.get(number)}")) + } + } + } + } + val deserialized = deserializer.deserialize(result) + assert( + deserialized.isDefined, + "Protobuf deserializer cannot return an empty result because filters are not pushed down") + deserialized.get + } catch { + // There could be multiple possible exceptions here, e.g. java.io.IOException, + // ProtoRuntimeException, ArrayIndexOutOfBoundsException, etc. + // To make it simple, catch all the exceptions here. + case NonFatal(e) => + handleException(e) + } + } + + override def prettyName: String = "from_protobuf" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + nullSafeCodeGen( + ctx, + ev, + eval => { + val result = ctx.freshName("result") + val dt = CodeGenerator.boxedType(dataType) + s""" + $dt $result = ($dt) $expr.nullSafeEval($eval); + if ($result == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = $result; + } + """ + }) + } + + override protected def withNewChildInternal(newChild: Expression): ProtobufDataToCatalyst = + copy(child = newChild) +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala new file mode 100644 index 0000000000000..0403b741ebfa7 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import java.util.concurrent.TimeUnit + +import com.google.protobuf.{ByteString, DynamicMessage, Message} +import com.google.protobuf.Descriptors._ +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.ProtobufUtils.ProtoMatchedField +import org.apache.spark.sql.protobuf.utils.ProtobufUtils.toFieldStr +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[sql] class ProtobufDeserializer( + rootDescriptor: Descriptor, + rootCatalystType: DataType, + filters: StructFilters) { + + def this(rootDescriptor: Descriptor, rootCatalystType: DataType) = { + this(rootDescriptor, rootCatalystType, new NoopFilters) + } + + private val converter: Any => Option[InternalRow] = + try { + rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (_: Any) => Some(InternalRow.empty) + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val applyFilters = filters.skipRow(resultRow, _) + val writer = getRecordWriter(rootDescriptor, st, Nil, Nil, applyFilters) + (data: Any) => { + val record = data.asInstanceOf[DynamicMessage] + val skipRow = writer(fieldUpdater, record) + if (skipRow) None else Some(resultRow) + } + } + } catch { + case ise: IncompatibleSchemaException => + throw new IncompatibleSchemaException( + s"Cannot convert Protobuf type ${rootDescriptor.getName} " + + s"to SQL type ${rootCatalystType.sql}.", + ise) + } + + def deserialize(data: Message): Option[InternalRow] = converter(data) + + private def newArrayWriter( + protoField: FieldDescriptor, + protoPath: Seq[String], + catalystPath: Seq[String], + elementType: DataType, + containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = { + + val protoElementPath = protoPath :+ "element" + val elementWriter = + newWriter(protoField, elementType, protoElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iterator = collection.iterator() + while (iterator.hasNext) { + val element = iterator.next() + if (element == null) { + if (!containsNull) { + throw QueryCompilationErrors.nullableArrayOrMapElementError(protoElementPath) + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + } + + private def newMapWriter( + protoType: FieldDescriptor, + protoPath: Seq[String], + catalystPath: Seq[String], + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = { + val keyField = protoType.getMessageType.getFields.get(0) + val valueField = protoType.getMessageType.getFields.get(1) + val keyWriter = newWriter(keyField, keyType, protoPath :+ "key", catalystPath :+ "key") + val valueWriter = + newWriter(valueField, valueType, protoPath :+ "value", catalystPath :+ "value") + (updater, ordinal, value) => + if (value != null) { + val messageList = value.asInstanceOf[java.util.List[com.google.protobuf.Message]] + val valueArray = createArrayData(valueType, messageList.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val keyArray = createArrayData(keyType, messageList.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + var i = 0 + messageList.forEach { field => + { + keyWriter(keyUpdater, i, field.getField(keyField)) + if (field.getField(valueField) == null) { + if (!valueContainsNull) { + throw QueryCompilationErrors.nullableArrayOrMapElementError(protoPath) + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, field.getField(valueField)) + } + } + i += 1 + } + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + } + } + + /** + * Creates a writer to write Protobuf values to Catalyst values at the given ordinal with the + * given updater. + */ + private def newWriter( + protoType: FieldDescriptor, + catalystType: DataType, + protoPath: Seq[String], + catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = { + val errorPrefix = s"Cannot convert Protobuf ${toFieldStr(protoPath)} to " + + s"SQL ${toFieldStr(catalystPath)} because " + val incompatibleMsg = errorPrefix + + s"schema is incompatible (protoType = ${protoType} ${protoType.toProto.getLabel} " + + s"${protoType.getJavaType} ${protoType.getType}, sqlType = ${catalystType.sql})" + + (protoType.getJavaType, catalystType) match { + + case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of Protobuf provide primitive accessors. + case (BOOLEAN, BooleanType) => + (updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => + (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, ByteType) => + (updater, ordinal, value) => updater.setByte(ordinal, value.asInstanceOf[Byte]) + + case (INT, ShortType) => + (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short]) + + case (BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, + ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated => + newArrayWriter(protoType, protoPath, catalystPath, dataType, containsNull) + + case (LONG, LongType) => + (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (FLOAT, FloatType) => + (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => + (updater, ordinal, value) => updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => + (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + } + updater.set(ordinal, str) + + case (BYTE_STRING, BinaryType) => + (updater, ordinal, value) => + val byte_array = value match { + case s: ByteString => s.toByteArray + case _ => throw new Exception("Invalid ByteString format") + } + updater.set(ordinal, byte_array) + + case (MESSAGE, MapType(keyType, valueType, valueContainsNull)) => + newMapWriter(protoType, protoPath, catalystPath, keyType, valueType, valueContainsNull) + + case (MESSAGE, TimestampType) => + (updater, ordinal, value) => + val secondsField = protoType.getMessageType.getFields.get(0) + val nanoSecondsField = protoType.getMessageType.getFields.get(1) + val message = value.asInstanceOf[DynamicMessage] + val seconds = message.getField(secondsField).asInstanceOf[Long] + val nanoSeconds = message.getField(nanoSecondsField).asInstanceOf[Int] + val micros = DateTimeUtils.millisToMicros(seconds * 1000) + updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds)) + + case (MESSAGE, DayTimeIntervalType(startField, endField)) => + (updater, ordinal, value) => + val secondsField = protoType.getMessageType.getFields.get(0) + val nanoSecondsField = protoType.getMessageType.getFields.get(1) + val message = value.asInstanceOf[DynamicMessage] + val seconds = message.getField(secondsField).asInstanceOf[Long] + val nanoSeconds = message.getField(nanoSecondsField).asInstanceOf[Int] + val micros = DateTimeUtils.millisToMicros(seconds * 1000) + updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds)) + + case (MESSAGE, st: StructType) => + val writeRecord = getRecordWriter( + protoType.getMessageType, + st, + protoPath, + catalystPath, + applyFilters = _ => false) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) + updater.set(ordinal, row) + + case (MESSAGE, ArrayType(st: StructType, containsNull)) => + newArrayWriter(protoType, protoPath, catalystPath, st, containsNull) + + case (ENUM, StringType) => + (updater, ordinal, value) => updater.set(ordinal, UTF8String.fromString(value.toString)) + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + + private def getRecordWriter( + protoType: Descriptor, + catalystType: StructType, + protoPath: Seq[String], + catalystPath: Seq[String], + applyFilters: Int => Boolean): (CatalystDataUpdater, DynamicMessage) => Boolean = { + + val protoSchemaHelper = + new ProtobufUtils.ProtoSchemaHelper(protoType, catalystType, protoPath, catalystPath) + + // TODO revisit validation of protobuf-catalyst fields. + // protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + + var i = 0 + val (validFieldIndexes, fieldWriters) = protoSchemaHelper.matchedFields + .map { case ProtoMatchedField(catalystField, ordinal, protoField) => + val baseWriter = newWriter( + protoField, + catalystField.dataType, + protoPath :+ protoField.getName, + catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + i += 1 + (protoField, fieldWriter) + } + .toArray + .unzip + + (fieldUpdater, record) => { + var i = 0 + var skipRow = false + while (i < validFieldIndexes.length && !skipRow) { + val field = validFieldIndexes(i) + val value = if (field.isRepeated || field.hasDefaultValue || record.hasField(field)) { + record.getField(field) + } else null + fieldWriters(i)(fieldUpdater, value) + skipRow = applyFilters(i) + i += 1 + } + skipRow + } + } + + // TODO: All of the code below this line is same between protobuf and avro, it can be shared. + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } + +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala new file mode 100644 index 0000000000000..5d9af92c5c077 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{Duration, DynamicMessage, Timestamp} +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.ProtobufUtils.{toFieldStr, ProtoMatchedField} +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in Protobuf format. + */ +private[sql] class ProtobufSerializer( + rootCatalystType: DataType, + rootDescriptor: Descriptor, + nullable: Boolean) + extends Logging { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val baseConverter = + try { + rootCatalystType match { + case st: StructType => + newStructConverter(st, rootDescriptor, Nil, Nil).asInstanceOf[Any => Any] + } + } catch { + case ise: IncompatibleSchemaException => + throw new IncompatibleSchemaException( + s"Cannot convert SQL type ${rootCatalystType.sql} to Protobuf type " + + s"${rootDescriptor.getName}.", + ise) + } + if (nullable) { (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private def newConverter( + catalystType: DataType, + fieldDescriptor: FieldDescriptor, + catalystPath: Seq[String], + protoPath: Seq[String]): Converter = { + val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " + + s"to Protobuf ${toFieldStr(protoPath)} because " + (catalystType, fieldDescriptor.getJavaType) match { + case (NullType, _) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => { + getter.getInt(ordinal) + } + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (StringType, ENUM) => + val enumSymbols: Set[String] = + fieldDescriptor.getEnumType.getValues.asScala.map(e => e.toString).toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + errorPrefix + + s""""$data" cannot be written since it's not defined in enum """ + + enumSymbols.mkString("\"", "\", \"", "\"")) + } + fieldDescriptor.getEnumType.findValueByName(data) + case (StringType, STRING) => + (getter, ordinal) => { + String.valueOf(getter.getUTF8String(ordinal)) + } + + case (BinaryType, BYTE_STRING) => + (getter, ordinal) => getter.getBinary(ordinal) + + case (DateType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (TimestampType, MESSAGE) => + (getter, ordinal) => + val millis = DateTimeUtils.microsToMillis(getter.getLong(ordinal)) + Timestamp.newBuilder() + .setSeconds((millis / 1000)) + .setNanos(((millis % 1000) * 1000000).toInt) + .build() + + case (ArrayType(et, containsNull), _) => + val elementConverter = + newConverter(et, fieldDescriptor, catalystPath :+ "element", protoPath :+ "element") + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // Protobuf writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (st: StructType, MESSAGE) => + val structConverter = + newStructConverter(st, fieldDescriptor.getMessageType, catalystPath, protoPath) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case (MapType(kt, vt, valueContainsNull), MESSAGE) => + var keyField: FieldDescriptor = null + var valueField: FieldDescriptor = null + fieldDescriptor.getMessageType.getFields.asScala.map { field => + field.getName match { + case "key" => + keyField = field + case "value" => + valueField = field + } + } + + val keyConverter = newConverter(kt, keyField, catalystPath :+ "key", protoPath :+ "key") + val valueConverter = + newConverter(vt, valueField, catalystPath :+ "value", protoPath :+ "value") + + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val list = new java.util.ArrayList[DynamicMessage]() + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val result = DynamicMessage.newBuilder(fieldDescriptor.getMessageType) + if (valueContainsNull && valueArray.isNullAt(i)) { + result.setField(keyField, keyConverter(keyArray, i)) + result.setField(valueField, valueField.getDefaultValue) + } else { + result.setField(keyField, keyConverter(keyArray, i)) + result.setField(valueField, valueConverter(valueArray, i)) + } + list.add(result.build()) + i += 1 + } + list + + case (DayTimeIntervalType(startField, endField), MESSAGE) => + (getter, ordinal) => + val dayTimeIntervalString = + IntervalUtils.toDayTimeIntervalString(getter.getLong(ordinal) + , ANSI_STYLE, startField, endField) + val calendarInterval = IntervalUtils.fromIntervalString(dayTimeIntervalString) + + val millis = DateTimeUtils.microsToMillis(calendarInterval.microseconds) + val duration = Duration.newBuilder() + .setSeconds((millis / 1000)) + .setNanos(((millis % 1000) * 1000000).toInt) + + if (duration.getSeconds < 0 && duration.getNanos > 0) { + duration.setSeconds(duration.getSeconds + 1) + duration.setNanos(duration.getNanos - 1000000000) + } else if (duration.getSeconds > 0 && duration.getNanos < 0) { + duration.setSeconds(duration.getSeconds - 1) + duration.setNanos(duration.getNanos + 1000000000) + } + duration.build() + + case _ => + throw new IncompatibleSchemaException( + errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, " + + s"protoType = ${fieldDescriptor.getJavaType})") + } + } + + private def newStructConverter( + catalystStruct: StructType, + descriptor: Descriptor, + catalystPath: Seq[String], + protoPath: Seq[String]): InternalRow => DynamicMessage = { + + val protoSchemaHelper = + new ProtobufUtils.ProtoSchemaHelper(descriptor, catalystStruct, protoPath, catalystPath) + + protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + protoSchemaHelper.validateNoExtraRequiredProtoFields() + + val (protoIndices, fieldConverters: Array[Converter]) = protoSchemaHelper.matchedFields + .map { case ProtoMatchedField(catalystField, _, protoField) => + val converter = newConverter( + catalystField.dataType, + protoField, + catalystPath :+ catalystField.name, + protoPath :+ protoField.getName) + (protoField, converter) + } + .toArray + .unzip + + val numFields = catalystStruct.length + row: InternalRow => + val result = DynamicMessage.newBuilder(descriptor) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + if (!protoIndices(i).isRepeated() && + protoIndices(i).getJavaType() != FieldDescriptor.JavaType.MESSAGE && + protoIndices(i).isRequired()) { + result.setField(protoIndices(i), protoIndices(i).getDefaultValue()) + } + } else { + result.setField(protoIndices(i), fieldConverters(i).apply(row, i)) + } + i += 1 + } + result.build() + } +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala new file mode 100644 index 0000000000000..283d1ca8c412c --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column + +// scalastyle:off: object.name +object functions { +// scalastyle:on: object.name + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. The + * specified schema must match actual schema of the read data, otherwise the behavior is + * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible + * and evolved schema, the expected Protobuf schema can be set via the option protoSchema. + * + * @param data + * the binary column. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @param messageName + * the protobuf message name to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def from_protobuf( + data: Column, + descFilePath: String, + messageName: String, + options: java.util.Map[String, String]): Column = { + new Column( + ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap)) + } + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. The + * specified schema must match actual schema of the read data, otherwise the behavior is + * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible + * and evolved schema, the expected Protobuf schema can be set via the option protoSchema. + * + * @param data + * the binary column. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @param messageName + * the protobuf MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = { + new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty)) + } + + /** + * Converts a column into binary of protobuf format. + * + * @param data + * the data column. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @param messageName + * the protobuf MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = { + new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName)) + } +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala new file mode 100644 index 0000000000000..82cdc6b9c5816 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +package object protobuf { + protected[protobuf] object ScalaReflectionLock +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala new file mode 100644 index 0000000000000..1cece0d7966e5 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf.utils + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} + +/** + * Options for Protobuf Reader and Writer stored in case insensitive manner. + */ +private[sql] class ProtobufOptions( + @transient val parameters: CaseInsensitiveMap[String], + @transient val conf: Configuration) + extends FileSourceOptions(parameters) + with Logging { + + def this(parameters: Map[String, String], conf: Configuration) = { + this(CaseInsensitiveMap(parameters), conf) + } + + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) +} + +private[sql] object ProtobufOptions { + def apply(parameters: Map[String, String]): ProtobufOptions = { + val hadoopConf = SparkSession.getActiveSession + .map(_.sessionState.newHadoopConf()) + .getOrElse(new Configuration()) + new ProtobufOptions(CaseInsensitiveMap(parameters), hadoopConf) + } +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala new file mode 100644 index 0000000000000..5ad043142a2d2 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf.utils + +import java.io.{BufferedInputStream, FileInputStream, IOException} +import java.util.Locale + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException} +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ + +private[sql] object ProtobufUtils extends Logging { + + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Protobuf field. */ + private[sql] case class ProtoMatchedField( + catalystField: StructField, + catalystPosition: Int, + fieldDescriptor: FieldDescriptor) + + /** + * Helper class to perform field lookup/matching on Protobuf schemas. + * + * This will match `descriptor` against `catalystSchema`, attempting to find a matching field in + * the Protobuf descriptor for each field in the Catalyst schema and vice-versa, respecting + * settings for case sensitivity. The match results can be accessed using the getter methods. + * + * @param descriptor + * The descriptor in which to search for fields. Must be of type Descriptor. + * @param catalystSchema + * The Catalyst schema to use for matching. + * @param protoPath + * The seq of parent field names leading to `protoSchema`. + * @param catalystPath + * The seq of parent field names leading to `catalystSchema`. + */ + class ProtoSchemaHelper( + descriptor: Descriptor, + catalystSchema: StructType, + protoPath: Seq[String], + catalystPath: Seq[String]) { + if (descriptor.getName == null) { + throw new IncompatibleSchemaException( + s"Attempting to treat ${descriptor.getName} as a RECORD, " + + s"but it was: ${descriptor.getContainingType}") + } + + private[this] val protoFieldArray = descriptor.getFields.asScala.toArray + private[this] val fieldMap = descriptor.getFields.asScala + .groupBy(_.getName.toLowerCase(Locale.ROOT)) + .mapValues(_.toSeq) // toSeq needed for scala 2.13 + + /** The fields which have matching equivalents in both Protobuf and Catalyst schemas. */ + val matchedFields: Seq[ProtoMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getFieldByName(sqlField.name).map(ProtoMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Protobuf field, + * throwing [[IncompatibleSchemaException]] if such extra fields are found. If + * `ignoreNullable` is false, consider nullable Catalyst fields to be eligible to be an extra + * field; otherwise, ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.fields.foreach { sqlField => + if (getFieldByName(sqlField.name).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Protobuf schema") + } + } + + /** + * Validate that there are no Protobuf fields which don't have a matching Catalyst field, + * throwing [[IncompatibleSchemaException]] if such extra fields are found. Only required + * (non-nullable) fields are checked; nullable fields are ignored. + */ + def validateNoExtraRequiredProtoFields(): Unit = { + val extraFields = protoFieldArray.toSet -- matchedFields.map(_.fieldDescriptor) + extraFields.filterNot(isNullable).foreach { extraField => + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(protoPath :+ extraField.getName())} in Protobuf schema " + + "but there is no match in the SQL schema") + } + } + + /** + * Extract a single field from the contained Protobuf schema which has the desired field name, + * performing the matching with proper case sensitivity according to SQLConf.resolver. + * + * @param name + * The name of the field to search for. + * @return + * `Some(match)` if a matching Protobuf field is found, otherwise `None`. + */ + private[protobuf] def getFieldByName(name: String): Option[FieldDescriptor] = { + + // get candidates, ignoring case of field name + val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty) + + // search candidates, taking into account case sensitivity settings + candidates.filter(f => SQLConf.get.resolver(f.getName(), name)) match { + case Seq(protoField) => Some(protoField) + case Seq() => None + case matches => + throw new IncompatibleSchemaException( + s"Searching for '$name' in " + + s"Protobuf schema at ${toFieldStr(protoPath)} gave ${matches.size} matches. " + + s"Candidates: " + matches.map(_.getName()).mkString("[", ", ", "]")) + } + } + } + + def buildDescriptor(descFilePath: String, messageName: String): Descriptor = { + val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath) + var result: Descriptors.Descriptor = null; + + for (descriptor <- fileDescriptor.getMessageTypes.asScala) { + if (descriptor.getName().equals(messageName)) { + result = descriptor + } + } + + if (null == result) { + throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor"); + } + result + } + + def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = { + var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null + try { + val dscFile = new BufferedInputStream(new FileInputStream(descFilePath)) + fileDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(dscFile) + } catch { + case ex: InvalidProtocolBufferException => + // TODO move all the exceptions to core/src/main/resources/error/error-classes.json + throw new RuntimeException("Error parsing descriptor byte[] into Descriptor object", ex) + case ex: IOException => + throw new RuntimeException( + "Error reading Protobuf descriptor file at path: " + + descFilePath, + ex) + } + + val descriptorProto: DescriptorProtos.FileDescriptorProto = fileDescriptorSet.getFile(0) + try { + val fileDescriptor: Descriptors.FileDescriptor = Descriptors.FileDescriptor.buildFrom( + descriptorProto, + new Array[Descriptors.FileDescriptor](0)) + if (fileDescriptor.getMessageTypes().isEmpty()) { + throw new RuntimeException("No MessageTypes returned, " + fileDescriptor.getName()); + } + fileDescriptor + } catch { + case e: Descriptors.DescriptorValidationException => + throw new RuntimeException("Error constructing FileDescriptor", e) + } + } + + /** + * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable + * string representing the field, like "field 'foo.bar'". If `names` is empty, the string + * "top-level record" is returned. + */ + private[protobuf] def toFieldStr(names: Seq[String]): String = names match { + case Seq() => "top-level record" + case n => s"field '${n.mkString(".")}'" + } + + /** Return true if `fieldDescriptor` is optional. */ + private[protobuf] def isNullable(fieldDescriptor: FieldDescriptor): Boolean = + !fieldDescriptor.isOptional + +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala new file mode 100644 index 0000000000000..e385b816abe70 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf.utils + +import scala.collection.JavaConverters._ + +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.protobuf.ScalaReflectionLock +import org.apache.spark.sql.types._ + +@DeveloperApi +object SchemaConverters { + + /** + * Internal wrapper for SQL data type and nullability. + * + * @since 3.4.0 + */ + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * Converts an Protobuf schema to a corresponding Spark SQL schema. + * + * @since 3.4.0 + */ + def toSqlType(descriptor: Descriptor): SchemaType = { + toSqlTypeHelper(descriptor) + } + + def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized { + SchemaType( + StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toSeq), + nullable = true) + } + + def structFieldFor( + fd: FieldDescriptor, + existingRecordNames: Set[String]): Option[StructField] = { + import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + val dataType = fd.getJavaType match { + case INT => Some(IntegerType) + case LONG => Some(LongType) + case FLOAT => Some(FloatType) + case DOUBLE => Some(DoubleType) + case BOOLEAN => Some(BooleanType) + case STRING => Some(StringType) + case BYTE_STRING => Some(BinaryType) + case ENUM => Some(StringType) + case MESSAGE if fd.getMessageType.getName == "Duration" => + Some(DayTimeIntervalType.defaultConcreteType) + case MESSAGE if fd.getMessageType.getName == "Timestamp" => + Some(TimestampType) + case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry => + var keyType: DataType = NullType + var valueType: DataType = NullType + fd.getMessageType.getFields.forEach { field => + field.getName match { + case "key" => + keyType = structFieldFor(field, existingRecordNames).get.dataType + case "value" => + valueType = structFieldFor(field, existingRecordNames).get.dataType + } + } + return Option( + StructField( + fd.getName, + MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, + nullable = false)) + case MESSAGE => + if (existingRecordNames.contains(fd.getFullName)) { + throw new IncompatibleSchemaException(s""" + |Found recursive reference in Protobuf schema, which can not be processed by Spark: + |${fd.toString()}""".stripMargin) + } + val newRecordNames = existingRecordNames + fd.getFullName + + Option( + fd.getMessageType.getFields.asScala + .flatMap(structFieldFor(_, newRecordNames.toSet)) + .toSeq) + .filter(_.nonEmpty) + .map(StructType.apply) + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Protobuf type" + + s" ${fd.getJavaType}") + } + dataType.map(dt => + StructField( + fd.getName, + if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt, + nullable = !fd.isRequired && !fd.isRepeated)) + } + + private[protobuf] class IncompatibleSchemaException(msg: String, ex: Throwable = null) + extends Exception(msg, ex) +} diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc new file mode 100644 index 0000000000000..59255b488a03d --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc @@ -0,0 +1,48 @@ + +� +Cconnector/protobuf/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.protobuf") + +BooleanMsg + bool_type (RboolType"+ + +IntegerMsg + +int32_type (R int32Type", + DoubleMsg + double_type (R +doubleType") +FloatMsg + +float_type (R floatType") +BytesMsg + +bytes_type ( R bytesType", + StringMsg + string_type ( R +stringType". +Person +name ( Rname +age (Rage"n +Bad +col_0 ( Rcol0 +col_1 (Rcol1 +col_2 ( Rcol2 +col_3 (Rcol3 +col_4 (Rcol4"q +Actual +col_0 ( Rcol0 +col_1 (Rcol1 +col_2 (Rcol2 +col_3 (Rcol3 +col_4 (Rcol4" + oldConsumer +key ( Rkey"5 + newProducer +key ( Rkey +value (Rvalue"t + newConsumer +key ( Rkey +value (Rvalue= +actual ( 2%.org.apache.spark.sql.protobuf.ActualRactual" + oldProducer +key ( RkeyBB CatalystTypesbproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto new file mode 100644 index 0000000000000..54e6bc18df153 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto +// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto + +syntax = "proto3"; + +package org.apache.spark.sql.protobuf; +option java_outer_classname = "CatalystTypes"; + +message BooleanMsg { + bool bool_type = 1; +} +message IntegerMsg { + int32 int32_type = 1; +} +message DoubleMsg { + double double_type = 1; +} +message FloatMsg { + float float_type = 1; +} +message BytesMsg { + bytes bytes_type = 1; +} +message StringMsg { + string string_type = 1; +} + +message Person { + string name = 1; + int32 age = 2; +} + +message Bad { + bytes col_0 = 1; + double col_1 = 2; + string col_2 = 3; + float col_3 = 4; + int64 col_4 = 5; +} + +message Actual { + string col_0 = 1; + int32 col_1 = 2; + float col_2 = 3; + bool col_3 = 4; + double col_4 = 5; +} + +message oldConsumer { + string key = 1; +} + +message newProducer { + string key = 1; + int32 value = 2; +} + +message newConsumer { + string key = 1; + int32 value = 2; + Actual actual = 3; +} + +message oldProducer { + string key = 1; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc new file mode 100644 index 0000000000000..6e3a396727729 Binary files /dev/null and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto new file mode 100644 index 0000000000000..f38c041b799ec --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// To compile and create test class: +// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto +// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/functions_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto + +syntax = "proto3"; + +package org.apache.spark.sql.protobuf; + +option java_outer_classname = "SimpleMessageProtos"; + +message SimpleMessageJavaTypes { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + bytes bytes_value = 8; +} + +message SimpleMessage { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + uint32 uint32_value = 4; + sint32 sint32_value = 5; + fixed32 fixed32_value = 6; + sfixed32 sfixed32_value = 7; + int64 int64_value = 8; + uint64 uint64_value = 9; + sint64 sint64_value = 10; + fixed64 fixed64_value = 11; + sfixed64 sfixed64_value = 12; + double double_value = 13; + float float_value = 14; + bool bool_value = 15; + bytes bytes_value = 16; +} + +message SimpleMessageRepeated { + string key = 1; + string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + repeated string rstring_value = 3; + repeated int32 rint32_value = 4; + repeated bool rbool_value = 5; + repeated int64 rint64_value = 6; + repeated float rfloat_value = 7; + repeated double rdouble_value = 8; + repeated bytes rbytes_value = 9; + repeated NestedEnum rnested_enum = 10; +} + +message BasicMessage { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + bytes bytes_value = 8; +} + +message RepeatedMessage { + repeated BasicMessage basic_message = 1; +} + +message SimpleMessageMap { + string key = 1; + string value = 2; + map string_mapdata = 3; + map int32_mapdata = 4; + map uint32_mapdata = 5; + map sint32_mapdata = 6; + map float32_mapdata = 7; + map sfixed32_mapdata = 8; + map int64_mapdata = 9; + map uint64_mapdata = 10; + map sint64_mapdata = 11; + map fixed64_mapdata = 12; + map sfixed64_mapdata = 13; + map double_mapdata = 14; + map float_mapdata = 15; + map bool_mapdata = 16; + map bytes_mapdata = 17; +} + +message BasicEnumMessage { + enum BasicEnum { + NOTHING = 0; + FIRST = 1; + SECOND = 2; + } +} + +message SimpleMessageEnum { + string key = 1; + string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + BasicEnumMessage.BasicEnum basic_enum = 3; + NestedEnum nested_enum = 4; +} + + +message OtherExample { + string other = 1; +} + +message IncludedExample { + string included = 1; + OtherExample other = 2; +} + +message MultipleExample { + IncludedExample included_example = 1; +} + +message recursiveA { + string keyA = 1; + recursiveB messageB = 2; +} + +message recursiveB { + string keyB = 1; + recursiveA messageA = 2; +} + +message recursiveC { + string keyC = 1; + recursiveD messageD = 2; +} + +message recursiveD { + string keyD = 1; + repeated recursiveC messageC = 2; +} + +message requiredMsg { + string key = 1; + int32 col_1 = 2; + string col_2 = 3; + int32 col_3 = 4; +} + +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto +message Timestamp { + int64 seconds = 1; + int32 nanos = 2; +} + +message timeStampMsg { + string key = 1; + Timestamp stmp = 2; +} +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/duration.proto +message Duration { + int64 seconds = 1; + int32 nanos = 2; +} + +message durationMsg { + string key = 1; + Duration duration = 2; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc new file mode 100644 index 0000000000000..3d1847eecc5c3 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc @@ -0,0 +1,27 @@ + +� +Fconnector/protobuf/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.protobuf"D + BasicMessage4 +foo ( 2".org.apache.spark.sql.protobuf.FooRfoo" +Foo +bar (Rbar"' +MissMatchTypeInRoot +foo (Rfoo"T +FieldMissingInProto= +foo ( 2+.org.apache.spark.sql.protobuf.MissingFieldRfoo"& + MissingField +barFoo (RbarFoo"\ +MissMatchTypeInDeepNested? +top ( 2-.org.apache.spark.sql.protobuf.TypeMissNestedRtop"K +TypeMissNested9 +foo ( 2'.org.apache.spark.sql.protobuf.TypeMissRfoo" +TypeMiss +bar (Rbar"_ +FieldMissingInSQLRoot4 +foo ( 2".org.apache.spark.sql.protobuf.FooRfoo +boo (Rboo"O +FieldMissingInSQLNested4 +foo ( 2".org.apache.spark.sql.protobuf.BazRfoo") +Baz +bar (Rbar +baz (RbazBBSimpleMessageProtosbproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto new file mode 100644 index 0000000000000..1e3065259aa02 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// To compile and create test class: +// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto +// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/serde_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto + +syntax = "proto3"; + +package org.apache.spark.sql.protobuf; +option java_outer_classname = "SimpleMessageProtos"; + +/* Clean Message*/ +message BasicMessage { + Foo foo = 1; +} + +message Foo { + int32 bar = 1; +} + +/* Field Type missMatch in root Message*/ +message MissMatchTypeInRoot { + int64 foo = 1; +} + +/* Field bar missing from protobuf and Available in SQL*/ +message FieldMissingInProto { + MissingField foo = 1; +} + +message MissingField { + int64 barFoo = 1; +} + +/* Deep-nested field bar type missMatch Message*/ +message MissMatchTypeInDeepNested { + TypeMissNested top = 1; +} + +message TypeMissNested { + TypeMiss foo = 1; +} + +message TypeMiss { + int64 bar = 1; +} + +/* Field boo missing from SQL root, but available in Protobuf root*/ +message FieldMissingInSQLRoot { + Foo foo = 1; + int32 boo = 2; +} + +/* Field baz missing from SQL nested and available in Protobuf nested*/ +message FieldMissingInSQLNested { + Baz foo = 1; +} + +message Baz { + int32 bar = 1; + int32 baz = 2; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..b730ebb4fea80 --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf + +import com.google.protobuf.{ByteString, DynamicMessage, Message} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} +import org.apache.spark.sql.sources.{EqualTo, Not} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ProtobufCatalystDataConversionSuite + extends SparkFunSuite + with SharedSparkSession + with ExpressionEvalHelper { + + private def checkResult( + data: Literal, + descFilePath: String, + messageName: String, + expected: Any): Unit = { + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, descFilePath, messageName), + descFilePath, + messageName, + Map.empty), + prepareExpectedResult(expected)) + } + + protected def checkUnsupportedRead( + data: Literal, + descFilePath: String, + actualSchema: String, + badSchema: String): Unit = { + + val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema) + + intercept[Exception] { + ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "FAILFAST")).eval() + } + + val expected = { + val expectedSchema = ProtobufUtils.buildDescriptor(descFilePath, badSchema) + SchemaConverters.toSqlType(expectedSchema).dataType match { + case st: StructType => + Row.fromSeq((0 until st.length).map { _ => + null + }) + case _ => null + } + } + + checkEvaluation( + ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "PERMISSIVE")), + expected) + } + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark byte and short both map to Protobuf int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + private val testingTypes = Seq( + StructType(StructField("int32_type", IntegerType, nullable = true) :: Nil), + StructType(StructField("double_type", DoubleType, nullable = true) :: Nil), + StructType(StructField("float_type", FloatType, nullable = true) :: Nil), + StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil), + StructType(StructField("string_type", StringType, nullable = true) :: Nil)) + + private val catalystTypesToProtoMessages: Map[DataType, String] = Map( + IntegerType -> "IntegerMsg", + DoubleType -> "DoubleMsg", + FloatType -> "FloatMsg", + BinaryType -> "BytesMsg", + StringType -> "StringMsg") + + testingTypes.foreach { dt => + val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1) + val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + + checkResult( + input, + filePath, + catalystTypesToProtoMessages(dt.fields(0).dataType), + input.eval()) + } + } + + private def checkDeserialization( + descFilePath: String, + messageName: String, + data: Message, + expected: Option[Any], + filters: StructFilters = new NoopFilters): Unit = { + + val descriptor = ProtobufUtils.buildDescriptor(descFilePath, messageName) + val dataType = SchemaConverters.toSqlType(descriptor).dataType + + val deserializer = new ProtobufDeserializer(descriptor, dataType, filters) + + val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray) + val deserialized = deserializer.deserialize(dynMsg) + expected match { + case None => assert(deserialized.isEmpty) + case Some(d) => + assert(checkResult(d, deserialized.get, dataType, exprNullable = false)) + } + } + + test("Handle unsupported input of message type") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val actualSchema = StructType( + Seq( + StructField("col_0", StringType, nullable = false), + StructField("col_1", IntegerType, nullable = false), + StructField("col_2", FloatType, nullable = false), + StructField("col_3", BooleanType, nullable = false), + StructField("col_4", DoubleType, nullable = false))) + + val seed = scala.util.Random.nextLong() + withClue(s"create random record with seed $seed") { + val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema) + val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema) + val input = Literal.create(converter(data), actualSchema) + checkUnsupportedRead(input, testFileDesc, "Actual", "Bad") + } + } + + test("filter push-down to Protobuf deserializer") { + + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val sqlSchema = new StructType() + .add("name", "string") + .add("age", "int") + + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person") + val dynamicMessage = DynamicMessage + .newBuilder(descriptor) + .setField(descriptor.findFieldByName("name"), "Maxim") + .setField(descriptor.findFieldByName("age"), 39) + .build() + + val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39)) + checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow) + checkDeserialization( + testFileDesc, + "Person", + dynamicMessage, + expectedRow, + new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema)) + + checkDeserialization( + testFileDesc, + "Person", + dynamicMessage, + None, + new OrderedFilters(Seq(Not(EqualTo("name", "Maxim"))), sqlSchema)) + } + + test("ProtobufDeserializer with binary type") { + + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53)) + + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") + + val dynamicMessage = DynamicMessage + .newBuilder(descriptor) + .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb)) + .build() + + val expected = InternalRow(Array[Byte](97, 48, 53)) + checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected)) + } +} diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala new file mode 100644 index 0000000000000..4e9bc1c1c287a --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -0,0 +1,615 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import java.sql.Timestamp +import java.time.Duration + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{ByteString, DynamicMessage} + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{lit, struct} +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType} + +class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Serializable { + + import testImplicits._ + + val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/") + + test("roundtrip in to_protobuf and from_protobuf - struct") { + val df = spark + .range(1, 10) + .select(struct( + $"id", + $"id".cast("string").as("string_value"), + $"id".cast("int").as("int32_value"), + $"id".cast("int").as("uint32_value"), + $"id".cast("int").as("sint32_value"), + $"id".cast("int").as("fixed32_value"), + $"id".cast("int").as("sfixed32_value"), + $"id".cast("long").as("int64_value"), + $"id".cast("long").as("uint64_value"), + $"id".cast("long").as("sint64_value"), + $"id".cast("long").as("fixed64_value"), + $"id".cast("long").as("sfixed64_value"), + $"id".cast("double").as("double_value"), + lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), + lit(true).as("bool_value"), + lit("0".getBytes).as("bytes_value")).as("SimpleMessage")) + val protoStructDF = df.select( + functions.to_protobuf($"SimpleMessage", testFileDesc, "SimpleMessage").as("proto")) + val actualDf = protoStructDF.select( + functions.from_protobuf($"proto", testFileDesc, "SimpleMessage").as("proto.*")) + checkAnswer(actualDf, df) + } + + test("roundtrip in from_protobuf and to_protobuf - Repeated") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated") + + val dynamicMessage = DynamicMessage + .newBuilder(descriptor) + .setField(descriptor.findFieldByName("key"), "key") + .setField(descriptor.findFieldByName("value"), "value") + .addRepeatedField(descriptor.findFieldByName("rbool_value"), false) + .addRepeatedField(descriptor.findFieldByName("rbool_value"), true) + .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654d) + .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654d) + .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f) + .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f) + .addRepeatedField( + descriptor.findFieldByName("rnested_enum"), + descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + .addRepeatedField( + descriptor.findFieldByName("rnested_enum"), + descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "SimpleMessageRepeated").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageRepeated").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions + .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated") + .as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") { + val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") { + val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage1 = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value1") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer1")) + .build() + val basicMessage2 = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1112L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value2") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12346) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), false) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer2")) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Map") { + val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageMap") + + val mapStr1 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "string_key") + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value1") + .build() + val mapStr2 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "string_key") + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value2") + .build() + val mapInt64 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("Int64MapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"), + 0x90000000000L) + .setField( + messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"), + 0x90000000001L) + .build() + val mapInt32 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("Int32MapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"), + 12345) + .setField( + messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"), + 54321) + .build() + val mapFloat = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("FloatMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"), + "float_key") + .setField( + messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"), + 109202.234f) + .build() + val mapDouble = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key"), + "double_key") + .setField( + messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value"), + 109202.12d) + .build() + val mapBool = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"), + true) + .setField( + messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"), + false) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(messageMapDesc) + .setField(messageMapDesc.findFieldByName("key"), "key") + .setField(messageMapDesc.findFieldByName("value"), "value") + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1) + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2) + .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64) + .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32) + .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat) + .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble) + .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "SimpleMessageMap").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageMap").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageMap").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Enum") { + val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum") + val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicEnumMessage") + + val dynamicMessage = DynamicMessage + .newBuilder(messageEnumDesc) + .setField(messageEnumDesc.findFieldByName("key"), "key") + .setField(messageEnumDesc.findFieldByName("value"), "value") + .setField( + messageEnumDesc.findFieldByName("nested_enum"), + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + .setField( + messageEnumDesc.findFieldByName("nested_enum"), + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + .setField( + messageEnumDesc.findFieldByName("basic_enum"), + basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("FIRST")) + .setField( + messageEnumDesc.findFieldByName("basic_enum"), + basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("NOTHING")) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "SimpleMessageEnum").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageEnum").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageEnum").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Multiple Message") { + val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, "MultipleExample") + val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, "IncludedExample") + val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, "OtherExample") + + val otherMessage = DynamicMessage + .newBuilder(messageOtherDesc) + .setField(messageOtherDesc.findFieldByName("other"), "other value") + .build() + + val includeMessage = DynamicMessage + .newBuilder(messageIncludeDesc) + .setField(messageIncludeDesc.findFieldByName("included"), "included value") + .setField(messageIncludeDesc.findFieldByName("other"), otherMessage) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(messageMultiDesc) + .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "MultipleExample").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "MultipleExample").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "MultipleExample").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("Handle recursive fields in Protobuf schema, A->B->A") { + val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA") + val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB") + + val messageBForA = DynamicMessage + .newBuilder(schemaB) + .setField(schemaB.findFieldByName("keyB"), "key") + .build() + + val messageA = DynamicMessage + .newBuilder(schemaA) + .setField(schemaA.findFieldByName("keyA"), "key") + .setField(schemaA.findFieldByName("messageB"), messageBForA) + .build() + + val messageB = DynamicMessage + .newBuilder(schemaB) + .setField(schemaB.findFieldByName("keyB"), "key") + .setField(schemaB.findFieldByName("messageA"), messageA) + .build() + + val df = Seq(messageB.toByteArray).toDF("messageB") + + val e = intercept[IncompatibleSchemaException] { + df.select( + functions.from_protobuf($"messageB", testFileDesc, "recursiveB").as("messageFromProto")) + .show() + } + val expectedMessage = s""" + |Found recursive reference in Protobuf schema, which can not be processed by Spark: + |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin + assert(e.getMessage == expectedMessage) + } + + test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { + val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC") + val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD") + + val messageDForC = DynamicMessage + .newBuilder(schemaD) + .setField(schemaD.findFieldByName("keyD"), "key") + .build() + + val messageC = DynamicMessage + .newBuilder(schemaC) + .setField(schemaC.findFieldByName("keyC"), "key") + .setField(schemaC.findFieldByName("messageD"), messageDForC) + .build() + + val messageD = DynamicMessage + .newBuilder(schemaD) + .setField(schemaD.findFieldByName("keyD"), "key") + .addRepeatedField(schemaD.findFieldByName("messageC"), messageC) + .build() + + val df = Seq(messageD.toByteArray).toDF("messageD") + + val e = intercept[IncompatibleSchemaException] { + df.select( + functions.from_protobuf($"messageD", testFileDesc, "recursiveD").as("messageFromProto")) + .show() + } + val expectedMessage = + s""" + |Found recursive reference in Protobuf schema, which can not be processed by Spark: + |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin + assert(e.getMessage == expectedMessage) + } + + test("Handle extra fields : oldProducer -> newConsumer") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer") + val newConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "newConsumer") + + val oldProducerMessage = DynamicMessage + .newBuilder(oldProducer) + .setField(oldProducer.findFieldByName("key"), "key") + .build() + + val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData") + val fromProtoDf = df.select( + functions + .from_protobuf($"oldProducerData", testFileDesc, "newConsumer") + .as("fromProto")) + + val toProtoDf = fromProtoDf.select( + functions + .to_protobuf($"fromProto", testFileDesc, "newConsumer") + .as("toProto")) + + val toProtoDfToFromProtoDf = toProtoDf.select( + functions + .from_protobuf($"toProto", testFileDesc, "newConsumer") + .as("toProtoToFromProto")) + + val actualFieldNames = + toProtoDfToFromProtoDf.select("toProtoToFromProto.*").schema.fields.toSeq.map(f => f.name) + newConsumer.getFields.asScala.map { f => + { + assert(actualFieldNames.contains(f.getName)) + + } + } + assert( + toProtoDfToFromProtoDf.select("toProtoToFromProto.value").take(1).toSeq(0).get(0) == null) + assert( + toProtoDfToFromProtoDf.select("toProtoToFromProto.actual.*").take(1).toSeq(0).get(0) == null) + } + + test("Handle extra fields : newProducer -> oldConsumer") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val newProducer = ProtobufUtils.buildDescriptor(testFileDesc, "newProducer") + val oldConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "oldConsumer") + + val newProducerMessage = DynamicMessage + .newBuilder(newProducer) + .setField(newProducer.findFieldByName("key"), "key") + .setField(newProducer.findFieldByName("value"), 1) + .build() + + val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData") + val fromProtoDf = df.select( + functions + .from_protobuf($"newProducerData", testFileDesc, "oldConsumer") + .as("oldConsumerProto")) + + val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName) + fromProtoDf.select("oldConsumerProto.*").schema.fields.toSeq.map { f => + { + assert(expectedFieldNames.contains(f.name)) + } + } + } + + test("roundtrip in to_protobuf and from_protobuf - with nulls") { + val schema = StructType( + StructField("requiredMsg", + StructType( + StructField("key", StringType, nullable = false) :: + StructField("col_1", IntegerType, nullable = true) :: + StructField("col_2", StringType, nullable = false) :: + StructField("col_3", IntegerType, nullable = true) :: Nil + ), + nullable = true + ) :: Nil + ) + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(Row("key1", null, "value2", null)) + )), + schema + ) + val toProtobuf = inputDf.select( + functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg") + .as("to_proto")) + + val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]] + + val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc, "requiredMsg") + val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary) + + assert(actualMessage.getField(messageDescriptor.findFieldByName("key")) + == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2")) + == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0) + + val fromProtoDf = toProtobuf.select( + functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 'from_proto) + + assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0) + == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0) + == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null) + assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null) + } + + test("from_protobuf filter to_protobuf") { + val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "slam") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + val df = Seq(basicMessage.toByteArray).toDF("value") + val resultFrom = df + .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) + .where("sample.string_value == \"slam\"") + + val resultToFrom = resultFrom + .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") as 'value) + .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) + .where("sample.string_value == \"slam\"") + + assert(resultFrom.except(resultToFrom).isEmpty) + } + + test("Handle TimestampType between to_protobuf and from_protobuf") { + val schema = StructType( + StructField("timeStampMsg", + StructType( + StructField("key", StringType, nullable = true) :: + StructField("stmp", TimestampType, nullable = true) :: Nil + ), + nullable = true + ) :: Nil + ) + + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(Row("key1", Timestamp.valueOf("2016-05-09 10:12:43.999"))) + )), + schema + ) + + val toProtoDf = inputDf + .select(functions.to_protobuf($"timeStampMsg", testFileDesc, "timeStampMsg") as 'to_proto) + + val fromProtoDf = toProtoDf + .select(functions.from_protobuf($"to_proto", testFileDesc, "timeStampMsg") as 'timeStampMsg) + fromProtoDf.show(truncate = false) + + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList + + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)) + } + + test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") { + val schema = StructType( + StructField("durationMsg", + StructType( + StructField("key", StringType, nullable = true) :: + StructField("duration", + DayTimeIntervalType.defaultConcreteType, nullable = true) :: Nil + ), + nullable = true + ) :: Nil + ) + + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(Row("key1", + Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4) + )) + )), + schema + ) + + val toProtoDf = inputDf + .select(functions.to_protobuf($"durationMsg", testFileDesc, "durationMsg") as 'to_proto) + + val fromProtoDf = toProtoDf + .select(functions.from_protobuf($"to_proto", testFileDesc, "durationMsg") as 'durationMsg) + + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList + + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) + + } +} diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala new file mode 100644 index 0000000000000..37c59743e7714 --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf + +import com.google.protobuf.Descriptors.Descriptor +import com.google.protobuf.DynamicMessage + +import org.apache.spark.sql.catalyst.NoopFilters +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructType} + +/** + * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on + * those classes. + */ +class ProtobufSerdeSuite extends SharedSparkSession { + + import ProtoSerdeSuite._ + import ProtoSerdeSuite.MatchType._ + + val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/") + + test("Test basic conversion") { + withFieldMatchType { fieldMatch => + val (top, nest) = fieldMatch match { + case BY_NAME => ("foo", "bar") + } + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val dynamicMessageFoo = DynamicMessage + .newBuilder(protoFile.getFile.findMessageTypeByName("Foo")) + .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(protoFile) + .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo) + .build() + + val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + + assert( + serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage) + } + } + + test("Fail to convert with field type mismatch") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot") + + withFieldMatchType { fieldMatch => + assertFailedConversionMessage( + protoFile, + Deserializer, + fieldMatch, + "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is incompatible " + + s"(protoType = org.apache.spark.sql.protobuf.MissMatchTypeInRoot.foo " + + s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})".stripMargin) + + assertFailedConversionMessage( + protoFile, + Serializer, + fieldMatch, + s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is incompatible " + + s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""") + } + } + + test("Fail to convert with missing nested Protobuf fields for serializer") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInProto") + + val nonnullCatalyst = new StructType() + .add("foo", new StructType().add("bar", IntegerType, nullable = false)) + + // serialize fails whether or not 'bar' is nullable + val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema" + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg) + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) + } + + test("Fail to convert with deeply nested field type mismatch") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested") + val catalyst = new StructType().add("top", CATALYST_STRUCT) + + withFieldMatchType { fieldMatch => + assertFailedConversionMessage( + protoFile, + Deserializer, + fieldMatch, + s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " + + s"is incompatible (protoType = org.apache.spark.sql.protobuf.TypeMiss.bar " + + s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin, + catalyst) + + assertFailedConversionMessage( + protoFile, + Serializer, + fieldMatch, + "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because schema " + + """is incompatible (sqlType = INT, protoType = LONG)""", + catalyst) + } + } + + test("Fail to convert with missing Catalyst fields") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage( + protoFile, + Serializer, + BY_NAME, + "Found field 'boo' in Protobuf schema but there is no match in the SQL schema") + + /* deserializing should work regardless of whether the extra field is missing + in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + + val protoNestedFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested") + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage( + protoNestedFile, + Serializer, + BY_NAME, + "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema") + + /* deserializing should work regardless of whether the extra field is missing + in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + } + + /** + * Attempt to convert `catalystSchema` to `protoSchema` (or vice-versa if `deserialize` is + * true), assert that it fails, and assert that the _cause_ of the thrown exception has a + * message matching `expectedCauseMessage`. + */ + private def assertFailedConversionMessage( + protoSchema: Descriptor, + serdeFactory: SerdeFactory[_], + fieldMatchType: MatchType, + expectedCauseMessage: String, + catalystSchema: StructType = CATALYST_STRUCT): Unit = { + val e = intercept[IncompatibleSchemaException] { + serdeFactory.create(catalystSchema, protoSchema, fieldMatchType) + } + val expectMsg = serdeFactory match { + case Deserializer => + s"Cannot convert Protobuf type ${protoSchema.getName} to SQL type ${catalystSchema.sql}." + case Serializer => + s"Cannot convert SQL type ${catalystSchema.sql} to Protobuf type ${protoSchema.getName}." + } + + assert(e.getMessage === expectMsg) + assert(e.getCause.getMessage === expectedCauseMessage) + } + + def withFieldMatchType(f: MatchType => Unit): Unit = { + MatchType.values.foreach { fieldMatchType => + withClue(s"fieldMatchType == $fieldMatchType") { + f(fieldMatchType) + } + } + } +} + +object ProtoSerdeSuite { + + val CATALYST_STRUCT = + new StructType().add("foo", new StructType().add("bar", IntegerType)) + + /** + * Specifier for type of field matching to be used for easy creation of tests that do by-name + * field matching. + */ + object MatchType extends Enumeration { + type MatchType = Value + val BY_NAME = Value + } + + import MatchType._ + + /** + * Specifier for type of serde to be used for easy creation of tests that do both serialization + * and deserialization. + */ + sealed trait SerdeFactory[T] { + def create(sqlSchema: StructType, descriptor: Descriptor, fieldMatchType: MatchType): T + } + + object Serializer extends SerdeFactory[ProtobufSerializer] { + override def create( + sql: StructType, + descriptor: Descriptor, + matchType: MatchType): ProtobufSerializer = new ProtobufSerializer(sql, descriptor, false) + } + + object Deserializer extends SerdeFactory[ProtobufDeserializer] { + override def create( + sql: StructType, + descriptor: Descriptor, + matchType: MatchType): ProtobufDeserializer = + new ProtobufDeserializer(descriptor, sql, new NoopFilters) + } +} diff --git a/core/src/main/java/org/apache/spark/SparkThrowable.java b/core/src/main/java/org/apache/spark/SparkThrowable.java index 7fb693d9c5569..e1235b2982ba0 100644 --- a/core/src/main/java/org/apache/spark/SparkThrowable.java +++ b/core/src/main/java/org/apache/spark/SparkThrowable.java @@ -51,7 +51,7 @@ default boolean isInternalError() { } default Map getMessageParameters() { - return new HashMap(); + return new HashMap<>(); } default QueryContext[] getQueryContext() { return new QueryContext[0]; } diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index a3a2dff0e2744..dd95c0f83d1f2 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -2833,11 +2833,6 @@ "Table or view '' not found." ] }, - "_LEGACY_ERROR_TEMP_1304" : { - "message" : [ - "Unexpected type of the relation ." - ] - }, "_LEGACY_ERROR_TEMP_1305" : { "message" : [ "Unsupported TableChange in JDBC catalog." @@ -3433,5 +3428,382 @@ "message" : [ "Write is not supported for binary file data source" ] + }, + "_LEGACY_ERROR_TEMP_2076" : { + "message" : [ + "The length of is , which exceeds the max length allowed: ." + ] + }, + "_LEGACY_ERROR_TEMP_2077" : { + "message" : [ + "Unsupported field name: " + ] + }, + "_LEGACY_ERROR_TEMP_2078" : { + "message" : [ + "Both '' and '' can not be specified at the same time." + ] + }, + "_LEGACY_ERROR_TEMP_2079" : { + "message" : [ + "Option '' or '' is required." + ] + }, + "_LEGACY_ERROR_TEMP_2080" : { + "message" : [ + "Option `` can not be empty." + ] + }, + "_LEGACY_ERROR_TEMP_2081" : { + "message" : [ + "Invalid value `` for parameter ``. This can be `NONE`, `READ_UNCOMMITTED`, `READ_COMMITTED`, `REPEATABLE_READ` or `SERIALIZABLE`." + ] + }, + "_LEGACY_ERROR_TEMP_2082" : { + "message" : [ + "Can't get JDBC type for " + ] + }, + "_LEGACY_ERROR_TEMP_2083" : { + "message" : [ + "Unsupported type " + ] + }, + "_LEGACY_ERROR_TEMP_2084" : { + "message" : [ + "Unsupported array element type based on binary" + ] + }, + "_LEGACY_ERROR_TEMP_2085" : { + "message" : [ + "Nested arrays unsupported" + ] + }, + "_LEGACY_ERROR_TEMP_2086" : { + "message" : [ + "Can't translate non-null value for field " + ] + }, + "_LEGACY_ERROR_TEMP_2087" : { + "message" : [ + "Invalid value `` for parameter `` in table writing via JDBC. The minimum value is 1." + ] + }, + "_LEGACY_ERROR_TEMP_2088" : { + "message" : [ + " is not supported yet." + ] + }, + "_LEGACY_ERROR_TEMP_2089" : { + "message" : [ + "DataType: " + ] + }, + "_LEGACY_ERROR_TEMP_2090" : { + "message" : [ + "The input filter of should be fully convertible." + ] + }, + "_LEGACY_ERROR_TEMP_2091" : { + "message" : [ + "Could not read footer for file: " + ] + }, + "_LEGACY_ERROR_TEMP_2092" : { + "message" : [ + "Could not read footer for file: " + ] + }, + "_LEGACY_ERROR_TEMP_2093" : { + "message" : [ + "Found duplicate field(s) \"\": in case-insensitive mode" + ] + }, + "_LEGACY_ERROR_TEMP_2094" : { + "message" : [ + "Found duplicate field(s) \"\": in id mapping mode" + ] + }, + "_LEGACY_ERROR_TEMP_2095" : { + "message" : [ + "Failed to merge incompatible schemas and " + ] + }, + "_LEGACY_ERROR_TEMP_2096" : { + "message" : [ + " is not supported temporarily." + ] + }, + "_LEGACY_ERROR_TEMP_2097" : { + "message" : [ + "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1" + ] + }, + "_LEGACY_ERROR_TEMP_2098" : { + "message" : [ + "Could not compare cost with " + ] + }, + "_LEGACY_ERROR_TEMP_2099" : { + "message" : [ + "Unsupported data type:
" + ] + }, + "_LEGACY_ERROR_TEMP_2100" : { + "message" : [ + "not support type: " + ] + }, + "_LEGACY_ERROR_TEMP_2101" : { + "message" : [ + "Not support non-primitive type now" + ] + }, + "_LEGACY_ERROR_TEMP_2102" : { + "message" : [ + "Unsupported type: " + ] + }, + "_LEGACY_ERROR_TEMP_2103" : { + "message" : [ + "Dictionary encoding should not be used because of dictionary overflow." + ] + }, + "_LEGACY_ERROR_TEMP_2104" : { + "message" : [ + "End of the iterator" + ] + }, + "_LEGACY_ERROR_TEMP_2105" : { + "message" : [ + "Could not allocate memory to grow BytesToBytesMap" + ] + }, + "_LEGACY_ERROR_TEMP_2106" : { + "message" : [ + "Can't acquire bytes memory to build hash relation, got bytes" + ] + }, + "_LEGACY_ERROR_TEMP_2107" : { + "message" : [ + "There is not enough memory to build hash map" + ] + }, + "_LEGACY_ERROR_TEMP_2108" : { + "message" : [ + "Does not support row that is larger than 256M" + ] + }, + "_LEGACY_ERROR_TEMP_2109" : { + "message" : [ + "Cannot build HashedRelation with more than 1/3 billions unique keys" + ] + }, + "_LEGACY_ERROR_TEMP_2110" : { + "message" : [ + "Can not build a HashedRelation that is larger than 8G" + ] + }, + "_LEGACY_ERROR_TEMP_2111" : { + "message" : [ + "failed to push a row into " + ] + }, + "_LEGACY_ERROR_TEMP_2112" : { + "message" : [ + "Unexpected window function frame ." + ] + }, + "_LEGACY_ERROR_TEMP_2113" : { + "message" : [ + "Unable to parse as a percentile" + ] + }, + "_LEGACY_ERROR_TEMP_2114" : { + "message" : [ + " is not a recognised statistic" + ] + }, + "_LEGACY_ERROR_TEMP_2115" : { + "message" : [ + "Unknown column: " + ] + }, + "_LEGACY_ERROR_TEMP_2116" : { + "message" : [ + "Unexpected: " + ] + }, + "_LEGACY_ERROR_TEMP_2117" : { + "message" : [ + "Unscaled value too large for precision. If necessary set to false to bypass this error." + ] + }, + "_LEGACY_ERROR_TEMP_2118" : { + "message" : [ + "Decimal precision exceeds max precision " + ] + }, + "_LEGACY_ERROR_TEMP_2119" : { + "message" : [ + "out of decimal type range: " + ] + }, + "_LEGACY_ERROR_TEMP_2120" : { + "message" : [ + "Do not support array of type ." + ] + }, + "_LEGACY_ERROR_TEMP_2121" : { + "message" : [ + "Do not support type ." + ] + }, + "_LEGACY_ERROR_TEMP_2122" : { + "message" : [ + "Failed parsing : " + ] + }, + "_LEGACY_ERROR_TEMP_2123" : { + "message" : [ + "Failed to merge fields '' and ''. " + ] + }, + "_LEGACY_ERROR_TEMP_2124" : { + "message" : [ + "Failed to merge decimal types with incompatible scale and " + ] + }, + "_LEGACY_ERROR_TEMP_2125" : { + "message" : [ + "Failed to merge incompatible data types ${leftCatalogString} and ${rightCatalogString}" + ] + }, + "_LEGACY_ERROR_TEMP_2126" : { + "message" : [ + "Unsuccessful attempt to build maps with elements due to exceeding the map size limit ." + ] + }, + "_LEGACY_ERROR_TEMP_2127" : { + "message" : [ + "Duplicate map key was found, please check the input data. If you want to remove the duplicated keys, you can set to so that the key inserted at last takes precedence." + ] + }, + "_LEGACY_ERROR_TEMP_2128" : { + "message" : [ + "The key array and value array of MapData must have the same length." + ] + }, + "_LEGACY_ERROR_TEMP_2129" : { + "message" : [ + "Conflict found: Field differs from derived from " + ] + }, + "_LEGACY_ERROR_TEMP_2130" : { + "message" : [ + "Fail to recognize '' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html" + ] + }, + "_LEGACY_ERROR_TEMP_2131" : { + "message" : [ + "Exception when registering StreamingQueryListener" + ] + }, + "_LEGACY_ERROR_TEMP_2132" : { + "message" : [ + "Parsing JSON arrays as structs is forbidden." + ] + }, + "_LEGACY_ERROR_TEMP_2133" : { + "message" : [ + "Cannot parse field name , field value , [] as target spark data type []." + ] + }, + "_LEGACY_ERROR_TEMP_2134" : { + "message" : [ + "Cannot parse field value for pattern as target spark data type []." + ] + }, + "_LEGACY_ERROR_TEMP_2135" : { + "message" : [ + "Failed to parse an empty string for data type " + ] + }, + "_LEGACY_ERROR_TEMP_2136" : { + "message" : [ + "Failed to parse field name , field value , [] to target spark data type []." + ] + }, + "_LEGACY_ERROR_TEMP_2137" : { + "message" : [ + "Root converter returned null" + ] + }, + "_LEGACY_ERROR_TEMP_2138" : { + "message" : [ + "Cannot have circular references in bean class, but got the circular reference of class " + ] + }, + "_LEGACY_ERROR_TEMP_2139" : { + "message" : [ + "cannot have circular references in class, but got the circular reference of class " + ] + }, + "_LEGACY_ERROR_TEMP_2140" : { + "message" : [ + "`` is not a valid identifier of Java and cannot be used as field name", + "" + ] + }, + "_LEGACY_ERROR_TEMP_2141" : { + "message" : [ + "No Encoder found for ", + "" + ] + }, + "_LEGACY_ERROR_TEMP_2142" : { + "message" : [ + "Attributes for type is not supported" + ] + }, + "_LEGACY_ERROR_TEMP_2143" : { + "message" : [ + "Schema for type is not supported" + ] + }, + "_LEGACY_ERROR_TEMP_2144" : { + "message" : [ + "Unable to find constructor for . This could happen if is an interface, or a trait without companion object constructor." + ] + }, + "_LEGACY_ERROR_TEMP_2145" : { + "message" : [ + " cannot be more than one character" + ] + }, + "_LEGACY_ERROR_TEMP_2146" : { + "message" : [ + " should be an integer. Found " + ] + }, + "_LEGACY_ERROR_TEMP_2147" : { + "message" : [ + " flag can be true or false" + ] + }, + "_LEGACY_ERROR_TEMP_2148" : { + "message" : [ + "null value found but field is not nullable." + ] + }, + "_LEGACY_ERROR_TEMP_2149" : { + "message" : [ + "Malformed CSV record" + ] + }, + "_LEGACY_ERROR_TEMP_2150" : { + "message" : [ + "Due to Scala's limited support of tuple, tuple with more than 22 elements are not supported." + ] } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 3a1f0862e94e1..56fb5bf6c6cfe 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1023,7 +1023,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { file.deleteOnExit() val cmd = s""" - |#!/bin/bash + |#!/usr/bin/env bash |trap "" SIGTERM |sleep 10 """.stripMargin diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index e21a39a688170..1621432c01c65 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -508,8 +508,11 @@ def main(): else: title = pr["title"] - modified_body = re.sub(re.compile(r"\n?", re.DOTALL), "", pr["body"]).lstrip() - if modified_body != pr["body"]: + body = pr["body"] + if body is None: + body = "" + modified_body = re.sub(re.compile(r"\n?", re.DOTALL), "", body).lstrip() + if modified_body != body: print("=" * 80) print(modified_body) print("=" * 80) @@ -519,13 +522,10 @@ def main(): body = modified_body print("Using modified body:") else: - body = pr["body"] print("Using original body:") print("=" * 80) print(body) print("=" * 80) - else: - body = pr["body"] target_ref = pr["base"]["ref"] user_login = pr["user"]["login"] base_ref = pr["head"]["ref"] diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index 9acfa5f2db4c1..d736ff8f83f3d 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -329,7 +329,7 @@ Data skew can severely downgrade the performance of join queries. This feature d spark.sql.adaptive.skewJoin.skewedPartitionFactor - 5 + 5.0 A partition is considered as skewed if its size is larger than this factor multiplying the median partition size and also larger than spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes. diff --git a/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java b/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java index 978466cd77ccd..013dde2766f49 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java +++ b/launcher/src/main/java/org/apache/spark/launcher/JavaModuleOptions.java @@ -40,7 +40,8 @@ public class JavaModuleOptions { "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", "--add-opens=java.base/sun.security.action=ALL-UNNAMED", "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", - "--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED"}; + "--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED", + "-Djdk.reflect.useDirectMethodHandle=false"}; /** * Returns the default Java options related to `--add-opens' and diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala index 8f78bcc15347f..8016258f054a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala @@ -20,8 +20,10 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.Utils /** * Machine learning specific Pair RDD functions. @@ -37,14 +39,30 @@ class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Se * @return an RDD that contains the top k values for each key */ def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = { - self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))( - seqOp = (queue, item) => { - queue += item - }, - combOp = (queue1, queue2) => { - queue1 ++= queue2 - } - ).mapValues(_.toArray.sorted(ord.reverse)) // This is a min-heap, so we reverse the order. + val createCombiner = (v: V) => new BoundedPriorityQueue[V](num)(ord) += v + val mergeValue = (c: BoundedPriorityQueue[V], v: V) => c += v + val mergeCombiners = (c1: BoundedPriorityQueue[V], c2: BoundedPriorityQueue[V]) => c1 ++= c2 + + val aggregator = new Aggregator[K, V, BoundedPriorityQueue[V]]( + self.context.clean(createCombiner), + self.context.clean(mergeValue), + self.context.clean(mergeCombiners)) + + self.mapPartitions(iter => { + val context = TaskContext.get() + new InterruptibleIterator( + context, + aggregator + .combineValuesByKey(iter, context) + .map { case (k, v) => (k, v.toArray.sorted(ord.reverse)) } + ) + }, preservesPartitioning = true + ).reduceByKey { (array1, array2) => + val size = math.min(num, array1.length + array2.length) + val array = Array.ofDim[V](size) + Utils.mergeOrdered[V](Seq(array1, array2))(ord.reverse).copyToArray(array, 0, size) + array + } } } diff --git a/pom.xml b/pom.xml index ca354d16e6242..cab9929954be5 100644 --- a/pom.xml +++ b/pom.xml @@ -101,6 +101,7 @@ connector/kafka-0-10-sql connector/avro connector/connect + connector/protobuf @@ -312,6 +313,7 @@ --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED + -Djdk.reflect.useDirectMethodHandle=false @@ -3084,6 +3086,9 @@ spark-warehouse + + dist + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d3609a44b5e40..03970bb862cec 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "protobuf" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -59,7 +59,7 @@ object BuildCommons { ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "tags", "sketch", "kvstore" - ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) ++ Seq(protobuf) val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, sparkGangliaLgpl, streamingKinesisAsl, @@ -390,7 +390,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, tokenProviderKafka010, sqlKafka010, connect + unsafe, tags, tokenProviderKafka010, sqlKafka010, connect, protobuf ).contains(x) } @@ -433,6 +433,9 @@ object SparkBuild extends PomBuild { enable(SparkConnect.settings)(connect) + /* Connector/proto settings */ + enable(SparkProtobuf.settings)(protobuf) + // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) @@ -699,6 +702,48 @@ object SparkConnect { ) } +object SparkProtobuf { + + import BuildCommons.protoVersion + + private val shadePrefix = "org.sparkproject.spark-protobuf" + val shadeJar = taskKey[Unit]("Shade the Jars") + + lazy val settings = Seq( + // Setting version for the protobuf compiler. This has to be propagated to every sub-project + // even if the project is not using it. + PB.protocVersion := BuildCommons.protoVersion, + + // For some reason the resolution from the imported Maven build does not work for some + // of these dependendencies that we need to shade later on. + libraryDependencies ++= Seq( + "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" + ), + + dependencyOverrides ++= Seq( + "com.google.protobuf" % "protobuf-java" % protoVersion + ), + + (Compile / PB.targets) := Seq( + PB.gens.java -> (Compile / sourceManaged).value, + ), + + (assembly / test) := false, + + (assembly / logLevel) := Level.Info, + + (assembly / assemblyShadeRules) := Seq( + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.spark-protobuf.protobuf.@1").inAll, + ), + + (assembly / assemblyMergeStrategy) := { + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard + // Drop all proto files that are not needed as artifacts of the build. + case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard + case _ => MergeStrategy.first + }, + ) +} object Unsafe { lazy val settings = Seq( // This option is needed to suppress warnings from sun.misc.Unsafe usage @@ -1143,10 +1188,10 @@ object Unidoc { (ScalaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connect), + yarn, tags, streamingKafka010, sqlKafka010, connect, protobuf), (JavaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connect), + yarn, tags, streamingKafka010, sqlKafka010, connect, protobuf), (ScalaUnidoc / unidoc / unidocAllClasspaths) := { ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value) @@ -1232,6 +1277,7 @@ object CopyDependencies { // produce the shaded Jar which happens automatically in the case of Maven. // Later, when the dependencies are copied, we manually copy the shaded Jar only. val fid = (LocalProject("connect") / assembly).value + val fidProtobuf = (LocalProject("protobuf")/assembly).value (Compile / dependencyClasspath).value.map(_.data) .filter { jar => jar.isFile() } @@ -1244,6 +1290,9 @@ object CopyDependencies { if (jar.getName.contains("spark-connect") && !SbtPomKeys.profiles.value.contains("noshade-connect")) { Files.copy(fid.toPath, destJar.toPath) + } else if (jar.getName.contains("spark-protobuf") && + !SbtPomKeys.profiles.value.contains("noshade-protobuf")) { + Files.copy(fid.toPath, destJar.toPath) } else { Files.copy(jar.toPath(), destJar.toPath()) } @@ -1313,7 +1362,8 @@ object TestSettings { "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", "--add-opens=java.base/sun.security.action=ALL-UNNAMED", - "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED").mkString(" ") + "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", + "-Djdk.reflect.useDirectMethodHandle=false").mkString(" ") s"-Xmx4g -Xss4m -XX:MaxMetaspaceSize=$metaspaceSize -XX:ReservedCodeCacheSize=128m -Dfile.encoding=UTF-8 $extraTestJavaArgs" .split(" ").toSeq }, diff --git a/project/plugins.sbt b/project/plugins.sbt index 24023cc7ad0a7..f6b4fb32a2022 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -22,8 +22,8 @@ addSbtPlugin("com.etsy" % "sbt-checkstyle-plugin" % "3.1.1") // please check pom.xml in the root of the source tree too. libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "9.3" -// checkstyle uses guava 23.0. -libraryDependencies += "com.google.guava" % "guava" % "23.0" +// checkstyle uses guava 31.0.1-jre. +libraryDependencies += "com.google.guava" % "guava" % "31.0.1-jre" addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0") diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 1aa94623ac31f..835c13d6fdd50 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1747,7 +1747,6 @@ def corrwith( sdf = combined._internal.spark_frame index_col_name = verify_temp_column_name(sdf, "__corrwith_index_temp_column__") - tuple_col_name = verify_temp_column_name(sdf, "__corrwith_tuple_temp_column__") this_numeric_column_labels: List[Label] = [] for column_label in this._internal.column_labels: @@ -1797,15 +1796,7 @@ def corrwith( ) if len(pair_scols) > 0: - sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tuple_col_name)).select( - F.col(f"{tuple_col_name}.{index_col_name}").alias(index_col_name), - F.col(f"{tuple_col_name}.{CORRELATION_VALUE_1_COLUMN}").alias( - CORRELATION_VALUE_1_COLUMN - ), - F.col(f"{tuple_col_name}.{CORRELATION_VALUE_2_COLUMN}").alias( - CORRELATION_VALUE_2_COLUMN - ), - ) + sdf = sdf.select(F.inline(F.array(*pair_scols))) sdf = compute(sdf=sdf, groupKeys=[index_col_name], method=method).select( index_col_name, CORRELATION_CORR_OUTPUT_COLUMN diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index b2525ce9a60ad..c5dbcb79710a5 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -468,21 +468,10 @@ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fr if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def first(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.first(col, ignorenulls=True)) - - else: - - def first(col: Column) -> Column: - return F.first(col, ignorenulls=True) - return self._reduce_for_stat_function( - first, + lambda col: F.first(col, ignorenulls=True), accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike: @@ -549,21 +538,10 @@ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fra if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def last(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.last(col, ignorenulls=True)) - - else: - - def last(col: Column) -> Column: - return F.last(col, ignorenulls=True) - return self._reduce_for_stat_function( - last, + lambda col: F.last(col, ignorenulls=True), accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike: @@ -624,20 +602,10 @@ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def max(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.max(col)) - - else: - - def max(col: Column) -> Column: - return F.max(col) - return self._reduce_for_stat_function( - max, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None + F.max, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def mean(self, numeric_only: Optional[bool] = True) -> FrameLike: @@ -802,20 +770,10 @@ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def min(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.min(col)) - - else: - - def min(col: Column) -> Column: - return F.min(col) - return self._reduce_for_stat_function( - min, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None + F.min, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) # TODO: sync the doc. @@ -944,20 +902,11 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL f"numeric_only=False, skip unsupported columns: {unsupported}" ) - if min_count > 0: - - def sum(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.sum(col)) - - else: - - def sum(col: Column) -> Column: - return F.sum(col) - return self._reduce_for_stat_function( - sum, accepted_spark_types=(NumericType,), bool_to_numeric=True + F.sum, + accepted_spark_types=(NumericType, BooleanType), + bool_to_numeric=True, + min_count=min_count, ) # TODO: sync the doc. @@ -1324,22 +1273,11 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame self._validate_agg_columns(numeric_only=numeric_only, function_name="prod") - if min_count > 0: - - def prod(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(SF.product(col, True)) - - else: - - def prod(col: Column) -> Column: - return SF.product(col, True) - return self._reduce_for_stat_function( - prod, + lambda col: SF.product(col, True), accepted_spark_types=(NumericType, BooleanType), bool_to_numeric=True, + min_count=min_count, ) def all(self, skipna: bool = True) -> FrameLike: @@ -3596,6 +3534,7 @@ def _reduce_for_stat_function( sfun: Callable[[Column], Column], accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None, bool_to_numeric: bool = False, + **kwargs: Any, ) -> FrameLike: """Apply an aggregate function `sfun` per column and reduce to a FrameLike. @@ -3615,14 +3554,19 @@ def _reduce_for_stat_function( psdf: DataFrame = DataFrame(internal) if len(psdf._internal.column_labels) > 0: + min_count = kwargs.get("min_count", 0) stat_exprs = [] for label in psdf._internal.column_labels: psser = psdf._psser_for(label) - stat_exprs.append( - sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias( - psser._internal.data_spark_column_names[0] + input_scol = psser._dtype_op.nan_to_null(psser).spark.column + output_scol = sfun(input_scol) + + if min_count > 0: + output_scol = F.when( + F.count(F.when(~F.isnull(input_scol), F.lit(0))) >= min_count, output_scol ) - ) + + stat_exprs.append(output_scol.alias(psser._internal.data_spark_column_names[0])) sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs) else: sdf = sdf.select(*groupkey_names).distinct() diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index bc19b66a8d14c..9b2f689a72548 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -3254,7 +3254,7 @@ def add_suffix(self, suffix: str) -> "Series": DataFrame(internal.with_new_sdf(sdf, index_fields=([None] * internal.index_level))) ) - def autocorr(self, periods: int = 1) -> float: + def autocorr(self, lag: int = 1) -> float: """ Compute the lag-N autocorrelation. @@ -3270,7 +3270,7 @@ def autocorr(self, periods: int = 1) -> float: Parameters ---------- - periods : int, default 1 + lag : int, default 1 Number of lags to apply before performing autocorrelation. Returns @@ -3312,15 +3312,15 @@ def autocorr(self, periods: int = 1) -> float: """ # This implementation is suboptimal because it moves all data to a single partition, # global sort should be used instead of window, but it should be a start - if not isinstance(periods, int): - raise TypeError("periods should be an int; however, got [%s]" % type(periods).__name__) + if not isinstance(lag, int): + raise TypeError("lag should be an int; however, got [%s]" % type(lag).__name__) sdf = self._internal.spark_frame scol = self.spark.column - if periods == 0: + if lag == 0: corr = sdf.select(F.corr(scol, scol)).head()[0] else: - lag_scol = F.lag(scol, periods).over(Window.orderBy(NATURAL_ORDER_COLUMN_NAME)) + lag_scol = F.lag(scol, lag).over(Window.orderBy(NATURAL_ORDER_COLUMN_NAME)) lag_col_name = verify_temp_column_name(sdf, "__autocorr_lag_tmp_col__") corr = ( sdf.withColumn(lag_col_name, lag_scol) diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index af6ecd6152b6e..e47f716ecfc26 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -3240,7 +3240,7 @@ def test_autocorr(self): self._test_autocorr(pdf) psser = ps.from_pandas(pdf["s1"]) - with self.assertRaisesRegex(TypeError, r"periods should be an int; however, got"): + with self.assertRaisesRegex(TypeError, r"lag should be an int; however, got"): psser.autocorr(1.0) def _test_autocorr(self, pdf): diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 5691011795dcf..780cfdfba8e9b 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/CharVarcharCodegenUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/CharVarcharCodegenUtils.java index 581f4bb6d259f..582b697c92a2b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/CharVarcharCodegenUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/CharVarcharCodegenUtils.java @@ -52,4 +52,15 @@ public static UTF8String varcharTypeWriteSideCheck(UTF8String inputStr, int limi return trimTrailingSpaces(inputStr, numChars, limit); } } + + public static UTF8String readSidePadding(UTF8String inputStr, int limit) { + int numChars = inputStr.numChars(); + if (numChars == limit) { + return inputStr; + } else if (numChars < limit) { + return inputStr.rpad(limit, SPACE); + } else { + return inputStr; + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java index 09b26d8f793f7..6c9e5ac577a7b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsAtomicPartitionManagement.java @@ -45,6 +45,7 @@ @Experimental public interface SupportsAtomicPartitionManagement extends SupportsPartitionManagement { + @SuppressWarnings("unchecked") @Override default void createPartition( InternalRow ident, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index b32958d13daf1..fe16174586bad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -47,7 +47,7 @@ public class V2ExpressionSQLBuilder { public String build(Expression expr) { if (expr instanceof Literal) { - return visitLiteral((Literal) expr); + return visitLiteral((Literal) expr); } else if (expr instanceof NamedReference) { return visitNamedReference((NamedReference) expr); } else if (expr instanceof Cast) { @@ -213,7 +213,7 @@ public String build(Expression expr) { } } - protected String visitLiteral(Literal literal) { + protected String visitLiteral(Literal literal) { return literal.toString(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 7acd27759a1ba..844734ff7ccb7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -21,6 +21,7 @@ import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -68,6 +69,19 @@ default String description() { * be returned by the scan, even if a filter can narrow the set of changes to a single file * in the partition. Similarly, a data source that can swap individual files must produce all * rows from files where at least one record must be changed, not just rows that must be changed. + *

+ * Data sources that replace groups of data (e.g. files, partitions) may prune entire groups + * using provided data source filters when building a scan for this row-level operation. + * However, such data skipping is limited as not all expressions can be converted into data source + * filters and some can only be evaluated by Spark (e.g. subqueries). Since rewriting groups is + * expensive, Spark allows group-based data sources to filter groups at runtime. The runtime + * filtering enables data sources to narrow down the scope of rewriting to only groups that must + * be rewritten. If the row-level operation scan implements {@link SupportsRuntimeV2Filtering}, + * Spark will execute a query at runtime to find which records match the row-level condition. + * The runtime group filter subquery will leverage a regular batch scan, which isn't required to + * produce all rows in a group if any are returned. The information about matching records will + * be passed back into the row-level operation scan, allowing data sources to discard groups + * that don't have to be rewritten. */ ScanBuilder newScanBuilder(CaseInsensitiveStringMap options); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java index 444263f31113e..283258ecb0a55 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.Random; @@ -53,20 +54,20 @@ public class NumericHistogram { * * @since 3.3.0 */ - public static class Coord implements Comparable { + public static class Coord implements Comparable { public double x; public double y; @Override - public int compareTo(Object other) { - return Double.compare(x, ((Coord) other).x); + public int compareTo(Coord other) { + return Double.compare(x, other.x); } } // Class variables private int nbins; private int nusedbins; - private ArrayList bins; + private List bins; private Random prng; /** @@ -146,7 +147,7 @@ public void addBin(double x, double y, int b) { */ public void allocate(int num_bins) { nbins = num_bins; - bins = new ArrayList(); + bins = new ArrayList<>(); nusedbins = 0; } @@ -163,7 +164,7 @@ public void merge(NumericHistogram other) { // by deserializing the ArrayList of (x,y) pairs into an array of Coord objects nbins = other.nbins; nusedbins = other.nusedbins; - bins = new ArrayList(nusedbins); + bins = new ArrayList<>(nusedbins); for (int i = 0; i < other.nusedbins; i += 1) { Coord bin = new Coord(); bin.x = other.getBin(i).x; @@ -174,7 +175,7 @@ public void merge(NumericHistogram other) { // The aggregation buffer already contains a partial histogram. Therefore, we need // to merge histograms using Algorithm #2 from the Ben-Haim and Tom-Tov paper. - ArrayList tmp_bins = new ArrayList(nusedbins + other.nusedbins); + List tmp_bins = new ArrayList<>(nusedbins + other.nusedbins); // Copy all the histogram bins from us and 'other' into an overstuffed histogram for (int i = 0; i < nusedbins; i++) { Coord bin = new Coord(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a577bc005be84..62d930dcd2076 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -56,7 +56,7 @@ import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils import org.apache.spark.util.collection.{Utils => CUtils} @@ -322,8 +322,6 @@ class Analyzer(override val catalogManager: CatalogManager) Seq(ResolveWithCTE) ++ extendedResolutionRules : _*), Batch("Remove TempResolvedColumn", Once, RemoveTempResolvedColumn), - Batch("Apply Char Padding", Once, - ApplyCharTypePadding), Batch("Post-Hoc Resolution", Once, Seq(ResolveCommandsWithIfExists) ++ postHocResolutionRules: _*), @@ -4290,126 +4288,6 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } } -/** - * This rule performs string padding for char type comparison. - * - * When comparing char type column/field with string literal or char type column/field, - * right-pad the shorter one to the longer length. - */ -object ApplyCharTypePadding extends Rule[LogicalPlan] { - - object AttrOrOuterRef { - def unapply(e: Expression): Option[Attribute] = e match { - case a: Attribute => Some(a) - case OuterReference(a: Attribute) => Some(a) - case _ => None - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (SQLConf.get.charVarcharAsString) { - return plan - } - plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) { - case operator => operator.transformExpressionsUpWithPruning( - _.containsAnyPattern(BINARY_COMPARISON, IN)) { - case e if !e.childrenResolved => e - - // String literal is treated as char type when it's compared to a char type column. - // We should pad the shorter one to the longer length. - case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => - b.withNewChildren(newChildren) - }.getOrElse(b) - - case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => - b.withNewChildren(newChildren.reverse) - }.getOrElse(b) - - case i @ In(e @ AttrOrOuterRef(attr), list) - if attr.dataType == StringType && list.forall(_.foldable) => - CharVarcharUtils.getRawType(attr.metadata).flatMap { - case CharType(length) => - val (nulls, literalChars) = - list.map(_.eval().asInstanceOf[UTF8String]).partition(_ == null) - val literalCharLengths = literalChars.map(_.numChars()) - val targetLen = (length +: literalCharLengths).max - Some(i.copy( - value = addPadding(e, length, targetLen), - list = list.zip(literalCharLengths).map { - case (lit, charLength) => addPadding(lit, charLength, targetLen) - } ++ nulls.map(Literal.create(_, StringType)))) - case _ => None - }.getOrElse(i) - - // For char type column or inner field comparison, pad the shorter one to the longer length. - case b @ BinaryComparison(e1 @ AttrOrOuterRef(left), e2 @ AttrOrOuterRef(right)) - // For the same attribute, they must be the same length and no padding is needed. - if !left.semanticEquals(right) => - val outerRefs = (e1, e2) match { - case (_: OuterReference, _: OuterReference) => Seq(left, right) - case (_: OuterReference, _) => Seq(left) - case (_, _: OuterReference) => Seq(right) - case _ => Nil - } - val newChildren = CharVarcharUtils.addPaddingInStringComparison(Seq(left, right)) - if (outerRefs.nonEmpty) { - b.withNewChildren(newChildren.map(_.transform { - case a: Attribute if outerRefs.exists(_.semanticEquals(a)) => OuterReference(a) - })) - } else { - b.withNewChildren(newChildren) - } - - case i @ In(e @ AttrOrOuterRef(attr), list) if list.forall(_.isInstanceOf[Attribute]) => - val newChildren = CharVarcharUtils.addPaddingInStringComparison( - attr +: list.map(_.asInstanceOf[Attribute])) - if (e.isInstanceOf[OuterReference]) { - i.copy( - value = newChildren.head.transform { - case a: Attribute if a.semanticEquals(attr) => OuterReference(a) - }, - list = newChildren.tail) - } else { - i.copy(value = newChildren.head, list = newChildren.tail) - } - } - } - } - - private def padAttrLitCmp( - expr: Expression, - metadata: Metadata, - lit: Expression): Option[Seq[Expression]] = { - if (expr.dataType == StringType) { - CharVarcharUtils.getRawType(metadata).flatMap { - case CharType(length) => - val str = lit.eval().asInstanceOf[UTF8String] - if (str == null) { - None - } else { - val stringLitLen = str.numChars() - if (length < stringLitLen) { - Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) - } else if (length > stringLitLen) { - Some(Seq(expr, StringRPad(lit, Literal(length)))) - } else { - None - } - } - case _ => None - } - } else { - None - } - } - - private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { - if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr - } -} - /** * The rule `ResolveAggregationFunctions` in the main resolution batch creates * [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0ea2cc6308b34..5143b29af377e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -509,10 +509,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { if (!dataTypesAreCompatibleFn(dt1, dt2)) { val errorMessage = s""" - |${operator.nodeName} can only be performed on tables with the compatible + |${operator.nodeName} can only be performed on tables with compatible |column types. The ${ordinalNumber(ci)} column of the |${ordinalNumber(ti + 1)} table is ${dt1.catalogString} type which is not - |compatible with ${dt2.catalogString} at same column of first table + |compatible with ${dt2.catalogString} at the same column of the first table """.stripMargin.replace("\n", " ").trim() failAnalysis(errorMessage + extraHintForAnsiTypeCoercionPlan(operator)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 72622da7eecdf..3f33a2f06ac46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,6 +21,8 @@ import scala.math.{max, min} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.SQLQueryContext @@ -1198,12 +1200,17 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least two arguments") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map("actualNum" -> children.length.toString)) } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got LEAST(${children.map(_.dataType.catalogString).mkString(", ")}).") + DataTypeMismatch( + errorSubClass = "DATA_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> children.map(_.dataType).map(toSQLType).mkString("[", ", ", "]") + ) + ) } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } @@ -1281,12 +1288,17 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least two arguments") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map("actualNum" -> children.length.toString)) } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { - TypeCheckResult.TypeCheckFailure( - s"The expressions should all have the same type," + - s" got GREATEST(${children.map(_.dataType.catalogString).mkString(", ")}).") + DataTypeMismatch( + errorSubClass = "DATA_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> children.map(_.dataType).map(toSQLType).mkString("[", ", ", "]") + ) + ) } else { TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 3380b52832bd8..1e4364b3f4a9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -128,7 +128,11 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { val ctePlan = DeduplicateRelations( Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1) val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) => - Alias(srcAttr, tgtAttr.name)(exprId = tgtAttr.exprId) + if (srcAttr.semanticEquals(tgtAttr)) { + tgtAttr + } else { + Alias(srcAttr, tgtAttr.name)(exprId = tgtAttr.exprId) + } } Project(projectList, ctePlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 2b3566d4b1580..75af56bdee828 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -181,27 +181,39 @@ object CharVarcharUtils extends Logging { } def stringLengthCheck(expr: Expression, dt: DataType): Expression = { + processStringForCharVarchar( + expr, + dt, + charFuncName = Some("charTypeWriteSideCheck"), + varcharFuncName = Some("varcharTypeWriteSideCheck")) + } + + private def processStringForCharVarchar( + expr: Expression, + dt: DataType, + charFuncName: Option[String], + varcharFuncName: Option[String]): Expression = { dt match { - case CharType(length) => + case CharType(length) if charFuncName.isDefined => StaticInvoke( classOf[CharVarcharCodegenUtils], StringType, - "charTypeWriteSideCheck", + charFuncName.get, expr :: Literal(length) :: Nil, returnNullable = false) - case VarcharType(length) => + case VarcharType(length) if varcharFuncName.isDefined => StaticInvoke( classOf[CharVarcharCodegenUtils], StringType, - "varcharTypeWriteSideCheck", + varcharFuncName.get, expr :: Literal(length) :: Nil, returnNullable = false) case StructType(fields) => val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => - Seq(Literal(f.name), - stringLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType)) + Seq(Literal(f.name), processStringForCharVarchar( + GetStructField(expr, i, Some(f.name)), f.dataType, charFuncName, varcharFuncName)) }) if (expr.nullable) { If(IsNull(expr), Literal(null, struct.dataType), struct) @@ -209,24 +221,40 @@ object CharVarcharUtils extends Logging { struct } - case ArrayType(et, containsNull) => stringLengthCheckInArray(expr, et, containsNull) + case ArrayType(et, containsNull) => + processStringForCharVarcharInArray(expr, et, containsNull, charFuncName, varcharFuncName) case MapType(kt, vt, valueContainsNull) => - val newKeys = stringLengthCheckInArray(MapKeys(expr), kt, containsNull = false) - val newValues = stringLengthCheckInArray(MapValues(expr), vt, valueContainsNull) + val newKeys = processStringForCharVarcharInArray( + MapKeys(expr), kt, containsNull = false, charFuncName, varcharFuncName) + val newValues = processStringForCharVarcharInArray( + MapValues(expr), vt, valueContainsNull, charFuncName, varcharFuncName) MapFromArrays(newKeys, newValues) case _ => expr } } - private def stringLengthCheckInArray( - arr: Expression, et: DataType, containsNull: Boolean): Expression = { + private def processStringForCharVarcharInArray( + arr: Expression, + et: DataType, + containsNull: Boolean, + charFuncName: Option[String], + varcharFuncName: Option[String]): Expression = { val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) - val func = LambdaFunction(stringLengthCheck(param, et), Seq(param)) + val func = LambdaFunction( + processStringForCharVarchar(param, et, charFuncName, varcharFuncName), + Seq(param)) ArrayTransform(arr, func) } + def addPaddingForScan(attr: Attribute): Expression = { + getRawType(attr.metadata).map { rawType => + processStringForCharVarchar( + attr, rawType, charFuncName = Some("readSidePadding"), varcharFuncName = None) + }.getOrElse(attr) + } + /** * Return expressions to apply char type padding for the string comparison between the given * attributes. When comparing two char type columns/fields, we need to pad the shorter one to diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 9eb7bb13a8719..0c6aeedfc4acb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2850,14 +2850,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map("ident" -> ident.quoted)) } - def unexpectedTypeOfRelationError(relation: LogicalPlan, tableName: String): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1304", - messageParameters = Map( - "className" -> relation.getClass.getCanonicalName, - "tableName" -> tableName)) - } - def unsupportedTableChangeInJDBCCatalogError(change: TableChange): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1305", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index a3e1b980d1fa7..392b9bf6c727c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.errors import java.io.{FileNotFoundException, IOException} import java.lang.reflect.InvocationTargetException import java.net.{URISyntaxException, URL} -import java.sql.{SQLException, SQLFeatureNotSupportedException} +import java.sql.{SQLFeatureNotSupportedException} import java.time.{DateTimeException, LocalDate} import java.time.temporal.ChronoField import java.util.ConcurrentModificationException @@ -967,39 +967,55 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { def fileLengthExceedsMaxLengthError(status: FileStatus, maxLength: Int): Throwable = { new SparkException( - s"The length of ${status.getPath} is ${status.getLen}, " + - s"which exceeds the max length allowed: ${maxLength}.") + errorClass = "_LEGACY_ERROR_TEMP_2076", + messageParameters = Map( + "path" -> status.getPath.toString(), + "len" -> status.getLen.toString(), + "maxLength" -> maxLength.toString()), + cause = null) } - def unsupportedFieldNameError(fieldName: String): Throwable = { - new RuntimeException(s"Unsupported field name: ${fieldName}") + def unsupportedFieldNameError(fieldName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2077", + messageParameters = Map("fieldName" -> fieldName)) } def cannotSpecifyBothJdbcTableNameAndQueryError( - jdbcTableName: String, jdbcQueryString: String): Throwable = { - new IllegalArgumentException( - s"Both '$jdbcTableName' and '$jdbcQueryString' can not be specified at the same time.") + jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2078", + messageParameters = Map( + "jdbcTableName" -> jdbcTableName, + "jdbcQueryString" -> jdbcQueryString)) } def missingJdbcTableNameAndQueryError( - jdbcTableName: String, jdbcQueryString: String): Throwable = { - new IllegalArgumentException( - s"Option '$jdbcTableName' or '$jdbcQueryString' is required." - ) + jdbcTableName: String, jdbcQueryString: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2079", + messageParameters = Map( + "jdbcTableName" -> jdbcTableName, + "jdbcQueryString" -> jdbcQueryString)) } - def emptyOptionError(optionName: String): Throwable = { - new IllegalArgumentException(s"Option `$optionName` can not be empty.") + def emptyOptionError(optionName: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2080", + messageParameters = Map("optionName" -> optionName)) } - def invalidJdbcTxnIsolationLevelError(jdbcTxnIsolationLevel: String, value: String): Throwable = { - new IllegalArgumentException( - s"Invalid value `$value` for parameter `$jdbcTxnIsolationLevel`. This can be " + - "`NONE`, `READ_UNCOMMITTED`, `READ_COMMITTED`, `REPEATABLE_READ` or `SERIALIZABLE`.") + def invalidJdbcTxnIsolationLevelError( + jdbcTxnIsolationLevel: String, value: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2081", + messageParameters = Map("value" -> value, "jdbcTxnIsolationLevel" -> jdbcTxnIsolationLevel)) } - def cannotGetJdbcTypeError(dt: DataType): Throwable = { - new IllegalArgumentException(s"Can't get JDBC type for ${dt.catalogString}") + def cannotGetJdbcTypeError(dt: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2082", + messageParameters = Map("catalogString" -> dt.catalogString)) } def unrecognizedSqlTypeError(sqlType: Int): Throwable = { @@ -1008,27 +1024,35 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map("typeName" -> sqlType.toString)) } - def unsupportedJdbcTypeError(content: String): Throwable = { - new SQLException(s"Unsupported type $content") + def unsupportedJdbcTypeError(content: String): SparkSQLException = { + new SparkSQLException( + errorClass = "_LEGACY_ERROR_TEMP_2083", + messageParameters = Map("content" -> content)) } - def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): Throwable = { - new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.catalogString} based on binary") + def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2084", + messageParameters = Map("catalogString" -> dt.catalogString)) } - def nestedArraysUnsupportedError(): Throwable = { - new IllegalArgumentException("Nested arrays unsupported") + def nestedArraysUnsupportedError(): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2085", + messageParameters = Map.empty) } - def cannotTranslateNonNullValueForFieldError(pos: Int): Throwable = { - new IllegalArgumentException(s"Can't translate non-null value for field $pos") + def cannotTranslateNonNullValueForFieldError(pos: Int): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2086", + messageParameters = Map("pos" -> pos.toString())) } - def invalidJdbcNumPartitionsError(n: Int, jdbcNumPartitions: String): Throwable = { - new IllegalArgumentException( - s"Invalid value `$n` for parameter `$jdbcNumPartitions` in table writing " + - "via JDBC. The minimum value is 1.") + def invalidJdbcNumPartitionsError( + n: Int, jdbcNumPartitions: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2087", + messageParameters = Map("n" -> n.toString(), "jdbcNumPartitions" -> jdbcNumPartitions)) } def transactionUnsupportedByJdbcServerError(): Throwable = { @@ -1037,204 +1061,314 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map.empty[String, String]) } - def dataTypeUnsupportedYetError(dataType: DataType): Throwable = { - new UnsupportedOperationException(s"$dataType is not supported yet.") + def dataTypeUnsupportedYetError(dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2088", + messageParameters = Map("dataType" -> dataType.toString())) } - def unsupportedOperationForDataTypeError(dataType: DataType): Throwable = { - new UnsupportedOperationException(s"DataType: ${dataType.catalogString}") + def unsupportedOperationForDataTypeError( + dataType: DataType): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2089", + messageParameters = Map("catalogString" -> dataType.catalogString)) } def inputFilterNotFullyConvertibleError(owner: String): Throwable = { - new SparkException(s"The input filter of $owner should be fully convertible.") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2090", + messageParameters = Map("owner" -> owner), + cause = null) } def cannotReadFooterForFileError(file: Path, e: IOException): Throwable = { - new SparkException(s"Could not read footer for file: $file", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2091", + messageParameters = Map("file" -> file.toString()), + cause = e) } def cannotReadFooterForFileError(file: FileStatus, e: RuntimeException): Throwable = { - new IOException(s"Could not read footer for file: $file", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2092", + messageParameters = Map("file" -> file.toString()), + cause = e) } def foundDuplicateFieldInCaseInsensitiveModeError( - requiredFieldName: String, matchedOrcFields: String): Throwable = { - new RuntimeException( - s""" - |Found duplicate field(s) "$requiredFieldName": $matchedOrcFields - |in case-insensitive mode - """.stripMargin.replaceAll("\n", " ")) + requiredFieldName: String, matchedOrcFields: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2093", + messageParameters = Map( + "requiredFieldName" -> requiredFieldName, + "matchedOrcFields" -> matchedOrcFields)) } def foundDuplicateFieldInFieldIdLookupModeError( - requiredId: Int, matchedFields: String): Throwable = { - new RuntimeException( - s""" - |Found duplicate field(s) "$requiredId": $matchedFields - |in id mapping mode - """.stripMargin.replaceAll("\n", " ")) + requiredId: Int, matchedFields: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2094", + messageParameters = Map( + "requiredId" -> requiredId.toString(), + "matchedFields" -> matchedFields)) } def failedToMergeIncompatibleSchemasError( left: StructType, right: StructType, e: Throwable): Throwable = { - new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2095", + messageParameters = Map("left" -> left.toString(), "right" -> right.toString()), + cause = e) } - def ddlUnsupportedTemporarilyError(ddl: String): Throwable = { - new UnsupportedOperationException(s"$ddl is not supported temporarily.") + def ddlUnsupportedTemporarilyError(ddl: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2096", + messageParameters = Map("ddl" -> ddl)) } def executeBroadcastTimeoutError(timeout: Long, ex: Option[TimeoutException]): Throwable = { new SparkException( - s""" - |Could not execute broadcast in $timeout secs. You can increase the timeout - |for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or disable broadcast join - |by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 - """.stripMargin.replaceAll("\n", " "), ex.orNull) + errorClass = "_LEGACY_ERROR_TEMP_2097", + messageParameters = Map( + "timeout" -> timeout.toString(), + "broadcastTimeout" -> toSQLConf(SQLConf.BROADCAST_TIMEOUT.key), + "autoBroadcastJoinThreshold" -> toSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key)), + cause = ex.orNull) } - def cannotCompareCostWithTargetCostError(cost: String): Throwable = { - new IllegalArgumentException(s"Could not compare cost with $cost") + def cannotCompareCostWithTargetCostError(cost: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2098", + messageParameters = Map("cost" -> cost)) } - def unsupportedDataTypeError(dt: String): Throwable = { - new UnsupportedOperationException(s"Unsupported data type: ${dt}") + def unsupportedDataTypeError(dt: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2099", + messageParameters = Map("dt" -> dt)) } def notSupportTypeError(dataType: DataType): Throwable = { - new Exception(s"not support type: $dataType") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2100", + messageParameters = Map("dataType" -> dataType.toString()), + cause = null) } - def notSupportNonPrimitiveTypeError(): Throwable = { - new RuntimeException("Not support non-primitive type now") + def notSupportNonPrimitiveTypeError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2101", + messageParameters = Map.empty) } def unsupportedTypeError(dataType: DataType): Throwable = { - new Exception(s"Unsupported type: ${dataType.catalogString}") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2102", + messageParameters = Map("catalogString" -> dataType.catalogString), + cause = null) } def useDictionaryEncodingWhenDictionaryOverflowError(): Throwable = { - new IllegalStateException( - "Dictionary encoding should not be used because of dictionary overflow.") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2103", + messageParameters = Map.empty, + cause = null) } def endOfIteratorError(): Throwable = { - new NoSuchElementException("End of the iterator") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2104", + messageParameters = Map.empty, + cause = null) } def cannotAllocateMemoryToGrowBytesToBytesMapError(): Throwable = { - new IOException("Could not allocate memory to grow BytesToBytesMap") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2105", + messageParameters = Map.empty, + cause = null) } def cannotAcquireMemoryToBuildLongHashedRelationError(size: Long, got: Long): Throwable = { - new SparkException(s"Can't acquire $size bytes memory to build hash relation, " + - s"got $got bytes") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2106", + messageParameters = Map("size" -> size.toString(), "got" -> got.toString()), + cause = null) } def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { - new SparkOutOfMemoryError("There is not enough memory to build hash map") + new SparkOutOfMemoryError( + "_LEGACY_ERROR_TEMP_2107") } - def rowLargerThan256MUnsupportedError(): Throwable = { - new UnsupportedOperationException("Does not support row that is larger than 256M") + def rowLargerThan256MUnsupportedError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2108", + messageParameters = Map.empty) } - def cannotBuildHashedRelationWithUniqueKeysExceededError(): Throwable = { - new UnsupportedOperationException( - "Cannot build HashedRelation with more than 1/3 billions unique keys") + def cannotBuildHashedRelationWithUniqueKeysExceededError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2109", + messageParameters = Map.empty) } - def cannotBuildHashedRelationLargerThan8GError(): Throwable = { - new UnsupportedOperationException( - "Can not build a HashedRelation that is larger than 8G") + def cannotBuildHashedRelationLargerThan8GError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2110", + messageParameters = Map.empty) } def failedToPushRowIntoRowQueueError(rowQueue: String): Throwable = { - new SparkException(s"failed to push a row into $rowQueue") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2111", + messageParameters = Map("rowQueue" -> rowQueue), + cause = null) } - def unexpectedWindowFunctionFrameError(frame: String): Throwable = { - new RuntimeException(s"Unexpected window function frame $frame.") + def unexpectedWindowFunctionFrameError(frame: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2112", + messageParameters = Map("frame" -> frame)) } def cannotParseStatisticAsPercentileError( - stats: String, e: NumberFormatException): Throwable = { - new IllegalArgumentException(s"Unable to parse $stats as a percentile", e) + stats: String, e: NumberFormatException): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2113", + messageParameters = Map("stats" -> stats)) } - def statisticNotRecognizedError(stats: String): Throwable = { - new IllegalArgumentException(s"$stats is not a recognised statistic") + def statisticNotRecognizedError(stats: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2114", + messageParameters = Map("stats" -> stats)) } - def unknownColumnError(unknownColumn: String): Throwable = { - new IllegalArgumentException(s"Unknown column: $unknownColumn") + def unknownColumnError(unknownColumn: String): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2115", + messageParameters = Map("unknownColumn" -> unknownColumn.toString())) } - def unexpectedAccumulableUpdateValueError(o: Any): Throwable = { - new IllegalArgumentException(s"Unexpected: $o") + def unexpectedAccumulableUpdateValueError(o: Any): SparkIllegalArgumentException = { + new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_2116", + messageParameters = Map("o" -> o.toString())) } - def unscaledValueTooLargeForPrecisionError(): Throwable = { - new ArithmeticException("Unscaled value too large for precision. " + - s"If necessary set ${SQLConf.ANSI_ENABLED.key} to false to bypass this error.") + def unscaledValueTooLargeForPrecisionError(): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "_LEGACY_ERROR_TEMP_2117", + messageParameters = Map("ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), + context = Array.empty, + summary = "") } - def decimalPrecisionExceedsMaxPrecisionError(precision: Int, maxPrecision: Int): Throwable = { - new ArithmeticException( - s"Decimal precision $precision exceeds max precision $maxPrecision") + def decimalPrecisionExceedsMaxPrecisionError( + precision: Int, maxPrecision: Int): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "_LEGACY_ERROR_TEMP_2118", + messageParameters = Map( + "precision" -> precision.toString(), + "maxPrecision" -> maxPrecision.toString()), + context = Array.empty, + summary = "") } - def outOfDecimalTypeRangeError(str: UTF8String): Throwable = { - new ArithmeticException(s"out of decimal type range: $str") + def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = { + new SparkArithmeticException( + errorClass = "_LEGACY_ERROR_TEMP_2119", + messageParameters = Map("str" -> str.toString()), + context = Array.empty, + summary = "") } - def unsupportedArrayTypeError(clazz: Class[_]): Throwable = { - new RuntimeException(s"Do not support array of type $clazz.") + def unsupportedArrayTypeError(clazz: Class[_]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2120", + messageParameters = Map("clazz" -> clazz.toString())) } - def unsupportedJavaTypeError(clazz: Class[_]): Throwable = { - new RuntimeException(s"Do not support type $clazz.") + def unsupportedJavaTypeError(clazz: Class[_]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2121", + messageParameters = Map("clazz" -> clazz.toString())) } - def failedParsingStructTypeError(raw: String): Throwable = { - new RuntimeException(s"Failed parsing ${StructType.simpleString}: $raw") + def failedParsingStructTypeError(raw: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2122", + messageParameters = Map("simpleString" -> StructType.simpleString, "raw" -> raw)) } def failedMergingFieldsError(leftName: String, rightName: String, e: Throwable): Throwable = { - new SparkException(s"Failed to merge fields '$leftName' and '$rightName'. ${e.getMessage}") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2123", + messageParameters = Map( + "leftName" -> leftName, + "rightName" -> rightName, + "message" -> e.getMessage), + cause = null) } def cannotMergeDecimalTypesWithIncompatibleScaleError( leftScale: Int, rightScale: Int): Throwable = { - new SparkException("Failed to merge decimal types with incompatible " + - s"scale $leftScale and $rightScale") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2124", + messageParameters = Map( + "leftScale" -> leftScale.toString(), + "rightScale" -> rightScale.toString()), + cause = null) } def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { - new SparkException(s"Failed to merge incompatible data types ${left.catalogString}" + - s" and ${right.catalogString}") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2125", + messageParameters = Map( + "leftCatalogString" -> left.catalogString, + "rightCatalogString" -> right.catalogString), + cause = null) } - def exceedMapSizeLimitError(size: Int): Throwable = { - new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " + - s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + def exceedMapSizeLimitError(size: Int): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2126", + messageParameters = Map( + "size" -> size.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString())) } - def duplicateMapKeyFoundError(key: Any): Throwable = { - new RuntimeException(s"Duplicate map key $key was found, please check the input " + - "data. If you want to remove the duplicated keys, you can set " + - s"${SQLConf.MAP_KEY_DEDUP_POLICY.key} to ${SQLConf.MapKeyDedupPolicy.LAST_WIN} so that " + - "the key inserted at last takes precedence.") + def duplicateMapKeyFoundError(key: Any): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2127", + messageParameters = Map( + "key" -> key.toString(), + "mapKeyDedupPolicy" -> toSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key), + "lastWin" -> toSQLConf(SQLConf.MapKeyDedupPolicy.LAST_WIN.toString()))) } - def mapDataKeyArrayLengthDiffersFromValueArrayLengthError(): Throwable = { - new RuntimeException("The key array and value array of MapData must have the same length.") + def mapDataKeyArrayLengthDiffersFromValueArrayLengthError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2128", + messageParameters = Map.empty) } def fieldDiffersFromDerivedLocalDateError( - field: ChronoField, actual: Int, expected: Int, candidate: LocalDate): Throwable = { - new DateTimeException(s"Conflict found: Field $field $actual differs from" + - s" $field $expected derived from $candidate") + field: ChronoField, + actual: Int, + expected: Int, + candidate: LocalDate): SparkDateTimeException = { + new SparkDateTimeException( + errorClass = "_LEGACY_ERROR_TEMP_2129", + messageParameters = Map( + "field" -> field.toString(), + "actual" -> actual.toString(), + "expected" -> expected.toString(), + "candidate" -> candidate.toString()), + context = Array.empty, + summary = "") } def failToParseDateTimeInNewParserError(s: String, e: Throwable): Throwable = { @@ -1255,15 +1389,18 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { e) } - def failToRecognizePatternError(pattern: String, e: Throwable): Throwable = { - new RuntimeException(s"Fail to recognize '$pattern' pattern in the" + - " DateTimeFormatter. You can form a valid datetime pattern" + - " with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html", - e) + def failToRecognizePatternError(pattern: String, e: Throwable): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2130", + messageParameters = Map("pattern" -> pattern), + cause = e) } def registeringStreamingQueryListenerError(e: Exception): Throwable = { - new SparkException("Exception when registering StreamingQueryListener", e) + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2131", + messageParameters = Map.empty, + cause = e) } def concurrentQueryInstanceError(): Throwable = { @@ -1272,105 +1409,149 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { messageParameters = Map.empty[String, String]) } - def cannotParseJsonArraysAsStructsError(): Throwable = { - new RuntimeException("Parsing JSON arrays as structs is forbidden.") + def cannotParseJsonArraysAsStructsError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2132", + messageParameters = Map.empty) } def cannotParseStringAsDataTypeError(parser: JsonParser, token: JsonToken, dataType: DataType) - : Throwable = { - new RuntimeException( - s"Cannot parse field name ${parser.getCurrentName}, " + - s"field value ${parser.getText}, " + - s"[$token] as target spark data type [$dataType].") + : SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2133", + messageParameters = Map( + "fieldName" -> parser.getCurrentName, + "fieldValue" -> parser.getText, + "token" -> token.toString(), + "dataType" -> dataType.toString())) } def cannotParseStringAsDataTypeError(pattern: String, value: String, dataType: DataType) - : Throwable = { - new RuntimeException( - s"Cannot parse field value ${toSQLValue(value, StringType)} " + - s"for pattern ${toSQLValue(pattern, StringType)} " + - s"as target spark data type [$dataType].") + : SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2134", + messageParameters = Map( + "value" -> toSQLValue(value, StringType), + "pattern" -> toSQLValue(pattern, StringType), + "dataType" -> dataType.toString())) } - def failToParseEmptyStringForDataTypeError(dataType: DataType): Throwable = { - new RuntimeException( - s"Failed to parse an empty string for data type ${dataType.catalogString}") + def failToParseEmptyStringForDataTypeError(dataType: DataType): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2135", + messageParameters = Map( + "dataType" -> dataType.catalogString)) } def failToParseValueForDataTypeError(parser: JsonParser, token: JsonToken, dataType: DataType) - : Throwable = { - new RuntimeException( - s"Failed to parse field name ${parser.getCurrentName}, " + - s"field value ${parser.getText}, " + - s"[$token] to target spark data type [$dataType].") + : SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2136", + messageParameters = Map( + "fieldName" -> parser.getCurrentName.toString(), + "fieldValue" -> parser.getText.toString(), + "token" -> token.toString(), + "dataType" -> dataType.toString())) } - def rootConverterReturnNullError(): Throwable = { - new RuntimeException("Root converter returned null") + def rootConverterReturnNullError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2137", + messageParameters = Map.empty) } - def cannotHaveCircularReferencesInBeanClassError(clazz: Class[_]): Throwable = { - new UnsupportedOperationException( - "Cannot have circular references in bean class, but got the circular reference " + - s"of class $clazz") + def cannotHaveCircularReferencesInBeanClassError( + clazz: Class[_]): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2138", + messageParameters = Map("clazz" -> clazz.toString())) } - def cannotHaveCircularReferencesInClassError(t: String): Throwable = { - new UnsupportedOperationException( - s"cannot have circular references in class, but got the circular reference of class $t") + def cannotHaveCircularReferencesInClassError(t: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2139", + messageParameters = Map("t" -> t)) } def cannotUseInvalidJavaIdentifierAsFieldNameError( - fieldName: String, walkedTypePath: WalkedTypePath): Throwable = { - new UnsupportedOperationException(s"`$fieldName` is not a valid identifier of " + - s"Java and cannot be used as field name\n$walkedTypePath") + fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2140", + messageParameters = Map( + "fieldName" -> fieldName, + "walkedTypePath" -> walkedTypePath.toString())) } def cannotFindEncoderForTypeError( - tpe: String, walkedTypePath: WalkedTypePath): Throwable = { - new UnsupportedOperationException(s"No Encoder found for $tpe\n$walkedTypePath") + tpe: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2141", + messageParameters = Map( + "tpe" -> tpe, + "walkedTypePath" -> walkedTypePath.toString())) } - def attributesForTypeUnsupportedError(schema: Schema): Throwable = { - new UnsupportedOperationException(s"Attributes for type $schema is not supported") + def attributesForTypeUnsupportedError(schema: Schema): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2142", + messageParameters = Map( + "schema" -> schema.toString())) } - def schemaForTypeUnsupportedError(tpe: String): Throwable = { - new UnsupportedOperationException(s"Schema for type $tpe is not supported") + def schemaForTypeUnsupportedError(tpe: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2143", + messageParameters = Map( + "tpe" -> tpe)) } - def cannotFindConstructorForTypeError(tpe: String): Throwable = { - new UnsupportedOperationException( - s""" - |Unable to find constructor for $tpe. - |This could happen if $tpe is an interface, or a trait without companion object - |constructor. - """.stripMargin.replaceAll("\n", " ")) + def cannotFindConstructorForTypeError(tpe: String): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2144", + messageParameters = Map( + "tpe" -> tpe)) } - def paramExceedOneCharError(paramName: String): Throwable = { - new RuntimeException(s"$paramName cannot be more than one character") + def paramExceedOneCharError(paramName: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2145", + messageParameters = Map( + "paramName" -> paramName)) } - def paramIsNotIntegerError(paramName: String, value: String): Throwable = { - new RuntimeException(s"$paramName should be an integer. Found ${toSQLValue(value, StringType)}") + def paramIsNotIntegerError(paramName: String, value: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2146", + messageParameters = Map( + "paramName" -> paramName, + "value" -> value)) } def paramIsNotBooleanValueError(paramName: String): Throwable = { - new Exception(s"$paramName flag can be true or false") + new SparkException( + errorClass = "_LEGACY_ERROR_TEMP_2147", + messageParameters = Map( + "paramName" -> paramName), + cause = null) } - def foundNullValueForNotNullableFieldError(name: String): Throwable = { - new RuntimeException(s"null value found but field $name is not nullable.") + def foundNullValueForNotNullableFieldError(name: String): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2148", + messageParameters = Map( + "name" -> name)) } - def malformedCSVRecordError(): Throwable = { - new RuntimeException("Malformed CSV record") + def malformedCSVRecordError(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "_LEGACY_ERROR_TEMP_2149", + messageParameters = Map.empty) } - def elementsOfTupleExceedLimitError(): Throwable = { - new UnsupportedOperationException("Due to Scala's limited support of tuple, " + - "tuple with more than 22 elements are not supported.") + def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = { + new SparkUnsupportedOperationException( + errorClass = "_LEGACY_ERROR_TEMP_2150", + messageParameters = Map.empty) } def expressionDecodingError(e: Exception, expressions: Seq[Expression]): Throwable = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 376bcece3c61c..2f96209222b2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -412,6 +412,21 @@ object SQLConf { .longConf .createWithDefault(67108864L) + val RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED = + buildConf("spark.sql.optimizer.runtime.rowLevelOperationGroupFilter.enabled") + .doc("Enables runtime group filtering for group-based row-level operations. " + + "Data sources that replace groups of data (e.g. files, partitions) may prune entire " + + "groups using provided data source filters when planning a row-level operation scan. " + + "However, such filtering is limited as not all expressions can be converted into data " + + "source filters and some expressions can only be evaluated by Spark (e.g. subqueries). " + + "Since rewriting groups is expensive, Spark can execute a query at runtime to find what " + + "records match the condition of the row-level operation. The information about matching " + + "records will be passed back to the row-level operation scan, allowing data sources to " + + "discard groups that don't have to be rewritten.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val PLANNED_WRITE_ENABLED = buildConf("spark.sql.optimizer.plannedWrite.enabled") .internal() .doc("When set to true, Spark optimizer will add logical sort operators to V1 write commands " + @@ -708,9 +723,9 @@ object SQLConf { "multiplying the median partition size and also larger than " + "'spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes'") .version("3.0.0") - .intConf + .doubleConf .checkValue(_ >= 0, "The skew factor cannot be negative.") - .createWithDefault(5) + .createWithDefault(5.0) val SKEW_JOIN_SKEWED_PARTITION_THRESHOLD = buildConf("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes") @@ -2909,7 +2924,15 @@ object SQLConf { .booleanConf .createWithDefault(sys.env.get("SPARK_ANSI_SQL_MODE").contains("true")) - val DOUBLE_QUOTED_IDENTIFIERS = buildConf("spark.sql.ansi.double_quoted_identifiers") + val ENFORCE_RESERVED_KEYWORDS = buildConf("spark.sql.ansi.enforceReservedKeywords") + .doc(s"When true and '${ANSI_ENABLED.key}' is true, the Spark SQL parser enforces the ANSI " + + "reserved keywords and forbids SQL queries that use reserved keywords as alias names " + + "and/or identifiers for table, view, function, etc.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + + val DOUBLE_QUOTED_IDENTIFIERS = buildConf("spark.sql.ansi.doubleQuotedIdentifiers") .doc("When true, Spark SQL reads literals enclosed in double quoted (\") as identifiers. " + "When false they are read as string literals.") .version("3.4.0") @@ -2964,14 +2987,6 @@ object SQLConf { .booleanConf .createWithDefault(false) - val ENFORCE_RESERVED_KEYWORDS = buildConf("spark.sql.ansi.enforceReservedKeywords") - .doc(s"When true and '${ANSI_ENABLED.key}' is true, the Spark SQL parser enforces the ANSI " + - "reserved keywords and forbids SQL queries that use reserved keywords as alias names " + - "and/or identifiers for table, view, function, etc.") - .version("3.3.0") - .booleanConf - .createWithDefault(false) - val SORT_BEFORE_REPARTITION = buildConf("spark.sql.execution.sortBeforeRepartition") .internal() @@ -3805,6 +3820,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val READ_SIDE_CHAR_PADDING = buildConf("spark.sql.readSideCharPadding") + .doc("When true, Spark applies string padding when reading CHAR type columns/fields, " + + "in addition to the write-side padding. This config is true by default to better enforce " + + "CHAR type semantic in cases such as external tables.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val CLI_PRINT_HEADER = buildConf("spark.sql.cli.print.header") .doc("When set to true, spark-sql CLI prints the names of the columns in query output.") @@ -4091,6 +4114,9 @@ class SQLConf extends Serializable with Logging { def runtimeFilterCreationSideThreshold: Long = getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD) + def runtimeRowLevelOperationGroupFilterEnabled: Boolean = + getConf(RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) @@ -4592,7 +4618,7 @@ class SQLConf extends Serializable with Logging { def enforceReservedKeywords: Boolean = ansiEnabled && getConf(ENFORCE_RESERVED_KEYWORDS) - def doubleQuotedIdentifiers: Boolean = getConf(DOUBLE_QUOTED_IDENTIFIERS) + def doubleQuotedIdentifiers: Boolean = ansiEnabled && getConf(DOUBLE_QUOTED_IDENTIFIERS) def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match { case "TIMESTAMP_LTZ" => @@ -4695,6 +4721,8 @@ class SQLConf extends Serializable with Logging { def charVarcharAsString: Boolean = getConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING) + def readSideCharPadding: Boolean = getConf(SQLConf.READ_SIDE_CHAR_PADDING) + def cliPrintHeader: Boolean = getConf(SQLConf.CLI_PRINT_HEADER) def legacyIntervalEnabled: Boolean = getConf(LEGACY_INTERVAL_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d6a63dd0d0bf5..ac823183ce9da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -398,32 +398,32 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "union with incompatible column types", testRelation.union(nestedRelation), - "union" :: "the compatible column types" :: Nil) + "union" :: "compatible column types" :: Nil) errorTest( "union with a incompatible column type and compatible column types", testRelation3.union(testRelation4), - "union" :: "the compatible column types" :: "map" :: "decimal" :: Nil) + "union" :: "compatible column types" :: "map" :: "decimal" :: Nil) errorTest( "intersect with incompatible column types", testRelation.intersect(nestedRelation, isAll = false), - "intersect" :: "the compatible column types" :: Nil) + "intersect" :: "compatible column types" :: Nil) errorTest( "intersect with a incompatible column type and compatible column types", testRelation3.intersect(testRelation4, isAll = false), - "intersect" :: "the compatible column types" :: "map" :: "decimal" :: Nil) + "intersect" :: "compatible column types" :: "map" :: "decimal" :: Nil) errorTest( "except with incompatible column types", testRelation.except(nestedRelation, isAll = false), - "except" :: "the compatible column types" :: Nil) + "except" :: "compatible column types" :: Nil) errorTest( "except with a incompatible column type and compatible column types", testRelation3.except(testRelation4, isAll = false), - "except" :: "the compatible column types" :: "map" :: "decimal" :: Nil) + "except" :: "compatible column types" :: "map" :: "decimal" :: Nil) errorClassTest( "SPARK-9955: correct error message for aggregate", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 095eca6499256..f74cdab55443a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -926,29 +926,29 @@ class AnalysisSuite extends AnalysisTest with Matchers { val r5 = Intersect(firstTable, secondTable, isAll = false) assertAnalysisError(r1, - Seq("Union can only be performed on tables with the compatible column types. " + + Seq("Union can only be performed on tables with compatible column types. " + "The second column of the second table is timestamp type which is not compatible " + - "with double at same column of first table")) + "with double at the same column of the first table")) assertAnalysisError(r2, - Seq("Union can only be performed on tables with the compatible column types. " + + Seq("Union can only be performed on tables with compatible column types. " + "The third column of the second table is timestamp type which is not compatible " + - "with int at same column of first table")) + "with int at the same column of the first table")) assertAnalysisError(r3, - Seq("Union can only be performed on tables with the compatible column types. " + + Seq("Union can only be performed on tables with compatible column types. " + "The 4th column of the second table is timestamp type which is not compatible " + - "with float at same column of first table")) + "with float at the same column of the first table")) assertAnalysisError(r4, - Seq("Except can only be performed on tables with the compatible column types. " + + Seq("Except can only be performed on tables with compatible column types. " + "The second column of the second table is timestamp type which is not compatible " + - "with double at same column of first table")) + "with double at the same column of the first table")) assertAnalysisError(r5, - Seq("Intersect can only be performed on tables with the compatible column types. " + + Seq("Intersect can only be performed on tables with compatible column types. " + "The second column of the second table is timestamp type which is not compatible " + - "with double at same column of first table")) + "with double at the same column of the first table")) } test("SPARK-31975: Throw user facing error when use WindowFunction directly") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 633f9a648157d..f3b102d1feb4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -83,6 +83,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer parameters = messageParameters) } + def assertErrorForWrongNumParameters( + expr: Expression, messageParameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] { + assertSuccess(expr) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = messageParameters) + } + def assertForWrongType(expr: Expression, messageParameters: Map[String, String]): Unit = { checkError( exception = intercept[AnalysisException] { @@ -526,9 +536,24 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer test("check types for Greatest/Least") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - assertError(operator(Seq($"booleanField")), "requires at least two arguments") - assertError(operator(Seq($"intField", $"stringField")), - "should all have the same type") + val expr1 = operator(Seq($"booleanField")) + assertErrorForWrongNumParameters( + expr = expr1, + messageParameters = Map( + "sqlExpr" -> toSQLExpr(expr1), + "actualNum" -> "1") + ) + + val expr2 = operator(Seq($"intField", $"stringField")) + assertErrorForDataDifferingTypes( + expr = expr2, + messageParameters = Map( + "sqlExpr" -> toSQLExpr(expr2), + "functionName" -> toSQLId(expr2.prettyName), + "dataType" -> "[\"INT\", \"STRING\"]" + ) + ) + val expr3 = operator(Seq($"mapField", $"mapField")) assertErrorForOrderingTypes( expr = expr3, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 63862ee3553ec..e21793ab506c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -23,7 +23,7 @@ import java.time.temporal.ChronoUnit import org.apache.spark.{SparkArithmeticException, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin @@ -503,10 +503,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) // Type checking error - assert( - Least(Seq(Literal(1), Literal("1"))).checkInputDataTypes() == - TypeCheckFailure("The expressions should all have the same type, " + - "got LEAST(int, string).")) + Least(Seq(Literal(1), Literal("1"))).checkInputDataTypes() match { + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "DATA_DIFF_TYPES") + assert(messageParameters === Map( + "functionName" -> "`least`", + "dataType" -> "[\"INT\", \"STRING\"]")) + } DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) @@ -561,10 +564,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) // Type checking error - assert( - Greatest(Seq(Literal(1), Literal("1"))).checkInputDataTypes() == - TypeCheckFailure("The expressions should all have the same type, " + - "got GREATEST(int, string).")) + Greatest(Seq(Literal(1), Literal("1"))).checkInputDataTypes() match { + case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => + assert(errorSubClass == "DATA_DIFF_TYPES") + assert(messageParameters === Map( + "functionName" -> "`greatest`", + "dataType" -> "[\"INT\", \"STRING\"]")) + } DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index b99fd089f0ae4..b48a950d9d557 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -16,8 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.SparkThrowable import org.apache.spark.sql.catalyst.analysis.AnalysisTest -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * Test various parser errors. @@ -27,198 +27,262 @@ class ErrorParserSuite extends AnalysisTest { import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ - private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - assert(parsePlan(sqlCommand) == plan) - } - - private def interceptImpl(sql: String, messages: String*)( - line: Option[Int] = None, - startPosition: Option[Int] = None, - stopPosition: Option[Int] = None, - errorClass: Option[String] = None): Unit = { - val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) - - // Check messages. - val error = e.getMessage - messages.foreach { message => - assert(error.contains(message)) - } - - // Check position. - if (line.isDefined) { - assert(line.isDefined && startPosition.isDefined && stopPosition.isDefined) - assert(e.line.isDefined) - assert(e.line.get === line.get) - assert(e.startPosition.isDefined) - assert(e.startPosition.get === startPosition.get) - assert(e.stop.startPosition.isDefined) - assert(e.stop.startPosition.get === stopPosition.get) - } - - // Check error class. - if (errorClass.isDefined) { - assert(e.getErrorClass == errorClass.get) - } - } - - def intercept(sqlCommand: String, errorClass: Option[String], messages: String*): Unit = { - interceptImpl(sqlCommand, messages: _*)(errorClass = errorClass) - } - - def intercept( - sql: String, line: Int, startPosition: Int, stopPosition: Int, messages: String*): Unit = { - interceptImpl(sql, messages: _*)(Some(line), Some(startPosition), Some(stopPosition)) - } - - def intercept(sql: String, errorClass: String, line: Int, startPosition: Int, stopPosition: Int, - messages: String*): Unit = { - interceptImpl(sql, messages: _*)( - Some(line), Some(startPosition), Some(stopPosition), Some(errorClass)) + def parseException(sql: String): SparkThrowable = { + intercept[ParseException](CatalystSqlParser.parsePlan(sql)) } test("semantic errors") { - intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, 11, - "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", - "^^^") + checkError( + exception = parseException("select *\nfrom r\norder by q\ncluster by q"), + errorClass = "_LEGACY_ERROR_TEMP_0011", + parameters = Map.empty, + context = ExpectedContext(fragment = "order by q\ncluster by q", start = 16, stop = 38)) } test("hyphen in identifier - DDL tests") { - val msg = "unquoted identifier" - intercept("USE test-test", 1, 8, 9, msg + " test-test") - intercept("CREATE DATABASE IF NOT EXISTS my-database", 1, 32, 33, msg + " my-database") - intercept( + checkError( + exception = parseException("USE test-test"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-test")) + checkError( + exception = parseException("CREATE DATABASE IF NOT EXISTS my-database"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "my-database")) + checkError( + exception = parseException( """ |ALTER DATABASE my-database - |SET DBPROPERTIES ('p1'='v1')""".stripMargin, 2, 17, 18, msg + " my-database") - intercept("DROP DATABASE my-database", 1, 16, 17, msg + " my-database") - intercept( - """ - |ALTER TABLE t - |CHANGE COLUMN - |test-col TYPE BIGINT - """.stripMargin, 4, 4, 5, msg + " test-col") - intercept( - """ - |ALTER TABLE t - |RENAME COLUMN - |test-col TO test - """.stripMargin, 4, 4, 5, msg + " test-col") - intercept( - """ - |ALTER TABLE t - |RENAME COLUMN - |test TO test-col - """.stripMargin, 4, 12, 13, msg + " test-col") - intercept( - """ - |ALTER TABLE t - |DROP COLUMN - |test-col, test - """.stripMargin, 4, 4, 5, msg + " test-col") - intercept("CREATE TABLE test (attri-bute INT)", 1, 24, 25, msg + " attri-bute") - intercept("CREATE FUNCTION test-func as org.test.func", 1, 20, 21, msg + " test-func") - intercept("DROP FUNCTION test-func as org.test.func", 1, 18, 19, msg + " test-func") - intercept("SHOW FUNCTIONS LIKE test-func", 1, 24, 25, msg + " test-func") - intercept( - """ - |CREATE TABLE IF NOT EXISTS mydb.page-view - |USING parquet - |COMMENT 'This is the staging page view table' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin, 2, 36, 37, msg + " page-view") - intercept( - """ - |CREATE TABLE IF NOT EXISTS tab - |USING test-provider - |AS SELECT * FROM src""".stripMargin, 3, 10, 11, msg + " test-provider") - intercept("SHOW TABLES IN hyphen-database", 1, 21, 22, msg + " hyphen-database") - intercept("SHOW TABLE EXTENDED IN hyphen-db LIKE \"str\"", 1, 29, 30, msg + " hyphen-db") - intercept("SHOW COLUMNS IN t FROM test-db", 1, 27, 28, msg + " test-db") - intercept("DESC SCHEMA EXTENDED test-db", 1, 25, 26, msg + " test-db") - intercept("ANALYZE TABLE test-table PARTITION (part1)", 1, 18, 19, msg + " test-table") - intercept("LOAD DATA INPATH \"path\" INTO TABLE my-tab", 1, 37, 38, msg + " my-tab") + |SET DBPROPERTIES ('p1'='v1')""".stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "my-database")) + checkError( + exception = parseException("DROP DATABASE my-database"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "my-database")) + checkError( + exception = parseException( + """ + |ALTER TABLE t + |CHANGE COLUMN + |test-col TYPE BIGINT + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-col")) + checkError( + exception = parseException( + """ + |ALTER TABLE t + |RENAME COLUMN + |test-col TO test + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-col")) + checkError( + exception = parseException( + """ + |ALTER TABLE t + |RENAME COLUMN + |test TO test-col + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-col")) + checkError( + exception = parseException( + """ + |ALTER TABLE t + |DROP COLUMN + |test-col, test + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-col")) + checkError( + exception = parseException("CREATE TABLE test (attri-bute INT)"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "attri-bute")) + checkError( + exception = parseException("CREATE FUNCTION test-func as org.test.func"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-func")) + checkError( + exception = parseException("DROP FUNCTION test-func as org.test.func"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-func")) + checkError( + exception = parseException("SHOW FUNCTIONS LIKE test-func"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-func")) + checkError( + exception = parseException( + """ + |CREATE TABLE IF NOT EXISTS mydb.page-view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "page-view")) + checkError( + exception = parseException( + """ + |CREATE TABLE IF NOT EXISTS tab + |USING test-provider + |AS SELECT * FROM src""".stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-provider")) + checkError( + exception = parseException("SHOW TABLES IN hyphen-database"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "hyphen-database")) + checkError( + exception = parseException("SHOW TABLE EXTENDED IN hyphen-db LIKE \"str\""), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "hyphen-db")) + checkError( + exception = parseException("SHOW COLUMNS IN t FROM test-db"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-db")) + checkError( + exception = parseException("DESC SCHEMA EXTENDED test-db"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-db")) + checkError( + exception = parseException("ANALYZE TABLE test-table PARTITION (part1)"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-table")) + checkError( + exception = parseException("LOAD DATA INPATH \"path\" INTO TABLE my-tab"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "my-tab")) } test("hyphen in identifier - DML tests") { - val msg = "unquoted identifier" // dml tests - intercept("SELECT * FROM table-with-hyphen", 1, 19, 25, msg + " table-with-hyphen") + checkError( + exception = parseException("SELECT * FROM table-with-hyphen"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "table-with-hyphen")) // special test case: minus in expression shouldn't be treated as hyphen in identifiers - intercept("SELECT a-b FROM table-with-hyphen", 1, 21, 27, msg + " table-with-hyphen") - intercept("SELECT a-b AS a-b FROM t", 1, 15, 16, msg + " a-b") - intercept("SELECT a-b FROM table-hyphen WHERE a-b = 0", 1, 21, 22, msg + " table-hyphen") - intercept("SELECT (a - test_func(b-c)) FROM test-table", 1, 37, 38, msg + " test-table") - intercept("WITH a-b AS (SELECT 1 FROM s) SELECT * FROM s;", 1, 6, 7, msg + " a-b") - intercept( - """ - |SELECT a, b - |FROM t1 JOIN t2 - |USING (a, b, at-tr) - """.stripMargin, 4, 15, 16, msg + " at-tr" - ) - intercept( - """ - |SELECT product, category, dense_rank() - |OVER (PARTITION BY category ORDER BY revenue DESC) as hyphen-rank - |FROM productRevenue - """.stripMargin, 3, 60, 61, msg + " hyphen-rank" - ) - intercept( - """ - |SELECT a, b - |FROM grammar-breaker - |WHERE a-b > 10 - |GROUP BY fake-breaker - |ORDER BY c - """.stripMargin, 3, 12, 13, msg + " grammar-breaker") - assertEqual( + checkError( + exception = parseException("SELECT a-b FROM table-with-hyphen"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "table-with-hyphen")) + checkError( + exception = parseException("SELECT a-b AS a-b FROM t"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "a-b")) + checkError( + exception = parseException("SELECT a-b FROM table-hyphen WHERE a-b = 0"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "table-hyphen")) + checkError( + exception = parseException("SELECT (a - test_func(b-c)) FROM test-table"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-table")) + checkError( + exception = parseException("WITH a-b AS (SELECT 1 FROM s) SELECT * FROM s;"), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "a-b")) + checkError( + exception = parseException( + """ + |SELECT a, b + |FROM t1 JOIN t2 + |USING (a, b, at-tr) + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "at-tr")) + checkError( + exception = parseException( + """ + |SELECT product, category, dense_rank() + |OVER (PARTITION BY category ORDER BY revenue DESC) as hyphen-rank + |FROM productRevenue + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "hyphen-rank")) + checkError( + exception = parseException( + """ + |SELECT a, b + |FROM grammar-breaker + |WHERE a-b > 10 + |GROUP BY fake-breaker + |ORDER BY c + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "grammar-breaker")) + assert(parsePlan( """ |SELECT a, b |FROM t |WHERE a-b > 10 |GROUP BY fake-breaker |ORDER BY c - """.stripMargin, + """.stripMargin) === table("t") .where($"a" - $"b" > 10) .groupBy($"fake" - $"breaker")($"a", $"b") .orderBy($"c".asc)) - intercept( - """ - |SELECT * FROM tab - |WINDOW hyphen-window AS - | (PARTITION BY a, b ORDER BY c rows BETWEEN 1 PRECEDING AND 1 FOLLOWING) - """.stripMargin, 3, 13, 14, msg + " hyphen-window") - intercept( - """ - |SELECT * FROM tab - |WINDOW window_ref AS window-ref - """.stripMargin, 3, 27, 28, msg + " window-ref") - intercept( - """ - |SELECT tb.* - |FROM t-a INNER JOIN tb - |ON ta.a = tb.a AND ta.tag = tb.tag - """.stripMargin, 3, 6, 7, msg + " t-a") - intercept( - """ - |FROM test-table - |SELECT a - |SELECT b - """.stripMargin, 2, 9, 10, msg + " test-table") + checkError( + exception = parseException( + """ + |SELECT * FROM tab + |WINDOW hyphen-window AS + | (PARTITION BY a, b ORDER BY c rows BETWEEN 1 PRECEDING AND 1 FOLLOWING) + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "hyphen-window")) + checkError( + exception = parseException( + """ + |SELECT * FROM tab + |WINDOW window_ref AS window-ref + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "window-ref")) + checkError( + exception = parseException( + """ + |SELECT tb.* + |FROM t-a INNER JOIN tb + |ON ta.a = tb.a AND ta.tag = tb.tag + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "t-a")) + checkError( + exception = parseException( + """ + |FROM test-table + |SELECT a + |SELECT b + """.stripMargin), + errorClass = "_LEGACY_ERROR_TEMP_0040", + parameters = Map("ident" -> "test-table")) } test("datatype not supported") { // general bad types - intercept("SELECT cast(1 as badtype)", 1, 17, 17, "DataType badtype is not supported.") - + checkError( + exception = parseException("SELECT cast(1 as badtype)"), + errorClass = "_LEGACY_ERROR_TEMP_0030", + parameters = Map("dataType" -> "badtype"), + context = ExpectedContext(fragment = "badtype", start = 17, stop = 23)) // special handling on char and varchar - intercept("SELECT cast('a' as CHAR)", "PARSE_CHAR_MISSING_LENGTH", 1, 19, 19, - "DataType \"CHAR\" requires a length parameter") - intercept("SELECT cast('a' as Varchar)", "PARSE_CHAR_MISSING_LENGTH", 1, 19, 19, - "DataType \"VARCHAR\" requires a length parameter") - intercept("SELECT cast('a' as Character)", "PARSE_CHAR_MISSING_LENGTH", 1, 19, 19, - "DataType \"CHARACTER\" requires a length parameter") + checkError( + exception = parseException("SELECT cast('a' as CHAR)"), + errorClass = "PARSE_CHAR_MISSING_LENGTH", + parameters = Map("type" -> "\"CHAR\""), + context = ExpectedContext(fragment = "CHAR", start = 19, stop = 22)) + checkError( + exception = parseException("SELECT cast('a' as Varchar)"), + errorClass = "PARSE_CHAR_MISSING_LENGTH", + parameters = Map("type" -> "\"VARCHAR\""), + context = ExpectedContext(fragment = "Varchar", start = 19, stop = 25)) + checkError( + exception = parseException("SELECT cast('a' as Character)"), + errorClass = "PARSE_CHAR_MISSING_LENGTH", + parameters = Map("type" -> "\"CHARACTER\""), + context = ExpectedContext(fragment = "Character", start = 19, stop = 27)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index cb061602ec151..08c22a02b8555 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -34,6 +34,9 @@ class InMemoryRowLevelOperationTable( properties: util.Map[String, String]) extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations { + // used in row-level operation tests to verify replaced partitions + var replacedPartitions: Seq[Seq[Any]] = Seq.empty + override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { () => PartitionBasedOperation(info.command) @@ -88,8 +91,9 @@ class InMemoryRowLevelOperationTable( override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) - val readPartitions = readRows.map(r => getKey(r, schema)) + val readPartitions = readRows.map(r => getKey(r, schema)).distinct dataMap --= readPartitions + replacedPartitions = readPartitions withData(newData, schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 72bdab409a9e6..017d1f937c34c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes} -import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} +import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( @@ -50,7 +50,8 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, - PartitionPruning) :+ + PartitionPruning, + RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+ Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter) :+ Batch("MergeScalarSubqueries", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index d4a173bb9cceb..37cdea084d8a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -64,7 +64,7 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) */ def getSkewThreshold(medianSize: Long): Long = { conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD).max( - medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR)) + (medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR)).toLong) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala index 9a780c11eefab..21bc55110fe80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelati case class PlanAdaptiveDynamicPruningFilters( rootPlan: AdaptiveSparkPlanExec) extends Rule[SparkPlan] with AdaptiveSparkPlanHelper { def apply(plan: SparkPlan): SparkPlan = { - if (!conf.dynamicPartitionPruningEnabled) { + if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 41f60bfa2ff03..6883f93523bc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -22,7 +22,7 @@ import java.net.URI import scala.collection.mutable import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -113,13 +113,12 @@ object CommandUtils extends Logging { // countFileSize to count the table size. val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - def getPathSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) + def getPathSize(fs: FileSystem, fileStatus: FileStatus): Long = { val size = if (fileStatus.isDirectory) { - fs.listStatus(path) + fs.listStatus(fileStatus.getPath) .map { status => if (isDataPath(status.getPath, stagingDir)) { - getPathSize(fs, status.getPath) + getPathSize(fs, status) } else { 0L } @@ -136,7 +135,7 @@ object CommandUtils extends Logging { val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) - getPathSize(fs, path) + getPathSize(fs, fs.getFileStatus(path)) } catch { case NonFatal(e) => logWarning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala new file mode 100644 index 0000000000000..b5bf337a5a2e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{CharType, Metadata, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * This rule performs string padding for char type. + * + * When reading values from column/field of type CHAR(N), right-pad the values to length N, if the + * read-side padding config is turned on. + * + * When comparing char type column/field with string literal or char type column/field, + * right-pad the shorter one to the longer length. + */ +object ApplyCharTypePadding extends Rule[LogicalPlan] { + + object AttrOrOuterRef { + def unapply(e: Expression): Option[Attribute] = e match { + case a: Attribute => Some(a) + case OuterReference(a: Attribute) => Some(a) + case _ => None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (conf.charVarcharAsString) { + return plan + } + + if (conf.readSideCharPadding) { + val newPlan = plan.resolveOperatorsUpWithNewOutput { + case r: LogicalRelation => + readSidePadding(r, () => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))) + case r: DataSourceV2Relation => + readSidePadding(r, () => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))) + case r: HiveTableRelation => + readSidePadding(r, () => { + val cleanedDataCols = r.dataCols.map(CharVarcharUtils.cleanAttrMetadata) + val cleanedPartCols = r.partitionCols.map(CharVarcharUtils.cleanAttrMetadata) + r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols) + }) + } + paddingForStringComparison(newPlan) + } else { + paddingForStringComparison(plan) + } + } + + private def readSidePadding( + relation: LogicalPlan, + cleanedRelation: () => LogicalPlan) + : (LogicalPlan, Seq[(Attribute, Attribute)]) = { + val projectList = relation.output.map { attr => + CharVarcharUtils.addPaddingForScan(attr) match { + case ne: NamedExpression => ne + case other => Alias(other, attr.name)(explicitMetadata = Some(attr.metadata)) + } + } + if (projectList == relation.output) { + relation -> Nil + } else { + val newPlan = Project(projectList, cleanedRelation()) + newPlan -> relation.output.zip(newPlan.output) + } + } + + private def paddingForStringComparison(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) { + case operator => operator.transformExpressionsUpWithPruning( + _.containsAnyPattern(BINARY_COMPARISON, IN)) { + case e if !e.childrenResolved => e + + // String literal is treated as char type when it's compared to a char type column. + // We should pad the shorter one to the longer length. + case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable => + padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + b.withNewChildren(newChildren) + }.getOrElse(b) + + case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable => + padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + b.withNewChildren(newChildren.reverse) + }.getOrElse(b) + + case i @ In(e @ AttrOrOuterRef(attr), list) + if attr.dataType == StringType && list.forall(_.foldable) => + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val (nulls, literalChars) = + list.map(_.eval().asInstanceOf[UTF8String]).partition(_ == null) + val literalCharLengths = literalChars.map(_.numChars()) + val targetLen = (length +: literalCharLengths).max + Some(i.copy( + value = addPadding(e, length, targetLen), + list = list.zip(literalCharLengths).map { + case (lit, charLength) => addPadding(lit, charLength, targetLen) + } ++ nulls.map(Literal.create(_, StringType)))) + case _ => None + }.getOrElse(i) + + // For char type column or inner field comparison, pad the shorter one to the longer length. + case b @ BinaryComparison(e1 @ AttrOrOuterRef(left), e2 @ AttrOrOuterRef(right)) + // For the same attribute, they must be the same length and no padding is needed. + if !left.semanticEquals(right) => + val outerRefs = (e1, e2) match { + case (_: OuterReference, _: OuterReference) => Seq(left, right) + case (_: OuterReference, _) => Seq(left) + case (_, _: OuterReference) => Seq(right) + case _ => Nil + } + val newChildren = CharVarcharUtils.addPaddingInStringComparison(Seq(left, right)) + if (outerRefs.nonEmpty) { + b.withNewChildren(newChildren.map(_.transform { + case a: Attribute if outerRefs.exists(_.semanticEquals(a)) => OuterReference(a) + })) + } else { + b.withNewChildren(newChildren) + } + + case i @ In(e @ AttrOrOuterRef(attr), list) if list.forall(_.isInstanceOf[Attribute]) => + val newChildren = CharVarcharUtils.addPaddingInStringComparison( + attr +: list.map(_.asInstanceOf[Attribute])) + if (e.isInstanceOf[OuterReference]) { + i.copy( + value = newChildren.head.transform { + case a: Attribute if a.semanticEquals(attr) => OuterReference(a) + }, + list = newChildren.tail) + } else { + i.copy(value = newChildren.head, list = newChildren.tail) + } + } + } + } + + private def padAttrLitCmp( + expr: Expression, + metadata: Metadata, + lit: Expression): Option[Seq[Expression]] = { + if (expr.dataType == StringType) { + CharVarcharUtils.getRawType(metadata).flatMap { + case CharType(length) => + val str = lit.eval().asInstanceOf[UTF8String] + if (str == null) { + None + } else { + val stringLitLen = str.numChars() + if (length < stringLitLen) { + Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) + } else if (length > stringLitLen) { + Some(Seq(expr, StringRPad(lit, Literal(length)))) + } else { + None + } + } + case _ => None + } + } else { + None + } + } + + private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { + if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index c9ff28eb0459f..df5e3ea13652d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -45,7 +45,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp } override def apply(plan: SparkPlan): SparkPlan = { - if (!conf.dynamicPartitionPruningEnabled) { + if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala new file mode 100644 index 0000000000000..232c320bcd454 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.dynamicpruning + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruningSubquery, Expression, PredicateHelper, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} + +/** + * A rule that assigns a subquery to filter groups in row-level operations at runtime. + * + * Data skipping during job planning for row-level operations is limited to expressions that can be + * converted to data source filters. Since not all expressions can be pushed down that way and + * rewriting groups is expensive, Spark allows data sources to filter group at runtime. + * If the primary scan in a group-based row-level operation supports runtime filtering, this rule + * will inject a subquery to find all rows that match the condition so that data sources know + * exactly which groups must be rewritten. + * + * Note this rule only applies to group-based row-level operations. + */ +case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPlan]) + extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // apply special dynamic filtering only for group-based row-level operations + case GroupBasedRowLevelOperation(replaceData, cond, + DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) + if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral => + + // use reference equality on scan to find required scan relations + val newQuery = replaceData.query transformUp { + case r: DataSourceV2ScanRelation if r.scan eq scan => + // use the original table instance that was loaded for this row-level operation + // in order to leverage a regular batch scan in the group filter query + val originalTable = r.relation.table.asRowLevelOperationTable.table + val relation = r.relation.copy(table = originalTable) + val matchingRowsPlan = buildMatchingRowsPlan(relation, cond) + + val filterAttrs = scan.filterAttributes + val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) + val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) + val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys) + + Filter(dynamicPruningCond, r) + } + + // optimize subqueries to rewrite them as joins and trigger job planning + replaceData.copy(query = optimizeSubqueries(newQuery)) + } + + private def buildMatchingRowsPlan( + relation: DataSourceV2Relation, + cond: Expression): LogicalPlan = { + + val matchingRowsPlan = Filter(cond, relation) + + // clone the relation and assign new expr IDs to avoid conflicts + matchingRowsPlan transformUpWithNewOutput { + case r: DataSourceV2Relation if r eq relation => + val oldOutput = r.output + val newOutput = oldOutput.map(_.newInstance()) + r.copy(output = newOutput) -> oldOutput.zip(newOutput) + } + } + + private def buildDynamicPruningCond( + matchingRowsPlan: LogicalPlan, + buildKeys: Seq[Attribute], + pruningKeys: Seq[Attribute]): Expression = { + + val buildQuery = Project(buildKeys, matchingRowsPlan) + val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) => + DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false) + } + dynamicPruningSubqueries.reduce(And) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2271990741d7c..b12a86c08d18b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -196,6 +196,7 @@ abstract class BaseSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis(this) +: + ApplyCharTypePadding +: ReplaceCharWithVarchar +: customPostHocResolutionRules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index e3a303d4c0a67..335e52fee18df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdenti import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, SubqueryAlias, TableSpec, View} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, LogicalPlan, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, TableSpec, View} import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, FunctionCatalog, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -801,29 +801,24 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // Temporary and global temporary views are not supposed to be put into the relation cache // since they are tracked separately. V1 and V2 plans are cache invalidated accordingly. - relation match { - case SubqueryAlias(_, v: View) if !v.isTempView => - sessionCatalog.invalidateCachedTable(v.desc.identifier) - case SubqueryAlias(_, r: LogicalRelation) => + def invalidateCache(plan: LogicalPlan): Unit = plan match { + case v: View => + if (!v.isTempView) sessionCatalog.invalidateCachedTable(v.desc.identifier) + case r: LogicalRelation => sessionCatalog.invalidateCachedTable(r.catalogTable.get.identifier) - case SubqueryAlias(_, h: HiveTableRelation) => + case h: HiveTableRelation => sessionCatalog.invalidateCachedTable(h.tableMeta.identifier) - case SubqueryAlias(_, r: DataSourceV2Relation) => + case r: DataSourceV2Relation => r.catalog.get.asTableCatalog.invalidateTable(r.identifier.get) - case SubqueryAlias(_, v: View) if v.isTempView => - case _ => - throw QueryCompilationErrors.unexpectedTypeOfRelationError(relation, tableName) + case _ => plan.children.foreach(invalidateCache) } + invalidateCache(relation) + // Re-caches the logical plan of the relation. // Note this is a no-op for the relation itself if it's not cached, but will clear all // caches referencing this relation. If this relation is cached as an InMemoryRelation, // this will clear the relation cache and caches of all its dependents. - relation match { - case SubqueryAlias(_, relationPlan) => - sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relationPlan) - case _ => - throw QueryCompilationErrors.unexpectedTypeOfRelationError(relation, tableName) - } + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, relation) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/double-quoted-identifiers-disabled.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/double-quoted-identifiers-disabled.sql new file mode 100644 index 0000000000000..b8ff8cdb81376 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/double-quoted-identifiers-disabled.sql @@ -0,0 +1,2 @@ +--SET spark.sql.ansi.doubleQuotedIdentifiers=false +--IMPORT double-quoted-identifiers.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/double-quoted-identifiers-enabled.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/double-quoted-identifiers-enabled.sql new file mode 100644 index 0000000000000..9547d011c76ea --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/double-quoted-identifiers-enabled.sql @@ -0,0 +1,3 @@ +--SET spark.sql.ansi.doubleQuotedIdentifiers=true +--IMPORT double-quoted-identifiers.sql + diff --git a/sql/core/src/test/resources/sql-tests/inputs/double-quoted-identifiers.sql b/sql/core/src/test/resources/sql-tests/inputs/double-quoted-identifiers.sql index 7fe35e5a410ba..ffb52b403346f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/double-quoted-identifiers.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/double-quoted-identifiers.sql @@ -1,8 +1,3 @@ --- test cases for spark.sql.ansi.double_quoted_identifiers - --- Base line -SET spark.sql.ansi.double_quoted_identifiers = false; - -- All these should error out in the parser SELECT 1 FROM "not_exist"; @@ -45,51 +40,6 @@ DROP VIEW v; SELECT INTERVAL "1" YEAR; --- Now turn on the config. -SET spark.sql.ansi.double_quoted_identifiers = true; - --- All these should error out in analysis now -SELECT 1 FROM "not_exist"; - -USE SCHEMA "not_exist"; - -ALTER TABLE "not_exist" ADD COLUMN not_exist int; - -ALTER TABLE not_exist ADD COLUMN "not_exist" int; - -SELECT 1 AS "not_exist" FROM not_exist; - -SELECT 1 FROM not_exist AS X("hello"); - -SELECT "not_exist"(); - -SELECT "not_exist".not_exist(); - -SELECT "hello"; - --- Back ticks still work -SELECT 1 FROM `hello`; - -USE SCHEMA `not_exist`; - -ALTER TABLE `not_exist` ADD COLUMN not_exist int; - -ALTER TABLE not_exist ADD COLUMN `not_exist` int; - -SELECT 1 AS `not_exist` FROM `not_exist`; - -SELECT 1 FROM not_exist AS X(`hello`); - -SELECT `not_exist`(); - -SELECT `not_exist`.not_exist(); - --- These fail in the parser now -CREATE TEMPORARY VIEW v(c1 COMMENT "hello") AS SELECT 1; -DROP VIEW v; - -SELECT INTERVAL "1" YEAR; - -- Single ticks still work SELECT 'hello'; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/double-quoted-identifiers-disabled.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/double-quoted-identifiers-disabled.sql.out new file mode 100644 index 0000000000000..57fad89d57c94 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/double-quoted-identifiers-disabled.sql.out @@ -0,0 +1,369 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT 1 FROM "not_exist" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +USE SCHEMA "not_exist" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +ALTER TABLE "not_exist" ADD COLUMN not_exist int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +ALTER TABLE not_exist ADD COLUMN "not_exist" int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +SELECT 1 AS "not_exist" FROM not_exist +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +SELECT 1 FROM not_exist AS X("hello") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"hello\"'", + "hint" : "" + } +} + + +-- !query +SELECT "not_exist"() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +SELECT "not_exist".not_exist() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"not_exist\"'", + "hint" : "" + } +} + + +-- !query +SELECT 1 FROM `hello` +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: hello; line 1 pos 14 + + +-- !query +USE SCHEMA `not_exist` +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException +Database 'not_exist' not found + + +-- !query +ALTER TABLE `not_exist` ADD COLUMN not_exist int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table not found: not_exist; line 1 pos 12 + + +-- !query +ALTER TABLE not_exist ADD COLUMN `not_exist` int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table not found: not_exist; line 1 pos 12 + + +-- !query +SELECT 1 AS `not_exist` FROM `not_exist` +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 29 + + +-- !query +SELECT 1 FROM not_exist AS X(`hello`) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 14 + + +-- !query +SELECT `not_exist`() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1242", + "messageParameters" : { + "fullName" : "spark_catalog.default.not_exist", + "rawName" : "not_exist" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "`not_exist`()" + } ] +} + + +-- !query +SELECT `not_exist`.not_exist() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1243", + "messageParameters" : { + "rawName" : "not_exist.not_exist" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "`not_exist`.not_exist()" + } ] +} + + +-- !query +SELECT "hello" +-- !query schema +struct +-- !query output +hello + + +-- !query +CREATE TEMPORARY VIEW v(c1 COMMENT "hello") AS SELECT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW v +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT INTERVAL "1" YEAR +-- !query schema +struct +-- !query output +1-0 + + +-- !query +SELECT 'hello' +-- !query schema +struct +-- !query output +hello + + +-- !query +CREATE TEMPORARY VIEW v(c1 COMMENT 'hello') AS SELECT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW v +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT INTERVAL '1' YEAR +-- !query schema +struct +-- !query output +1-0 + + +-- !query +CREATE SCHEMA "myschema" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"myschema\"'", + "hint" : "" + } +} + + +-- !query +CREATE TEMPORARY VIEW "myview"("c1") AS + WITH "v"("a") AS (SELECT 1) SELECT "a" FROM "v" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"myview\"'", + "hint" : "" + } +} + + +-- !query +SELECT "a1" AS "a2" FROM "myview" AS "atab"("a1") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"a2\"'", + "hint" : "" + } +} + + +-- !query +DROP TABLE "myview" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"myview\"'", + "hint" : "" + } +} + + +-- !query +DROP SCHEMA "myschema" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"myschema\"'", + "hint" : "" + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/double-quoted-identifiers-enabled.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/double-quoted-identifiers-enabled.sql.out new file mode 100644 index 0000000000000..fb34e9a16197a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/double-quoted-identifiers-enabled.sql.out @@ -0,0 +1,334 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT 1 FROM "not_exist" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 14 + + +-- !query +USE SCHEMA "not_exist" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException +Database 'not_exist' not found + + +-- !query +ALTER TABLE "not_exist" ADD COLUMN not_exist int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table not found: not_exist; line 1 pos 12 + + +-- !query +ALTER TABLE not_exist ADD COLUMN "not_exist" int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table not found: not_exist; line 1 pos 12 + + +-- !query +SELECT 1 AS "not_exist" FROM not_exist +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 29 + + +-- !query +SELECT 1 FROM not_exist AS X("hello") +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 14 + + +-- !query +SELECT "not_exist"() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1242", + "messageParameters" : { + "fullName" : "spark_catalog.default.not_exist", + "rawName" : "not_exist" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "\"not_exist\"()" + } ] +} + + +-- !query +SELECT "not_exist".not_exist() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1243", + "messageParameters" : { + "rawName" : "not_exist.not_exist" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "\"not_exist\".not_exist()" + } ] +} + + +-- !query +SELECT 1 FROM `hello` +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: hello; line 1 pos 14 + + +-- !query +USE SCHEMA `not_exist` +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException +Database 'not_exist' not found + + +-- !query +ALTER TABLE `not_exist` ADD COLUMN not_exist int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table not found: not_exist; line 1 pos 12 + + +-- !query +ALTER TABLE not_exist ADD COLUMN `not_exist` int +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table not found: not_exist; line 1 pos 12 + + +-- !query +SELECT 1 AS `not_exist` FROM `not_exist` +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 29 + + +-- !query +SELECT 1 FROM not_exist AS X(`hello`) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +Table or view not found: not_exist; line 1 pos 14 + + +-- !query +SELECT `not_exist`() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1242", + "messageParameters" : { + "fullName" : "spark_catalog.default.not_exist", + "rawName" : "not_exist" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 20, + "fragment" : "`not_exist`()" + } ] +} + + +-- !query +SELECT `not_exist`.not_exist() +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1243", + "messageParameters" : { + "rawName" : "not_exist.not_exist" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "`not_exist`.not_exist()" + } ] +} + + +-- !query +SELECT "hello" +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "sqlState" : "42000", + "messageParameters" : { + "objectName" : "`hello`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 14, + "fragment" : "\"hello\"" + } ] +} + + +-- !query +CREATE TEMPORARY VIEW v(c1 COMMENT "hello") AS SELECT 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"hello\"'", + "hint" : "" + } +} + + +-- !query +DROP VIEW v +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1115", + "messageParameters" : { + "msg" : "Table spark_catalog.default.v not found" + } +} + + +-- !query +SELECT INTERVAL "1" YEAR +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", + "messageParameters" : { + "error" : "'\"1\"'", + "hint" : "" + } +} + + +-- !query +SELECT 'hello' +-- !query schema +struct +-- !query output +hello + + +-- !query +CREATE TEMPORARY VIEW v(c1 COMMENT 'hello') AS SELECT 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW v +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT INTERVAL '1' YEAR +-- !query schema +struct +-- !query output +1-0 + + +-- !query +CREATE SCHEMA "myschema" +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TEMPORARY VIEW "myview"("c1") AS + WITH "v"("a") AS (SELECT 1) SELECT "a" FROM "v" +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT "a1" AS "a2" FROM "myview" AS "atab"("a1") +-- !query schema +struct +-- !query output +1 + + +-- !query +DROP TABLE "myview" +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP SCHEMA "myschema" +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/double-quoted-identifiers.sql.out b/sql/core/src/test/resources/sql-tests/results/double-quoted-identifiers.sql.out index a67a5cffd31ca..57fad89d57c94 100644 --- a/sql/core/src/test/resources/sql-tests/results/double-quoted-identifiers.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/double-quoted-identifiers.sql.out @@ -1,12 +1,4 @@ -- Automatically generated by SQLQueryTestSuite --- !query -SET spark.sql.ansi.double_quoted_identifiers = false --- !query schema -struct --- !query output -spark.sql.ansi.double_quoted_identifiers false - - -- !query SELECT 1 FROM "not_exist" -- !query schema @@ -265,231 +257,72 @@ struct -- !query -SET spark.sql.ansi.double_quoted_identifiers = true --- !query schema -struct --- !query output -spark.sql.ansi.double_quoted_identifiers true - - --- !query -SELECT 1 FROM "not_exist" --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table or view not found: not_exist; line 1 pos 14 - - --- !query -USE SCHEMA "not_exist" --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException -Database 'not_exist' not found - - --- !query -ALTER TABLE "not_exist" ADD COLUMN not_exist int +SELECT 'hello' -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table not found: not_exist; line 1 pos 12 +hello -- !query -ALTER TABLE not_exist ADD COLUMN "not_exist" int +CREATE TEMPORARY VIEW v(c1 COMMENT 'hello') AS SELECT 1 -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -Table not found: not_exist; line 1 pos 12 - --- !query -SELECT 1 AS "not_exist" FROM not_exist --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table or view not found: not_exist; line 1 pos 29 -- !query -SELECT 1 FROM not_exist AS X("hello") +DROP VIEW v -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -Table or view not found: not_exist; line 1 pos 14 --- !query -SELECT "not_exist"() --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "_LEGACY_ERROR_TEMP_1242", - "messageParameters" : { - "fullName" : "spark_catalog.default.not_exist", - "rawName" : "not_exist" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 20, - "fragment" : "\"not_exist\"()" - } ] -} - -- !query -SELECT "not_exist".not_exist() +SELECT INTERVAL '1' YEAR -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "_LEGACY_ERROR_TEMP_1243", - "messageParameters" : { - "rawName" : "not_exist.not_exist" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 30, - "fragment" : "\"not_exist\".not_exist()" - } ] -} +1-0 -- !query -SELECT "hello" +CREATE SCHEMA "myschema" -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", + "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42000", "messageParameters" : { - "objectName" : "`hello`" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 14, - "fragment" : "\"hello\"" - } ] -} - - --- !query -SELECT 1 FROM `hello` --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table or view not found: hello; line 1 pos 14 - - --- !query -USE SCHEMA `not_exist` --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException -Database 'not_exist' not found - - --- !query -ALTER TABLE `not_exist` ADD COLUMN not_exist int --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table not found: not_exist; line 1 pos 12 - - --- !query -ALTER TABLE not_exist ADD COLUMN `not_exist` int --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table not found: not_exist; line 1 pos 12 - - --- !query -SELECT 1 AS `not_exist` FROM `not_exist` --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table or view not found: not_exist; line 1 pos 29 - - --- !query -SELECT 1 FROM not_exist AS X(`hello`) --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -Table or view not found: not_exist; line 1 pos 14 - - --- !query -SELECT `not_exist`() --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "_LEGACY_ERROR_TEMP_1242", - "messageParameters" : { - "fullName" : "spark_catalog.default.not_exist", - "rawName" : "not_exist" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 20, - "fragment" : "`not_exist`()" - } ] + "error" : "'\"myschema\"'", + "hint" : "" + } } -- !query -SELECT `not_exist`.not_exist() +CREATE TEMPORARY VIEW "myview"("c1") AS + WITH "v"("a") AS (SELECT 1) SELECT "a" FROM "v" -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException +org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_1243", + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", "messageParameters" : { - "rawName" : "not_exist.not_exist" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 30, - "fragment" : "`not_exist`.not_exist()" - } ] + "error" : "'\"myview\"'", + "hint" : "" + } } -- !query -CREATE TEMPORARY VIEW v(c1 COMMENT "hello") AS SELECT 1 +SELECT "a1" AS "a2" FROM "myview" AS "atab"("a1") -- !query schema struct<> -- !query output @@ -498,28 +331,30 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42000", "messageParameters" : { - "error" : "'\"hello\"'", + "error" : "'\"a2\"'", "hint" : "" } } -- !query -DROP VIEW v +DROP TABLE "myview" -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.analysis.NoSuchTableException +org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_1115", + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42000", "messageParameters" : { - "msg" : "Table spark_catalog.default.v not found" + "error" : "'\"myview\"'", + "hint" : "" } } -- !query -SELECT INTERVAL "1" YEAR +DROP SCHEMA "myschema" -- !query schema struct<> -- !query output @@ -528,80 +363,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "PARSE_SYNTAX_ERROR", "sqlState" : "42000", "messageParameters" : { - "error" : "'\"1\"'", + "error" : "'\"myschema\"'", "hint" : "" } } - - --- !query -SELECT 'hello' --- !query schema -struct --- !query output -hello - - --- !query -CREATE TEMPORARY VIEW v(c1 COMMENT 'hello') AS SELECT 1 --- !query schema -struct<> --- !query output - - - --- !query -DROP VIEW v --- !query schema -struct<> --- !query output - - - --- !query -SELECT INTERVAL '1' YEAR --- !query schema -struct --- !query output -1-0 - - --- !query -CREATE SCHEMA "myschema" --- !query schema -struct<> --- !query output - - - --- !query -CREATE TEMPORARY VIEW "myview"("c1") AS - WITH "v"("a") AS (SELECT 1) SELECT "a" FROM "v" --- !query schema -struct<> --- !query output - - - --- !query -SELECT "a1" AS "a2" FROM "myview" AS "atab"("a1") --- !query schema -struct --- !query output -1 - - --- !query -DROP TABLE "myview" --- !query schema -struct<> --- !query output - - - --- !query -DROP SCHEMA "myschema" --- !query schema -struct<> --- !query output - diff --git a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out index c510ad1d8314d..a6902d06cc272 100644 --- a/sql/core/src/test/resources/sql-tests/results/except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except-all.sql.out @@ -138,7 +138,7 @@ SELECT array(1) struct<> -- !query output org.apache.spark.sql.AnalysisException -ExceptAll can only be performed on tables with the compatible column types. The first column of the second table is array type which is not compatible with int at same column of first table +ExceptAll can only be performed on tables with compatible column types. The first column of the second table is array type which is not compatible with int at the same column of the first table -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out index 062c3761d2513..b439f79562a6b 100644 --- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -95,7 +95,7 @@ SELECT array(1), 2 struct<> -- !query output org.apache.spark.sql.AnalysisException -IntersectAll can only be performed on tables with the compatible column types. The first column of the second table is array type which is not compatible with int at same column of first table +IntersectAll can only be performed on tables with compatible column types. The first column of the second table is array type which is not compatible with int at the same column of the first table -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index a373c1f513aa8..5e04562a648d3 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -890,8 +890,13 @@ select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') -- !query schema struct<> -- !query output -java.lang.RuntimeException -Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2130", + "messageParameters" : { + "pattern" : "yyyy-MM-dd GGGGG" + } +} -- !query @@ -899,8 +904,13 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') -- !query schema struct<> -- !query output -java.lang.RuntimeException -Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2130", + "messageParameters" : { + "pattern" : "dd MM yyyy EEEEEE" + } +} -- !query @@ -908,8 +918,13 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') -- !query schema struct<> -- !query output -java.lang.RuntimeException -Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2130", + "messageParameters" : { + "pattern" : "dd MM yyyy EEEEE" + } +} -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out index c831b75c68116..8d98209e6254d 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out @@ -867,8 +867,13 @@ select to_timestamp('2019-10-06 A', 'yyyy-MM-dd GGGGG') -- !query schema struct<> -- !query output -java.lang.RuntimeException -Fail to recognize 'yyyy-MM-dd GGGGG' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2130", + "messageParameters" : { + "pattern" : "yyyy-MM-dd GGGGG" + } +} -- !query @@ -876,8 +881,13 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEEE') -- !query schema struct<> -- !query output -java.lang.RuntimeException -Fail to recognize 'dd MM yyyy EEEEEE' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2130", + "messageParameters" : { + "pattern" : "dd MM yyyy EEEEEE" + } +} -- !query @@ -885,8 +895,13 @@ select to_timestamp('22 05 2020 Friday', 'dd MM yyyy EEEEE') -- !query schema struct<> -- !query output -java.lang.RuntimeException -Fail to recognize 'dd MM yyyy EEEEE' pattern in the DateTimeFormatter. You can form a valid datetime pattern with the guide from https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_2130", + "messageParameters" : { + "pattern" : "dd MM yyyy EEEEE" + } +} -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out index f830797212a76..34c46c1a2c088 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/widenSetOperationTypes.sql.out @@ -85,7 +85,7 @@ SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with tinyint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with tinyint at the same column of the first table -- !query @@ -94,7 +94,7 @@ SELECT cast(1 as tinyint) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with tinyint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with tinyint at the same column of the first table -- !query @@ -103,7 +103,7 @@ SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as ti struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with tinyint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with tinyint at the same column of the first table -- !query @@ -112,7 +112,7 @@ SELECT cast(1 as tinyint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with tinyint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with tinyint at the same column of the first table -- !query @@ -193,7 +193,7 @@ SELECT cast(1 as smallint) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with smallint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with smallint at the same column of the first table -- !query @@ -202,7 +202,7 @@ SELECT cast(1 as smallint) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with smallint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with smallint at the same column of the first table -- !query @@ -211,7 +211,7 @@ SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with smallint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with smallint at the same column of the first table -- !query @@ -220,7 +220,7 @@ SELECT cast(1 as smallint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as dat struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with smallint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with smallint at the same column of the first table -- !query @@ -301,7 +301,7 @@ SELECT cast(1 as int) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with int at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with int at the same column of the first table -- !query @@ -310,7 +310,7 @@ SELECT cast(1 as int) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with int at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with int at the same column of the first table -- !query @@ -319,7 +319,7 @@ SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as timest struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with int at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with int at the same column of the first table -- !query @@ -328,7 +328,7 @@ SELECT cast(1 as int) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) FR struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with int at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with int at the same column of the first table -- !query @@ -409,7 +409,7 @@ SELECT cast(1 as bigint) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with bigint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with bigint at the same column of the first table -- !query @@ -418,7 +418,7 @@ SELECT cast(1 as bigint) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with bigint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with bigint at the same column of the first table -- !query @@ -427,7 +427,7 @@ SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as tim struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with bigint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with bigint at the same column of the first table -- !query @@ -436,7 +436,7 @@ SELECT cast(1 as bigint) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with bigint at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with bigint at the same column of the first table -- !query @@ -517,7 +517,7 @@ SELECT cast(1 as float) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with float at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with float at the same column of the first table -- !query @@ -526,7 +526,7 @@ SELECT cast(1 as float) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with float at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with float at the same column of the first table -- !query @@ -535,7 +535,7 @@ SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as time struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with float at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with float at the same column of the first table -- !query @@ -544,7 +544,7 @@ SELECT cast(1 as float) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with float at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with float at the same column of the first table -- !query @@ -625,7 +625,7 @@ SELECT cast(1 as double) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with double at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with double at the same column of the first table -- !query @@ -634,7 +634,7 @@ SELECT cast(1 as double) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with double at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with double at the same column of the first table -- !query @@ -643,7 +643,7 @@ SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as tim struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with double at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with double at the same column of the first table -- !query @@ -652,7 +652,7 @@ SELECT cast(1 as double) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date) struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with double at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with double at the same column of the first table -- !query @@ -733,7 +733,7 @@ SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with decimal(10,0) at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with decimal(10,0) at the same column of the first table -- !query @@ -742,7 +742,7 @@ SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with decimal(10,0) at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with decimal(10,0) at the same column of the first table -- !query @@ -751,7 +751,7 @@ SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00.0 struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with decimal(10,0) at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with decimal(10,0) at the same column of the first table -- !query @@ -760,7 +760,7 @@ SELECT cast(1 as decimal(10, 0)) FROM t UNION SELECT cast('2017-12-11 09:30:00' struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with decimal(10,0) at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with decimal(10,0) at the same column of the first table -- !query @@ -841,7 +841,7 @@ SELECT cast(1 as string) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with string at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with string at the same column of the first table -- !query @@ -850,7 +850,7 @@ SELECT cast(1 as string) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with string at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with string at the same column of the first table -- !query @@ -877,7 +877,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as tinyint) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is tinyint type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is tinyint type which is not compatible with binary at the same column of the first table -- !query @@ -886,7 +886,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as smallint) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is smallint type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is smallint type which is not compatible with binary at the same column of the first table -- !query @@ -895,7 +895,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as int) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is int type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is int type which is not compatible with binary at the same column of the first table -- !query @@ -904,7 +904,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as bigint) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is bigint type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is bigint type which is not compatible with binary at the same column of the first table -- !query @@ -913,7 +913,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as float) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is float type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is float type which is not compatible with binary at the same column of the first table -- !query @@ -922,7 +922,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as double) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is double type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is double type which is not compatible with binary at the same column of the first table -- !query @@ -931,7 +931,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with binary at the same column of the first table -- !query @@ -940,7 +940,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as string) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is string type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is string type which is not compatible with binary at the same column of the first table -- !query @@ -958,7 +958,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast(2 as boolean) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with binary at the same column of the first table -- !query @@ -967,7 +967,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with binary at the same column of the first table -- !query @@ -976,7 +976,7 @@ SELECT cast('1' as binary) FROM t UNION SELECT cast('2017-12-11 09:30:00' as dat struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with binary at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with binary at the same column of the first table -- !query @@ -985,7 +985,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as tinyint) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is tinyint type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is tinyint type which is not compatible with boolean at the same column of the first table -- !query @@ -994,7 +994,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as smallint) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is smallint type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is smallint type which is not compatible with boolean at the same column of the first table -- !query @@ -1003,7 +1003,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as int) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is int type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is int type which is not compatible with boolean at the same column of the first table -- !query @@ -1012,7 +1012,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as bigint) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is bigint type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is bigint type which is not compatible with boolean at the same column of the first table -- !query @@ -1021,7 +1021,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as float) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is float type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is float type which is not compatible with boolean at the same column of the first table -- !query @@ -1030,7 +1030,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as double) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is double type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is double type which is not compatible with boolean at the same column of the first table -- !query @@ -1039,7 +1039,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as decimal(10, 0)) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with boolean at the same column of the first table -- !query @@ -1048,7 +1048,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast(2 as string) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is string type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is string type which is not compatible with boolean at the same column of the first table -- !query @@ -1057,7 +1057,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast('2' as binary) FROM t struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with boolean at the same column of the first table -- !query @@ -1074,7 +1074,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00.0' as ti struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is timestamp type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is timestamp type which is not compatible with boolean at the same column of the first table -- !query @@ -1083,7 +1083,7 @@ SELECT cast(1 as boolean) FROM t UNION SELECT cast('2017-12-11 09:30:00' as date struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is date type which is not compatible with boolean at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is date type which is not compatible with boolean at the same column of the first table -- !query @@ -1092,7 +1092,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is tinyint type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is tinyint type which is not compatible with timestamp at the same column of the first table -- !query @@ -1101,7 +1101,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is smallint type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is smallint type which is not compatible with timestamp at the same column of the first table -- !query @@ -1110,7 +1110,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is int type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is int type which is not compatible with timestamp at the same column of the first table -- !query @@ -1119,7 +1119,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is bigint type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is bigint type which is not compatible with timestamp at the same column of the first table -- !query @@ -1128,7 +1128,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is float type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is float type which is not compatible with timestamp at the same column of the first table -- !query @@ -1137,7 +1137,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is double type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is double type which is not compatible with timestamp at the same column of the first table -- !query @@ -1146,7 +1146,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with timestamp at the same column of the first table -- !query @@ -1164,7 +1164,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast('2' a struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with timestamp at the same column of the first table -- !query @@ -1173,7 +1173,7 @@ SELECT cast('2017-12-12 09:30:00.0' as timestamp) FROM t UNION SELECT cast(2 as struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with timestamp at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with timestamp at the same column of the first table -- !query @@ -1200,7 +1200,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as tinyint struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is tinyint type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is tinyint type which is not compatible with date at the same column of the first table -- !query @@ -1209,7 +1209,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as smallin struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is smallint type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is smallint type which is not compatible with date at the same column of the first table -- !query @@ -1218,7 +1218,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as int) FR struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is int type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is int type which is not compatible with date at the same column of the first table -- !query @@ -1227,7 +1227,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as bigint) struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is bigint type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is bigint type which is not compatible with date at the same column of the first table -- !query @@ -1236,7 +1236,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as float) struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is float type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is float type which is not compatible with date at the same column of the first table -- !query @@ -1245,7 +1245,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as double) struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is double type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is double type which is not compatible with date at the same column of the first table -- !query @@ -1254,7 +1254,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as decimal struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is decimal(10,0) type which is not compatible with date at the same column of the first table -- !query @@ -1272,7 +1272,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast('2' as binar struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is binary type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is binary type which is not compatible with date at the same column of the first table -- !query @@ -1281,7 +1281,7 @@ SELECT cast('2017-12-12 09:30:00' as date) FROM t UNION SELECT cast(2 as boolean struct<> -- !query output org.apache.spark.sql.AnalysisException -Union can only be performed on tables with the compatible column types. The first column of the second table is boolean type which is not compatible with date at same column of first table +Union can only be performed on tables with compatible column types. The first column of the second table is boolean type which is not compatible with date at the same column of the first table -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out index 997308bdbf67a..cb125a648b73b 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except-all.sql.out @@ -138,7 +138,7 @@ SELECT array(1) struct<> -- !query output org.apache.spark.sql.AnalysisException -ExceptAll can only be performed on tables with the compatible column types. The first column of the second table is array type which is not compatible with int at same column of first table +ExceptAll can only be performed on tables with compatible column types. The first column of the second table is array type which is not compatible with int at the same column of the first table -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out index 29febc747ea13..68a9e11bd23eb 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-intersect-all.sql.out @@ -95,7 +95,7 @@ SELECT array(1), udf(2) struct<> -- !query output org.apache.spark.sql.AnalysisException -IntersectAll can only be performed on tables with the compatible column types. The first column of the second table is array type which is not compatible with int at same column of first table +IntersectAll can only be performed on tables with compatible column types. The first column of the second table is array type which is not compatible with int at the same column of the first table -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 35a62d687de90..27a630c169be0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -806,13 +806,11 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } test("create table w/ location and fit length values") { - Seq("char", "varchar").foreach { typ => - withTempPath { dir => - withTable("t") { - sql("SELECT '12' as col").write.format(format).save(dir.toString) - sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'") - checkAnswer(sql("select * from t"), Row("12")) - } + withTempPath { dir => + withTable("t") { + sql("SELECT '12' as col1, '12' as col2").write.format(format).save(dir.toString) + sql(s"CREATE TABLE t (col1 char(3), col2 varchar(3)) using $format LOCATION '$dir'") + checkAnswer(sql("select * from t"), Row("12 ", "12")) } } } @@ -830,14 +828,12 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } test("alter table set location w/ fit length values") { - Seq("char", "varchar").foreach { typ => - withTempPath { dir => - withTable("t") { - sql("SELECT '12' as col").write.format(format).save(dir.toString) - sql(s"CREATE TABLE t (col $typ(2)) using $format") - sql(s"ALTER TABLE t SET LOCATION '$dir'") - checkAnswer(spark.table("t"), Row("12")) - } + withTempPath { dir => + withTable("t") { + sql("SELECT '12' as col1, '12' as col2").write.format(format).save(dir.toString) + sql(s"CREATE TABLE t (col1 char(3), col2 varchar(3)) using $format") + sql(s"ALTER TABLE t SET LOCATION '$dir'") + checkAnswer(spark.table("t"), Row("12 ", "12")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 41af747a83e41..74ec36988d799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -4229,15 +4229,57 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(errMsg.contains(s"input to function $name requires at least one argument")) } - val funcsMustHaveAtLeastTwoArgs = - ("greatest", (df: DataFrame) => df.select(greatest())) :: - ("greatest", (df: DataFrame) => df.selectExpr("greatest()")) :: - ("least", (df: DataFrame) => df.select(least())) :: - ("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil - funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) => - val errMsg = intercept[AnalysisException] { func(df) }.getMessage - assert(errMsg.contains(s"input to function $name requires at least two arguments")) - } + checkError( + exception = intercept[AnalysisException] { + df.select(greatest()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"greatest()\"", + "actualNum" -> "0") + ) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("greatest()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"greatest()\"", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "greatest()", + start = 0, + stop = 9) + ) + + checkError( + exception = intercept[AnalysisException] { + df.select(least()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"least()\"", + "actualNum" -> "0") + ) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("least()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"least()\"", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "least()", + start = 0, + stop = 6) + ) } test("SPARK-24734: Fix containsNull of Concat for array type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index ca04adf642e15..0acb3842b039e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -1001,10 +1001,10 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { df1.unionByName(df2) }.getMessage assert(errMsg.contains("Union can only be performed on tables with" + - " the compatible column types." + + " compatible column types." + " The third column of the second table is struct>" + - " type which is not compatible with struct> at same" + - " column of first table")) + " type which is not compatible with struct> at the same" + + " column of the first table")) // diff Case sensitive attributes names and diff sequence scenario for unionByName df1 = Seq((1, 2, UnionClass1d(1, 2, Struct3(1)))).toDF("a", "b", "c") @@ -1084,7 +1084,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { val err = intercept[AnalysisException](df7.union(df8).collect()) assert(err.message - .contains("Union can only be performed on tables with the compatible column types")) + .contains("Union can only be performed on tables with compatible column types")) } test("SPARK-36546: Add unionByName support to arrays of structs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index ace1756a14d82..643dcc20c65d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -256,16 +256,19 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { protected def testQuery(tpcdsGroup: String, query: String, suffix: String = ""): Unit = { val queryString = resourceToString(s"$tpcdsGroup/$query.sql", classLoader = Thread.currentThread().getContextClassLoader) - val qe = sql(queryString).queryExecution - val plan = qe.executedPlan - val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode))) - - assert(ValidateRequirements.validate(plan)) - - if (regenerateGoldenFiles) { - generateGoldenFile(plan, query + suffix, explain) - } else { - checkWithApproved(plan, query + suffix, explain) + // Disable char/varchar read-side handling for better performance. + withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") { + val qe = sql(queryString).queryExecution + val plan = qe.executedPlan + val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode))) + + assert(ValidateRequirements.validate(plan)) + + if (regenerateGoldenFiles) { + generateGoldenFile(plan, query + suffix, explain) + } else { + checkWithApproved(plan, query + suffix, explain) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ab6914f4fe836..6a6f1ba989f92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2674,7 +2674,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val m2 = intercept[AnalysisException] { sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") }.message - assert(m2.contains("Except can only be performed on tables with the compatible column types")) + assert(m2.contains("Except can only be performed on tables with compatible column types")) withTable("t", "S") { sql("CREATE TABLE t(c struct) USING parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 80b9fc767d56b..c693021f387cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -758,7 +758,7 @@ class SubquerySuite extends QueryTest Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t4") // Simplest case - intercept[AnalysisException] { + val exception1 = intercept[AnalysisException] { sql( """ | select t1.c1 @@ -767,9 +767,22 @@ class SubquerySuite extends QueryTest | from t2 | where t1.c2 >= t2.c2)""".stripMargin).collect() } + checkErrorMatchPVals( + exception1, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = + """select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2""".stripMargin, + start = 44, + stop = 128)) // Add a HAVING on top and augmented within an OR predicate - intercept[AnalysisException] { + val exception2 = intercept[AnalysisException] { sql( """ | select t1.c1 @@ -780,9 +793,23 @@ class SubquerySuite extends QueryTest | having count(*) > 0 ) | or t1.c2 >= 0""".stripMargin).collect() } + checkErrorMatchPVals( + exception2, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = + """select max(t2.c1) + | from t2 + | where t1.c2 >= t2.c2 + | having count(*) > 0""".stripMargin, + start = 44, + stop = 166)) // Add a HAVING on top and augmented within an OR predicate - intercept[AnalysisException] { + val exception3 = intercept[AnalysisException] { sql( """ | select t1.c1 @@ -794,10 +821,24 @@ class SubquerySuite extends QueryTest | or t3.c2 = t2.c2) | )""".stripMargin).collect() } + checkErrorMatchPVals( + exception3, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = + """select max(t2.c1) + | from t2 + | where t1.c2 = t2.c2 + | or t3.c2 = t2.c2""".stripMargin, + start = 77, + stop = 205)) // In Window expression: changing the data set to // demonstrate if this query ran, it would return incorrect result. - intercept[AnalysisException] { + val exception4 = intercept[AnalysisException] { sql( """ | select c1 @@ -806,6 +847,19 @@ class SubquerySuite extends QueryTest | from t4 | where t3.c2 >= t4.c2)""".stripMargin).collect() } + checkErrorMatchPVals( + exception4, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = + """select max(t4.c1) over () + | from t4 + | where t3.c2 >= t4.c2""".stripMargin, + start = 38, + stop = 123)) } } // This restriction applies to @@ -820,7 +874,7 @@ class SubquerySuite extends QueryTest Seq(1).toDF("c1").createOrReplaceTempView("t3") // Left outer join (LOJ) in IN subquery context - intercept[AnalysisException] { + val exception1 = intercept[AnalysisException] { sql( """ | select t1.c1 @@ -830,8 +884,19 @@ class SubquerySuite extends QueryTest | (select c1 from t2 where t1.c1 = 2) t2 | on t2.c1 = t3.c1)""".stripMargin).collect() } + checkErrorMatchPVals( + exception1, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = "(select c1 from t2 where t1.c1 = 2) t2", + start = 110, + stop = 147)) + // Right outer join (ROJ) in EXISTS subquery context - intercept[AnalysisException] { + val exception2 = intercept[AnalysisException] { sql( """ | select t1.c1 @@ -841,8 +906,19 @@ class SubquerySuite extends QueryTest | right outer join t3 | on t2.c1 = t3.c1)""".stripMargin).collect() } + checkErrorMatchPVals( + exception2, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = "(select c1 from t2 where t1.c1 = 2) t2", + start = 74, + stop = 111)) + // SPARK-18578: Full outer join (FOJ) in scalar subquery context - intercept[AnalysisException] { + val exception3 = intercept[AnalysisException] { sql( """ | select (select max(1) @@ -851,6 +927,18 @@ class SubquerySuite extends QueryTest | on t2.c1=t3.c1) | from t1""".stripMargin).collect() } + checkErrorMatchPVals( + exception3, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = + """full join t3 + | on t2.c1=t3.c1""".stripMargin, + start = 112, + stop = 154)) } } @@ -880,6 +968,15 @@ class SubquerySuite extends QueryTest | WHERE t1.c1 = t2.c1) """.stripMargin) } + checkErrorMatchPVals( + exception1, + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_CORRELATED_REFERENCE", + parameters = Map("treeNode" -> "(?s).*"), + sqlState = None, + context = ExpectedContext( + fragment = "LATERAL VIEW explode(t2.arr_c2) q AS c2", + start = 68, + stop = 106)) assert(exception1.getMessage.contains( "Expressions referencing the outer query are not supported outside of WHERE/HAVING")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index 8c4d25a7eb988..ffd15eb46a48e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -29,6 +29,10 @@ import org.apache.spark.tags.ExtendedSQLTest @ExtendedSQLTest class TPCDSQuerySuite extends BenchmarkQueryTest with TPCDSBase { + override protected def sparkConf: SparkConf = + // Disable read-side char padding so that the generated code is less than 8000. + super.sparkConf.set(SQLConf.READ_SIDE_CHAR_PADDING, false) + // q72 is skipped due to GitHub Actions' memory limit. tpcdsQueries.filterNot(sys.env.contains("GITHUB_ACTIONS") && _ == "q72").foreach { name => val queryString = resourceToString(s"tpcds/$name.sql", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index a2cfdde2671f6..d9a12b47ec269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -22,7 +22,7 @@ import java.util.Collections import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTableCatalog} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -46,15 +46,19 @@ abstract class DeleteFromTableSuiteBase spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") } - private val namespace = Array("ns1") - private val ident = Identifier.of(namespace, "test_table") - private val tableNameAsString = "cat." + ident.toString + protected val namespace: Array[String] = Array("ns1") + protected val ident: Identifier = Identifier.of(namespace, "test_table") + protected val tableNameAsString: String = "cat." + ident.toString - private def catalog: InMemoryRowLevelOperationTableCatalog = { + protected def catalog: InMemoryRowLevelOperationTableCatalog = { val catalog = spark.sessionState.catalogManager.catalog("cat") catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] } + protected def table: InMemoryRowLevelOperationTable = { + catalog.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] + } + test("EXPLAIN only delete") { createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""") @@ -553,13 +557,13 @@ abstract class DeleteFromTableSuiteBase } } - private def createTable(schemaString: String): Unit = { + protected def createTable(schemaString: String): Unit = { val schema = StructType.fromDDL(schemaString) val tableProps = Collections.emptyMap[String, String] catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), tableProps) } - private def createAndInitTable(schemaString: String, jsonData: String): Unit = { + protected def createAndInitTable(schemaString: String, jsonData: String): Unit = { createTable(schemaString) append(schemaString, jsonData) } @@ -606,7 +610,7 @@ abstract class DeleteFromTableSuiteBase } // executes an operation and keeps the executed plan - private def executeAndKeepPlan(func: => Unit): SparkPlan = { + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { var executedPlan: SparkPlan = null val listener = new QueryExecutionListener { @@ -625,5 +629,3 @@ abstract class DeleteFromTableSuiteBase stripAQEPlan(executedPlan) } } - -class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala new file mode 100644 index 0000000000000..36905027cb0cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.execution.InSubqueryExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { + + import testImplicits._ + + test("delete with IN predicate and runtime group filtering") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "salary INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + + test("delete with subqueries and runtime group filtering") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |{ "id": 4, "salary": 150, "dep": 'software' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq(Some("software"), None).toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + executeDeleteAndCheckScans( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (SELECT * FROM deleted_id) + | AND + | dep IN (SELECT * FROM deleted_dep) + |""".stripMargin, + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: Nil) + + checkReplacedPartitions(Seq("software")) + } + } + + test("delete runtime group filtering (DPP enabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (DPP disabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE enabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE disabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + private def checkDeleteRuntimeGroupFiltering(): Unit = { + withTempView("deleted_id") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(1), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + } + + private def executeDeleteAndCheckScans( + query: String, + primaryScanSchema: String, + groupFilterScanSchema: String): Unit = { + + val executedPlan = executeAndKeepPlan { + sql(query) + } + + val primaryScan = collect(executedPlan) { + case s: BatchScanExec => s + }.head + assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema))) + + primaryScan.runtimeFilters match { + case Seq(DynamicPruningExpression(child: InSubqueryExec)) => + val groupFilterScan = collect(child.plan) { + case s: BatchScanExec => s + }.head + assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema))) + + case _ => + fail("could not find group filter scan") + } + } + + private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { + val actualPartitions = table.replacedPartitions.map { + case Seq(partValue: UTF8String) => partValue.toString + case Seq(partValue) => partValue + case other => fail(s"expected only one partition value: $other" ) + } + assert(actualPartitions == expectedPartitions, "replaced partitions must match") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 6c87178f267c4..75f427e478a0b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.SparkConf +import org.apache.spark.SparkException import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Kryo._ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} @@ -534,7 +535,7 @@ class HashedRelationSuite extends SharedSparkSession { buffer.append(keyIterator.next().getLong(0)) } // attempt an illegal next() call - val caught = intercept[NoSuchElementException] { + val caught = intercept[SparkException] { keyIterator.next() } assert(caught.getLocalizedMessage === "End of the iterator") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 8f263f042cf9f..eac77c2938207 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -141,7 +141,7 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", - s"char_$i", + s"char_$i".padTo(18, ' '), Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), diff --git a/sql/create-docs.sh b/sql/create-docs.sh index 8721df874ee73..c5a36e0474eb0 100755 --- a/sql/create-docs.sh +++ b/sql/create-docs.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Licensed to the Apache Software Foundation (ASF) under one or more diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java index 79426e0e3de18..8ee606be314c2 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java @@ -17,6 +17,7 @@ package org.apache.hive.service.cli.operation; import java.io.CharArrayWriter; +import java.io.Serializable; import java.util.Map; import java.util.regex.Pattern; @@ -265,7 +266,7 @@ private static StringLayout initLayout(OperationLog.LoggingLevel loggingMode) { Map appenders = root.getAppenders(); for (Appender ap : appenders.values()) { if (ap.getClass().equals(ConsoleAppender.class)) { - Layout l = ap.getLayout(); + Layout l = ap.getLayout(); if (l instanceof StringLayout) { layout = (StringLayout) l; break; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java index a261a54581828..6ee48186e7ea8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/OperationManager.java @@ -39,7 +39,7 @@ import org.apache.hive.service.cli.session.HiveSession; import org.apache.hive.service.rpc.thrift.TRowSet; import org.apache.hive.service.rpc.thrift.TTableSchema; -import org.apache.logging.log4j.core.appender.AbstractWriterAppender; +import org.apache.logging.log4j.core.Appender; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -82,7 +82,7 @@ public synchronized void stop() { private void initOperationLogCapture(String loggingMode) { // Register another Appender (with the same layout) that talks to us. - AbstractWriterAppender ap = new LogDivertAppender(this, OperationLog.getLoggingLevel(loggingMode)); + Appender ap = new LogDivertAppender(this, OperationLog.getLoggingLevel(loggingMode)); ((org.apache.logging.log4j.core.Logger)org.apache.logging.log4j.LogManager.getRootLogger()).addAppender(ap); ap.start(); } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index b554958572ad3..c4e0057ae952d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -101,6 +101,7 @@ class HiveSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis(this) +: + ApplyCharTypePadding +: HiveAnalysis +: ReplaceCharWithVarchar +: customPostHocResolutionRules