Skip to content

Commit c6ab716

Browse files
mobai-zhanjfsrowen
authored andcommitted
[SPARK-29224][ML] Implement Factorization Machines as a ml-pipeline component
### What changes were proposed in this pull request? Implement Factorization Machines as a ml-pipeline component 1. loss function supports: logloss, mse 2. optimizer: GD, adamW ### Why are the changes needed? Factorization Machines is widely used in advertising and recommendation system to estimate CTR(click-through rate). Advertising and recommendation system usually has a lot of data, so we need Spark to estimate the CTR, and Factorization Machines are common ml model to estimate CTR. References: 1. S. Rendle, “Factorization machines,” in Proceedings of IEEE International Conference on Data Mining (ICDM), pp. 995–1000, 2010. https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf ### Does this PR introduce any user-facing change? No ### How was this patch tested? run unit tests Closes apache#26124 from mob-ai/ml/fm. Authored-by: zhanjf <zhanjf@mob.com> Signed-off-by: Sean Owen <srowen@gmail.com>
1 parent 640dcc4 commit c6ab716

File tree

13 files changed

+2590
-2
lines changed

13 files changed

+2590
-2
lines changed

docs/ml-classification-regression.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,42 @@ Refer to the [R API docs](api/R/spark.naiveBayes.html) for more details.
530530
</div>
531531

532532

533+
## Factorization machines classifier
534+
535+
For more background and more details about the implementation of factorization machines,
536+
refer to the [Factorization Machines section](ml-classification-regression.html#factorization-machines).
537+
538+
**Examples**
539+
540+
The following examples load a dataset in LibSVM format, split it into training and test sets,
541+
train on the first dataset, and then evaluate on the held-out test set.
542+
We scale features to be between 0 and 1 to prevent the exploding gradient problem.
543+
544+
<div class="codetabs">
545+
<div data-lang="scala" markdown="1">
546+
547+
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.FMClassifier) for more details.
548+
549+
{% include_example scala/org/apache/spark/examples/ml/FMClassifierExample.scala %}
550+
</div>
551+
552+
<div data-lang="java" markdown="1">
553+
554+
Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/FMClassifier.html) for more details.
555+
556+
{% include_example java/org/apache/spark/examples/ml/JavaFMClassifierExample.java %}
557+
</div>
558+
559+
<div data-lang="python" markdown="1">
560+
561+
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.FMClassifier) for more details.
562+
563+
{% include_example python/ml/fm_classifier_example.py %}
564+
</div>
565+
566+
</div>
567+
568+
533569
# Regression
534570

535571
## Linear regression
@@ -1015,6 +1051,43 @@ Refer to the [`IsotonicRegression` R API docs](api/R/spark.isoreg.html) for more
10151051

10161052
</div>
10171053

1054+
1055+
## Factorization machines regressor
1056+
1057+
For more background and more details about the implementation of factorization machines,
1058+
refer to the [Factorization Machines section](ml-classification-regression.html#factorization-machines).
1059+
1060+
**Examples**
1061+
1062+
The following examples load a dataset in LibSVM format, split it into training and test sets,
1063+
train on the first dataset, and then evaluate on the held-out test set.
1064+
We scale features to be between 0 and 1 to prevent the exploding gradient problem.
1065+
1066+
<div class="codetabs">
1067+
<div data-lang="scala" markdown="1">
1068+
1069+
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.FMRegressor) for more details.
1070+
1071+
{% include_example scala/org/apache/spark/examples/ml/FMRegressorExample.scala %}
1072+
</div>
1073+
1074+
<div data-lang="java" markdown="1">
1075+
1076+
Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/FMRegressor.html) for more details.
1077+
1078+
{% include_example java/org/apache/spark/examples/ml/JavaFMRegressorExample.java %}
1079+
</div>
1080+
1081+
<div data-lang="python" markdown="1">
1082+
1083+
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.FMRegressor) for more details.
1084+
1085+
{% include_example python/ml/fm_regressor_example.py %}
1086+
</div>
1087+
1088+
</div>
1089+
1090+
10181091
# Linear methods
10191092

10201093
We implement popular linear methods such as logistic
@@ -1044,6 +1117,40 @@ regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model.
10441117
We implement Pipelines API for both linear regression and logistic
10451118
regression with elastic net regularization.
10461119

1120+
# Factorization Machines
1121+
1122+
[Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) are able to estimate interactions
1123+
between features even in problems with huge sparsity (like advertising and recommendation system).
1124+
The `spark.ml` implementation supports factorization machines for binary classification and for regression.
1125+
1126+
Factorization machines formula is:
1127+
1128+
$$
1129+
\hat{y} = w_0 + \sum\limits^n_{i-1} w_i x_i +
1130+
\sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j
1131+
$$
1132+
1133+
The first two terms denote intercept and linear term (same as in linear regression),
1134+
and the last term denotes pairwise interactions term. $$v_i$$ describes the i-th variable
1135+
with k factors.
1136+
1137+
FM can be used for regression and optimization criterion is mean square error. FM also can be used for
1138+
binary classification through sigmoid function. The optimization criterion is logistic loss.
1139+
1140+
The pairwise interactions can be reformulated:
1141+
1142+
$$
1143+
\sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j
1144+
= \frac{1}{2}\sum\limits^k_{f=1}
1145+
\left(\left( \sum\limits^n_{i=1}v_{i,f}x_i \right)^2 -
1146+
\sum\limits^n_{i=1}v_{i,f}^2x_i^2 \right)
1147+
$$
1148+
1149+
This equation has only linear complexity in both k and n - i.e. its computation is in $$O(kn)$$.
1150+
1151+
In general, in order to prevent the exploding gradient problem, it is best to scale continuous features to be between 0 and 1,
1152+
or bin the continuous features and one-hot encode them.
1153+
10471154
# Decision trees
10481155

10491156
[Decision trees](http://en.wikipedia.org/wiki/Decision_tree_learning)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
// $example on$
21+
import org.apache.spark.ml.Pipeline;
22+
import org.apache.spark.ml.PipelineModel;
23+
import org.apache.spark.ml.PipelineStage;
24+
import org.apache.spark.ml.classification.FMClassificationModel;
25+
import org.apache.spark.ml.classification.FMClassifier;
26+
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
27+
import org.apache.spark.ml.feature.*;
28+
import org.apache.spark.sql.Dataset;
29+
import org.apache.spark.sql.Row;
30+
import org.apache.spark.sql.SparkSession;
31+
// $example off$
32+
33+
public class JavaFMClassifierExample {
34+
public static void main(String[] args) {
35+
SparkSession spark = SparkSession
36+
.builder()
37+
.appName("JavaFMClassifierExample")
38+
.getOrCreate();
39+
40+
// $example on$
41+
// Load and parse the data file, converting it to a DataFrame.
42+
Dataset<Row> data = spark
43+
.read()
44+
.format("libsvm")
45+
.load("data/mllib/sample_libsvm_data.txt");
46+
47+
// Index labels, adding metadata to the label column.
48+
// Fit on whole dataset to include all labels in index.
49+
StringIndexerModel labelIndexer = new StringIndexer()
50+
.setInputCol("label")
51+
.setOutputCol("indexedLabel")
52+
.fit(data);
53+
// Scale features.
54+
MinMaxScalerModel featureScaler = new MinMaxScaler()
55+
.setInputCol("features")
56+
.setOutputCol("scaledFeatures")
57+
.fit(data);
58+
59+
// Split the data into training and test sets (30% held out for testing)
60+
Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
61+
Dataset<Row> trainingData = splits[0];
62+
Dataset<Row> testData = splits[1];
63+
64+
// Train a FM model.
65+
FMClassifier fm = new FMClassifier()
66+
.setLabelCol("indexedLabel")
67+
.setFeaturesCol("scaledFeatures")
68+
.setStepSize(0.001);
69+
70+
// Convert indexed labels back to original labels.
71+
IndexToString labelConverter = new IndexToString()
72+
.setInputCol("prediction")
73+
.setOutputCol("predictedLabel")
74+
.setLabels(labelIndexer.labelsArray()[0]);
75+
76+
// Create a Pipeline.
77+
Pipeline pipeline = new Pipeline()
78+
.setStages(new PipelineStage[] {labelIndexer, featureScaler, fm, labelConverter});
79+
80+
// Train model.
81+
PipelineModel model = pipeline.fit(trainingData);
82+
83+
// Make predictions.
84+
Dataset<Row> predictions = model.transform(testData);
85+
86+
// Select example rows to display.
87+
predictions.select("predictedLabel", "label", "features").show(5);
88+
89+
// Select (prediction, true label) and compute test accuracy.
90+
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
91+
.setLabelCol("indexedLabel")
92+
.setPredictionCol("prediction")
93+
.setMetricName("accuracy");
94+
double accuracy = evaluator.evaluate(predictions);
95+
System.out.println("Test Accuracy = " + accuracy);
96+
97+
FMClassificationModel fmModel = (FMClassificationModel)(model.stages()[2]);
98+
System.out.println("Factors: " + fmModel.factors());
99+
System.out.println("Linear: " + fmModel.linear());
100+
System.out.println("Intercept: " + fmModel.intercept());
101+
// $example off$
102+
103+
spark.stop();
104+
}
105+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
// $example on$
21+
import org.apache.spark.ml.Pipeline;
22+
import org.apache.spark.ml.PipelineModel;
23+
import org.apache.spark.ml.PipelineStage;
24+
import org.apache.spark.ml.evaluation.RegressionEvaluator;
25+
import org.apache.spark.ml.feature.MinMaxScaler;
26+
import org.apache.spark.ml.feature.MinMaxScalerModel;
27+
import org.apache.spark.ml.regression.FMRegressionModel;
28+
import org.apache.spark.ml.regression.FMRegressor;
29+
import org.apache.spark.sql.Dataset;
30+
import org.apache.spark.sql.Row;
31+
import org.apache.spark.sql.SparkSession;
32+
// $example off$
33+
34+
public class JavaFMRegressorExample {
35+
public static void main(String[] args) {
36+
SparkSession spark = SparkSession
37+
.builder()
38+
.appName("JavaFMRegressorExample")
39+
.getOrCreate();
40+
41+
// $example on$
42+
// Load and parse the data file, converting it to a DataFrame.
43+
Dataset<Row> data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
44+
45+
// Scale features.
46+
MinMaxScalerModel featureScaler = new MinMaxScaler()
47+
.setInputCol("features")
48+
.setOutputCol("scaledFeatures")
49+
.fit(data);
50+
51+
// Split the data into training and test sets (30% held out for testing).
52+
Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
53+
Dataset<Row> trainingData = splits[0];
54+
Dataset<Row> testData = splits[1];
55+
56+
// Train a FM model.
57+
FMRegressor fm = new FMRegressor()
58+
.setLabelCol("label")
59+
.setFeaturesCol("scaledFeatures")
60+
.setStepSize(0.001);
61+
62+
// Create a Pipeline.
63+
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureScaler, fm});
64+
65+
// Train model.
66+
PipelineModel model = pipeline.fit(trainingData);
67+
68+
// Make predictions.
69+
Dataset<Row> predictions = model.transform(testData);
70+
71+
// Select example rows to display.
72+
predictions.select("prediction", "label", "features").show(5);
73+
74+
// Select (prediction, true label) and compute test error.
75+
RegressionEvaluator evaluator = new RegressionEvaluator()
76+
.setLabelCol("label")
77+
.setPredictionCol("prediction")
78+
.setMetricName("rmse");
79+
double rmse = evaluator.evaluate(predictions);
80+
System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
81+
82+
FMRegressionModel fmModel = (FMRegressionModel)(model.stages()[1]);
83+
System.out.println("Factors: " + fmModel.factors());
84+
System.out.println("Linear: " + fmModel.linear());
85+
System.out.println("Intercept: " + fmModel.intercept());
86+
// $example off$
87+
88+
spark.stop();
89+
}
90+
}

0 commit comments

Comments
 (0)