Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implementation for MsSqlServer
  • Loading branch information
EnricoMi committed Apr 28, 2025
commit e93239e7ba72536acee6494823c121ee7b5110b5
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.tags.DockerTest
* }}}
*/
@DockerTest
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with UpsertTests {
override val db = new MsSQLServerDatabaseOnDocker

override def dataPreparation(conn: Connection): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.jdbc

import java.sql.SQLException
import java.sql.{SQLException, Statement}
import java.util.Locale

import scala.util.control.NonFatal
Expand All @@ -28,13 +28,13 @@ import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.expressions.{Expression, NullOrdering, SortDirection}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.MsSqlServerDialect.{GEOGRAPHY, GEOMETRY}
import org.apache.spark.sql.types._


private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCError {
private case class MsSqlServerDialect() extends JdbcDialect with MergeByTempTable with NoLegacyJDBCError {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver")

Expand Down Expand Up @@ -270,6 +270,23 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
new MsSqlServerSQLQueryBuilder(this, options)

override def supportsLimit: Boolean = true

override def createTempTableName(): String = "##" + super.createTempTableName()

override def createTempTable(
statement: Statement,
tableName: String,
strSchema: String,
options: JdbcOptionsInWrite): Unit = {
// MsSqlServer does not have a temp table specific syntax
super.createTable(statement, tableName, strSchema, options)
}

override def getCreatePrimaryIndex(tableName: String, columns: Array[String]): String = {
val indexColumns = columns.map(quoteIdentifier).mkString(", ")
s"ALTER TABLE $tableName ADD PRIMARY KEY CLUSTERED ($indexColumns)"
}

}

private object MsSqlServerDialect {
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1345,11 +1345,29 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
assert(tmp2 !== tmp1)
}

test("MergeByTempTable: Create temp table name - MsSqlServer") {
val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver")
assert(msSqlServerDialect.isInstanceOf[MergeByTempTable])
val upsert = msSqlServerDialect.asInstanceOf[MergeByTempTable]

val tmp = upsert.createTempTableName()
assert(tmp.startsWith("##"))
}

test("MergeByTempTable: Create primary index") {
val sql = testMergeByTempTableDialect.getCreatePrimaryIndex("test", Array("id", "ts"))
assert(sql === """ALTER TABLE test ADD PRIMARY KEY ("id", "ts")""")
}

test("MergeByTempTable: Create primary index - MsSqlServer") {
val msSqlServerDialect = JdbcDialects.get("jdbc:sqlserver")
assert(msSqlServerDialect.isInstanceOf[MergeByTempTable])
val upsert = msSqlServerDialect.asInstanceOf[MergeByTempTable]

val sql = upsert.getCreatePrimaryIndex("test", Array("id", "ts"))
assert(sql === """ALTER TABLE test ADD PRIMARY KEY CLUSTERED ("id", "ts")""")
}

test("MergeByTempTable: MERGE table into table") {
val columns = Array("id", "ts", "v1", "v2")
.map(testMergeByTempTableDialect.quoteIdentifier)
Expand Down