diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 8496cbda261be..a8eff6be7ddde 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Collections, Locale, UUID} import scala.collection.JavaConverters._ @@ -359,6 +359,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def schema(): StructType = KafkaOffsetReader.kafkaSchema + override def capabilities(): ju.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java deleted file mode 100644 index ea7c5d2b108f0..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. - *

- * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} - * that builds {@link Scan} with {@link Scan#toBatch()} implemented. - *

- */ -@Evolving -public interface SupportsBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java deleted file mode 100644 index 09e23f84fd6bf..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.WriteBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. - *

- * If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)} must return a - * {@link WriteBuilder} with {@link WriteBuilder#buildForBatch()} implemented. - *

- */ -@Evolving -public interface SupportsBatchWrite extends SupportsWrite {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java index 14990effeda37..67fc72e070dc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -26,7 +26,7 @@ * {@link #newScanBuilder(CaseInsensitiveStringMap)} that is used to create a scan for batch, * micro-batch, or continuous processing. */ -interface SupportsRead extends Table { +public interface SupportsRead extends Table { /** * Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java index f0d8e44f15287..b215963868217 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java @@ -26,7 +26,7 @@ * {@link #newWriteBuilder(CaseInsensitiveStringMap)} that is used to create a write * for batch or streaming. */ -interface SupportsWrite extends Table { +public interface SupportsWrite extends Table { /** * Returns a {@link WriteBuilder} which can be used to create {@link BatchWrite}. Spark will call diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 08664859b8de2..78f979a2a9a44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -20,16 +20,15 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; +import java.util.Set; + /** * An interface representing a logical structured data set of a data source. For example, the * implementation can be a directory on the file system, a topic of Kafka, or a table in the * catalog, etc. *

- * This interface can mixin the following interfaces to support different operations: - *

- * + * This interface can mixin the following interfaces to support different operations, like + * {@code SupportsRead}. */ @Evolving public interface Table { @@ -45,4 +44,9 @@ public interface Table { * empty schema can be returned here. */ StructType schema(); + + /** + * Returns the set of capabilities for this table. + */ + Set capabilities(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java new file mode 100644 index 0000000000000..8d3fdcd694e2c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Experimental; + +/** + * Capabilities that can be provided by a {@link Table} implementation. + *

+ * Tables use {@link Table#capabilities()} to return a set of capabilities. Each capability signals + * to Spark that the table supports a feature identified by the capability. For example, returning + * {@code BATCH_READ} allows Spark to read from the table using a batch scan. + */ +@Experimental +public enum TableCapability { + /** + * Signals that the table supports reads in batch execution mode. + */ + BATCH_READ, + + /** + * Signals that the table supports append writes in batch execution mode. + *

+ * Tables that return this capability must support appending data and may also support additional + * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and + * {@link #OVERWRITE_DYNAMIC}. + */ + BATCH_WRITE, + + /** + * Signals that the table can be truncated in a write operation. + *

+ * Truncating a table removes all existing rows. + *

+ * See {@link org.apache.spark.sql.sources.v2.writer.SupportsTruncate}. + */ + TRUNCATE, + + /** + * Signals that the table can replace existing data that matches a filter with appended data in + * a write operation. + *

+ * See {@link org.apache.spark.sql.sources.v2.writer.SupportsOverwrite}. + */ + OVERWRITE_BY_FILTER, + + /** + * Signals that the table can dynamically replace existing data partitions with appended data in + * a write operation. + *

+ * See {@link org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}. + */ + OVERWRITE_DYNAMIC +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index 25ab06eee42e0..e97d0548c66ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -21,7 +21,6 @@ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousStream; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.sources.v2.SupportsBatchRead; import org.apache.spark.sql.sources.v2.SupportsContinuousRead; import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; import org.apache.spark.sql.sources.v2.Table; @@ -33,8 +32,8 @@ * This logical representation is shared between batch scan, micro-batch streaming scan and * continuous streaming scan. Data sources must implement the corresponding methods in this * interface, to match what the table promises to support. For example, {@link #toBatch()} must be - * implemented, if the {@link Table} that creates this {@link Scan} implements - * {@link SupportsBatchRead}. + * implemented, if the {@link Table} that creates this {@link Scan} returns BATCH_READ support in + * its {@link Table#capabilities()}. *

*/ @Evolving @@ -62,7 +61,7 @@ default String description() { /** * Returns the physical representation of this scan for batch query. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this scan implements {@link SupportsBatchRead}. + * {@link Table} that creates this returns batch read support in its {@link Table#capabilities()}. * * @throws UnsupportedOperationException */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index 07529fe1dee91..e08d34fbf453e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.SupportsBatchWrite; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; @@ -58,7 +57,8 @@ default WriteBuilder withInputDataSchema(StructType schema) { /** * Returns a {@link BatchWrite} to write data to batch source. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this scan implements {@link SupportsBatchWrite}. + * {@link Table} that creates this write returns BATCH_WRITE support in its + * {@link Table#capabilities()}. * * Note that, the returned {@link BatchWrite} can be null if the implementation supports SaveMode, * to indicate that no writing is needed. We can clean it up after removing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index dfba12a5856ef..e057d33a6a148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,8 +37,9 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, FileTable} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2} import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -220,8 +221,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) } + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { - case _: SupportsBatchRead => + case _: SupportsRead if table.supports(BATCH_READ) => Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions)) case _ => loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3c51edd8ab603..b439a82e52cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, Logi import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -264,8 +265,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { - case table: SupportsBatchWrite => + case table: SupportsWrite if table.supports(BATCH_WRITE) => lazy val relation = DataSourceV2Relation.create(table, dsOptions) mode match { case SaveMode.Append => @@ -273,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { AppendData.byName(relation, df.logicalPlan) } - case SaveMode.Overwrite => + case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => // truncate the table runCommand(df.sparkSession, "save") { OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index aa2a5e9a06fbd..96a78d3a0da20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.datasources.noop +import java.util + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister @@ -35,10 +39,11 @@ class NoopDataSource extends TableProvider with DataSourceRegister { override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } -private[noop] object NoopTable extends Table with SupportsBatchWrite with SupportsStreamingWrite { +private[noop] object NoopTable extends Table with SupportsWrite with SupportsStreamingWrite { override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava } private[noop] object NoopWriteBuilder extends WriteBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index 2081af35ce2d1..eed69cdc8cac6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -18,26 +18,30 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} object DataSourceV2Implicits { implicit class TableHelper(table: Table) { - def asBatchReadable: SupportsBatchRead = { + def asReadable: SupportsRead = { table match { - case support: SupportsBatchRead => + case support: SupportsRead => support case _ => - throw new AnalysisException(s"Table does not support batch reads: ${table.name}") + throw new AnalysisException(s"Table does not support reads: ${table.name}") } } - def asBatchWritable: SupportsBatchWrite = { + def asWritable: SupportsWrite = { table match { - case support: SupportsBatchWrite => + case support: SupportsWrite => support case _ => - throw new AnalysisException(s"Table does not support batch writes: ${table.name}") + throw new AnalysisException(s"Table does not support writes: ${table.name}") } } + + def supports(capability: TableCapability): Boolean = table.capabilities.contains(capability) + + def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 17407827d0564..411995718603c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -49,7 +49,7 @@ case class DataSourceV2Relation( } def newScanBuilder(): ScanBuilder = { - table.asBatchReadable.newScanBuilder(options) + table.asReadable.newScanBuilder(options) } override def computeStats(): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 424fbed6fc1e6..f8c7e2c826a36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -148,7 +148,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - AppendDataExec(r.table.asBatchWritable, r.options, planLater(query)) :: Nil + AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. @@ -158,10 +158,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { }.toArray OverwriteByExpressionExec( - r.table.asBatchWritable, filters, r.options, planLater(query)) :: Nil + r.table.asWritable, filters, r.options, planLater(query)) :: Nil case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => - OverwritePartitionsDynamicExec(r.table.asBatchWritable, r.options, planLater(query)) :: Nil + OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 9423fe95fb97f..5944a20dd1efa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -22,7 +22,8 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.SchemaUtils @@ -32,7 +33,7 @@ abstract class FileTable( options: CaseInsensitiveStringMap, paths: Seq[String], userSpecifiedSchema: Option[StructType]) - extends Table with SupportsBatchRead with SupportsBatchWrite { + extends Table with SupportsRead with SupportsWrite { lazy val fileIndex: PartitioningAwareFileIndex = { val scalaMap = options.asScala.toMap @@ -62,6 +63,8 @@ abstract class FileTable( partitionSchema, caseSensitive)._1 } + override def capabilities(): java.util.Set[TableCapability] = FileTable.CAPABILITIES + /** * When possible, this method should return the schema of the given `files`. When the format * does not support inference, or no valid files are given should return None. In these cases @@ -69,3 +72,7 @@ abstract class FileTable( */ def inferSchema(files: Seq[FileStatus]): Option[StructType] } + +object FileTable { + private val CAPABILITIES = Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala new file mode 100644 index 0000000000000..cf77998c122f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.types.BooleanType + +object V2WriteSupportCheck extends (LogicalPlan => Unit) { + import DataSourceV2Implicits._ + + def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) + + override def apply(plan: LogicalPlan): Unit = plan foreach { + case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) => + failAnalysis(s"Table does not support append in batch mode: ${rel.table}") + + case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _) + if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) => + failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}") + + case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) => + expr match { + case Literal(true, BooleanType) => + if (!rel.table.supports(BATCH_WRITE) || + !rel.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) { + failAnalysis( + s"Table does not support truncate in batch mode: ${rel.table}") + } + case _ => + if (!rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_BY_FILTER)) { + failAnalysis(s"Table does not support overwrite expression ${expr.sql} " + + s"in batch mode: ${rel.table}") + } + } + + case _ => // OK + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 51606abdb563a..607f2fa0f82c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} -import org.apache.spark.sql.sources.v2.SupportsBatchWrite +import org.apache.spark.sql.sources.v2.SupportsWrite import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -53,7 +53,7 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) * Rows in the output data set are appended. */ case class AppendDataExec( - table: SupportsBatchWrite, + table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -80,7 +80,7 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( - table: SupportsBatchWrite, + table: SupportsWrite, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -101,7 +101,7 @@ case class OverwriteByExpressionExec( builder.overwrite(deleteWhere).buildForBatch() case _ => - throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + throw new SparkException(s"Table does not support overwrite by expression: $table") } doWrite(batchWrite) @@ -118,7 +118,7 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( - table: SupportsBatchWrite, + table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -153,7 +153,7 @@ case class WriteToDataSourceV2Exec( * Helper for physical plans that build batch writes. */ trait BatchWriteHelper { - def table: SupportsBatchWrite + def table: SupportsWrite def query: SparkPlan def writeOptions: CaseInsensitiveStringMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index dbdfcf8085604..884b92ae9421c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.util +import java.util.Collections + import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} @@ -63,6 +66,8 @@ object ConsoleTable extends Table with SupportsStreamingWrite { override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index df7990c6a652e..bfa9c09985503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util +import java.util.Collections import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy @@ -97,6 +99,8 @@ class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table override def schema(): StructType = stream.fullSchema() + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MemoryStreamScanBuilder(stream) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 44516bbb2a5a1..807e0b12c6278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util +import java.util.Collections + import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -45,6 +48,8 @@ case class ForeachWriterTable[T]( override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3d8a90e99b85a..08aea75de2b5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util +import java.util.Collections + import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream @@ -84,6 +87,8 @@ class RateStreamTable( override def schema(): StructType = RateStreamProvider.SCHEMA + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = RateStreamProvider.SCHEMA diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index 0adbf1d9b3689..c0292acdf1044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.streaming.sources import java.text.SimpleDateFormat -import java.util.Locale +import java.util +import java.util.{Collections, Locale} import scala.util.{Failure, Success, Try} @@ -78,6 +79,8 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest } } + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = schema() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 22adceba930fb..8eb5de0f640a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util +import java.util.Collections import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -31,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.SupportsStreamingWrite +import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, TableCapability} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -47,6 +49,8 @@ class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Loggi override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var needTruncate: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a605dc640dc96..f05aa5113e03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.V2WriteSupportCheck import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -172,6 +173,7 @@ abstract class BaseSessionStateBuilder( PreWriteCheck +: PreReadCheck +: HiveOnlyCheck +: + V2WriteSupportCheck +: customCheckRules } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java index cb5954d5a6211..9b0eb610a206f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java @@ -18,15 +18,23 @@ package test.org.apache.spark.sql.sources.v2; import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.SupportsRead; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableCapability; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -abstract class JavaSimpleBatchTable implements Table, SupportsBatchRead { +abstract class JavaSimpleBatchTable implements Table, SupportsRead { + private static final Set CAPABILITIES = new HashSet<>(Arrays.asList( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.TRUNCATE)); @Override public StructType schema() { @@ -37,6 +45,11 @@ public StructType schema() { public String name() { return this.getClass().toString(); } + + @Override + public Set capabilities() { + return CAPABILITIES; + } } abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 705559d099bec..587cfa9bd6647 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.sources.v2 import java.io.File +import java.util import java.util.OptionalLong +import scala.collection.JavaConverters._ + import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -30,6 +33,7 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext @@ -411,11 +415,13 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } -abstract class SimpleBatchTable extends Table with SupportsBatchRead { +abstract class SimpleBatchTable extends Table with SupportsRead { override def schema(): StructType = new StructType().add("i", "int").add("j", "int") override def name(): String = this.getClass.toString + + override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava } abstract class SimpleScanBuilder extends ScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index f9f9db35ac2dd..e019dbfe3f512 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.sources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -38,7 +40,7 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { } } -class DummyReadOnlyFileTable extends Table with SupportsBatchRead { +class DummyReadOnlyFileTable extends Table with SupportsRead { override def name(): String = "dummy" override def schema(): StructType = StructType(Nil) @@ -46,6 +48,9 @@ class DummyReadOnlyFileTable extends Table with SupportsBatchRead { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { throw new AnalysisException("Dummy file reader") } + + override def capabilities(): java.util.Set[TableCapability] = + Set(TableCapability.BATCH_READ).asJava } class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -59,13 +64,16 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { } } -class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite { +class DummyWriteOnlyFileTable extends Table with SupportsWrite { override def name(): String = "dummy" override def schema(): StructType = StructType(Nil) override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = throw new AnalysisException("Dummy file writer") + + override def capabilities(): java.util.Set[TableCapability] = + Set(TableCapability.BATCH_WRITE).asJava } class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 160354520e432..edebb0b62b29c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} +import java.util import scala.collection.JavaConverters._ @@ -27,6 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType @@ -142,7 +144,7 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { } class MyTable(options: CaseInsensitiveStringMap) - extends SimpleBatchTable with SupportsBatchWrite { + extends SimpleBatchTable with SupportsWrite { private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration @@ -156,6 +158,9 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new MyWriteBuilder(path) } + + override def capabilities(): util.Set[TableCapability] = + Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala new file mode 100644 index 0000000000000..1d76ee34a0e0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, V2WriteSupportCheck} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V2WriteSupportCheckSuite extends AnalysisTest { + + test("AppendData: check missing capabilities") { + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(), CaseInsensitiveStringMap.empty), TestRelation) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support append in batch mode")) + } + + test("AppendData: check correct capabilities") { + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty), + TestRelation) + + V2WriteSupportCheck.apply(plan) + } + + test("Truncate: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(TRUNCATE), + CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + Literal(true)) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support truncate in batch mode")) + } + } + + test("Truncate: check correct capabilities") { + Seq(CapabilityTable(BATCH_WRITE, TRUNCATE), + CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + Literal(true)) + + V2WriteSupportCheck.apply(plan) + } + } + + test("OverwriteByExpression: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains( + "does not support overwrite expression (`x` = 5) in batch mode")) + } + } + + test("OverwriteByExpression: check correct capabilities") { + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER) + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) + + V2WriteSupportCheck.apply(plan) + } + + test("OverwritePartitionsDynamic: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(OVERWRITE_DYNAMIC)).foreach { table => + + val plan = OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + + val exc = intercept[AnalysisException] { + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support dynamic overwrite in batch mode")) + } + } + + test("OverwritePartitionsDynamic: check correct capabilities") { + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_DYNAMIC) + val plan = OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + + V2WriteSupportCheck.apply(plan) + } +} + +private object V2WriteSupportCheckSuite { + val schema: StructType = new StructType().add("id", LongType).add("data", StringType) +} + +private case object TestRelation extends LeafNode with NamedRelation { + override def name: String = "source_relation" + override def output: Seq[AttributeReference] = V2WriteSupportCheckSuite.schema.toAttributes +} + +private case class CapabilityTable(_capabilities: TableCapability*) extends Table { + override def name(): String = "capability_test_table" + override def schema(): StructType = V2WriteSupportCheckSuite.schema + override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 13bb686fbd3b9..f022edea275e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming.sources +import java.util +import java.util.Collections + import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} @@ -77,18 +80,21 @@ class FakeWriteBuilder extends WriteBuilder with StreamingWrite { trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder } trait FakeContinuousReadTable extends Table with SupportsContinuousRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder } trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new FakeWriteBuilder } @@ -137,6 +143,7 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() } } } @@ -164,6 +171,7 @@ class FakeNoWrite extends DataSourceRegister with TableProvider { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 132b0e4db0d71..68f4b2ddbac0b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.V2WriteSupportCheck import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -86,6 +87,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: PreReadCheck +: + V2WriteSupportCheck +: customCheckRules }