Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ object ScalaReflection extends ScalaReflection {
// Since we are creating a runtime mirror usign the class loader of current thread,
// we need to use def at here. So, every time we call mirror, it is using the
// class loader of the current thread.
override def mirror: universe.Mirror =
override def mirror: universe.Mirror = ScalaReflectionLock.synchronized {
universe.runtimeMirror(Thread.currentThread().getContextClassLoader)
}

import universe._

Expand Down Expand Up @@ -665,7 +666,7 @@ trait ScalaReflection {
*
* @see SPARK-5281
*/
def localTypeOf[T: TypeTag]: `Type` = {
def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized {
val tag = implicitly[TypeTag[T]]
tag.in(mirror).tpe.normalize
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ package org.apache.spark.sql.catalyst
import java.math.BigInteger
import java.sql.{Date, Timestamp}

import scala.reflect.runtime.universe.typeOf

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

case class PrimitiveData(
intField: Int,
Expand Down Expand Up @@ -237,4 +241,45 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(anyTypes.forall(!_.isPrimitive))
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
}

private def testThreadSafetyFor(name: String)(exec: () => Any) = {
test(s"thread safety of ${name}") {
for (_ <- 0 until 100) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(0 until 100).foreach? probably just a matter of taste but it's symmetrical with your inner loop then.
You can import java.net.URLClassLoader.
It doesn't really seem like you need a method here; it took a moment to see there was a test in here.

Maybe it's obvious to you but why do all these classes/methods need to be tested separately?

And is this locking still safe in 2.11?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Thank you for your comment.

(0 until 100).foreach?

I repeated the test 100 times here because it is for thread-safety. Thread safety problem sometimes happens but sometimes doesn't.

You can import java.net.URLClassLoader.

I'll modify to use import.

It doesn't really seem like you need a method here; it took a moment to see there was a test in here.

I'll modify to move out of the method.

Maybe it's obvious to you but why do all these classes/methods need to be tested separately?

The methods are public, i.e. can be called by multi-thread, so I thought these also need to be tested.
But I'm wondering some of them could be removed?

And is this locking still safe in 2.11?

Yes, reflection in Scala 2.11 is thread-safe.
If we don't support Scala 2.10, these lockings in ScalaReflection would not be needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I understand repeating it 100 times, but your idiom for iterating n times is different on two nearby lines.

I actually meant the opposite -- remove the helper method but run tests in a loop still -- but leave it.

My question was more like, does the extra locking cause trouble in 2.11? I know you said 2.11 doesn't need this change per se.

val loader = new java.net.URLClassLoader(Array.empty, Utils.getContextOrSparkClassLoader)
(0 until 10).par.foreach { _ =>
val cl = Thread.currentThread.getContextClassLoader
try {
Thread.currentThread.setContextClassLoader(loader)
exec()
} finally {
Thread.currentThread.setContextClassLoader(cl)
}
}
}
}
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

for {
(name, exec) <- Seq(
("mirror", () => mirror),
("dataTypeFor", () => dataTypeFor[ComplexData]),
("constructorFor", () => constructorFor[ComplexData]),
("extractorsFor", {
val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false)
() => extractorsFor[ComplexData](inputObject)
}),
("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])),
("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])),
("getClassFromType", () => getClassFromType(typeOfComplexData)),
("schemaFor", () => schemaFor[ComplexData]),
("localTypeOf", () => localTypeOf[ComplexData]),
("getClassNameFromType", () => getClassNameFromType(typeOfComplexData)),
("getParameterTypes", () => getParameterTypes(() => ())),
("getConstructorParameters(tpe)", () => getClassNameFromType(typeOfComplexData)))
} {
testThreadSafetyFor(name)(exec)
}
}