Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
include all numeric types
  • Loading branch information
Wayne Zhang committed May 4, 2017
commit d149350d73fcc4f576880730661fd688fed8fc07
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._


class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

import testImplicits._
Expand Down Expand Up @@ -166,7 +165,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testDefaultReadWrite(t)
}

test("Bucket non-double numeric features") {
test("Bucket numeric features") {
val splits = Array(-3.0, 0.0, 3.0)
val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0)
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0)
Expand All @@ -177,12 +176,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setOutputCol("result")
.setSplits(splits)

val types = Seq(ShortType, IntegerType, LongType, FloatType)
val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
ByteType, DecimalType(10, 0))
for (mType <- types) {
val df = dataFrame.withColumn("feature", col("feature").cast(mType))
bucketizer.transform(df).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y, "The feature value is not correct after bucketing in type " +
assert(x === y, "The result is not correct after bucketing in type " +
mType.toString + ". " + s"Expected $y but found $x.")
}
}
Expand Down