Skip to content

Commit c13c10f

Browse files
committed
[SPARK-53362][ML][CONNECT] Fix IDFModel local loader bug
### What changes were proposed in this pull request? Fix IDFModel local loader bug: https://github.com/apache/spark/blob/50a2ebe87f5ca4dbcc732cf3a543d6ebaea856ad/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala#L234 This line code shouldn't be executed when it is loading from local disk. Otherwise file formatting error might be raised. ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? updated test in mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala ### Was this patch authored or co-authored using generative AI tooling? no. Closes #52111 from WeichenXu123/SPARK-53362. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent ef9322f commit c13c10f

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ object IDFModel extends MLReadable[IDFModel] {
231231
override def load(path: String): IDFModel = {
232232
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
233233
val dataPath = new Path(path, "data").toString
234-
val data = sparkSession.read.parquet(dataPath)
235234

236235
val model = if (majorVersion(metadata.sparkVersion) >= 3) {
237236
val data = ReadWriteUtils.loadObject[Data](dataPath, sparkSession, deserializeData)
@@ -240,6 +239,7 @@ object IDFModel extends MLReadable[IDFModel] {
240239
new feature.IDFModel(OldVectors.fromML(data.idf), data.docFreq, data.numDocs)
241240
)
242241
} else {
242+
val data = sparkSession.read.parquet(dataPath)
243243
val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf")
244244
.select("idf")
245245
.head()

mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,12 @@ class IDFSuite extends MLTest with DefaultReadWriteTest {
120120
new OldIDFModel(Vectors.dense(1.0, 2.0), Array(1, 2), 2))
121121
.setInputCol("myInputCol")
122122
.setOutputCol("myOutputCol")
123-
val newInstance = testDefaultReadWrite(instance)
124-
assert(newInstance.idf === instance.idf)
125-
assert(newInstance.docFreq === instance.docFreq)
126-
assert(newInstance.numDocs === instance.numDocs)
123+
124+
for (testSaveToLocal <- Seq(false, true)) {
125+
val newInstance = testDefaultReadWrite(instance, testSaveToLocal = testSaveToLocal)
126+
assert(newInstance.idf === instance.idf)
127+
assert(newInstance.docFreq === instance.docFreq)
128+
assert(newInstance.numDocs === instance.numDocs)
129+
}
127130
}
128131
}

python/pyspark/ml/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def remote_call() -> Any:
339339
session.client.execute_command(create_summary_command) # type: ignore
340340

341341
return remote_call()
342+
343+
# for other unexpected error, re-raise it.
344+
raise
342345
else:
343346
return f(self, name, *args)
344347

0 commit comments

Comments
 (0)