diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
new file mode 100644
index 000000000000..e48fd8adaef0
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.types.{DataType, IntegerType}
+
+/**
+ * Base class for expressions that are converted to v2 partition transforms.
+ *
+ * Subclasses represent abstract transform functions with concrete implementations that are
+ * determined by data source implementations. Because the concrete implementation is not known,
+ * these expressions are [[Unevaluable]].
+ *
+ * These expressions are used to pass transformations from the DataFrame API:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").partitionedBy($"category", days($"timestamp")).create()
+ * }}}
+ */
+abstract class PartitionTransformExpression extends Expression with Unevaluable {
+ override def nullable: Boolean = true
+}
+
+/**
+ * Expression for the v2 partition transform years.
+ */
+case class Years(child: Expression) extends PartitionTransformExpression {
+ override def dataType: DataType = IntegerType
+ override def children: Seq[Expression] = Seq(child)
+}
+
+/**
+ * Expression for the v2 partition transform months.
+ */
+case class Months(child: Expression) extends PartitionTransformExpression {
+ override def dataType: DataType = IntegerType
+ override def children: Seq[Expression] = Seq(child)
+}
+
+/**
+ * Expression for the v2 partition transform days.
+ */
+case class Days(child: Expression) extends PartitionTransformExpression {
+ override def dataType: DataType = IntegerType
+ override def children: Seq[Expression] = Seq(child)
+}
+
+/**
+ * Expression for the v2 partition transform hours.
+ */
+case class Hours(child: Expression) extends PartitionTransformExpression {
+ override def dataType: DataType = IntegerType
+ override def children: Seq[Expression] = Seq(child)
+}
+
+/**
+ * Expression for the v2 partition transform bucket.
+ */
+case class Bucket(numBuckets: Literal, child: Expression) extends PartitionTransformExpression {
+ override def dataType: DataType = IntegerType
+ override def children: Seq[Expression] = Seq(numBuckets, child)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dcb6af6829c3..0cb59411bd95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2506,7 +2506,7 @@ class Analyzer(
*/
object ResolveOutputRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
- case append @ AppendData(table, query, isByName)
+ case append @ AppendData(table, query, _, isByName)
if table.resolved && query.resolved && !append.outputResolved =>
val projection =
TableOutputResolver.resolveOutputColumns(
@@ -2518,7 +2518,7 @@ class Analyzer(
append
}
- case overwrite @ OverwriteByExpression(table, _, query, isByName)
+ case overwrite @ OverwriteByExpression(table, _, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
val projection =
TableOutputResolver.resolveOutputColumns(
@@ -2530,7 +2530,7 @@ class Analyzer(
overwrite
}
- case overwrite @ OverwritePartitionsDynamic(table, query, isByName)
+ case overwrite @ OverwritePartitionsDynamic(table, query, _, isByName)
if table.resolved && query.resolved && !overwrite.outputResolved =>
val projection =
TableOutputResolver.resolveOutputColumns(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 0be61cf14704..6e1825e4997c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -489,7 +489,7 @@ case class ReplaceTableAsSelect(
override def tableSchema: StructType = query.schema
override def children: Seq[LogicalPlan] = Seq(query)
- override lazy val resolved: Boolean = {
+ override lazy val resolved: Boolean = childrenResolved && {
// the table schema is created from the query schema, so the only resolution needed is to check
// that the columns referenced by the table's partitioning exist in the query schema
val references = partitioning.flatMap(_.references).toSet
@@ -507,15 +507,22 @@ case class ReplaceTableAsSelect(
case class AppendData(
table: NamedRelation,
query: LogicalPlan,
+ writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand
object AppendData {
- def byName(table: NamedRelation, df: LogicalPlan): AppendData = {
- new AppendData(table, df, isByName = true)
+ def byName(
+ table: NamedRelation,
+ df: LogicalPlan,
+ writeOptions: Map[String, String] = Map.empty): AppendData = {
+ new AppendData(table, df, writeOptions, isByName = true)
}
- def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = {
- new AppendData(table, query, isByName = false)
+ def byPosition(
+ table: NamedRelation,
+ query: LogicalPlan,
+ writeOptions: Map[String, String] = Map.empty): AppendData = {
+ new AppendData(table, query, writeOptions, isByName = false)
}
}
@@ -526,19 +533,26 @@ case class OverwriteByExpression(
table: NamedRelation,
deleteExpr: Expression,
query: LogicalPlan,
+ writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand {
override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved
}
object OverwriteByExpression {
def byName(
- table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = {
- OverwriteByExpression(table, deleteExpr, df, isByName = true)
+ table: NamedRelation,
+ df: LogicalPlan,
+ deleteExpr: Expression,
+ writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
+ OverwriteByExpression(table, deleteExpr, df, writeOptions, isByName = true)
}
def byPosition(
- table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = {
- OverwriteByExpression(table, deleteExpr, query, isByName = false)
+ table: NamedRelation,
+ query: LogicalPlan,
+ deleteExpr: Expression,
+ writeOptions: Map[String, String] = Map.empty): OverwriteByExpression = {
+ OverwriteByExpression(table, deleteExpr, query, writeOptions, isByName = false)
}
}
@@ -548,15 +562,22 @@ object OverwriteByExpression {
case class OverwritePartitionsDynamic(
table: NamedRelation,
query: LogicalPlan,
+ writeOptions: Map[String, String],
isByName: Boolean) extends V2WriteCommand
object OverwritePartitionsDynamic {
- def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = {
- OverwritePartitionsDynamic(table, df, isByName = true)
+ def byName(
+ table: NamedRelation,
+ df: LogicalPlan,
+ writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
+ OverwritePartitionsDynamic(table, df, writeOptions, isByName = true)
}
- def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = {
- OverwritePartitionsDynamic(table, query, isByName = false)
+ def byPosition(
+ table: NamedRelation,
+ query: LogicalPlan,
+ writeOptions: Map[String, String] = Map.empty): OverwritePartitionsDynamic = {
+ OverwritePartitionsDynamic(table, query, writeOptions, isByName = false)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
index 2d59c42ee868..ab33e8e5ceaf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.sources.v2.{SupportsDelete, SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
object DataSourceV2Implicits {
implicit class TableHelper(table: Table) {
@@ -53,4 +56,10 @@ object DataSourceV2Implicits {
def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports)
}
+
+ implicit class OptionsHelper(options: Map[String, String]) {
+ def asOptions: CaseInsensitiveStringMap = {
+ new CaseInsensitiveStringMap(options.asJava)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
index 0dea1e3a68dc..2dc4f8b680f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala
@@ -41,8 +41,11 @@ class InMemoryTable(
override val properties: util.Map[String, String])
extends Table with SupportsRead with SupportsWrite with SupportsDelete {
+ private val allowUnsupportedTransforms =
+ properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
+
partitioning.foreach { t =>
- if (!t.isInstanceOf[IdentityTransform]) {
+ if (!t.isInstanceOf[IdentityTransform] && !allowUnsupportedTransforms) {
throw new IllegalArgumentException(s"Transform $t must be IdentityTransform")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index d0a1d41c70dc..13d38d4ae1e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -271,13 +271,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
modeForDSV2 match {
case SaveMode.Append =>
runCommand(df.sparkSession, "save") {
- AppendData.byName(relation, df.logicalPlan)
+ AppendData.byName(relation, df.logicalPlan, extraOptions.toMap)
}
case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
// truncate the table
runCommand(df.sparkSession, "save") {
- OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true))
+ OverwriteByExpression.byName(
+ relation, df.logicalPlan, Literal(true), extraOptions.toMap)
}
case other =>
@@ -383,7 +384,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val command = modeForDSV2 match {
case SaveMode.Append =>
- AppendData.byPosition(table, df.logicalPlan)
+ AppendData.byPosition(table, df.logicalPlan, extraOptions.toMap)
case SaveMode.Overwrite =>
val conf = df.sparkSession.sessionState.conf
@@ -391,9 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC
if (dynamicPartitionOverwrite) {
- OverwritePartitionsDynamic.byPosition(table, df.logicalPlan)
+ OverwritePartitionsDynamic.byPosition(table, df.logicalPlan, extraOptions.toMap)
} else {
- OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true))
+ OverwriteByExpression.byPosition(table, df.logicalPlan, Literal(true), extraOptions.toMap)
}
case other =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
new file mode 100644
index 000000000000..57b212e6b9fe
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalog.v2.expressions.{LogicalExpressions, Transform}
+import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect}
+import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
+ extends CreateTableWriter[T] {
+
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+ import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._
+ import df.sparkSession.sessionState.analyzer.CatalogObjectIdentifier
+
+ private val df: DataFrame = ds.toDF()
+
+ private val sparkSession = ds.sparkSession
+
+ private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
+
+ private val (catalog, identifier) = {
+ val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName
+ val catalog = maybeCatalog.orElse(sparkSession.sessionState.analyzer.sessionCatalog)
+ .getOrElse(throw new AnalysisException(
+ s"No catalog specified for table ${identifier.quoted} and no default v2 catalog is set"))
+ .asTableCatalog
+
+ (catalog, identifier)
+ }
+
+ private val logicalPlan = df.queryExecution.logical
+
+ private var provider: Option[String] = None
+
+ private val options = new mutable.HashMap[String, String]()
+
+ private val properties = new mutable.HashMap[String, String]()
+
+ private var partitioning: Option[Seq[Transform]] = None
+
+ override def using(provider: String): CreateTableWriter[T] = {
+ this.provider = Some(provider)
+ this
+ }
+
+ override def option(key: String, value: String): DataFrameWriterV2[T] = {
+ this.options.put(key, value)
+ this
+ }
+
+ override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = {
+ options.foreach {
+ case (key, value) =>
+ this.options.put(key, value)
+ }
+ this
+ }
+
+ override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = {
+ this.options(options.asScala)
+ this
+ }
+
+ override def tableProperty(property: String, value: String): DataFrameWriterV2[T] = {
+ this.properties.put(property, value)
+ this
+ }
+
+ @scala.annotation.varargs
+ override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = {
+ val asTransforms = (column +: columns).map(_.expr).map {
+ case Years(attr: Attribute) =>
+ LogicalExpressions.years(attr.name)
+ case Months(attr: Attribute) =>
+ LogicalExpressions.months(attr.name)
+ case Days(attr: Attribute) =>
+ LogicalExpressions.days(attr.name)
+ case Hours(attr: Attribute) =>
+ LogicalExpressions.hours(attr.name)
+ case Bucket(Literal(numBuckets: Int, IntegerType), attr: Attribute) =>
+ LogicalExpressions.bucket(numBuckets, attr.name)
+ case attr: Attribute =>
+ LogicalExpressions.identity(attr.name)
+ case expr =>
+ throw new AnalysisException(s"Invalid partition transformation: ${expr.sql}")
+ }
+
+ this.partitioning = Some(asTransforms)
+ this
+ }
+
+ override def create(): Unit = {
+ // create and replace could alternatively create ParsedPlan statements, like
+ // `CreateTableFromDataFrameStatement(UnresolvedRelation(tableName), ...)`, to keep the catalog
+ // resolution logic in the analyzer.
+ runCommand("create") {
+ CreateTableAsSelect(
+ catalog,
+ identifier,
+ partitioning.getOrElse(Seq.empty),
+ logicalPlan,
+ properties = provider.map(p => properties + ("provider" -> p)).getOrElse(properties).toMap,
+ writeOptions = options.toMap,
+ ignoreIfExists = false)
+ }
+ }
+
+ override def replace(): Unit = {
+ internalReplace(orCreate = false)
+ }
+
+ override def createOrReplace(): Unit = {
+ internalReplace(orCreate = true)
+ }
+
+
+ /**
+ * Append the contents of the data frame to the output table.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+ * validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist
+ */
+ @throws(classOf[NoSuchTableException])
+ def append(): Unit = {
+ val append = loadTable(catalog, identifier) match {
+ case Some(t) =>
+ AppendData.byName(DataSourceV2Relation.create(t), logicalPlan, options.toMap)
+ case _ =>
+ throw new NoSuchTableException(identifier)
+ }
+
+ runCommand("append")(append)
+ }
+
+ /**
+ * Overwrite rows matching the given filter condition with the contents of the data frame in
+ * the output table.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]].
+ * The data frame will be validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist
+ */
+ @throws(classOf[NoSuchTableException])
+ def overwrite(condition: Column): Unit = {
+ val overwrite = loadTable(catalog, identifier) match {
+ case Some(t) =>
+ OverwriteByExpression.byName(
+ DataSourceV2Relation.create(t), logicalPlan, condition.expr, options.toMap)
+ case _ =>
+ throw new NoSuchTableException(identifier)
+ }
+
+ runCommand("overwrite")(overwrite)
+ }
+
+ /**
+ * Overwrite all partition for which the data frame contains at least one row with the contents
+ * of the data frame in the output table.
+ *
+ * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces
+ * partitions dynamically depending on the contents of the data frame.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+ * validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist
+ */
+ @throws(classOf[NoSuchTableException])
+ def overwritePartitions(): Unit = {
+ val dynamicOverwrite = loadTable(catalog, identifier) match {
+ case Some(t) =>
+ OverwritePartitionsDynamic.byName(
+ DataSourceV2Relation.create(t), logicalPlan, options.toMap)
+ case _ =>
+ throw new NoSuchTableException(identifier)
+ }
+
+ runCommand("overwritePartitions")(dynamicOverwrite)
+ }
+
+ /**
+ * Wrap an action to track the QueryExecution and time cost, then report to the user-registered
+ * callback functions.
+ */
+ private def runCommand(name: String)(command: LogicalPlan): Unit = {
+ val qe = sparkSession.sessionState.executePlan(command)
+ // call `QueryExecution.toRDD` to trigger the execution of commands.
+ SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd)
+ }
+
+ private def internalReplace(orCreate: Boolean): Unit = {
+ runCommand("replace") {
+ ReplaceTableAsSelect(
+ catalog,
+ identifier,
+ partitioning.getOrElse(Seq.empty),
+ logicalPlan,
+ properties = provider.map(p => properties + ("provider" -> p)).getOrElse(properties).toMap,
+ writeOptions = options.toMap,
+ orCreate = orCreate)
+ }
+ }
+}
+
+/**
+ * Configuration methods common to create/replace operations and insert/overwrite operations.
+ * @tparam R builder type to return
+ */
+trait WriteConfigMethods[R] {
+ /**
+ * Add a write option.
+ *
+ * @since 3.0.0
+ */
+ def option(key: String, value: String): R
+
+ /**
+ * Add a boolean output option.
+ *
+ * @since 3.0.0
+ */
+ def option(key: String, value: Boolean): R = option(key, value.toString)
+
+ /**
+ * Add a long output option.
+ *
+ * @since 3.0.0
+ */
+ def option(key: String, value: Long): R = option(key, value.toString)
+
+ /**
+ * Add a double output option.
+ *
+ * @since 3.0.0
+ */
+ def option(key: String, value: Double): R = option(key, value.toString)
+
+ /**
+ * Add write options from a Scala Map.
+ *
+ * @since 3.0.0
+ */
+ def options(options: scala.collection.Map[String, String]): R
+
+ /**
+ * Add write options from a Java Map.
+ *
+ * @since 3.0.0
+ */
+ def options(options: java.util.Map[String, String]): R
+}
+
+/**
+ * Trait to restrict calls to create and replace operations.
+ */
+trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
+ /**
+ * Create a new table from the contents of the data frame.
+ *
+ * The new table's schema, partition layout, properties, and other configuration will be
+ * based on the configuration set on this writer.
+ *
+ * If the output table exists, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]].
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+ * If the table already exists
+ */
+ @throws(classOf[TableAlreadyExistsException])
+ def create(): Unit
+
+ /**
+ * Replace an existing table with the contents of the data frame.
+ *
+ * The existing table's schema, partition layout, properties, and other configuration will be
+ * replaced with the contents of the data frame and the configuration set on this writer.
+ *
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]].
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
+ * If the table already exists
+ */
+ @throws(classOf[CannotReplaceMissingTableException])
+ def replace(): Unit
+
+ /**
+ * Create a new table or replace an existing table with the contents of the data frame.
+ *
+ * The output table's schema, partition layout, properties, and other configuration will be based
+ * on the contents of the data frame and the configuration set on this writer. If the table
+ * exists, its configuration and data will be replaced.
+ */
+ def createOrReplace(): Unit
+
+ /**
+ * Partition the output table created by `create`, `createOrReplace`, or `replace` using
+ * the given columns or transforms.
+ *
+ * When specified, the table data will be stored by these values for efficient reads.
+ *
+ * For example, when a table is partitioned by day, it may be stored in a directory layout like:
+ *
+ * - `table/day=2019-06-01/`
+ * - `table/day=2019-06-02/`
+ *
+ *
+ * Partitioning is one of the most widely used techniques to optimize physical data layout.
+ * It provides a coarse-grained index for skipping unnecessary data reads when queries have
+ * predicates on the partitioned columns. In order for partitioning to work well, the number
+ * of distinct values in each column should typically be less than tens of thousands.
+ *
+ * @since 3.0.0
+ */
+ def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]
+
+ /**
+ * Specifies a provider for the underlying output data source. Spark's default catalog supports
+ * "parquet", "json", etc.
+ *
+ * @since 3.0.0
+ */
+ def using(provider: String): CreateTableWriter[T]
+
+ /**
+ * Add a table property.
+ */
+ def tableProperty(property: String, value: String): CreateTableWriter[T]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 7c25397e32be..23360df04594 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3190,6 +3190,34 @@ class Dataset[T] private[sql](
new DataFrameWriter[T](this)
}
+ /**
+ * Create a write configuration builder for v2 sources.
+ *
+ * This builder is used to configure and execute write operations. For example, to append to an
+ * existing table, run:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").append()
+ * }}}
+ *
+ * This can also be used to create or replace existing tables:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
+ * }}}
+ *
+ * @group basic
+ * @since 3.0.0
+ */
+ def writeTo(table: String): DataFrameWriterV2[T] = {
+ // TODO: streaming could be adapted to use this interface
+ if (isStreaming) {
+ logicalPlan.failAnalysis(
+ "'writeTo' can not be called on streaming Dataset/DataFrame")
+ }
+ new DataFrameWriterV2[T](table, this)
+ }
+
/**
* Interface for saving the content of the streaming Dataset out into external storage.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index a934c095eee1..b5a573c170a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.datasources.v2
-import java.util.UUID
-
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -34,7 +32,6 @@ import org.apache.spark.sql.sources
import org.apache.spark.sql.sources.v2.TableCapability
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
-import org.apache.spark.sql.sources.v2.writer.V1WriteBuilder
import org.apache.spark.sql.util.CaseInsensitiveStringMap
object DataSourceV2Strategy extends Strategy with PredicateHelper {
@@ -212,15 +209,15 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
orCreate = orCreate) :: Nil
}
- case AppendData(r: DataSourceV2Relation, query, _) =>
+ case AppendData(r: DataSourceV2Relation, query, writeOptions, _) =>
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
- AppendDataExecV1(v1, r.options, query) :: Nil
+ AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil
case v2 =>
- AppendDataExec(v2, r.options, planLater(query)) :: Nil
+ AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil
}
- case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) =>
+ case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) =>
// fail if any filter cannot be converted. correctness depends on removing all matching data.
val filters = splitConjunctivePredicates(deleteExpr).map {
filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse(
@@ -228,13 +225,14 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
}.toArray
r.table.asWritable match {
case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) =>
- OverwriteByExpressionExecV1(v1, filters, r.options, query) :: Nil
+ OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil
case v2 =>
- OverwriteByExpressionExec(v2, filters, r.options, planLater(query)) :: Nil
+ OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil
}
- case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) =>
- OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil
+ case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) =>
+ OverwritePartitionsDynamicExec(
+ r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil
case DeleteFromTable(r: DataSourceV2Relation, condition) =>
// fail if any filter cannot be converted. correctness depends on removing all matching data.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala
index 5648d5439ba5..5a093ba5d5d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala
@@ -29,14 +29,14 @@ object V2WriteSupportCheck extends (LogicalPlan => Unit) {
def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)
override def apply(plan: LogicalPlan): Unit = plan foreach {
- case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) =>
+ case AppendData(rel: DataSourceV2Relation, _, _, _) if !rel.table.supports(BATCH_WRITE) =>
failAnalysis(s"Table does not support append in batch mode: ${rel.table}")
- case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _)
+ case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _, _)
if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) =>
failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}")
- case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) =>
+ case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _, _) =>
expr match {
case Literal(true, BooleanType) =>
if (!rel.table.supports(BATCH_WRITE) ||
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6b8127bab1cb..0ece755023ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -69,6 +69,7 @@ import org.apache.spark.util.Utils
* @groupname window_funcs Window functions
* @groupname string_funcs String functions
* @groupname collection_funcs Collection functions
+ * @groupname partition_transforms Partition transform functions
* @groupname Ungrouped Support functions for DataFrames
* @since 1.3.0
*/
@@ -3942,6 +3943,69 @@ object functions {
*/
def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String].asJava)
+ // turn off style check that object names must start with a capital letter
+ // scalastyle:off
+ object partitioning {
+ // scalastyle:on
+
+ /**
+ * A transform for timestamps and dates to partition data into years.
+ *
+ * @group partition_transforms
+ * @since 3.0.0
+ */
+ def years(e: Column): Column = withExpr { Years(e.expr) }
+
+ /**
+ * A transform for timestamps and dates to partition data into months.
+ *
+ * @group partition_transforms
+ * @since 3.0.0
+ */
+ def months(e: Column): Column = withExpr { Months(e.expr) }
+
+ /**
+ * A transform for timestamps and dates to partition data into days.
+ *
+ * @group partition_transforms
+ * @since 3.0.0
+ */
+ def days(e: Column): Column = withExpr { Days(e.expr) }
+
+ /**
+ * A transform for timestamps to partition data into hours.
+ *
+ * @group partition_transforms
+ * @since 3.0.0
+ */
+ def hours(e: Column): Column = withExpr { Hours(e.expr) }
+
+ /**
+ * A transform for any type that partitions by a hash of the input column.
+ *
+ * @group partition_transforms
+ * @since 3.0.0
+ */
+ def bucket(numBuckets: Column, e: Column): Column = withExpr {
+ numBuckets.expr match {
+ case lit @ Literal(_, IntegerType) =>
+ Bucket(lit, e.expr)
+ case _ =>
+ throw new AnalysisException(s"Invalid number of buckets: bucket($numBuckets, $e)")
+ }
+ }
+
+ /**
+ * A transform for any type that partitions by a hash of the input column.
+ *
+ * @group partition_transforms
+ * @since 3.0.0
+ */
+ def bucket(numBuckets: Int, e: Column): Column = withExpr {
+ Bucket(Literal(numBuckets), e.expr)
+ }
+ }
+
// scalastyle:off line.size.limit
// scalastyle:off parameter.number
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala
new file mode 100644
index 000000000000..810a192f331d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataFrameWriterV2Suite.scala
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalog.v2.{ Identifier, TableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
+import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.InMemoryTableCatalog
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
+
+class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter {
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+ import org.apache.spark.sql.functions._
+ import testImplicits._
+
+ private def catalog(name: String): TableCatalog = {
+ spark.sessionState.catalogManager.catalog(name).asTableCatalog
+ }
+
+ before {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+ val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data")
+ df.createOrReplaceTempView("source")
+ val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data")
+ df2.createOrReplaceTempView("source2")
+ }
+
+ after {
+ spark.sessionState.catalogManager.reset()
+ spark.sessionState.conf.clear()
+ }
+
+ test("Append: basic append") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ spark.table("source").writeTo("testcat.table_name").append()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.table("source2").writeTo("testcat.table_name").append()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+ }
+
+ test("Append: by name not position") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ val exc = intercept[AnalysisException] {
+ spark.table("source").withColumnRenamed("data", "d").writeTo("testcat.table_name").append()
+ }
+
+ assert(exc.getMessage.contains("Cannot find data for output column"))
+ assert(exc.getMessage.contains("'data'"))
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq())
+ }
+
+ test("Append: fail if table does not exist") {
+ val exc = intercept[NoSuchTableException] {
+ spark.table("source").writeTo("testcat.table_name").append()
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+ }
+
+ test("Overwrite: overwrite by expression: true") {
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ spark.table("source").writeTo("testcat.table_name").append()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.table("source2").writeTo("testcat.table_name").overwrite(lit(true))
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+ }
+
+ test("Overwrite: overwrite by expression: id = 3") {
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ spark.table("source").writeTo("testcat.table_name").append()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.table("source2").writeTo("testcat.table_name").overwrite($"id" === 3)
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+ }
+
+ test("Overwrite: by name not position") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ val exc = intercept[AnalysisException] {
+ spark.table("source").withColumnRenamed("data", "d")
+ .writeTo("testcat.table_name").overwrite(lit(true))
+ }
+
+ assert(exc.getMessage.contains("Cannot find data for output column"))
+ assert(exc.getMessage.contains("'data'"))
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq())
+ }
+
+ test("Overwrite: fail if table does not exist") {
+ val exc = intercept[NoSuchTableException] {
+ spark.table("source").writeTo("testcat.table_name").overwrite(lit(true))
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+ }
+
+ test("OverwritePartitions: overwrite conflicting partitions") {
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ spark.table("source").writeTo("testcat.table_name").append()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.table("source2").withColumn("id", $"id" - 2)
+ .writeTo("testcat.table_name").overwritePartitions()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "d"), Row(3L, "e"), Row(4L, "f")))
+ }
+
+ test("OverwritePartitions: overwrite all rows if not partitioned") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ spark.table("source").writeTo("testcat.table_name").append()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ spark.table("source2").writeTo("testcat.table_name").overwritePartitions()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+ }
+
+ test("OverwritePartitions: by name not position") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ checkAnswer(spark.table("testcat.table_name"), Seq.empty)
+
+ val exc = intercept[AnalysisException] {
+ spark.table("source").withColumnRenamed("data", "d")
+ .writeTo("testcat.table_name").overwritePartitions()
+ }
+
+ assert(exc.getMessage.contains("Cannot find data for output column"))
+ assert(exc.getMessage.contains("'data'"))
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq())
+ }
+
+ test("OverwritePartitions: fail if table does not exist") {
+ val exc = intercept[NoSuchTableException] {
+ spark.table("source").writeTo("testcat.table_name").overwritePartitions()
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+ }
+
+ test("Create: basic behavior") {
+ spark.table("source").writeTo("testcat.table_name").create()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning.isEmpty)
+ assert(table.properties.isEmpty)
+ }
+
+ test("Create: with using") {
+ spark.table("source").writeTo("testcat.table_name").using("foo").create()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning.isEmpty)
+ assert(table.properties === Map("provider" -> "foo").asJava)
+ }
+
+ test("Create: with property") {
+ spark.table("source").writeTo("testcat.table_name").tableProperty("prop", "value").create()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning.isEmpty)
+ assert(table.properties === Map("prop" -> "value").asJava)
+ }
+
+ test("Create: identity partitioned table") {
+ spark.table("source").writeTo("testcat.table_name").partitionedBy($"id").create()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning === Seq(IdentityTransform(FieldReference("id"))))
+ assert(table.properties.isEmpty)
+ }
+
+ test("Create: partitioned by years(ts)") {
+ spark.table("source")
+ .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
+ .writeTo("testcat.table_name")
+ .tableProperty("allow-unsupported-transforms", "true")
+ .partitionedBy(partitioning.years($"ts"))
+ .create()
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.partitioning === Seq(YearsTransform(FieldReference("ts"))))
+ }
+
+ test("Create: partitioned by months(ts)") {
+ spark.table("source")
+ .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
+ .writeTo("testcat.table_name")
+ .tableProperty("allow-unsupported-transforms", "true")
+ .partitionedBy(partitioning.months($"ts"))
+ .create()
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.partitioning === Seq(MonthsTransform(FieldReference("ts"))))
+ }
+
+ test("Create: partitioned by days(ts)") {
+ spark.table("source")
+ .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
+ .writeTo("testcat.table_name")
+ .tableProperty("allow-unsupported-transforms", "true")
+ .partitionedBy(partitioning.days($"ts"))
+ .create()
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.partitioning === Seq(DaysTransform(FieldReference("ts"))))
+ }
+
+ test("Create: partitioned by hours(ts)") {
+ spark.table("source")
+ .withColumn("ts", lit("2019-06-01 10:00:00.000000").cast("timestamp"))
+ .writeTo("testcat.table_name")
+ .tableProperty("allow-unsupported-transforms", "true")
+ .partitionedBy(partitioning.hours($"ts"))
+ .create()
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.partitioning === Seq(HoursTransform(FieldReference("ts"))))
+ }
+
+ test("Create: partitioned by bucket(4, id)") {
+ spark.table("source")
+ .writeTo("testcat.table_name")
+ .tableProperty("allow-unsupported-transforms", "true")
+ .partitionedBy(partitioning.bucket(4, $"id"))
+ .create()
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name === "testcat.table_name")
+ assert(table.partitioning ===
+ Seq(BucketTransform(LiteralValue(4, IntegerType), Seq(FieldReference("id")))))
+ }
+
+ test("Create: fail if table already exists") {
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+ val exc = intercept[TableAlreadyExistsException] {
+ spark.table("source").writeTo("testcat.table_name").create()
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // table should not have been changed
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning === Seq(IdentityTransform(FieldReference("id"))))
+ assert(table.properties === Map("provider" -> "foo").asJava)
+ }
+
+ test("Replace: basic behavior") {
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+ spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source")
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the initial table
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning === Seq(IdentityTransform(FieldReference("id"))))
+ assert(table.properties === Map("provider" -> "foo").asJava)
+
+ spark.table("source2")
+ .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd"))
+ .writeTo("testcat.table_name").replace()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even")))
+
+ val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the replacement table
+ assert(replaced.name === "testcat.table_name")
+ assert(replaced.schema === new StructType()
+ .add("id", LongType)
+ .add("data", StringType)
+ .add("even_or_odd", StringType))
+ assert(replaced.partitioning.isEmpty)
+ assert(replaced.properties.isEmpty)
+ }
+
+ test("Replace: partitioned table") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+ spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source")
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the initial table
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning.isEmpty)
+ assert(table.properties === Map("provider" -> "foo").asJava)
+
+ spark.table("source2")
+ .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd"))
+ .writeTo("testcat.table_name").partitionedBy($"id").replace()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even")))
+
+ val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the replacement table
+ assert(replaced.name === "testcat.table_name")
+ assert(replaced.schema === new StructType()
+ .add("id", LongType)
+ .add("data", StringType)
+ .add("even_or_odd", StringType))
+ assert(replaced.partitioning === Seq(IdentityTransform(FieldReference("id"))))
+ assert(replaced.properties.isEmpty)
+ }
+
+ test("Replace: fail if table does not exist") {
+ val exc = intercept[CannotReplaceMissingTableException] {
+ spark.table("source").writeTo("testcat.table_name").replace()
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+ }
+
+ test("CreateOrReplace: table does not exist") {
+ spark.table("source2").writeTo("testcat.table_name").createOrReplace()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f")))
+
+ val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the replacement table
+ assert(replaced.name === "testcat.table_name")
+ assert(replaced.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(replaced.partitioning.isEmpty)
+ assert(replaced.properties.isEmpty)
+ }
+
+ test("CreateOrReplace: table exists") {
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+ spark.sql("INSERT INTO TABLE testcat.table_name SELECT * FROM source")
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c")))
+
+ val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the initial table
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning === Seq(IdentityTransform(FieldReference("id"))))
+ assert(table.properties === Map("provider" -> "foo").asJava)
+
+ spark.table("source2")
+ .withColumn("even_or_odd", when(($"id" % 2) === 0, "even").otherwise("odd"))
+ .writeTo("testcat.table_name").createOrReplace()
+
+ checkAnswer(
+ spark.table("testcat.table_name"),
+ Seq(Row(4L, "d", "even"), Row(5L, "e", "odd"), Row(6L, "f", "even")))
+
+ val replaced = catalog("testcat").loadTable(Identifier.of(Array(), "table_name"))
+
+ // validate the replacement table
+ assert(replaced.name === "testcat.table_name")
+ assert(replaced.schema === new StructType()
+ .add("id", LongType)
+ .add("data", StringType)
+ .add("even_or_odd", StringType))
+ assert(replaced.partitioning.isEmpty)
+ assert(replaced.properties.isEmpty)
+ }
+}