File tree Expand file tree Collapse file tree 2 files changed +44
-1
lines changed
main/scala/com/high-performance-spark-examples/ml
test/scala/com/high-performance-spark-examples/ml Expand file tree Collapse file tree 2 files changed +44
-1
lines changed Original file line number Diff line number Diff line change @@ -122,7 +122,10 @@ class SimpleIndexer(override val uid: String)
122122 import dataset .sparkSession .implicits ._
123123 val words = dataset.select(dataset($(inputCol)).as[String ]).distinct
124124 .collect()
125- new SimpleIndexerModel (uid, words)
125+ val model = new SimpleIndexerModel (uid, words)
126+ model.set(inputCol, $(inputCol))
127+ model.set(outputCol, $(outputCol))
128+ model
126129 }
127130}
128131
Original file line number Diff line number Diff line change 1+ /**
2+ * Simple tests for our CustomPipeline demo pipeline stage
3+ */
4+ package com .highperformancespark .examples .ml
5+
6+ import com .holdenkarau .spark .testing .DataFrameSuiteBase
7+ import org .apache .spark .sql .Dataset
8+ import org .scalatest .FunSuite
9+
10+ case class TestRow (id : Int , inputColumn : String )
11+
12+ class CustomPipelineSuite extends FunSuite with DataFrameSuiteBase {
13+ val d = List (
14+ TestRow (0 , " a" ),
15+ TestRow (1 , " b" ),
16+ TestRow (2 , " c" ),
17+ TestRow (3 , " a" ),
18+ TestRow (4 , " a" ),
19+ TestRow (5 , " c" )
20+ )
21+
22+ test(" test spark context" ) {
23+ val session = spark
24+ val rdd = session.sparkContext.parallelize(1 to 10 )
25+ assert(rdd.sum === 55 )
26+ }
27+
28+ test(" simple indexer test" ) {
29+ val session = spark
30+ import session .implicits ._
31+ val ds : Dataset [TestRow ] = session.createDataset(d)
32+ val indexer = new SimpleIndexer ()
33+ indexer.setInputCol(" inputColumn" )
34+ indexer.setOutputCol(" categoryIndex" )
35+ val model = indexer.fit(ds)
36+ val predicted = model.transform(ds)
37+ assert(predicted.columns.contains(" categoryIndex" ))
38+ predicted.show()
39+ }
40+ }
You can’t perform that action at this time.
0 commit comments