Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make CatalogManager's defaultCatalog and *sessionCatalog private.
  • Loading branch information
imback82 committed Oct 16, 2019
commit 46d2fa35982b4ca0ef2fc237acfdc2292b0bfe80
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,20 @@ object SimpleAnalyzer extends Analyzer(
new CatalogManager(
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true),
FakeV2SessionCatalog,
SimpleAnalyzerHelper.createFakeV1SessionCatalog),
SimpleAnalyzerHelper.createFakeV1SessionCatalog,
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))

object SimpleAnalyzerHelper {
def createFakeV1SessionCatalog: SessionCatalog = {
new SessionCatalog(
new InMemoryCatalog,
EmptyFunctionRegistry,
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) {
override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {}
}),
new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true))
}
}
}

object FakeV2SessionCatalog extends TableCatalog {
private def fail() = throw new UnsupportedOperationException
Expand Down Expand Up @@ -122,24 +129,24 @@ object AnalysisContext {
*/
class Analyzer(
override val catalogManager: CatalogManager,
v1SessionCatalog: SessionCatalog,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed catalog to v1SessionCatalog in Analyzer to be explicit. Please let me know if this is not desired.

Copy link
Contributor Author

@imback82 imback82 Oct 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One downside of this approach (passing SessionCatalog as a separate parameter) is that a SessionCatalog instance can be different from the one stored in CatalogManager. Since CatlalogManager updates the current database of the session catalog, it can be out of sync. Please let me know if this approach is fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, what's the upside of doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1SessionCatalog is now private in CatalogManager. Since Analyzer uses v1SessionCatalog we need to pass this separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to make v1SessionCatalog private in CatalogManager. We only need to make sessionCatalog and defaultCatalog private. cc @rdblue

conf: SQLConf,
maxIterations: Int)
extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {

private val catalog: SessionCatalog = catalogManager.v1SessionCatalog

override def isView(nameParts: Seq[String]): Boolean = catalog.isView(nameParts)
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)

// Only for tests.
def this(catalog: SessionCatalog, conf: SQLConf) = {
this(
new CatalogManager(conf, FakeV2SessionCatalog, catalog),
catalog,
conf,
conf.optimizerMaxIterations)
}

def this(catalogManager: CatalogManager, conf: SQLConf) = {
this(catalogManager, conf, conf.optimizerMaxIterations)
def this(catalogManager: CatalogManager, catalog: SessionCatalog, conf: SQLConf) = {
this(catalogManager, catalog, conf, conf.optimizerMaxIterations)
}

def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
Expand Down Expand Up @@ -226,7 +233,7 @@ class Analyzer(
ResolveAggregateFunctions ::
TimeWindowing ::
ResolveInlineTables(conf) ::
ResolveHigherOrderFunctions(catalog) ::
ResolveHigherOrderFunctions(v1SessionCatalog) ::
ResolveLambdaVariables(conf) ::
ResolveTimeZone(conf) ::
ResolveRandomSeed ::
Expand Down Expand Up @@ -722,7 +729,7 @@ class Analyzer(
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident))
if catalog.isTemporaryTable(ident) =>
if v1SessionCatalog.isTemporaryTable(ident) =>
resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase))

case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
Expand Down Expand Up @@ -779,7 +786,7 @@ class Analyzer(
val tableIdentWithDb = tableIdentifier.copy(
database = tableIdentifier.database.orElse(defaultDatabase))
try {
catalog.lookupRelation(tableIdentWithDb)
v1SessionCatalog.lookupRelation(tableIdentWithDb)
} catch {
case _: NoSuchTableException | _: NoSuchDatabaseException =>
u
Expand All @@ -793,8 +800,9 @@ class Analyzer(
// Note that we are testing (!db_exists || !table_exists) because the catalog throws
// an exception from tableExists if the database does not exist.
private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = {
table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) &&
(!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))
table.database.isDefined && conf.runSQLonFile && !v1SessionCatalog.isTemporaryTable(table) &&
(!v1SessionCatalog.databaseExists(table.database.get)
|| !v1SessionCatalog.tableExists(table))
}
}

Expand Down Expand Up @@ -1512,13 +1520,14 @@ class Analyzer(
plan.resolveExpressions {
case f: UnresolvedFunction
if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f
case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) =>
case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f
case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) =>
externalFunctionNameSet.add(normalizeFuncName(f.name))
f
case f: UnresolvedFunction =>
withPosition(f) {
throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase),
throw new NoSuchFunctionException(
f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase),
f.name.funcName)
}
}
Expand All @@ -1533,7 +1542,7 @@ class Analyzer(

val databaseName = name.database match {
case Some(a) => formatDatabaseName(a)
case None => catalog.getCurrentDatabase
case None => v1SessionCatalog.getCurrentDatabase
}

FunctionIdentifier(funcName, Some(databaseName))
Expand All @@ -1558,7 +1567,7 @@ class Analyzer(
}
case u @ UnresolvedGenerator(name, children) =>
withPosition(u) {
catalog.lookupFunction(name, children) match {
v1SessionCatalog.lookupFunction(name, children) match {
case generator: Generator => generator
case other =>
failAnalysis(s"$name is expected to be a generator. However, " +
Expand All @@ -1567,7 +1576,7 @@ class Analyzer(
}
case u @ UnresolvedFunction(funcId, children, isDistinct) =>
withPosition(u) {
catalog.lookupFunction(funcId, children) match {
v1SessionCatalog.lookupFunction(funcId, children) match {
// AggregateWindowFunctions are AggregateFunctions that can only be evaluated within
// the context of a Window clause. They do not need to be wrapped in an
// AggregateExpression.
Expand Down Expand Up @@ -2768,17 +2777,17 @@ class Analyzer(
private def lookupV2RelationAndCatalog(
identifier: Seq[String]): Option[(DataSourceV2Relation, CatalogPlugin, Identifier)] =
identifier match {
case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(Some(v2Catalog), ident) =>
CatalogV2Util.loadTable(v2Catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), v2Catalog, ident))
case AsTemporaryViewIdentifier(ti) if v1SessionCatalog.isTemporaryTable(ti) => None
case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(table) => Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case CatalogObjectIdentifier(None, ident) =>
CatalogV2Util.loadTable(catalogManager.v2SessionCatalog, ident) match {
case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) =>
CatalogV2Util.loadTable(catalog, ident) match {
case Some(_: V1Table) => None
case Some(table) =>
Some((DataSourceV2Relation.create(table), catalogManager.v2SessionCatalog, ident))
Some((DataSourceV2Relation.create(table), catalog, ident))
case None => None
}
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ import org.apache.spark.sql.internal.SQLConf
* namespace in both `SessionCatalog` and `CatalogManger`, we let `CatalogManager` to set/get
* current database of `SessionCatalog` when the current catalog is the session catalog.
*/
// TODO: all commands should look up table from the current catalog. The `SessionCatalog` doesn't
// need to track current database at all.
private[sql]
class CatalogManager(
conf: SQLConf,
defaultSessionCatalog: CatalogPlugin,
val v1SessionCatalog: SessionCatalog) extends Logging {
v1SessionCatalog: SessionCatalog) extends Logging {
import CatalogManager.SESSION_CATALOG_NAME

private val catalogs = mutable.HashMap.empty[String, CatalogPlugin]
Expand All @@ -53,7 +51,7 @@ class CatalogManager(
}
}

def defaultCatalog: Option[CatalogPlugin] = {
private def defaultCatalog: Option[CatalogPlugin] = {
conf.defaultV2Catalog.flatMap { catalogName =>
try {
Some(catalog(catalogName))
Expand All @@ -74,9 +72,16 @@ class CatalogManager(
}
}

// If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2
// session catalog. Otherwise, return the default session catalog.
def v2SessionCatalog: CatalogPlugin = {
/**
* If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2
* session catalog. Otherwise, return the default session catalog.
*
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
private def v2SessionCatalog: CatalogPlugin = {
conf.getConf(SQLConf.V2_SESSION_CATALOG).map { customV2SessionCatalog =>
try {
catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,11 @@ private[sql] trait LookupCatalog extends Logging {

protected val catalogManager: CatalogManager

/**
* Returns the default catalog. When set, this catalog is used for all identifiers that do not
* set a specific catalog. When this is None, the session catalog is responsible for the
* identifier.
*
* If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will
* be used.
*/
def defaultCatalog: Option[CatalogPlugin] = catalogManager.defaultCatalog

/**
* Returns the current catalog set.
*/
def currentCatalog: CatalogPlugin = catalogManager.currentCatalog

/**
* This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the
* session catalog is responsible for an identifier, but the source requires the v2 catalog API.
* This happens when the source implementation extends the v2 TableProvider API and is not listed
* in the fallback configuration, spark.sql.sources.write.useV1SourceList
*/
def sessionCatalog: CatalogPlugin = catalogManager.v2SessionCatalog

/**
* Extract catalog plugin and remaining identifier names.
*
Expand All @@ -69,16 +51,14 @@ private[sql] trait LookupCatalog extends Logging {
}
}

type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)

/**
* Extract catalog and identifier from a multi-part identifier with the default catalog if needed.
* Extract catalog and identifier from a multi-part identifier with the current catalog if needed.
*/
object CatalogObjectIdentifier {
def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match {
case CatalogAndIdentifier(maybeCatalog, nameParts) =>
Some((
maybeCatalog.orElse(defaultCatalog),
maybeCatalog.getOrElse(currentCatalog),
Identifier.of(nameParts.init.toArray, nameParts.last)
))
}
Expand Down Expand Up @@ -108,7 +88,7 @@ private[sql] trait LookupCatalog extends Logging {
*/
object AsTableIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {
case CatalogAndIdentifier(None, names) if defaultCatalog.isEmpty =>
case CatalogAndIdentifier(None, names) if CatalogV2Util.isSessionCatalog(currentCatalog) =>
names match {
case Seq(name) =>
Some(TableIdentifier(name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(None)
manager
}

Expand Down Expand Up @@ -135,7 +134,6 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.defaultCatalog).thenReturn(catalogs.get("prod"))
manager
}

Expand Down
18 changes: 10 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def insertInto(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

assertNotBucketed("insertInto")

Expand All @@ -355,14 +356,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
insertInto(catalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
insertInto(sessionCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
insertInto(catalog, ident)

case AsTableIdentifier(tableIdentifier) =>
insertInto(tableIdentifier)
Expand Down Expand Up @@ -481,17 +482,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def saveAsTable(tableName: String): Unit = {
import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._

val session = df.sparkSession
val canUseV2 = lookupV2Provider().isDefined
val sessionCatalog = session.sessionState.analyzer.sessionCatalog

session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case CatalogObjectIdentifier(Some(catalog), ident) =>
case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) =>
saveAsTable(catalog.asTableCatalog, ident)

case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(sessionCatalog.asTableCatalog, ident)
case CatalogObjectIdentifier(catalog, ident)
if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 =>
saveAsTable(catalog.asTableCatalog, ident)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be correct if the current catalog is v2 session catalog that doesn't delegate to the v1 session catalog? If you look at the previous behavior, it's always using v1 session catalog.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a known problem that if the v2 session catalog doesn't delegate to v1 session catalog, many things can be broken.

I think the previous version was wrong. It always use the default v2 session catalog even if users set a custom v2 session catalog.


case AsTableIdentifier(tableIdentifier) =>
saveAsTable(tableIdentifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)

private val (catalog, identifier) = {
val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName
val catalog = maybeCatalog.getOrElse(catalogManager.currentCatalog).asTableCatalog
(catalog, identifier)
val CatalogObjectIdentifier(catalog, identifier) = tableName
(catalog.asTableCatalog, identifier)
}

private val logicalPlan = df.queryExecution.logical
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ abstract class BaseSessionStateBuilder(
*
* Note: this depends on the `conf` and `catalog` fields.
*/
protected def analyzer: Analyzer = new Analyzer(catalogManager, conf) {
protected def analyzer: Analyzer = new Analyzer(catalogManager, catalog, conf) {
override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DataSourceV2DataFrameSessionCatalogSuite
val t1 = "prop_table"
withTable(t1) {
spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1)
val cat = spark.sessionState.catalogManager.v2SessionCatalog.asInstanceOf[TableCatalog]
val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
val tableInfo = cat.loadTable(Identifier.of(Array.empty, t1))
assert(tableInfo.properties().get("location") === "abc")
assert(tableInfo.properties().get("provider") === v2Format)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DataSourceV2SQLSessionCatalogSuite
}

override def getTableMetadata(tableName: String): Table = {
val v2Catalog = spark.sessionState.catalogManager.v2SessionCatalog
val v2Catalog = spark.sessionState.catalogManager.currentCatalog
val nameParts = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName)
v2Catalog.asInstanceOf[TableCatalog]
.loadTable(Identifier.of(Array.empty, nameParts.last))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class PlanResolutionSuite extends AnalysisTest {
}
})
when(manager.currentCatalog).thenReturn(testCat)
when(manager.defaultCatalog).thenReturn(Some(testCat))
when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog)
when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog)
manager
}

Expand All @@ -114,9 +111,6 @@ class PlanResolutionSuite extends AnalysisTest {
}
})
when(manager.currentCatalog).thenReturn(v2SessionCatalog)
when(manager.defaultCatalog).thenReturn(None)
when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog)
when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog)
manager
}

Expand All @@ -126,7 +120,7 @@ class PlanResolutionSuite extends AnalysisTest {
} else {
catalogManagerWithoutDefault
}
val analyzer = new Analyzer(catalogManager, conf)
val analyzer = new Analyzer(catalogManager, v1SessionCatalog, conf)
val rules = Seq(
new ResolveCatalogs(catalogManager),
new ResolveSessionCatalog(catalogManager, conf, _ == Seq("v")),
Expand Down