Skip to content

Commit 357d82d

Browse files
hhbyyhMLnick
authored andcommitted
[SPARK-13629][ML] Add binary toggle Param to CountVectorizer
## What changes were proposed in this pull request? It would be handy to add a binary toggle Param to CountVectorizer, as in the scikit-learn one: http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html If set, then all non-zero counts will be set to 1. ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes apache#11536 from hhbyyh/cvToggle.
1 parent 204c9de commit 357d82d

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,27 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
206206
/** @group setParam */
207207
def setMinTF(value: Double): this.type = set(minTF, value)
208208

209+
/**
210+
* Binary toggle to control the output vector values.
211+
* If True, all non zero counts are set to 1. This is useful for discrete probabilistic
212+
* models that model binary events rather than integer counts
213+
*
214+
* Default: false
215+
* @group param
216+
*/
217+
val binary: BooleanParam =
218+
new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
219+
"This is useful for discrete probabilistic models that model binary events rather " +
220+
"than integer counts")
221+
222+
/** @group getParam */
223+
def getBinary: Boolean = $(binary)
224+
225+
/** @group setParam */
226+
def setBinary(value: Boolean): this.type = set(binary, value)
227+
228+
setDefault(binary -> false)
229+
209230
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
210231
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
211232

@@ -232,7 +253,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
232253
} else {
233254
tokenCount * minTf
234255
}
235-
Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq)
256+
val effectiveCounts = if ($(binary)) {
257+
termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq
258+
}
259+
else {
260+
termCounts.filter(_._2 >= effectiveMinTF).toSeq
261+
}
262+
Vectors.sparse(dictBr.value.size, effectiveCounts)
236263
}
237264
dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
238265
}

mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
157157
(3, split("e e e e e"), Vectors.sparse(4, Seq())))
158158
).toDF("id", "words", "expected")
159159

160-
// minTF: count
160+
// minTF: set frequency
161161
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
162162
.setInputCol("words")
163163
.setOutputCol("features")
@@ -168,6 +168,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
168168
}
169169
}
170170

171+
test("CountVectorizerModel with binary") {
172+
val df = sqlContext.createDataFrame(Seq(
173+
(0, split("a a a b b c"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
174+
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
175+
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0))))
176+
)).toDF("id", "words", "expected")
177+
178+
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
179+
.setInputCol("words")
180+
.setOutputCol("features")
181+
.setBinary(true)
182+
cv.transform(df).select("features", "expected").collect().foreach {
183+
case Row(features: Vector, expected: Vector) =>
184+
assert(features ~== expected absTol 1e-14)
185+
}
186+
}
187+
171188
test("CountVectorizer read/write") {
172189
val t = new CountVectorizer()
173190
.setInputCol("myInputCol")

0 commit comments

Comments
 (0)