diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 8ae3ff5043e6..d361e6248e2f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.lang.Double.longBitsToDouble -import java.lang.Float.intBitsToFloat import java.math.MathContext import scala.collection.mutable @@ -69,6 +67,28 @@ object RandomDataGenerator { Some(f) } + /** + * A wrapper of Float.intBitsToFloat to use a unique NaN value for all NaN values. + * This prevents `checkEvaluationWithUnsafeProjection` from failing due to + * the difference between `UnsafeRow` binary presentation for NaN. + * This is visible for testing. + */ + def intBitsToFloat(bits: Int): Float = { + val value = java.lang.Float.intBitsToFloat(bits) + if (value.isNaN) Float.NaN else value + } + + /** + * A wrapper of Double.longBitsToDouble to use a unique NaN value for all NaN values. + * This prevents `checkEvaluationWithUnsafeProjection` from failing due to + * the difference between `UnsafeRow` binary presentation for NaN. + * This is visible for testing. + */ + def longBitsToDouble(bits: Long): Double = { + val value = java.lang.Double.longBitsToDouble(bits) + if (value.isNaN) Double.NaN else value + } + /** * Returns a randomly generated schema, based on the given accepted types. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index 3c2f8a28875f..3e62ca069e9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql +import java.nio.ByteBuffer +import java.util.Arrays + import scala.util.Random import org.apache.spark.SparkFunSuite @@ -106,4 +109,32 @@ class RandomDataGeneratorSuite extends SparkFunSuite { assert(deviation.toDouble / expectedTotalElements < 2e-1) } } + + test("Use Float.NaN for all NaN values") { + val bits = -6966608 + val nan1 = java.lang.Float.intBitsToFloat(bits) + val nan2 = RandomDataGenerator.intBitsToFloat(bits) + assert(nan1.isNaN) + assert(nan2.isNaN) + + val arrayExpected = ByteBuffer.allocate(4).putFloat(Float.NaN).array + val array1 = ByteBuffer.allocate(4).putFloat(nan1).array + val array2 = ByteBuffer.allocate(4).putFloat(nan2).array + assert(!Arrays.equals(array1, arrayExpected)) + assert(Arrays.equals(array2, arrayExpected)) + } + + test("Use Double.NaN for all NaN values") { + val bits = -6966608 + val nan1 = java.lang.Double.longBitsToDouble(bits) + val nan2 = RandomDataGenerator.longBitsToDouble(bits) + assert(nan1.isNaN) + assert(nan2.isNaN) + + val arrayExpected = ByteBuffer.allocate(8).putDouble(Double.NaN).array + val array1 = ByteBuffer.allocate(8).putDouble(nan1).array + val array2 = ByteBuffer.allocate(8).putDouble(nan2).array + assert(!Arrays.equals(array1, arrayExpected)) + assert(Arrays.equals(array2, arrayExpected)) + } }