Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql

import java.lang.Double.longBitsToDouble
import java.lang.Float.intBitsToFloat
import java.math.MathContext

import scala.collection.mutable
Expand Down Expand Up @@ -69,6 +67,28 @@ object RandomDataGenerator {
Some(f)
}

/**
* A wrapper of Float.intBitsToFloat to use a unique NaN value for all NaN values.
* This prevents `checkEvaluationWithUnsafeProjection` from failing due to
* the difference between `UnsafeRow` binary presentation for NaN.
* This is visible for testing.
*/
def intBitsToFloat(bits: Int): Float = {
val value = java.lang.Float.intBitsToFloat(bits)
if (value.isNaN) Float.NaN else value
}

/**
* A wrapper of Double.longBitsToDouble to use a unique NaN value for all NaN values.
* This prevents `checkEvaluationWithUnsafeProjection` from failing due to
* the difference between `UnsafeRow` binary presentation for NaN.
* This is visible for testing.
*/
def longBitsToDouble(bits: Long): Double = {
val value = java.lang.Double.longBitsToDouble(bits)
if (value.isNaN) Double.NaN else value
}

/**
* Returns a randomly generated schema, based on the given accepted types.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql

import java.nio.ByteBuffer
import java.util.Arrays

import scala.util.Random

import org.apache.spark.SparkFunSuite
Expand Down Expand Up @@ -106,4 +109,32 @@ class RandomDataGeneratorSuite extends SparkFunSuite {
assert(deviation.toDouble / expectedTotalElements < 2e-1)
}
}

test("Use Float.NaN for all NaN values") {
val bits = -6966608
val nan1 = java.lang.Float.intBitsToFloat(bits)
val nan2 = RandomDataGenerator.intBitsToFloat(bits)
assert(nan1.isNaN)
assert(nan2.isNaN)

val arrayExpected = ByteBuffer.allocate(4).putFloat(Float.NaN).array
val array1 = ByteBuffer.allocate(4).putFloat(nan1).array
val array2 = ByteBuffer.allocate(4).putFloat(nan2).array
assert(!Arrays.equals(array1, arrayExpected))
assert(Arrays.equals(array2, arrayExpected))
}

test("Use Double.NaN for all NaN values") {
val bits = -6966608
val nan1 = java.lang.Double.longBitsToDouble(bits)
val nan2 = RandomDataGenerator.longBitsToDouble(bits)
assert(nan1.isNaN)
assert(nan2.isNaN)

val arrayExpected = ByteBuffer.allocate(8).putDouble(Double.NaN).array
val array1 = ByteBuffer.allocate(8).putDouble(nan1).array
val array2 = ByteBuffer.allocate(8).putDouble(nan2).array
assert(!Arrays.equals(array1, arrayExpected))
assert(Arrays.equals(array2, arrayExpected))
}
}