Skip to content

Commit 894345e

Browse files
author
Alessandro Gagliardi
committed
Fix SimpleIndexer fit method to set inputCol and outputCol correctly
1 parent 3b83a93 commit 894345e

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

src/main/scala/com/high-performance-spark-examples/ml/CustomPipeline.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ class SimpleIndexer(override val uid: String)
122122
import dataset.sparkSession.implicits._
123123
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
124124
.collect()
125-
new SimpleIndexerModel(uid, words)
125+
val model = new SimpleIndexerModel(uid, words)
126+
model.set(inputCol, $(inputCol))
127+
model.set(outputCol, $(outputCol))
128+
model
126129
}
127130
}
128131

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/**
2+
* Simple tests for our CustomPipeline demo pipeline stage
3+
*/
4+
package com.highperformancespark.examples.ml
5+
6+
import com.holdenkarau.spark.testing.DataFrameSuiteBase
7+
import org.apache.spark.sql.Dataset
8+
import org.scalatest.FunSuite
9+
10+
case class TestRow(id: Int, inputColumn: String)
11+
12+
class CustomPipelineSuite extends FunSuite with DataFrameSuiteBase {
13+
val d = List(
14+
TestRow(0, "a"),
15+
TestRow(1, "b"),
16+
TestRow(2, "c"),
17+
TestRow(3, "a"),
18+
TestRow(4, "a"),
19+
TestRow(5, "c")
20+
)
21+
22+
test("test spark context") {
23+
val session = spark
24+
val rdd = session.sparkContext.parallelize(1 to 10)
25+
assert(rdd.sum === 55)
26+
}
27+
28+
test("simple indexer test") {
29+
val session = spark
30+
import session.implicits._
31+
val ds: Dataset[TestRow] = session.createDataset(d)
32+
val indexer = new SimpleIndexer()
33+
indexer.setInputCol("inputColumn")
34+
indexer.setOutputCol("categoryIndex")
35+
val model = indexer.fit(ds)
36+
val predicted = model.transform(ds)
37+
assert(predicted.columns.contains("categoryIndex"))
38+
predicted.show()
39+
}
40+
}

0 commit comments

Comments
 (0)