Skip to content

Commit ef437cb

Browse files
committed
Update SVM.scala
provide a interface in object SVMWithSGD,to set useFeatureScaling
1 parent 249d36a commit ef437cb

File tree

1 file changed

+31
-1
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/classification

1 file changed

+31
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object SVMModel extends Loader[SVMModel] {
121121
* regularization is used, which can be changed via [[SVMWithSGD.optimizer]].
122122
* NOTE: Labels used in SVM should be {0, 1}.
123123
*/
124-
class SVMWithSGD (
124+
class SVMWithSGD private (
125125
private var stepSize: Double,
126126
private var numIterations: Int,
127127
private var regParam: Double,
@@ -152,6 +152,36 @@ class SVMWithSGD (
152152
* Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}.
153153
*/
154154
object SVMWithSGD {
155+
156+
/**
157+
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
158+
* of iterations of gradient descent using the specified step size. Each iteration uses
159+
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
160+
* gradient descent are initialized using the initial weights provided.
161+
*
162+
* NOTE: Labels used in SVM should be {0, 1}.
163+
*
164+
* @param input RDD of (label, array of features) pairs.
165+
* @param numIterations Number of iterations of gradient descent to run.
166+
* @param stepSize Step size to be used for each iteration of gradient descent.
167+
* @param regParam Regularization parameter.
168+
* @param miniBatchFraction Fraction of data to be used per iteration.
169+
* @param initialWeights Initial set of weights to be used. Array should be equal in size to
170+
* the number of features in the data.
171+
* @param useFeatureScaling Set if the algorithm should use feature scaling to improve the convergence during optimization.
172+
*/
173+
def train(
174+
input: RDD[LabeledPoint],
175+
numIterations: Int,
176+
stepSize: Double,
177+
regParam: Double,
178+
miniBatchFraction: Double,
179+
initialWeights: Vector,
180+
useFeatureScaling: Boolean): SVMModel = {
181+
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).setFeatureScaling(useFeatureScaling)
182+
.run(input, initialWeights)
183+
}
184+
155185

156186
/**
157187
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number

0 commit comments

Comments
 (0)