Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Make testTranformer per row special case of
testTransformerByGlobalCheckFunc.
  • Loading branch information
MrBago committed Dec 28, 2017
commit de345dcf2ba67121b43cb82fa51394c415055975
26 changes: 9 additions & 17 deletions mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit)
(globalCheckFunction: Seq[Row] => Unit): Unit = {

val columnNames = dataframe.schema.fieldNames
Expand All @@ -71,7 +70,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
.select(firstResultCol, otherResultCols: _*)
testStream(streamOutput) (
AddData(stream, data: _*),
CheckAnswer(checkFunction, globalCheckFunction)
CheckAnswer(globalCheckFunction)
)
}

Expand All @@ -80,18 +79,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit)
(globalCheckFunction: Seq[Row] => Unit): Unit = {
val dfOutput = transformer.transform(dataframe)
val outputs = dfOutput.select(firstResultCol, otherResultCols: _*).collect()
if (checkFunction != null) {
outputs.foreach { row =>
checkFunction(row)
}
}
if (globalCheckFunction != null) {
globalCheckFunction(outputs)
}
globalCheckFunction(outputs)
}

def testTransformer[A : Encoder](
Expand All @@ -100,10 +91,11 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit): Unit = {
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
otherResultCols: _*)(checkFunction)(null)
testTransformerOnDF(dataframe, transformer, firstResultCol,
otherResultCols: _*)(checkFunction)(null)
testTransformerByGlobalCheckFunc(
dataframe,
transformer,
firstResultCol,
otherResultCols: _*) { rows: Seq[Row] => rows.foreach(checkFunction(_)) }
}

def testTransformerByGlobalCheckFunc[A : Encoder](
Expand All @@ -113,8 +105,8 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
otherResultCols: String*)
(globalCheckFunction: Seq[Row] => Unit): Unit = {
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
otherResultCols: _*)(null)(globalCheckFunction)
otherResultCols: _*)(globalCheckFunction)
testTransformerOnDF(dataframe, transformer, firstResultCol,
otherResultCols: _*)(null)(globalCheckFunction)
otherResultCols: _*)(globalCheckFunction)
}
}
10 changes: 5 additions & 5 deletions mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.ml.util

import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.Row

Expand All @@ -32,21 +31,22 @@ class MLTestSuite extends MLTest {
val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
.setInputCol("label").setOutputCol("indexed")
val indexerModel = indexer.fit(data)
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
testTransformer[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id === indexed.toInt)
} { rows: Seq[Row] =>
}
testTransformerByGlobalCheckFunc[(Int, String)] (data, indexerModel, "id", "indexed") { rows =>
assert(rows.map(_.getDouble(1)).max === 5.0)
}

intercept[Exception] {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id != indexed.toInt)
} (null)
}
}
intercept[Exception] {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") (null) {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
rows: Seq[Row] =>
assert(rows.map(_.getDouble(1)).max === 1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be

def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)

def apply(checkFunction: Row => Unit,
globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(checkFunction, globalCheckFunction, false)
def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(globalCheckFunction, false)
}

/**
Expand All @@ -162,9 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be

def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)

def apply(checkFunction: Row => Unit,
globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(checkFunction, globalCheckFunction, true)
def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(globalCheckFunction, true)
}

case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
Expand All @@ -180,7 +178,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}

case class CheckAnswerRowsByFunc(
checkFunction: Row => Unit,
globalCheckFunction: Seq[Row] => Unit,
lastOnly: Boolean) extends StreamAction with StreamMustBeRunning {
override def toString: String = s"$operatorName"
Expand Down Expand Up @@ -643,23 +640,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
error => failTest(error)
}

case CheckAnswerRowsByFunc(checkFunction, globalCheckFunction, lastOnly) =>
case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
if (checkFunction != null) {
sparkAnswer.foreach { row =>
try {
checkFunction(row)
} catch {
case e: Throwable => failTest(e.toString)
}
}
}
if (globalCheckFunction != null) {
try {
globalCheckFunction(sparkAnswer)
} catch {
case e: Throwable => failTest(e.toString)
}
try {
globalCheckFunction(sparkAnswer)
} catch {
case e: Throwable => failTest(e.toString)
}
}
pos += 1
Expand Down