Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
package org.apache.spark.sql.application

import java.io.{PipedInputStream, PipedOutputStream}
import java.nio.file.Paths
import java.util.concurrent.{Executors, Semaphore, TimeUnit}

import org.apache.commons.io.output.ByteArrayOutputStream
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}

class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {

Expand Down Expand Up @@ -151,4 +152,32 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
assertContains("Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)", output)
}

test("Client-side JAR") {
// scalastyle:off classforname line.size.limit
val sparkHome = IntegrationTestUtils.sparkHome
val testJar = Paths
.get(s"$sparkHome/connector/connect/client/jvm/src/test/resources/TestHelloV2.jar")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine but for doubly sure, does it need Scala 2.13 jar too, @LuciferYang ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, should wait #41852

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#41852 merged

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've updated the code to account for the scala version

.toFile

assert(testJar.exists(), "Missing TestHelloV2 jar!")
val input = s"""
|import java.nio.file.Paths
|def classLoadingTest(x: Int): Int = {
| val classloader =
| Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
| val cls = Class.forName("com.example.Hello$$", true, classloader)
| val module = cls.getField("MODULE$$").get(null)
| cls.getMethod("test").invoke(module).asInstanceOf[Int]
|}
|val classLoaderUdf = udf(classLoadingTest _)
|
|val jarPath = Paths.get("$sparkHome/connector/connect/client/jvm/src/test/resources/TestHelloV2.jar").toUri
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there must be a better way to pass in the JAR path here. Open to suggestions!

|spark.addArtifact(jarPath)
|
|spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect()
""".stripMargin
val output = runCommandsInShell(input)
assertContains("Array[Int] = Array(2, 2, 2, 2, 2)", output)
// scalastyle:on classforname line.size.limit
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
jarsList.add(target)
jarsURI.add(artifactURI + "/" + target.toString)
jarsURI.add(artifactURI + "/" + remoteRelativePath.toString)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
val stringRemotePath = remoteRelativePath.toString
Expand Down