@@ -121,7 +121,7 @@ object SVMModel extends Loader[SVMModel] {
121121 * regularization is used, which can be changed via [[SVMWithSGD.optimizer ]].
122122 * NOTE: Labels used in SVM should be {0, 1}.
123123 */
124- class SVMWithSGD (
124+ class SVMWithSGD private (
125125 private var stepSize : Double ,
126126 private var numIterations : Int ,
127127 private var regParam : Double ,
@@ -152,6 +152,36 @@ class SVMWithSGD (
152152 * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}.
153153 */
154154object SVMWithSGD {
155+
156+ /**
157+ * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
158+ * of iterations of gradient descent using the specified step size. Each iteration uses
159+ * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
160+ * gradient descent are initialized using the initial weights provided.
161+ *
162+ * NOTE: Labels used in SVM should be {0, 1}.
163+ *
164+ * @param input RDD of (label, array of features) pairs.
165+ * @param numIterations Number of iterations of gradient descent to run.
166+ * @param stepSize Step size to be used for each iteration of gradient descent.
167+ * @param regParam Regularization parameter.
168+ * @param miniBatchFraction Fraction of data to be used per iteration.
169+ * @param initialWeights Initial set of weights to be used. Array should be equal in size to
170+ * the number of features in the data.
171+ * @param useFeatureScaling Set if the algorithm should use feature scaling to improve the convergence during optimization.
172+ */
173+ def train (
174+ input : RDD [LabeledPoint ],
175+ numIterations : Int ,
176+ stepSize : Double ,
177+ regParam : Double ,
178+ miniBatchFraction : Double ,
179+ initialWeights : Vector ,
180+ useFeatureScaling : Boolean ): SVMModel = {
181+ new SVMWithSGD (stepSize, numIterations, regParam, miniBatchFraction).setFeatureScaling(useFeatureScaling)
182+ .run(input, initialWeights)
183+ }
184+
155185
156186 /**
157187 * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
0 commit comments