Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.optimization

import breeze.linalg.{Vector => BV, DenseVector => BDV}

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.linalg.distributed.RowMatrix

class OrdinaryLeastSquares extends Optimizer with Logging {

@DeveloperApi
def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
OrdinaryLeastSquares.fit(data, initialWeights.size)
}
}

@DeveloperApi
object OrdinaryLeastSquares extends Logging {

def fit(data: RDD[(Double, Vector)], rank: Int): Vector = {
// TODO: Compute and return other statistics:
// (R-squared, Adjusted R-squared, Std. Error of weights, t-statistics, p-value)
val featureRows = data.map { case (y, x) => x }
val response = data.map { case (y, x) => y }.cache()

val featureRowMatrix = new RowMatrix(featureRows)
val svd = featureRowMatrix.computeSVD(rank, computeU = true)
val uRdd = svd.U.rows

val yHatRdd = response.zipPartitions(uRdd, true) {
(yIterator, uIterator) => new Iterator[BV[Double]] {
def hasNext = yIterator.hasNext
def next = {
val yValue = yIterator.next()
uIterator.next().toBreeze.map(_ * yValue)
}
}
}

val yHat = yHatRdd.reduce(_ + _)
val zArray = svd.s.toArray.zipWithIndex.map { case(value, index) => yHat(index) / value }

val vMat = svd.V.toBreeze
val solution = vMat * BDV(zArray)
Vectors.fromBreeze(solution)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,35 @@ object LinearRegressionWithSGD {
train(input, numIterations, 1.0, 1.0)
}
}

/**
* Train a linear regression model using Orthogonal Decomposition Method.
* This solves the least squares regression formulation
* f(weights) = 1/n ||A weights-y||^2
* (which is the mean squared error).
* Here the data matrix has n rows, and the input RDD holds the set of rows of A, each with
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
*/
class LinearRegressionWithOLS
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {

override val optimizer = new OrdinaryLeastSquares

override protected def createModel(weights: Vector, intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}

object LinearRegressionWithOLS {

def train(input: RDD[LabeledPoint], initialWeight: Vector): LinearRegressionModel = {
new LinearRegressionWithOLS().run(input, initialWeight)
}

def train(input: RDD[LabeledPoint]): LinearRegressionModel = {
new LinearRegressionWithOLS().run(input)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly fit Y = 3 + 10*X1 + 10*X2
test("OLS linear regression") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linRegOLS = new LinearRegressionWithOLS().setIntercept(true)

val model = linRegOLS.run(testRDD)
assert(model.intercept >= 2.5 && model.intercept <= 3.5)

val weights = model.weights
assert(weights.size === 2)
assert(weights(0) >= 9.0 && weights(0) <= 11.0)
assert(weights(1) >= 9.0 && weights(1) <= 11.0)

val validationData = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()

// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)

// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X2
test("linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
Expand Down Expand Up @@ -86,6 +111,32 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X2
test("OLS linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithOLS().setIntercept(false)

val model = linReg.run(testRDD)

assert(model.intercept === 0.0)

val weights = model.weights
assert(weights.size === 2)
assert(weights(0) >= 9.0 && weights(0) <= 11.0)
assert(weights(1) >= 9.0 && weights(1) <= 11.0)

val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()

// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)

// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X10000
test("sparse linear regression without intercept") {
val denseRDD = sc.parallelize(
Expand Down