Skip to content

Commit c8693d8

Browse files
committed
update
1 parent 00b91fd commit c8693d8

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ class LDA @Since("1.6.0") (
892892
val instr = Instrumentation.create(this, dataset)
893893
instr.logParams(featuresCol, topicDistributionCol, k, maxIter, subsamplingRate,
894894
checkpointInterval, keepLastCheckpoint, optimizeDocConcentration, topicConcentration,
895-
learningDecay, optimizer, learningOffset, seed)
895+
docConcentration, learningDecay, optimizer, learningOffset, seed)
896896

897897
val oldLDA = new OldLDA()
898898
.setK($(k))
@@ -912,6 +912,7 @@ class LDA @Since("1.6.0") (
912912
new DistributedLDAModel(uid, m.vocabSize, m, dataset.sparkSession, None)
913913
}
914914

915+
instr.logNumFeatures(newModel.vocabSize)
915916
val model = copyValues(newModel).setParent(this)
916917
instr.logSuccess(model)
917918
model

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
229229

230230
val instr = Instrumentation.create(this, dataset)
231231
instr.logParams(labelCol, featuresCol, censorCol, predictionCol, quantilesCol,
232-
fitIntercept, maxIter, tol, aggregationDepth)
232+
quantileProbabilities, fitIntercept, maxIter, tol, aggregationDepth)
233233
instr.logNumFeatures(numFeatures)
234234

235235
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
230230
summaryModel,
231231
model.diagInvAtWA.toArray,
232232
model.objectiveHistory)
233+
234+
lrModel.setSummary(Some(trainingSummary))
233235
instr.logSuccess(lrModel)
234-
return lrModel.setSummary(Some(trainingSummary))
236+
return lrModel
235237
}
236238

237239
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -284,8 +286,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
284286
model,
285287
Array(0D),
286288
Array(0D))
289+
290+
model.setSummary(Some(trainingSummary))
287291
instr.logSuccess(model)
288-
return model.setSummary(Some(trainingSummary))
292+
return model
289293
} else {
290294
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
291295
"Model cannot be regularized.")
@@ -407,8 +411,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
407411
model,
408412
Array(0D),
409413
objectiveHistory)
410-
instr.logSuccess(model)
414+
411415
model.setSummary(Some(trainingSummary))
416+
instr.logSuccess(model)
417+
model
412418
}
413419

414420
@Since("1.4.0")

0 commit comments

Comments
 (0)