Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ class Analyzer(
maxIterations: Int)
extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {

private val catalog: SessionCatalog = catalogManager.v1SessionCatalog
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog

override def isView(nameParts: Seq[String]): Boolean = catalog.isView(nameParts)
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)

// Only for tests.
def this(catalog: SessionCatalog, conf: SQLConf) = {
Expand Down Expand Up @@ -225,7 +225,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveHigherOrderFunctions(catalog) ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
Expand Down Expand Up @@ -721,7 +721,7 @@ class Analyzer(
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident))
if catalog.isTemporaryTable(ident) =>
if v1SessionCatalog.isTemporaryTable(ident) =>
resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase))

case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
Expand Down Expand Up @@ -778,7 +778,7 @@ class Analyzer(
val tableIdentWithDb = tableIdentifier.copy(
database = tableIdentifier.database.orElse(defaultDatabase))
try {
catalog.lookupRelation(tableIdentWithDb)
v1SessionCatalog.lookupRelation(tableIdentWithDb)
} catch {
case _: NoSuchTableException | _: NoSuchDatabaseException =>
u
Expand All @@ -792,8 +792,9 @@ class Analyzer(
// Note that we are testing (!db_exists || !table_exists) because the catalog throws
// an exception from tableExists if the database does not exist.
private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = {
table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) &&
(!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))
table.database.isDefined && conf.runSQLonFile && !v1SessionCatalog.isTemporaryTable(table) &&
(!v1SessionCatalog.databaseExists(table.database.get)
|| !v1SessionCatalog.tableExists(table))
}
}

Expand Down Expand Up @@ -1511,13 +1512,14 @@ class Analyzer(
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) =>
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase),
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
}
}
Expand All @@ -1532,7 +1534,7 @@ class Analyzer(

val databaseName = name.database match {
case Some(a) => formatDatabaseName(a)
case None => catalog.getCurrentDatabase
case None => v1SessionCatalog.getCurrentDatabase
}

FunctionIdentifier(funcName, Some(databaseName))
Expand All @@ -1557,7 +1559,7 @@ class Analyzer(
}
case u @ UnresolvedGenerator(name, children) =>
withPosition(u) {
catalog.lookupFunction(name, children) match {
v1SessionCatalog.lookupFunction(name, children) match {
case generator: Generator => generator
case other =>
failAnalysis(s"$name is expected to be a generator. However, " +
Expand All @@ -1566,7 +1568,7 @@ class Analyzer(
}
case u @ UnresolvedFunction(funcId, children, isDistinct) =>
withPosition(u) {
catalog.lookupFunction(funcId, children) match {
v1SessionCatalog.lookupFunction(funcId, children) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
Expand Down Expand Up @@ -2765,17 +2767,17 @@ class Analyzer(
private def lookupV2RelationAndCatalog(
identifier: Seq[String]): Option[(DataSourceV2Relation, CatalogPlugin, Identifier)] =
identifier match {
case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(Some(v2Catalog), ident) =>
CatalogV2Util.loadTable(v2Catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), v2Catalog, ident))
case AsTemporaryViewIdentifier(ti) if v1SessionCatalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case CatalogObjectIdentifier(None, ident) =>
CatalogV2Util.loadTable(catalogManager.v2SessionCatalog, ident) match {
case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(_: V1Table) => None
case Some(table) =>
Some((DataSourceV2Relation.create(table), catalogManager.v2SessionCatalog, ident))
Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,8 @@ class ResolveCatalogs(val catalogManager: CatalogManager)
case ShowTablesStatement(Some(NonSessionCatalog(catalog, nameParts)), pattern) =>
ShowTables(catalog.asTableCatalog, nameParts, pattern)

// TODO (SPARK-29014): we should check if the current catalog is not session catalog here.
case ShowTablesStatement(None, pattern) if defaultCatalog.isDefined =>
ShowTables(defaultCatalog.get.asTableCatalog, catalogManager.currentNamespace, pattern)
case ShowTablesStatement(None, pattern) if !isSessionCatalog(currentCatalog) =>
ShowTables(currentCatalog.asTableCatalog, catalogManager.currentNamespace, pattern)

case UseStatement(isNamespaceSet, nameParts) =>
if (isNamespaceSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class CatalogManager(
}
}

def defaultCatalog: Option[CatalogPlugin] = {
private def defaultCatalog: Option[CatalogPlugin] = {
conf.defaultV2Catalog.flatMap { catalogName =>
try {
Some(catalog(catalogName))
Expand All @@ -74,9 +74,16 @@ class CatalogManager(
}
}

// If the V2_SESSION_CATALOG_IMPLEMENTATION config is specified, we try to instantiate the
// user-specified v2 session catalog. Otherwise, return the default session catalog.
def v2SessionCatalog: CatalogPlugin = {
/**
* If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2
* session catalog. Otherwise, return the default session catalog.
*
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
private def v2SessionCatalog: CatalogPlugin = {
conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).map { customV2SessionCatalog =>
try {
catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,11 @@ private[sql] trait LookupCatalog extends Logging {

protected val catalogManager: CatalogManager

/**
* Returns the default catalog. When set, this catalog is used for all identifiers that do not
* set a specific catalog. When this is None, the session catalog is responsible for the
* identifier.
*
* If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will
* be used.
*/
def defaultCatalog: Option[CatalogPlugin] = catalogManager.defaultCatalog

/**
* Returns the current catalog set.
*/
def currentCatalog: CatalogPlugin = catalogManager.currentCatalog

/**
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
def sessionCatalog: CatalogPlugin = catalogManager.v2SessionCatalog

/**
* Extract catalog plugin and remaining identifier names.
*
Expand All @@ -69,16 +51,14 @@ private[sql] trait LookupCatalog extends Logging {
}
}

type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)

/**
* Extract catalog and identifier from a multi-part identifier with the default catalog if needed.
* Extract catalog and identifier from a multi-part identifier with the current catalog if needed.
*/
object CatalogObjectIdentifier {
def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match {
case CatalogAndIdentifier(maybeCatalog, nameParts) =>
Some((
maybeCatalog.orElse(defaultCatalog),
maybeCatalog.getOrElse(currentCatalog),
Identifier.of(nameParts.init.toArray, nameParts.last)
))
}
Expand Down Expand Up @@ -108,7 +88,7 @@ private[sql] trait LookupCatalog extends Logging {
*/
object AsTableIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {
case CatalogAndIdentifier(None, names) if defaultCatalog.isEmpty =>
case CatalogAndIdentifier(None, names) if CatalogV2Util.isSessionCatalog(currentCatalog) =>
names match {
case Seq(name) =>
Some(TableIdentifier(name))
Expand Down Expand Up @@ -146,8 +126,7 @@ private[sql] trait LookupCatalog extends Logging {
Some((catalogManager.catalog(nameParts.head), nameParts.tail))
} catch {
case _: CatalogNotFoundException =>
// TODO (SPARK-29014): use current catalog here.
Some((defaultCatalog.getOrElse(sessionCatalog), nameParts))
Some((currentCatalog, nameParts))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -36,29 +37,30 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
import CatalystSqlParser._

private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap
private val sessionCatalog = FakeV2SessionCatalog

override val catalogManager: CatalogManager = {
val manager = mock(classOf[CatalogManager])
when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(None)
when(manager.currentCatalog).thenReturn(sessionCatalog)
manager
}

test("catalog object identifier") {
Seq(
("tbl", None, Seq.empty, "tbl"),
("db.tbl", None, Seq("db"), "tbl"),
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", None, Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None,
("tbl", sessionCatalog, Seq.empty, "tbl"),
("db.tbl", sessionCatalog, Seq("db"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", sessionCatalog, Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", sessionCatalog, Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", sessionCatalog,
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
Expand Down Expand Up @@ -135,22 +137,22 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(catalogs.get("prod"))
when(manager.currentCatalog).thenReturn(catalogs("prod"))
manager
}

test("catalog object identifier") {
Seq(
("tbl", catalogs.get("prod"), Seq.empty, "tbl"),
("db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", catalogs.get("prod"), Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", catalogs.get("prod"), Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", catalogs.get("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs.get("prod"),
("tbl", catalogs("prod"), Seq.empty, "tbl"),
("db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("prod.func", catalogs("prod"), Seq.empty, "func"),
("ns1.ns2.tbl", catalogs("prod"), Seq("ns1", "ns2"), "tbl"),
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
("`db.tbl`", catalogs("prod"), Seq.empty, "db.tbl"),
("parquet.`file:/tmp/db.tbl`", catalogs("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs("prod"),
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
Expand Down
18 changes: 10 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def insertInto(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

assertNotBucketed("insertInto")

Expand All @@ -354,14 +355,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
insertInto(catalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
insertInto(sessionCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
insertInto(catalog, ident)

case AsTableIdentifier(tableIdentifier) =>
insertInto(tableIdentifier)
Expand Down Expand Up @@ -480,17 +481,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def saveAsTable(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
saveAsTable(catalog.asTableCatalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(sessionCatalog.asTableCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(catalog.asTableCatalog, ident)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be correct if the current catalog is v2 session catalog that doesn't delegate to the v1 session catalog? If you look at the previous behavior, it's always using v1 session catalog.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a known problem that if the v2 session catalog doesn't delegate to v1 session catalog, many things can be broken.

I think the previous version was wrong. It always use the default v2 session catalog even if users set a custom v2 session catalog.


case AsTableIdentifier(tableIdentifier) =>
saveAsTable(tableIdentifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)

private val (catalog, identifier) = {
val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName
val catalog = maybeCatalog.getOrElse(catalogManager.currentCatalog).asTableCatalog
(catalog, identifier)
val CatalogObjectIdentifier(catalog, identifier) = tableName
(catalog.asTableCatalog, identifier)
}

private val logicalPlan = df.queryExecution.logical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ class ResolveSessionCatalog(
}
ShowTablesCommand(Some(nameParts.head), pattern)

// TODO (SPARK-29014): we should check if the current catalog is session catalog here.
case ShowTablesStatement(None, pattern) if defaultCatalog.isEmpty =>
case ShowTablesStatement(None, pattern) if isSessionCatalog(currentCatalog) =>
ShowTablesCommand(None, pattern)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DataSourceV2DataFrameSessionCatalogSuite
val t1 = "prop_table"
withTable(t1) {
spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1)
val cat = spark.sessionState.catalogManager.v2SessionCatalog.asInstanceOf[TableCatalog]
val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableInfo = cat.loadTable(Identifier.of(Array.empty, t1))
assert(tableInfo.properties().get("location") === "abc")
assert(tableInfo.properties().get("provider") === v2Format)
Expand Down
Loading