Skip to content

Commit 9f1e574

Browse files
committed
fix tests
1 parent 4131274 commit 9f1e574

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.scalatest.Matchers._
2424

2525
import org.apache.spark.SparkFunSuite
2626
import org.apache.spark.sql.catalyst.TableIdentifier
27+
import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog
2728
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2829
import org.apache.spark.sql.util.CaseInsensitiveStringMap
2930

@@ -36,28 +37,30 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
3637
import CatalystSqlParser._
3738

3839
private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap
40+
private val sessionCatalog = FakeV2SessionCatalog
3941

4042
override val catalogManager: CatalogManager = {
4143
val manager = mock(classOf[CatalogManager])
4244
when(manager.catalog(any())).thenAnswer((invocation: InvocationOnMock) => {
4345
val name = invocation.getArgument[String](0)
4446
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
4547
})
48+
when(manager.currentCatalog).thenReturn(sessionCatalog)
4649
manager
4750
}
4851

4952
test("catalog object identifier") {
5053
Seq(
51-
("tbl", None, Seq.empty, "tbl"),
52-
("db.tbl", None, Seq("db"), "tbl"),
53-
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
54-
("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"),
55-
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
56-
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
57-
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
58-
("`db.tbl`", None, Seq.empty, "db.tbl"),
59-
("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"),
60-
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None,
54+
("tbl", sessionCatalog, Seq.empty, "tbl"),
55+
("db.tbl", sessionCatalog, Seq("db"), "tbl"),
56+
("prod.func", catalogs("prod"), Seq.empty, "func"),
57+
("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"),
58+
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
59+
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
60+
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
61+
("`db.tbl`", sessionCatalog, Seq.empty, "db.tbl"),
62+
("parquet.`file:/tmp/db.tbl`", sessionCatalog, Seq("parquet"), "file:/tmp/db.tbl"),
63+
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", sessionCatalog,
6164
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
6265
case (sql, expectedCatalog, namespace, name) =>
6366
inside(parseMultipartIdentifier(sql)) {
@@ -134,21 +137,22 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
134137
val name = invocation.getArgument[String](0)
135138
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
136139
})
140+
when(manager.currentCatalog).thenReturn(catalogs("prod"))
137141
manager
138142
}
139143

140144
test("catalog object identifier") {
141145
Seq(
142-
("tbl", catalogs.get("prod"), Seq.empty, "tbl"),
143-
("db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
144-
("prod.func", catalogs.get("prod"), Seq.empty, "func"),
145-
("ns1.ns2.tbl", catalogs.get("prod"), Seq("ns1", "ns2"), "tbl"),
146-
("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
147-
("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
148-
("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
149-
("`db.tbl`", catalogs.get("prod"), Seq.empty, "db.tbl"),
150-
("parquet.`file:/tmp/db.tbl`", catalogs.get("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
151-
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs.get("prod"),
146+
("tbl", catalogs("prod"), Seq.empty, "tbl"),
147+
("db.tbl", catalogs("prod"), Seq("db"), "tbl"),
148+
("prod.func", catalogs("prod"), Seq.empty, "func"),
149+
("ns1.ns2.tbl", catalogs("prod"), Seq("ns1", "ns2"), "tbl"),
150+
("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"),
151+
("test.db.tbl", catalogs("test"), Seq("db"), "tbl"),
152+
("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
153+
("`db.tbl`", catalogs("prod"), Seq.empty, "db.tbl"),
154+
("parquet.`file:/tmp/db.tbl`", catalogs("prod"), Seq("parquet"), "file:/tmp/db.tbl"),
155+
("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs("prod"),
152156
Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
153157
case (sql, expectedCatalog, namespace, name) =>
154158
inside(parseMultipartIdentifier(sql)) {

0 commit comments

Comments
 (0)