1717
1818package org .apache .spark .sql .hive .client
1919
20- import java .io .PrintStream
20+ import java .io .{ OutputStream , PrintStream }
2121import java .lang .{Iterable => JIterable }
2222import java .lang .reflect .InvocationTargetException
2323import java .nio .charset .StandardCharsets .UTF_8
@@ -28,6 +28,7 @@ import scala.collection.mutable
2828import scala .collection .mutable .ArrayBuffer
2929import scala .jdk .CollectionConverters ._
3030
31+ import org .apache .commons .lang3 .exception .ExceptionUtils
3132import org .apache .hadoop .conf .Configuration
3233import org .apache .hadoop .fs .Path
3334import org .apache .hadoop .hive .common .StatsSetupConst
@@ -44,7 +45,7 @@ import org.apache.hadoop.hive.serde2.MetadataTypedColumnsetSerDe
4445import org .apache .hadoop .hive .serde2 .`lazy` .LazySimpleSerDe
4546import org .apache .hadoop .security .UserGroupInformation
4647
47- import org .apache .spark .{SparkConf , SparkException }
48+ import org .apache .spark .{SparkConf , SparkException , SparkThrowable }
4849import org .apache .spark .deploy .SparkHadoopUtil .SOURCE_SPARK
4950import org .apache .spark .internal .{Logging , LogKeys , MDC }
5051import org .apache .spark .internal .LogKeys ._
@@ -121,6 +122,7 @@ private[hive] class HiveClientImpl(
121122 case hive.v2_3 => new Shim_v2_3 ()
122123 case hive.v3_0 => new Shim_v3_0 ()
123124 case hive.v3_1 => new Shim_v3_1 ()
125+ case hive.v4_0 => new Shim_v4_0 ()
124126 }
125127
126128 // Create an internal session state for this HiveClientImpl.
@@ -177,8 +179,10 @@ private[hive] class HiveClientImpl(
177179 // got changed. We reset it to clientLoader.ClassLoader here.
178180 state.getConf.setClassLoader(clientLoader.classLoader)
179181 shim.setCurrentSessionState(state)
180- state.out = new PrintStream (outputBuffer, true , UTF_8 .name())
181- state.err = new PrintStream (outputBuffer, true , UTF_8 .name())
182+ val clz = state.getClass.getField(" out" ).getType.asInstanceOf [Class [_ <: PrintStream ]]
183+ val ctor = clz.getConstructor(classOf [OutputStream ], classOf [Boolean ], classOf [String ])
184+ state.getClass.getField(" out" ).set(state, ctor.newInstance(outputBuffer, true , UTF_8 .name()))
185+ state.getClass.getField(" err" ).set(state, ctor.newInstance(outputBuffer, true , UTF_8 .name()))
182186 state
183187 }
184188
@@ -307,15 +311,27 @@ private[hive] class HiveClientImpl(
307311 }
308312
309313 def setOut (stream : PrintStream ): Unit = withHiveState {
310- state.out = stream
314+ val ctor = state.getClass.getField(" out" )
315+ .getType
316+ .asInstanceOf [Class [_ <: PrintStream ]]
317+ .getConstructor(classOf [OutputStream ])
318+ state.getClass.getField(" out" ).set(state, ctor.newInstance(stream))
311319 }
312320
313321 def setInfo (stream : PrintStream ): Unit = withHiveState {
314- state.info = stream
322+ val ctor = state.getClass.getField(" info" )
323+ .getType
324+ .asInstanceOf [Class [_ <: PrintStream ]]
325+ .getConstructor(classOf [OutputStream ])
326+ state.getClass.getField(" info" ).set(state, ctor.newInstance(stream))
315327 }
316328
317329 def setError (stream : PrintStream ): Unit = withHiveState {
318- state.err = stream
330+ val ctor = state.getClass.getField(" err" )
331+ .getType
332+ .asInstanceOf [Class [_ <: PrintStream ]]
333+ .getConstructor(classOf [OutputStream ])
334+ state.getClass.getField(" err" ).set(state, ctor.newInstance(stream))
319335 }
320336
321337 private def setCurrentDatabaseRaw (db : String ): Unit = {
@@ -629,21 +645,22 @@ private[hive] class HiveClientImpl(
629645 }
630646
631647 override def createPartitions (
632- db : String ,
633- table : String ,
648+ table : CatalogTable ,
634649 parts : Seq [CatalogTablePartition ],
635650 ignoreIfExists : Boolean ): Unit = withHiveState {
636651 def replaceExistException (e : Throwable ): Unit = e match {
637652 case _ : HiveException if e.getCause.isInstanceOf [AlreadyExistsException ] =>
638- val hiveTable = client.getTable(db, table)
653+ val db = table.identifier.database.getOrElse(state.getCurrentDatabase)
654+ val tableName = table.identifier.table
655+ val hiveTable = client.getTable(db, tableName)
639656 val existingParts = parts.filter { p =>
640657 shim.getPartitions(client, hiveTable, p.spec.asJava).nonEmpty
641658 }
642- throw new PartitionsAlreadyExistException (db, table , existingParts.map(_.spec))
659+ throw new PartitionsAlreadyExistException (db, tableName , existingParts.map(_.spec))
643660 case _ => throw e
644661 }
645662 try {
646- shim.createPartitions(client, db, table, parts, ignoreIfExists)
663+ shim.createPartitions(client, toHiveTable( table) , parts, ignoreIfExists)
647664 } catch {
648665 case e : InvocationTargetException => replaceExistException(e.getCause)
649666 case e : Throwable => replaceExistException(e)
@@ -861,11 +878,22 @@ private[hive] class HiveClientImpl(
861878 // Since HIVE-18238(Hive 3.0.0), the Driver.close function's return type changed
862879 // and the CommandProcessorFactory.clean function removed.
863880 driver.getClass.getMethod(" close" ).invoke(driver)
864- if (version != hive.v3_0 && version != hive.v3_1) {
881+ if (version != hive.v3_0 && version != hive.v3_1 && version != hive.v4_0 ) {
865882 CommandProcessorFactory .clean(conf)
866883 }
867884 }
868885
886+ def getResponseCode (response : CommandProcessorResponse ): Int = {
887+ if (version < hive.v4_0) {
888+ response.getResponseCode
889+ } else {
890+ // Since Hive 4.0, response code is removed from CommandProcessorResponse.
891+ // Here we simply return 0 for the positive cases as for error cases it will
892+ // throw exceptions early.
893+ 0
894+ }
895+ }
896+
869897 // Hive query needs to start SessionState.
870898 SessionState .start(state)
871899 logDebug(s " Running hiveql ' $cmd' " )
@@ -878,30 +906,44 @@ private[hive] class HiveClientImpl(
878906 val proc = shim.getCommandProcessor(tokens(0 ), conf)
879907 proc match {
880908 case driver : Driver =>
881- val response : CommandProcessorResponse = driver.run(cmd)
882- // Throw an exception if there is an error in query processing.
883- if (response.getResponseCode != 0 ) {
909+ try {
910+ val response : CommandProcessorResponse = driver.run(cmd)
911+ if (getResponseCode(response) != 0 ) {
912+ // Throw an exception if there is an error in query processing.
913+ // This works for hive 3.x and earlier versions.
914+ throw new QueryExecutionException (response.getErrorMessage)
915+ }
916+ driver.setMaxRows(maxRows)
917+ val results = shim.getDriverResults(driver)
918+ results
919+ } catch {
920+ case e @ (_ : QueryExecutionException | _ : SparkThrowable ) =>
921+ throw e
922+ case e : Exception =>
923+ // Wrap the original hive error with QueryExecutionException and throw it
924+ // if there is an error in query processing.
925+ // This works for hive 4.x and later versions.
926+ throw new QueryExecutionException (ExceptionUtils .getStackTrace(e))
927+ } finally {
884928 closeDriver(driver)
885- throw new QueryExecutionException (response.getErrorMessage)
886929 }
887- driver.setMaxRows(maxRows)
888-
889- val results = shim.getDriverResults(driver)
890- closeDriver(driver)
891- results
892930
893931 case _ =>
894- if (state.out != null ) {
932+ val out = state.getClass.getField(" out" ).get(state)
933+ if (out != null ) {
895934 // scalastyle:off println
896- state. out.println(tokens(0 ) + " " + cmd_1)
935+ out. asInstanceOf [ PrintStream ] .println(tokens(0 ) + " " + cmd_1)
897936 // scalastyle:on println
898937 }
899938 val response : CommandProcessorResponse = proc.run(cmd_1)
900- // Throw an exception if there is an error in query processing.
901- if (response.getResponseCode != 0 ) {
939+ val responseCode = getResponseCode(response)
940+ if (responseCode != 0 ) {
941+ // Throw an exception if there is an error in query processing.
942+ // This works for hive 3.x and earlier versions. For 4.x and later versions,
943+ // It will go to the catch block directly.
902944 throw new QueryExecutionException (response.getErrorMessage)
903945 }
904- Seq (response.getResponseCode .toString)
946+ Seq (responseCode .toString)
905947 }
906948 } catch {
907949 case e : Exception =>
@@ -971,7 +1013,7 @@ private[hive] class HiveClientImpl(
9711013 partSpec,
9721014 replace,
9731015 numDP,
974- listBucketingEnabled = hiveTable.isStoredAsSubDirectories )
1016+ hiveTable)
9751017 }
9761018
9771019 override def createFunction (db : String , func : CatalogFunction ): Unit = withHiveState {
0 commit comments