Skip to content

Commit 14e2700

Browse files
yu-iskwmengxr
authored andcommitted
[SPARK-12874][ML] ML StringIndexer does not protect itself from column name duplication
## What changes were proposed in this pull request? ML StringIndexer does not protect itself from column name duplication. We should still improve a way to validate a schema of `StringIndexer` and `StringIndexerModel`. However, it would be great to fix at another issue. ## How was this patch tested? unit test Author: Yu ISHIKAWA <[email protected]> Closes apache#11370 from yu-iskw/SPARK-12874.
1 parent fb8bb04 commit 14e2700

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class StringIndexerModel (
150150
"Skip StringIndexerModel.")
151151
return dataset
152152
}
153+
validateAndTransformSchema(dataset.schema)
153154

154155
val indexer = udf { label: String =>
155156
if (labelToIndex.contains(label)) {

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ class StringIndexerSuite
118118
assert(indexerModel.transform(df).eq(df))
119119
}
120120

121+
test("StringIndexerModel can't overwrite output column") {
122+
val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
123+
val indexer = new StringIndexer()
124+
.setInputCol("input")
125+
.setOutputCol("output")
126+
.fit(df)
127+
intercept[IllegalArgumentException] {
128+
indexer.transform(df)
129+
}
130+
}
131+
121132
test("StringIndexer read/write") {
122133
val t = new StringIndexer()
123134
.setInputCol("myInputCol")

0 commit comments

Comments
 (0)