Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
<dependency>
<groupId>org.scalanlp</groupId>
<artifactId>breeze_${scala.binary.version}</artifactId>
<version>0.7</version>
<version>0.8.1</version>
<exclusions>
<!-- This is included as a compile-scoped dependency by jtransforms, which is
a dependency of breeze. -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ import org.apache.spark.mllib.rdd.RDDFunctions._
* @param updater Updater to be used to update weights after every iteration.
*/
@DeveloperApi
class LBFGS(private var gradient: Gradient, private var updater: Updater)
class LBFGS(protected var gradient: Gradient, protected var updater: Updater)
extends Optimizer with Logging {

private var numCorrections = 10
private var convergenceTol = 1E-4
private var maxNumIterations = 100
private var regParam = 0.0
protected var numCorrections = 10
protected var convergenceTol = 1E-4
protected var maxNumIterations = 100
protected var regParam = 0.0

/**
* Set the number of corrections used in the LBFGS update. Default 10.
Expand Down Expand Up @@ -185,7 +185,7 @@ object LBFGS extends Logging {
* CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient
* at a particular point (weights). It's used in Breeze's convex optimization routines.
*/
private class CostFun(
class CostFun(
data: RDD[(Double, Vector)],
gradient: Gradient,
updater: Updater,
Expand Down
135 changes: 135 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/optimization/OWLQN.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.optimization

import scala.collection.mutable.ArrayBuffer

import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{OWLQN => BreezeOWLQN, CachedDiffFunction}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}

/**
* :: DeveloperApi ::
* Class used to solve an optimization problem with both L1 and L2 regularizations.
* Spark is used to compute and aggregate statistics needed to do OWL-QN steps.
* The OWLQN class from the breeze library does the orthant-projections and stepping.
* Reference: [[http://machinelearning.wustl.edu/mlpapers/paper_files/icml2007_AndrewG07.pdf]]
* @param gradient Gradient function to be used.
*/
@DeveloperApi
class OWLQN(gradient: Gradient)
extends LBFGS(gradient, new SquaredL2Updater) {

// This has to be between 0 and 1.
// 1.0 == L1 regularization. 0.0 == L2 regularization
private var alpha = 0.0

def setAlpha(alpha: Double): this.type = {
this.alpha = alpha
this
}

override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
val (weights, _) = OWLQN.runOWLQN(
data,
gradient,
updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
alpha,
initialWeights)
weights
}
}

/**
* :: DeveloperApi ::
* Top-level method to run OWLQN.
*/
@DeveloperApi
object OWLQN extends Logging {
/**
* Run OWL-QN in parallel using mini batches.
* The cost function to be used here is exactly the same as L-BFGS (which can handle L2 regularization as well).
* The only difference is that instead of L-BFGS from breeze, we use OWL-QN from breeze and
* we allow the user to specify the alpha that determines regularization weights between L1 and L2.
*
* @param data - Input data for OWLQN. RDD of the set of data examples, each of
* the form (label, [feature values]).
* @param gradient - Gradient object (used to compute the gradient of the loss function of
* one single data example)
* @param updater - Updater function to actually perform a gradient step in a given direction.
* @param numCorrections - The number of corrections used in the OWLQN update.
* @param convergenceTol - The convergence tolerance of iterations for OWLQN
* @param maxNumIterations - Maximal number of iterations that OWLQN can be run.
* @param regParam - Regularization parameter
* @param alpha - Between 0.0 and 1.0. L1 weight becomes alpha * regParam. L2 weight becomes (1 - alpha) * regParam
* @param initialWeights - Initial weights to start the optimization process from.
*
* @return A tuple containing two elements. The first element is a column matrix containing
* weights for every feature, and the second element is an array containing the loss
* computed for every iteration.
*/
def runOWLQN(
data: RDD[(Double, Vector)],
gradient: Gradient,
updater: Updater,
numCorrections: Int,
convergenceTol: Double,
maxNumIterations: Int,
regParam: Double,
alpha: Double,
initialWeights: Vector): (Vector, Array[Double]) = {

val lossHistory = new ArrayBuffer[Double](maxNumIterations)

val numExamples = data.count()

val l1RegParam = alpha * regParam
val l2RegParam = (1.0 - alpha) * regParam

// Cost function doesn't change from LBFGS because breeze's OWLQN code handles all the L1 related things.
val costFun =
new LBFGS.CostFun(data, gradient, updater, l2RegParam, numExamples)

val owlqn = new BreezeOWLQN[BDV[Double]](maxNumIterations, numCorrections, l1RegParam, convergenceTol)

val states =
owlqn.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)

var state = states.next()
while (states.hasNext) {
lossHistory.append(state.adjustedValue)
state = states.next()
}

lossHistory.append(state.adjustedValue)
val weights = Vectors.fromBreeze(state.x)

logInfo("OWLQN.runMiniBatchOWLQN finished. Last 10 losses %s".format(
lossHistory.takeRight(10).mkString(", ")))

(weights, lossHistory.toArray)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.optimization

import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext

class OWLQNSuite extends FunSuite with LocalSparkContext with ShouldMatchers {

val nPoints = 10000
val A = 2.0
val B = -1.5

val initialB = -1.0
val initialWeights = Array(initialB)

val gradient = new LogisticGradient()
val numCorrections = 10
val miniBatchFrac = 1.0

val l1Updater = new L1Updater()
val squaredL2Updater = new SquaredL2Updater()

val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
val data = testData.map { case LabeledPoint(label, features) =>
label -> Vectors.dense(1.0, features.toArray: _*)
}

lazy val dataRDD = sc.parallelize(data, 2).cache()

def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
math.abs(x - y) / (math.abs(y) + 1e-15) < tol
}

test("OWLQN loss should be decreasing and match the result of Gradient Descent.") {
val regParam = 0.3
val alpha = 1.0

val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*)
val convergenceTol = 1e-12
val maxNumIterations = 10

val (weights1, loss) = OWLQN.runOWLQN(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
alpha,
initialWeightsWithIntercept)

// Since the cost function is convex, the loss is guaranteed to be monotonically decreasing
// with OWLQN optimizer.
assert((loss, loss.tail).zipped.forall(_ > _), "loss should be monotonically decreasing.")

val stepSize = 1.0

// Well, GD converges slower, so it requires more iterations!
val numGDIterations = 100
val (weights2, lossGD) = GradientDescent.runMiniBatchSGD(
dataRDD,
gradient,
l1Updater,
stepSize,
numGDIterations,
regParam,
miniBatchFrac,
initialWeightsWithIntercept)

// GD converges a way slower than OWLQN. To achieve 1% difference,
// it requires 90 iterations in GD. No matter how hard we increase
// the number of iterations in GD here, the lossGD will be always
// larger than lossLBFGS. This is based on observation, no theoretically guaranteed
assert(Math.abs((lossGD.last - loss.last) / loss.last) < 0.02,
"OWLQN should match GD result within 2% difference.")
}

test("OWLQN with L2 regularization should get the same result as LBFGS with L2 regularization.") {
val regParam = 0.2

// Prepare another non-zero weights to compare the loss in the first iteration.
val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12)
val convergenceTol = 1e-12
val maxNumIterations = 10

val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
initialWeightsWithIntercept)

val (weightOWLQN, lossOWLQN) = OWLQN.runOWLQN(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
0.0,
initialWeightsWithIntercept)

assert(compareDouble(lossOWLQN(0), lossLBFGS(0)),
"The first losses of LBFGS and OWLQN should be the same.")

// OWLQN and LBFGS employ different line search, so the results might be slightly different.
assert(compareDouble(lossOWLQN.last, lossLBFGS.last, 0.02),
"The last losses of LBFGS and OWLQN should be within 2% difference.")

assert(compareDouble(weightLBFGS(0), weightOWLQN(0), 0.02) &&
compareDouble(weightLBFGS(1), weightOWLQN(1), 0.02),
"The weight differences between LBFGS and OWLQN should be within 2%.")
}

test("The convergence criteria should work as expected.") {
val regParam = 0.01
val alpha = 0.5

/**
* For the first run, we set the convergenceTol to 0.0, so that the algorithm will
* run up to the maxNumIterations which is 8 here.
*/
val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0)
val maxNumIterations = 8
var convergenceTol = 0.0

val (weights1, lossOWLQN1) = OWLQN.runOWLQN(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
alpha,
initialWeightsWithIntercept)

// Note that the first loss is computed with initial weights,
// so the total numbers of loss will be numbers of iterations + 1
assert(lossOWLQN1.length == 9)

convergenceTol = 0.1
val (_, lossOWLQN2) = OWLQN.runOWLQN(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
alpha,
initialWeightsWithIntercept)

// Based on observation, lossLBFGS2 runs 3 iterations, no theoretically guaranteed.
assert(lossOWLQN2.length == 4)
assert((lossOWLQN2(2) - lossOWLQN2(3)) / lossOWLQN2(2) < convergenceTol)

convergenceTol = 0.01
val (_, lossOWLQN3) = OWLQN.runOWLQN(
dataRDD,
gradient,
squaredL2Updater,
numCorrections,
convergenceTol,
maxNumIterations,
regParam,
alpha,
initialWeightsWithIntercept)

// With smaller convergenceTol, it takes more steps.
assert(lossOWLQN3.length > lossOWLQN2.length)

// Based on observation, lossLBFGS2 runs 6 iterations, no theoretically guaranteed.
assert(lossOWLQN3.length == 7)
assert((lossOWLQN3(4) - lossOWLQN3(5)) / lossOWLQN3(4) < convergenceTol)
}
}