@@ -24,6 +24,7 @@ import org.scalatest.Matchers._
2424
2525import org .apache .spark .SparkFunSuite
2626import org .apache .spark .sql .catalyst .TableIdentifier
27+ import org .apache .spark .sql .catalyst .analysis .FakeV2SessionCatalog
2728import org .apache .spark .sql .catalyst .parser .CatalystSqlParser
2829import org .apache .spark .sql .util .CaseInsensitiveStringMap
2930
@@ -36,28 +37,30 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
3637 import CatalystSqlParser ._
3738
3839 private val catalogs = Seq (" prod" , " test" ).map(x => x -> DummyCatalogPlugin (x)).toMap
40+ private val sessionCatalog = FakeV2SessionCatalog
3941
4042 override val catalogManager : CatalogManager = {
4143 val manager = mock(classOf [CatalogManager ])
4244 when(manager.catalog(any())).thenAnswer((invocation : InvocationOnMock ) => {
4345 val name = invocation.getArgument[String ](0 )
4446 catalogs.getOrElse(name, throw new CatalogNotFoundException (s " $name not found " ))
4547 })
48+ when(manager.currentCatalog).thenReturn(sessionCatalog)
4649 manager
4750 }
4851
4952 test(" catalog object identifier" ) {
5053 Seq (
51- (" tbl" , None , Seq .empty, " tbl" ),
52- (" db.tbl" , None , Seq (" db" ), " tbl" ),
53- (" prod.func" , catalogs.get (" prod" ), Seq .empty, " func" ),
54- (" ns1.ns2.tbl" , None , Seq (" ns1" , " ns2" ), " tbl" ),
55- (" prod.db.tbl" , catalogs.get (" prod" ), Seq (" db" ), " tbl" ),
56- (" test.db.tbl" , catalogs.get (" test" ), Seq (" db" ), " tbl" ),
57- (" test.ns1.ns2.ns3.tbl" , catalogs.get (" test" ), Seq (" ns1" , " ns2" , " ns3" ), " tbl" ),
58- (" `db.tbl`" , None , Seq .empty, " db.tbl" ),
59- (" parquet.`file:/tmp/db.tbl`" , None , Seq (" parquet" ), " file:/tmp/db.tbl" ),
60- (" `org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`" , None ,
54+ (" tbl" , sessionCatalog , Seq .empty, " tbl" ),
55+ (" db.tbl" , sessionCatalog , Seq (" db" ), " tbl" ),
56+ (" prod.func" , catalogs(" prod" ), Seq .empty, " func" ),
57+ (" ns1.ns2.tbl" , sessionCatalog , Seq (" ns1" , " ns2" ), " tbl" ),
58+ (" prod.db.tbl" , catalogs(" prod" ), Seq (" db" ), " tbl" ),
59+ (" test.db.tbl" , catalogs(" test" ), Seq (" db" ), " tbl" ),
60+ (" test.ns1.ns2.ns3.tbl" , catalogs(" test" ), Seq (" ns1" , " ns2" , " ns3" ), " tbl" ),
61+ (" `db.tbl`" , sessionCatalog , Seq .empty, " db.tbl" ),
62+ (" parquet.`file:/tmp/db.tbl`" , sessionCatalog , Seq (" parquet" ), " file:/tmp/db.tbl" ),
63+ (" `org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`" , sessionCatalog ,
6164 Seq (" org.apache.spark.sql.json" ), " s3://buck/tmp/abc.json" )).foreach {
6265 case (sql, expectedCatalog, namespace, name) =>
6366 inside(parseMultipartIdentifier(sql)) {
@@ -134,21 +137,22 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit
134137 val name = invocation.getArgument[String ](0 )
135138 catalogs.getOrElse(name, throw new CatalogNotFoundException (s " $name not found " ))
136139 })
140+ when(manager.currentCatalog).thenReturn(catalogs(" prod" ))
137141 manager
138142 }
139143
140144 test(" catalog object identifier" ) {
141145 Seq (
142- (" tbl" , catalogs.get (" prod" ), Seq .empty, " tbl" ),
143- (" db.tbl" , catalogs.get (" prod" ), Seq (" db" ), " tbl" ),
144- (" prod.func" , catalogs.get (" prod" ), Seq .empty, " func" ),
145- (" ns1.ns2.tbl" , catalogs.get (" prod" ), Seq (" ns1" , " ns2" ), " tbl" ),
146- (" prod.db.tbl" , catalogs.get (" prod" ), Seq (" db" ), " tbl" ),
147- (" test.db.tbl" , catalogs.get (" test" ), Seq (" db" ), " tbl" ),
148- (" test.ns1.ns2.ns3.tbl" , catalogs.get (" test" ), Seq (" ns1" , " ns2" , " ns3" ), " tbl" ),
149- (" `db.tbl`" , catalogs.get (" prod" ), Seq .empty, " db.tbl" ),
150- (" parquet.`file:/tmp/db.tbl`" , catalogs.get (" prod" ), Seq (" parquet" ), " file:/tmp/db.tbl" ),
151- (" `org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`" , catalogs.get (" prod" ),
146+ (" tbl" , catalogs(" prod" ), Seq .empty, " tbl" ),
147+ (" db.tbl" , catalogs(" prod" ), Seq (" db" ), " tbl" ),
148+ (" prod.func" , catalogs(" prod" ), Seq .empty, " func" ),
149+ (" ns1.ns2.tbl" , catalogs(" prod" ), Seq (" ns1" , " ns2" ), " tbl" ),
150+ (" prod.db.tbl" , catalogs(" prod" ), Seq (" db" ), " tbl" ),
151+ (" test.db.tbl" , catalogs(" test" ), Seq (" db" ), " tbl" ),
152+ (" test.ns1.ns2.ns3.tbl" , catalogs(" test" ), Seq (" ns1" , " ns2" , " ns3" ), " tbl" ),
153+ (" `db.tbl`" , catalogs(" prod" ), Seq .empty, " db.tbl" ),
154+ (" parquet.`file:/tmp/db.tbl`" , catalogs(" prod" ), Seq (" parquet" ), " file:/tmp/db.tbl" ),
155+ (" `org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`" , catalogs(" prod" ),
152156 Seq (" org.apache.spark.sql.json" ), " s3://buck/tmp/abc.json" )).foreach {
153157 case (sql, expectedCatalog, namespace, name) =>
154158 inside(parseMultipartIdentifier(sql)) {
0 commit comments