Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
e6e483c
[SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover
holdenk Sep 1, 2015
3f63bd6
[SPARK-10398] [DOCS] Migrate Spark download page to use new lua mirro…
srowen Sep 1, 2015
ec01280
[SPARK-4223] [CORE] Support * in acls.
Sep 1, 2015
bf550a4
[SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe f…
0x0FFF Sep 1, 2015
00d9af5
[SPARK-10392] [SQL] Pyspark - Wrong DateType support on JDBC connection
0x0FFF Sep 1, 2015
c3b881a
[SPARK-7336] [HISTORYSERVER] Fix bug that applications status incorre…
ArcherShao Sep 2, 2015
56c4c17
[SPARK-10034] [SQL] add regression test for Sort on Aggregate
cloud-fan Sep 2, 2015
fc48307
[SPARK-10389] [SQL] support order by non-attribute grouping expressio…
cloud-fan Sep 2, 2015
2da3a9e
[SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle…
Sep 2, 2015
6cd98c1
[SPARK-10417] [SQL] Iterating through Column results in infinite loop
0x0FFF Sep 2, 2015
03f3e91
[SPARK-10422] [SQL] String column in InMemoryColumnarCache needs to o…
yhuai Sep 3, 2015
44948a2
[SPARK-9723] [ML] params getordefault should throw more useful error
holdenk Sep 3, 2015
4bd85d0
[SPARK-5945] Spark should not retry a stage infinitely on a FetchFail…
Sep 3, 2015
0985d2c
[SPARK-8707] RDD#toDebugString fails if any cached RDD has invalid pa…
navis Sep 3, 2015
f6c447f
Removed code duplication in ShuffleBlockFetcherIterator
eracah Sep 3, 2015
3ddb9b3
[SPARK-10247] [CORE] improve readability of a test case in DAGSchedul…
squito Sep 3, 2015
62b4690
[SPARK-10379] preserve first page in UnsafeShuffleExternalSorter
Sep 3, 2015
0349b5b
[SPARK-10411] [SQL] Move visualization above explain output and hide …
zsxwing Sep 3, 2015
67580f1
[SPARK-10332] [CORE] Fix yarn spark executor validation
holdenk Sep 3, 2015
3abc0d5
[SPARK-9596] [SQL] treat hadoop classes as shared one in IsolatedClie…
WangTaoTheTonic Sep 3, 2015
af0e312
[SPARK-8951] [SPARKR] support Unicode characters in collect()
Sep 3, 2015
49aff7b
[SPARK-10432] spark.port.maxRetries documentation is unclear
Sep 3, 2015
d911c68
[SPARK-10431] [CORE] Fix intermittent test failure. Wait for event qu…
Sep 3, 2015
754f853
[SPARK-9869] [STREAMING] Wait for all event notifications before asse…
Sep 3, 2015
e62f4a4
[SPARK-9672] [MESOS] Don’t include SPARK_ENV_LOADED when passing env …
pashields Sep 3, 2015
11ef32c
[SPARK-10430] [CORE] Added hashCode methods in AccumulableInfo and RD…
Sep 3, 2015
db4c130
[SPARK-9591] [CORE] Job may fail for exception during getting remote …
jeanlyn Sep 3, 2015
08b0750
[SPARK-10435] Spark submit should fail fast for Mesos cluster mode wi…
Sep 3, 2015
208fbca
[SPARK-10421] [BUILD] Exclude curator artifacts from tachyon dependen…
Sep 3, 2015
cf42138
[SPARK-10003] Improve readability of DAGScheduler
Sep 4, 2015
143e521
[MINOR] Minor style fix in SparkR
shivaram Sep 4, 2015
804a012
MAINTENANCE: Automated closing of pull requests.
marmbrus Sep 4, 2015
c3c0e43
[SPARK-10176] [SQL] Show partially analyzed plans when checkAnswer fa…
cloud-fan Sep 4, 2015
3339e6f
[SPARK-10450] [SQL] Minor improvements to readability / style / typos…
Sep 4, 2015
b087d23
[SPARK-9669] [MESOS] Support PySpark on Mesos cluster mode.
tnachen Sep 4, 2015
2e1c175
[SPARK-10454] [SPARK CORE] wait for empty event queue
Sep 4, 2015
eafe372
[SPARK-10311] [STREAMING] Reload appId and attemptId when app starts …
XuTingjun Sep 4, 2015
22eab70
[SPARK-10402] [DOCS] [ML] Add defaults to the scaladoc for params in ml/
holdenk Sep 5, 2015
47058ca
[SPARK-9925] [SQL] [TESTS] Set SQLConf.SHUFFLE_PARTITIONS.key correct…
yhuai Sep 5, 2015
6c75194
[HOTFIX] [SQL] Fixes compilation error
liancheng Sep 5, 2015
7a4f326
[SPARK-10440] [STREAMING] [DOCS] Update python API stuff in the progr…
tdas Sep 5, 2015
bca8c07
[SPARK-10434] [SQL] Fixes Parquet schema of arrays that may contain null
liancheng Sep 5, 2015
871764c
[SPARK-10013] [ML] [JAVA] [TEST] remove java assert from java unit tests
holdenk Sep 5, 2015
5ffe752
[SPARK-9767] Remove ConnectionManager.
rxin Sep 7, 2015
9d8e838
[DOC] Added R to the list of languages with "high-level API" support …
Sep 8, 2015
6ceed85
Docs small fixes
jaceklaskowski Sep 8, 2015
990c9f7
[SPARK-9170] [SQL] Use OrcStructInspector to be case preserving when …
viirya Sep 8, 2015
5b2192e
[SPARK-10480] [ML] Fix ML.LinearRegressionModel.copy()
yanboliang Sep 8, 2015
5fd5795
[SPARK-10316] [SQL] respect nondeterministic expressions in PhysicalO…
cloud-fan Sep 8, 2015
f7b55db
[SPARK-10470] [ML] ml.IsotonicRegressionModel.copy should set parent
yanboliang Sep 8, 2015
7a9dcbc
[SPARK-10441] [SQL] Save data correctly to json.
yhuai Sep 8, 2015
e6f8d36
[SPARK-10468] [ MLLIB ] Verify schema before Dataframe select API call
Sep 8, 2015
52b24a6
[SPARK-10492] [STREAMING] [DOCUMENTATION] Update Streaming documentat…
tdas Sep 8, 2015
d637a66
[SPARK-10327] [SQL] Cache Table is not working while subquery has ali…
chenghao-intel Sep 8, 2015
2143d59
[HOTFIX] Fix build break caused by #8494
marmbrus Sep 8, 2015
ae74c3f
[RELEASE] Add more contributors & only show names in release notes.
rxin Sep 9, 2015
820913f
[SPARK-10071] [STREAMING] Output a warning when writing QueueInputDSt…
zsxwing Sep 9, 2015
52fe32f
[SPARK-9834] [MLLIB] implement weighted least squares via normal equa…
mengxr Sep 9, 2015
a157348
[SPARK-10464] [MLLIB] Add WeibullGenerator for RandomDataGenerator
yanboliang Sep 9, 2015
3a11e50
[SPARK-10373] [PYSPARK] move @since into pyspark from sql
Sep 9, 2015
0e2f216
[SPARK-10094] Pyspark ML Feature transformers marked as experimental
noel-smith Sep 9, 2015
2f6fd52
[SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark
holdenk Sep 9, 2015
91a577d
[SPARK-10249] [ML] [DOC] Add Python Code Example to StopWordsRemover …
hhbyyh Sep 9, 2015
c1bc4f4
[SPARK-10227] fatal warnings with sbt on Scala 2.11
Sep 9, 2015
2ddeb63
[SPARK-10117] [MLLIB] Implement SQL data source API for reading LIBSV…
Lewuathe Sep 9, 2015
c0052d8
[SPARK-10481] [YARN] SPARK_PREPEND_CLASSES make spark-yarn related ja…
zjffdu Sep 9, 2015
71da163
[SPARK-10461] [SQL] make sure `input.primitive` is always variable na…
cloud-fan Sep 9, 2015
45de518
[SPARK-9730] [SQL] Add Full Outer Join support for SortMergeJoin
viirya Sep 9, 2015
56a0fe5
[SPARK-9772] [PYSPARK] [ML] Add Python API for ml.feature.VectorSlicer
yanboliang Sep 10, 2015
1dc7548
[MINOR] [MLLIB] [ML] [DOC] fixed typo: label for negative result shou…
sparadiso Sep 10, 2015
48817cc
[SPARK-10497] [BUILD] [TRIVIAL] Handle both locations for JIRAError w…
holdenk Sep 10, 2015
4f1daa1
[SPARK-10065] [SQL] avoid the extra copy when generate unsafe array
cloud-fan Sep 10, 2015
f892d92
[SPARK-7142] [SQL] Minor enhancement to BooleanSimplification Optimiz…
Sep 10, 2015
49da38e
[SPARK-10301] [SPARK-10428] [SQL] Addresses comments of PR #8583 and …
liancheng Sep 10, 2015
e048111
[SPARK-10466] [SQL] UnsafeRow SerDe exception with data spill
chenghao-intel Sep 10, 2015
a76bde9
[SPARK-10469] [DOC] Try and document the three options
holdenk Sep 10, 2015
af3bc59
[SPARK-8167] Make tasks that fail from YARN preemption not fail job
mccheah Sep 10, 2015
f0562e8
[SPARK-6350] [MESOS] Fine-grained mode scheduler respects mesosExecut…
dragos Sep 10, 2015
a5ef2d0
[SPARK-10514] [MESOS] waiting for min no of total cores acquired by S…
SleepyThread Sep 10, 2015
d88abb7
[SPARK-9990] [SQL] Create local hash join operator
zsxwing Sep 10, 2015
45e3be5
[SPARK-10049] [SPARKR] Support collecting data of ArraryType in DataF…
Sep 10, 2015
3db7255
[SPARK-10443] [SQL] Refactor SortMergeOuterJoin to reduce duplication
Sep 10, 2015
4204757
Add 1.5 to master branch EC2 scripts
shivaram Sep 10, 2015
89562a1
[SPARK-7544] [SQL] [PySpark] pyspark.sql.types.Row implements __getit…
yanboliang Sep 10, 2015
0eabea8
[SPARK-9043] Serialize key, value and combiner classes in ShuffleDepe…
massie Sep 11, 2015
339a527
[SPARK-10023] [ML] [PySpark] Unified DecisionTreeParams checkpointInt…
yanboliang Sep 11, 2015
a140dd7
[SPARK-10027] [ML] [PySpark] Add Python API missing methods for ml.fe…
yanboliang Sep 11, 2015
e1d7f64
[SPARK-10472] [SQL] Fixes DataType.typeName for UDT
liancheng Sep 11, 2015
9bbe33f
[SPARK-10556] Remove explicit Scala version for sbt project build files
ahirreddy Sep 11, 2015
c268ca4
[SPARK-10518] [DOCS] Update code examples in spark.ml user guide to u…
y-shimizu Sep 11, 2015
b656e61
[SPARK-10026] [ML] [PySpark] Implement some common Params for regress…
yanboliang Sep 11, 2015
b01b262
[SPARK-9773] [ML] [PySpark] Add Python API for MultilayerPerceptronCl…
yanboliang Sep 11, 2015
960d2d0
[SPARK-10537] [ML] document LIBSVM source options in public API doc a…
mengxr Sep 11, 2015
2e3a280
[MINOR] [MLLIB] [ML] [DOC] Minor doc fixes for StringIndexer and Meta…
jkbradley Sep 11, 2015
6ce0886
[SPARK-10540] [SQL] Ignore HadoopFsRelationTest's "test all data type…
yhuai Sep 11, 2015
5f46444
[SPARK-8530] [ML] add python API for MinMaxScaler
hhbyyh Sep 11, 2015
b231ab8
[SPARK-10546] Check partitionId's range in ExternalSorter#spill()
tedyu Sep 11, 2015
c373866
[PYTHON] Fixed typo in exception message
icaromedeiros Sep 11, 2015
d5d6473
[SPARK-10442] [SQL] fix string to boolean cast
cloud-fan Sep 11, 2015
1eede3b
[SPARK-7142] [SQL] Minor enhancement to BooleanSimplification Optimiz…
Sep 11, 2015
e626ac5
[SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK…
zsxwing Sep 11, 2015
c2af42b
[SPARK-9990] [SQL] Local hash join follow-ups
Sep 11, 2015
d74c6a1
[SPARK-10564] ThreadingSuite: assertion failures in threads don't fai…
Sep 11, 2015
c34fc19
[SPARK-9014] [SQL] Allow Python spark API to use built-in exponential…
0x0FFF Sep 11, 2015
6d83678
[SPARK-10566] [CORE] SnappyCompressionCodec init exception handling m…
dimfeld Sep 12, 2015
8285e3b
[SPARK-10554] [CORE] Fix NPE with ShutdownHook
Sep 12, 2015
22730ad
[SPARK-10547] [TEST] Streamline / improve style of Java API tests
srowen Sep 12, 2015
f4a2280
[SPARK-6548] Adding stddev to DataFrame functions
JihongMA Sep 12, 2015
b3a7480
[SPARK-10330] Add Scalastyle rule to require use of SparkHadoopUtil J…
JoshRosen Sep 12, 2015
1dc614b
[SPARK-10222] [GRAPHX] [DOCS] More thoroughly deprecate Bagel in favo…
srowen Sep 13, 2015
7d94924
Deprecates SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC
liancheng Sep 1, 2015
85bbfde
Removes instead of deprecates the old option
liancheng Sep 2, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-9834] [MLLIB] implement weighted least squares via normal equa…
…tion

The goal of this PR is to have a weighted least squares implementation that takes the normal equation approach, and hence to be able to provide R-like summary statistics and support IRLS (used by GLMs). The tests match R's lm and glmnet.

There are couple TODOs that can be addressed in future PRs:
* consolidate summary statistics aggregators
* move `dspr` to `BLAS`
* etc

It would be nice to have this merged first because it blocks couple other features.

dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes apache#8588 from mengxr/SPARK-9834.
  • Loading branch information
mengxr committed Sep 9, 2015
commit 52fe32f6ac7a04fa9b4478fda1307c5b0c61c4a2
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.optim

import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.netlib.util.intW

import org.apache.spark.Logging
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.rdd.RDD

/**
* Model fitted by [[WeightedLeastSquares]].
* @param coefficients model coefficients
* @param intercept model intercept
*/
private[ml] class WeightedLeastSquaresModel(
val coefficients: DenseVector,
val intercept: Double) extends Serializable

/**
* Weighted least squares solver via normal equation.
* Given weighted observations (w,,i,,, a,,i,,, b,,i,,), we use the following weighted least squares
* formulation:
*
* min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w_i
* + 1/2 lambda / delta sum,,j,, (sigma,,j,, x,,j,,)^2^,
*
* where lambda is the regularization parameter, and delta and sigma,,j,, are controlled by
* [[standardizeLabel]] and [[standardizeFeatures]], respectively.
*
* Set [[regParam]] to 0.0 and turn off both [[standardizeFeatures]] and [[standardizeLabel]] to
* match R's `lm`.
* Turn on [[standardizeLabel]] to match R's `glmnet`.
*
* @param fitIntercept whether to fit intercept. If false, z is 0.0.
* @param regParam L2 regularization parameter (lambda)
* @param standardizeFeatures whether to standardize features. If true, sigma_,,j,, is the
* population standard deviation of the j-th column of A. Otherwise,
* sigma,,j,, is 1.0.
* @param standardizeLabel whether to standardize label. If true, delta is the population standard
* deviation of the label column b. Otherwise, delta is 1.0.
*/
private[ml] class WeightedLeastSquares(
val fitIntercept: Boolean,
val regParam: Double,
val standardizeFeatures: Boolean,
val standardizeLabel: Boolean) extends Logging with Serializable {
import WeightedLeastSquares._

require(regParam >= 0.0, s"regParam cannot be negative: $regParam")
if (regParam == 0.0) {
logWarning("regParam is zero, which might cause numerical instability and overfitting.")
}

/**
* Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.
*/
def fit(instances: RDD[Instance]): WeightedLeastSquaresModel = {
val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))
summary.validate()
logInfo(s"Number of instances: ${summary.count}.")
val triK = summary.triK
val bBar = summary.bBar
val bStd = summary.bStd
val aBar = summary.aBar
val aVar = summary.aVar
val abBar = summary.abBar
val aaBar = summary.aaBar
val aaValues = aaBar.values

if (fitIntercept) {
// shift centers
// A^T A - aBar aBar^T
RowMatrix.dspr(-1.0, aBar, aaValues)
// A^T b - bBar aBar
BLAS.axpy(-bBar, aBar, abBar)
}

// add regularization to diagonals
var i = 0
var j = 2
while (i < triK) {
var lambda = regParam
if (standardizeFeatures) {
lambda *= aVar(j - 2)
}
if (standardizeLabel) {
// TODO: handle the case when bStd = 0
lambda /= bStd
}
aaValues(i) += lambda
i += j
j += 1
}

val x = choleskySolve(aaBar.values, abBar)

// compute intercept
val intercept = if (fitIntercept) {
bBar - BLAS.dot(aBar, x)
} else {
0.0
}

new WeightedLeastSquaresModel(x, intercept)
}

/**
* Solves a symmetric positive definite linear system via Cholesky factorization.
* The input arguments are modified in-place to store the factorization and the solution.
* @param A the upper triangular part of A
* @param bx right-hand side
* @return the solution vector
*/
// TODO: SPARK-10490 - consolidate this and the Cholesky solver in ALS
private def choleskySolve(A: Array[Double], bx: DenseVector): DenseVector = {
val k = bx.size
val info = new intW(0)
lapack.dppsv("U", k, 1, A, bx.values, k, info)
val code = info.`val`
assert(code == 0, s"lapack.dpotrs returned $code.")
bx
}
}

private[ml] object WeightedLeastSquares {

/**
* Case class for weighted observations.
* @param w weight, must be positive
* @param a features
* @param b label
*/
case class Instance(w: Double, a: Vector, b: Double) {
require(w >= 0.0, s"Weight cannot be negative: $w.")
}

/**
* Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
*/
// TODO: consolidate aggregates for summary statistics
private class Aggregator extends Serializable {
var initialized: Boolean = false
var k: Int = _
var count: Long = _
var triK: Int = _
private var wSum: Double = _
private var wwSum: Double = _
private var bSum: Double = _
private var bbSum: Double = _
private var aSum: DenseVector = _
private var abSum: DenseVector = _
private var aaSum: DenseVector = _

private def init(k: Int): Unit = {
require(k <= 4096, "In order to take the normal equation approach efficiently, " +
s"we set the max number of features to 4096 but got $k.")
this.k = k
triK = k * (k + 1) / 2
count = 0L
wSum = 0.0
wwSum = 0.0
bSum = 0.0
bbSum = 0.0
aSum = new DenseVector(Array.ofDim(k))
abSum = new DenseVector(Array.ofDim(k))
aaSum = new DenseVector(Array.ofDim(triK))
initialized = true
}

/**
* Adds an instance.
*/
def add(instance: Instance): this.type = {
val Instance(w, a, b) = instance
val ak = a.size
if (!initialized) {
init(ak)
initialized = true
}
assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
count += 1L
wSum += w
wwSum += w * w
bSum += w * b
bbSum += w * b * b
BLAS.axpy(w, a, aSum)
BLAS.axpy(w * b, a, abSum)
RowMatrix.dspr(w, a, aaSum.values)
this
}

/**
* Merges another [[Aggregator]].
*/
def merge(other: Aggregator): this.type = {
if (!other.initialized) {
this
} else {
if (!initialized) {
init(other.k)
}
assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
count += other.count
wSum += other.wSum
wwSum += other.wwSum
bSum += other.bSum
bbSum += other.bbSum
BLAS.axpy(1.0, other.aSum, aSum)
BLAS.axpy(1.0, other.abSum, abSum)
BLAS.axpy(1.0, other.aaSum, aaSum)
this
}
}

/**
* Validates that we have seen observations.
*/
def validate(): Unit = {
assert(initialized, "Training dataset is empty.")
assert(wSum > 0.0, "Sum of weights cannot be zero.")
}

/**
* Weighted mean of features.
*/
def aBar: DenseVector = {
val output = aSum.copy
BLAS.scal(1.0 / wSum, output)
output
}

/**
* Weighted mean of labels.
*/
def bBar: Double = bSum / wSum

/**
* Weighted population standard deviation of labels.
*/
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)

/**
* Weighted mean of (label * features).
*/
def abBar: DenseVector = {
val output = abSum.copy
BLAS.scal(1.0 / wSum, output)
output
}

/**
* Weighted mean of (features * features^T^).
*/
def aaBar: DenseVector = {
val output = aaSum.copy
BLAS.scal(1.0 / wSum, output)
output
}

/**
* Weighted population variance of features.
*/
def aVar: DenseVector = {
val variance = Array.ofDim[Double](k)
var i = 0
var j = 2
val aaValues = aaSum.values
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
variance(l) = aaValues(i) / wSum - aw * aw
i += j
j += 1
}
new DenseVector(variance)
}
}
}
7 changes: 7 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ private[spark] object BLAS extends Serializable with Logging {
}
}

/** Y += a * x */
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.")
f2jBLAS.daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1)
}

/**
* dot(x, y)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@ object RowMatrix {
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: SPARK-10491 - move this method to linalg.BLAS
private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: Find a better home (breeze?) for this method.
val n = v.size
v match {
Expand Down
Loading