Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -87,10 +87,16 @@ final class Binarizer(override val uid: String)

override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object Binarizer extends Readable[Binarizer] {

@Since("1.6.0")
override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]

@Since("1.6.0")
override def load(path: String): Binarizer = read.load(path)
}
22 changes: 16 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature
import java.{util => ju}

import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
Expand All @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*/
@Experimental
final class Bucketizer(override val uid: String)
extends Model[Bucketizer] with HasInputCol with HasOutputCol {
extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable {

def this() = this(Identifiable.randomUID("bucketizer"))

Expand Down Expand Up @@ -93,11 +93,15 @@ final class Bucketizer(override val uid: String)
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

private[feature] object Bucketizer {
object Bucketizer extends Readable[Bucketizer] {

/** We require splits to be of length >= 3 and to be in strictly increasing order. */
def checkSplits(splits: Array[Double]): Boolean = {
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
Expand All @@ -115,7 +119,7 @@ private[feature] object Bucketizer {
* Binary searching in several buckets to place each data point.
* @throws SparkException if a feature is < splits.head or > splits.last
*/
def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
if (feature == splits.last) {
splits.length - 2
} else {
Expand All @@ -134,4 +138,10 @@ private[feature] object Bucketizer {
}
}
}

@Since("1.6.0")
override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer]

@Since("1.6.0")
override def load(path: String): Bucketizer = read.load(path)
}
19 changes: 16 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature

import edu.emory.mathcs.jtransforms.dct._

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.BooleanParam
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.types.DataType

Expand All @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class DCT(override val uid: String)
extends UnaryTransformer[Vector, Vector, DCT] {
extends UnaryTransformer[Vector, Vector, DCT] with Writable {

def this() = this(Identifiable.randomUID("dct"))

Expand Down Expand Up @@ -69,4 +69,17 @@ class DCT(override val uid: String)
}

override protected def outputDataType: DataType = new VectorUDT

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object DCT extends Readable[DCT] {

@Since("1.6.0")
override def read: Reader[DCT] = new DefaultParamsReader[DCT]

@Since("1.6.0")
override def load(path: String): DCT = read.load(path)
}
20 changes: 17 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
Expand All @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType}
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
@Experimental
class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
class HashingTF(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol with Writable {

def this() = this(Identifiable.randomUID("hashingTF"))

Expand Down Expand Up @@ -76,4 +77,17 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
}

override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object HashingTF extends Readable[HashingTF] {

@Since("1.6.0")
override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF]

@Since("1.6.0")
override def load(path: String): HashingTF = read.load(path)
}
29 changes: 25 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder

import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.ml.Transformer
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, Row}
Expand All @@ -42,24 +42,30 @@ import org.apache.spark.sql.types._
* `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal
* with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`.
*/
@Since("1.6.0")
@Experimental
class Interaction(override val uid: String) extends Transformer
with HasInputCols with HasOutputCol {
class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
with HasInputCols with HasOutputCol with Writable {

@Since("1.6.0")
def this() = this(Identifiable.randomUID("interaction"))

/** @group setParam */
@Since("1.6.0")
def setInputCols(values: Array[String]): this.type = set(inputCols, values)

/** @group setParam */
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

// optimistic schema; does not contain any ML attributes
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}

@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
validateParams()
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
Expand Down Expand Up @@ -208,14 +214,29 @@ class Interaction(override val uid: String) extends Transformer
}
}

@Since("1.6.0")
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)

@Since("1.6.0")
override def validateParams(): Unit = {
require(get(inputCols).isDefined, "Input cols must be defined first.")
require(get(outputCol).isDefined, "Output col must be defined first.")
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
}

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object Interaction extends Readable[Interaction] {

@Since("1.6.0")
override def read: Reader[Interaction] = new DefaultParamsReader[Interaction]

@Since("1.6.0")
override def load(path: String): Interaction = read.load(path)
}

/**
Expand Down
19 changes: 16 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

/**
Expand All @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
*/
@Experimental
class NGram(override val uid: String)
extends UnaryTransformer[Seq[String], Seq[String], NGram] {
extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable {

def this() = this(Identifiable.randomUID("ngram"))

Expand Down Expand Up @@ -66,4 +66,17 @@ class NGram(override val uid: String)
}

override protected def outputDataType: DataType = new ArrayType(StringType, false)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object NGram extends Readable[NGram] {

@Since("1.6.0")
override def read: Reader[NGram] = new DefaultParamsReader[NGram]

@Since("1.6.0")
override def load(path: String): NGram = read.load(path)
}
20 changes: 17 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
Expand All @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType
* Normalize a vector to have unit norm using the given p-norm.
*/
@Experimental
class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] {
class Normalizer(override val uid: String)
extends UnaryTransformer[Vector, Vector, Normalizer] with Writable {

def this() = this(Identifiable.randomUID("normalizer"))

Expand All @@ -55,4 +56,17 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect
}

override protected def outputDataType: DataType = new VectorUDT()

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object Normalizer extends Readable[Normalizer] {

@Since("1.6.0")
override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer]

@Since("1.6.0")
override def load(path: String): Normalizer = read.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
Expand All @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/
@Experimental
class OneHotEncoder(override val uid: String) extends Transformer
with HasInputCol with HasOutputCol {
with HasInputCol with HasOutputCol with Writable {

def this() = this(Identifiable.randomUID("oneHot"))

Expand Down Expand Up @@ -165,4 +165,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
}

override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object OneHotEncoder extends Readable[OneHotEncoder] {

@Since("1.6.0")
override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder]

@Since("1.6.0")
override def load(path: String): OneHotEncoder = read.load(path)
}
Loading