Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.{abs, sqrt}
import scala.util.Random
import scala.util.Sorting
import scala.util.{Random, Sorting}
import scala.util.hashing.byteswap32

import org.jblas.{DoubleMatrix, SimpleBlas, Solve}

import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.{Logging, HashPartitioner, Partitioner}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.mllib.optimization.NNLS

/**
* Out-link information for a user or product block. This includes the original user/product IDs
Expand Down Expand Up @@ -325,6 +325,11 @@ class ALS private (
new MatrixFactorizationModel(rank, usersOut, productsOut)
}

/**
* Java-friendly version of [[ALS.run]].
*/
def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd)

/**
* Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
* for each user (or product), in a distributed fashion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.mllib.recommendation

import java.lang.{Integer => JavaInteger}

import org.jblas.DoubleMatrix

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.SerDe
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD

/**
* Model representing the result of matrix factorization.
Expand Down Expand Up @@ -65,6 +65,13 @@ class MatrixFactorizationModel private[mllib] (
}
}

/**
* Java-friendly version of [[MatrixFactorizationModel.predict]].
*/
def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
}

/**
* Recommends products to a user.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
import scala.Tuple2;
import scala.Tuple3;

import com.google.common.collect.Lists;
import org.jblas.DoubleMatrix;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;

Expand All @@ -47,61 +48,48 @@ public void tearDown() {
sc = null;
}

static void validatePrediction(
void validatePrediction(
MatrixFactorizationModel model,
int users,
int products,
int features,
DoubleMatrix trueRatings,
double matchThreshold,
boolean implicitPrefs,
DoubleMatrix truePrefs) {
DoubleMatrix predictedU = new DoubleMatrix(users, features);
List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
for (Tuple2<Object, double[]> userFeature : userFeatures) {
predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
}
}
DoubleMatrix predictedP = new DoubleMatrix(products, features);

List<Tuple2<Object, double[]>> productFeatures =
model.productFeatures().toJavaRDD().collect();
for (int i = 0; i < features; ++i) {
for (Tuple2<Object, double[]> productFeature : productFeatures) {
predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
List<Tuple2<Integer, Integer>> localUsersProducts =
Lists.newArrayListWithCapacity(users * products);
for (int u=0; u < users; ++u) {
for (int p=0; p < products; ++p) {
localUsersProducts.add(new Tuple2<Integer, Integer>(u, p));
}
}

DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());

JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
List<Rating> predictedRatings = model.predict(usersProducts).collect();
Assert.assertEquals(users * products, predictedRatings.size());
if (!implicitPrefs) {
for (int u = 0; u < users; ++u) {
for (int p = 0; p < products; ++p) {
double prediction = predictedRatings.get(u, p);
double correct = trueRatings.get(u, p);
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
}
for (Rating r: predictedRatings) {
double prediction = r.rating();
double correct = trueRatings.get(r.user(), r.product());
Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
}
} else {
// For implicit prefs we use the confidence-weighted RMSE to test
// (ref Mahout's implicit ALS tests)
double sqErr = 0.0;
double denom = 0.0;
for (int u = 0; u < users; ++u) {
for (int p = 0; p < products; ++p) {
double prediction = predictedRatings.get(u, p);
double truePref = truePrefs.get(u, p);
double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
double err = confidence * (truePref - prediction) * (truePref - prediction);
sqErr += err;
denom += confidence;
}
for (Rating r: predictedRatings) {
double prediction = r.rating();
double truePref = truePrefs.get(r.user(), r.product());
double confidence = 1.0 +
/* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product()));
double err = confidence * (truePref - prediction) * (truePref - prediction);
sqErr += err;
denom += confidence;
}
double rmse = Math.sqrt(sqErr / denom);
Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
rmse, matchThreshold), rmse < matchThreshold);
rmse, matchThreshold), rmse < matchThreshold);
}
}

Expand All @@ -116,7 +104,7 @@ public void runALSUsingStaticMethods() {

JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}

@Test
Expand All @@ -132,8 +120,8 @@ public void runALSUsingConstructor() {

MatrixFactorizationModel model = new ALS().setRank(features)
.setIterations(iterations)
.run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
.run(data);
validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
}

@Test
Expand All @@ -147,7 +135,7 @@ public void runImplicitALSUsingStaticMethods() {

JavaRDD<Rating> data = sc.parallelize(testData._1());
MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}

@Test
Expand All @@ -165,7 +153,7 @@ public void runImplicitALSUsingConstructor() {
.setIterations(iterations)
.setImplicitPrefs(true)
.run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}

@Test
Expand All @@ -183,7 +171,7 @@ public void runImplicitALSWithNegativeWeight() {
.setImplicitPrefs(true)
.setSeed(8675309L)
.run(data.rdd());
validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
}

@Test
Expand Down