Skip to content

Commit e7f9016

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-11769][ML] Add save, load to all basic Transformers
This excludes Estimators and ones which include Vector and other non-basic types for Params or data. This adds: * Bucketizer * DCT * HashingTF * Interaction * NGram * Normalizer * OneHotEncoder * PolynomialExpansion * QuantileDiscretizer * RFormula * SQLTransformer * StopWordsRemover * StringIndexer * Tokenizer * VectorAssembler * VectorSlicer CC: mengxr Author: Joseph K. Bradley <[email protected]> Closes #9755 from jkbradley/transformer-io. (cherry picked from commit d98d1cb) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 88431fb commit e7f9016

32 files changed

+453
-84
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Since, Experimental}
2121
import org.apache.spark.ml.Transformer
2222
import org.apache.spark.ml.attribute.BinaryAttribute
2323
import org.apache.spark.ml.param._
@@ -87,10 +87,16 @@ final class Binarizer(override val uid: String)
8787

8888
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
8989

90+
@Since("1.6.0")
9091
override def write: Writer = new DefaultParamsWriter(this)
9192
}
9293

94+
@Since("1.6.0")
9395
object Binarizer extends Readable[Binarizer] {
9496

97+
@Since("1.6.0")
9598
override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
99+
100+
@Since("1.6.0")
101+
override def load(path: String): Binarizer = read.load(path)
96102
}

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ package org.apache.spark.ml.feature
2020
import java.{util => ju}
2121

2222
import org.apache.spark.SparkException
23-
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.annotation.{Since, Experimental}
2424
import org.apache.spark.ml.Model
2525
import org.apache.spark.ml.attribute.NominalAttribute
2626
import org.apache.spark.ml.param._
2727
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
28-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
28+
import org.apache.spark.ml.util._
2929
import org.apache.spark.sql._
3030
import org.apache.spark.sql.functions._
3131
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3636
*/
3737
@Experimental
3838
final class Bucketizer(override val uid: String)
39-
extends Model[Bucketizer] with HasInputCol with HasOutputCol {
39+
extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable {
4040

4141
def this() = this(Identifiable.randomUID("bucketizer"))
4242

@@ -93,11 +93,15 @@ final class Bucketizer(override val uid: String)
9393
override def copy(extra: ParamMap): Bucketizer = {
9494
defaultCopy[Bucketizer](extra).setParent(parent)
9595
}
96+
97+
@Since("1.6.0")
98+
override def write: Writer = new DefaultParamsWriter(this)
9699
}
97100

98-
private[feature] object Bucketizer {
101+
object Bucketizer extends Readable[Bucketizer] {
102+
99103
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
100-
def checkSplits(splits: Array[Double]): Boolean = {
104+
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
101105
if (splits.length < 3) {
102106
false
103107
} else {
@@ -115,7 +119,7 @@ private[feature] object Bucketizer {
115119
* Binary searching in several buckets to place each data point.
116120
* @throws SparkException if a feature is < splits.head or > splits.last
117121
*/
118-
def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
122+
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
119123
if (feature == splits.last) {
120124
splits.length - 2
121125
} else {
@@ -134,4 +138,10 @@ private[feature] object Bucketizer {
134138
}
135139
}
136140
}
141+
142+
@Since("1.6.0")
143+
override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer]
144+
145+
@Since("1.6.0")
146+
override def load(path: String): Bucketizer = read.load(path)
137147
}

mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ package org.apache.spark.ml.feature
1919

2020
import edu.emory.mathcs.jtransforms.dct._
2121

22-
import org.apache.spark.annotation.Experimental
22+
import org.apache.spark.annotation.{Since, Experimental}
2323
import org.apache.spark.ml.UnaryTransformer
2424
import org.apache.spark.ml.param.BooleanParam
25-
import org.apache.spark.ml.util.Identifiable
25+
import org.apache.spark.ml.util._
2626
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
2727
import org.apache.spark.sql.types.DataType
2828

@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType
3737
*/
3838
@Experimental
3939
class DCT(override val uid: String)
40-
extends UnaryTransformer[Vector, Vector, DCT] {
40+
extends UnaryTransformer[Vector, Vector, DCT] with Writable {
4141

4242
def this() = this(Identifiable.randomUID("dct"))
4343

@@ -69,4 +69,17 @@ class DCT(override val uid: String)
6969
}
7070

7171
override protected def outputDataType: DataType = new VectorUDT
72+
73+
@Since("1.6.0")
74+
override def write: Writer = new DefaultParamsWriter(this)
75+
}
76+
77+
@Since("1.6.0")
78+
object DCT extends Readable[DCT] {
79+
80+
@Since("1.6.0")
81+
override def read: Reader[DCT] = new DefaultParamsReader[DCT]
82+
83+
@Since("1.6.0")
84+
override def load(path: String): DCT = read.load(path)
7285
}

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Since, Experimental}
2121
import org.apache.spark.ml.Transformer
2222
import org.apache.spark.ml.attribute.AttributeGroup
2323
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
2424
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
25-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
25+
import org.apache.spark.ml.util._
2626
import org.apache.spark.mllib.feature
2727
import org.apache.spark.sql.DataFrame
2828
import org.apache.spark.sql.functions.{col, udf}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType}
3333
* Maps a sequence of terms to their term frequencies using the hashing trick.
3434
*/
3535
@Experimental
36-
class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
36+
class HashingTF(override val uid: String)
37+
extends Transformer with HasInputCol with HasOutputCol with Writable {
3738

3839
def this() = this(Identifiable.randomUID("hashingTF"))
3940

@@ -76,4 +77,17 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
7677
}
7778

7879
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
80+
81+
@Since("1.6.0")
82+
override def write: Writer = new DefaultParamsWriter(this)
83+
}
84+
85+
@Since("1.6.0")
86+
object HashingTF extends Readable[HashingTF] {
87+
88+
@Since("1.6.0")
89+
override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF]
90+
91+
@Since("1.6.0")
92+
override def load(path: String): HashingTF = read.load(path)
7993
}

mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ package org.apache.spark.ml.feature
2020
import scala.collection.mutable.ArrayBuilder
2121

2222
import org.apache.spark.SparkException
23-
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.annotation.{Since, Experimental}
2424
import org.apache.spark.ml.attribute._
2525
import org.apache.spark.ml.param._
2626
import org.apache.spark.ml.param.shared._
27-
import org.apache.spark.ml.util.Identifiable
27+
import org.apache.spark.ml.util._
2828
import org.apache.spark.ml.Transformer
2929
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
3030
import org.apache.spark.sql.{DataFrame, Row}
@@ -42,24 +42,30 @@ import org.apache.spark.sql.types._
4242
* `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal
4343
* with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`.
4444
*/
45+
@Since("1.6.0")
4546
@Experimental
46-
class Interaction(override val uid: String) extends Transformer
47-
with HasInputCols with HasOutputCol {
47+
class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
48+
with HasInputCols with HasOutputCol with Writable {
4849

50+
@Since("1.6.0")
4951
def this() = this(Identifiable.randomUID("interaction"))
5052

5153
/** @group setParam */
54+
@Since("1.6.0")
5255
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
5356

5457
/** @group setParam */
58+
@Since("1.6.0")
5559
def setOutputCol(value: String): this.type = set(outputCol, value)
5660

5761
// optimistic schema; does not contain any ML attributes
62+
@Since("1.6.0")
5863
override def transformSchema(schema: StructType): StructType = {
5964
validateParams()
6065
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
6166
}
6267

68+
@Since("1.6.0")
6369
override def transform(dataset: DataFrame): DataFrame = {
6470
validateParams()
6571
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
@@ -208,14 +214,29 @@ class Interaction(override val uid: String) extends Transformer
208214
}
209215
}
210216

217+
@Since("1.6.0")
211218
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
212219

220+
@Since("1.6.0")
213221
override def validateParams(): Unit = {
214222
require(get(inputCols).isDefined, "Input cols must be defined first.")
215223
require(get(outputCol).isDefined, "Output col must be defined first.")
216224
require($(inputCols).length > 0, "Input cols must have non-zero length.")
217225
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
218226
}
227+
228+
@Since("1.6.0")
229+
override def write: Writer = new DefaultParamsWriter(this)
230+
}
231+
232+
@Since("1.6.0")
233+
object Interaction extends Readable[Interaction] {
234+
235+
@Since("1.6.0")
236+
override def read: Reader[Interaction] = new DefaultParamsReader[Interaction]
237+
238+
@Since("1.6.0")
239+
override def load(path: String): Interaction = read.load(path)
219240
}
220241

221242
/**

mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Since, Experimental}
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param._
23-
import org.apache.spark.ml.util.Identifiable
23+
import org.apache.spark.ml.util._
2424
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
2525

2626
/**
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
3636
*/
3737
@Experimental
3838
class NGram(override val uid: String)
39-
extends UnaryTransformer[Seq[String], Seq[String], NGram] {
39+
extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable {
4040

4141
def this() = this(Identifiable.randomUID("ngram"))
4242

@@ -66,4 +66,17 @@ class NGram(override val uid: String)
6666
}
6767

6868
override protected def outputDataType: DataType = new ArrayType(StringType, false)
69+
70+
@Since("1.6.0")
71+
override def write: Writer = new DefaultParamsWriter(this)
72+
}
73+
74+
@Since("1.6.0")
75+
object NGram extends Readable[NGram] {
76+
77+
@Since("1.6.0")
78+
override def read: Reader[NGram] = new DefaultParamsReader[NGram]
79+
80+
@Since("1.6.0")
81+
override def load(path: String): NGram = read.load(path)
6982
}

mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Since, Experimental}
2121
import org.apache.spark.ml.UnaryTransformer
2222
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
23-
import org.apache.spark.ml.util.Identifiable
23+
import org.apache.spark.ml.util._
2424
import org.apache.spark.mllib.feature
2525
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2626
import org.apache.spark.sql.types.DataType
@@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType
3030
* Normalize a vector to have unit norm using the given p-norm.
3131
*/
3232
@Experimental
33-
class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] {
33+
class Normalizer(override val uid: String)
34+
extends UnaryTransformer[Vector, Vector, Normalizer] with Writable {
3435

3536
def this() = this(Identifiable.randomUID("normalizer"))
3637

@@ -55,4 +56,17 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect
5556
}
5657

5758
override protected def outputDataType: DataType = new VectorUDT()
59+
60+
@Since("1.6.0")
61+
override def write: Writer = new DefaultParamsWriter(this)
62+
}
63+
64+
@Since("1.6.0")
65+
object Normalizer extends Readable[Normalizer] {
66+
67+
@Since("1.6.0")
68+
override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer]
69+
70+
@Since("1.6.0")
71+
override def load(path: String): Normalizer = read.load(path)
5872
}

mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Since, Experimental}
2121
import org.apache.spark.ml.Transformer
2222
import org.apache.spark.ml.attribute._
2323
import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
25-
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
25+
import org.apache.spark.ml.util._
2626
import org.apache.spark.mllib.linalg.Vectors
2727
import org.apache.spark.sql.DataFrame
2828
import org.apache.spark.sql.functions.{col, udf}
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
4444
*/
4545
@Experimental
4646
class OneHotEncoder(override val uid: String) extends Transformer
47-
with HasInputCol with HasOutputCol {
47+
with HasInputCol with HasOutputCol with Writable {
4848

4949
def this() = this(Identifiable.randomUID("oneHot"))
5050

@@ -165,4 +165,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
165165
}
166166

167167
override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
168+
169+
@Since("1.6.0")
170+
override def write: Writer = new DefaultParamsWriter(this)
171+
}
172+
173+
@Since("1.6.0")
174+
object OneHotEncoder extends Readable[OneHotEncoder] {
175+
176+
@Since("1.6.0")
177+
override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder]
178+
179+
@Since("1.6.0")
180+
override def load(path: String): OneHotEncoder = read.load(path)
168181
}

0 commit comments

Comments
 (0)