@@ -169,7 +169,7 @@ private[spark] object GradientBoostedTrees extends Logging {
169169 * @param loss evaluation metric.
170170 * @return Measure of model error on data
171171 */
172- def computeError (
172+ def computeWeightedError (
173173 data : RDD [Instance ],
174174 trees : Array [DecisionTreeRegressionModel ],
175175 treeWeights : Array [Double ],
@@ -179,7 +179,7 @@ private[spark] object GradientBoostedTrees extends Logging {
179179 updatePrediction(features, acc, model, weight)
180180 }
181181 (loss.computeError(predicted, label) * weight, weight)
182- }.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
182+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
183183 (err1 + err2, weight1 + weight2)
184184 }
185185 errSum / weightSum
@@ -191,13 +191,13 @@ private[spark] object GradientBoostedTrees extends Logging {
191191 * @param predError Prediction and error.
192192 * @return Measure of model error on data
193193 */
194- def computeError (
194+ def computeWeightedError (
195195 data : RDD [Instance ],
196196 predError : RDD [(Double , Double )]): Double = {
197197 val (errSum, weightSum) = data.zip(predError).map {
198198 case (Instance (_, weight, _), (_, err)) =>
199199 (err * weight, weight)
200- }.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
200+ }.treeReduce { case ((err1, weight1), (err2, weight2)) =>
201201 (err1 + err2, weight1 + weight2)
202202 }
203203 errSum / weightSum
@@ -220,24 +220,18 @@ private[spark] object GradientBoostedTrees extends Logging {
220220 treeWeights : Array [Double ],
221221 loss : OldLoss ,
222222 algo : OldAlgo .Value ): Array [Double ] = {
223-
224- val sc = data.sparkContext
225223 val remappedData = algo match {
226224 case OldAlgo .Classification =>
227225 data.map(x => Instance ((x.label * 2 ) - 1 , x.weight, x.features))
228226 case _ => data
229227 }
230228
231- val broadcastTrees = sc.broadcast(trees)
232- val localTreeWeights = treeWeights
233229 val numTrees = trees.length
234-
235230 val (errSum, weightSum) = remappedData.mapPartitions { iter =>
236- val trees = broadcastTrees.value
237231 iter.map { case Instance (label, weight, features) =>
238232 val pred = Array .tabulate(numTrees) { i =>
239233 trees(i).rootNode.predictImpl(features)
240- .prediction * localTreeWeights (i)
234+ .prediction * treeWeights (i)
241235 }
242236 val err = pred.scanLeft(0.0 )(_ + _).drop(1 )
243237 .map(p => loss.computeError(p, label) * weight)
@@ -248,7 +242,6 @@ private[spark] object GradientBoostedTrees extends Logging {
248242 (err1, weight1 + weight2)
249243 }
250244
251- broadcastTrees.destroy()
252245 errSum.map(_ / weightSum)
253246 }
254247
@@ -298,8 +291,10 @@ private[spark] object GradientBoostedTrees extends Logging {
298291 }
299292
300293 // Prepare periodic checkpointers
294+ // Note: this is checkpointing the unweighted training error
301295 val predErrorCheckpointer = new PeriodicRDDCheckpointer [(Double , Double )](
302296 treeStrategy.getCheckpointInterval, input.sparkContext)
297+ // Note: this is checkpointing the unweighted validation error
303298 val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer [(Double , Double )](
304299 treeStrategy.getCheckpointInterval, input.sparkContext)
305300
@@ -319,15 +314,19 @@ private[spark] object GradientBoostedTrees extends Logging {
319314
320315 var predError = computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
321316 predErrorCheckpointer.update(predError)
322- logDebug(" error of gbt = " + computeError (input, predError))
317+ logDebug(" error of gbt = " + computeWeightedError (input, predError))
323318
324319 // Note: A model of type regression is used since we require raw prediction
325320 timer.stop(" building tree 0" )
326321
327322 var validatePredError =
328323 computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
329324 if (validate) validatePredErrorCheckpointer.update(validatePredError)
330- var bestValidateError = if (validate) computeError(validationInput, validatePredError) else 0.0
325+ var bestValidateError = if (validate) {
326+ computeWeightedError(validationInput, validatePredError)
327+ } else {
328+ 0.0
329+ }
331330 var bestM = 1
332331
333332 var m = 1
@@ -356,7 +355,7 @@ private[spark] object GradientBoostedTrees extends Logging {
356355 predError = updatePredictionError(
357356 input, predError, baseLearnerWeights(m), baseLearners(m), loss)
358357 predErrorCheckpointer.update(predError)
359- logDebug(" error of gbt = " + computeError (input, predError))
358+ logDebug(" error of gbt = " + computeWeightedError (input, predError))
360359
361360 if (validate) {
362361 // Stop training early if
@@ -367,7 +366,7 @@ private[spark] object GradientBoostedTrees extends Logging {
367366 validatePredError = updatePredictionError(
368367 validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
369368 validatePredErrorCheckpointer.update(validatePredError)
370- val currentValidateError = computeError (validationInput, validatePredError)
369+ val currentValidateError = computeWeightedError (validationInput, validatePredError)
371370 if (bestValidateError - currentValidateError < validationTol * Math .max(
372371 currentValidateError, 0.01 )) {
373372 doneLearning = true
0 commit comments