Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.{DefaultFormats, JValue}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -182,7 +185,7 @@ class ALSModel private[ml] (
val rank: Int,
@transient val userFactors: DataFrame,
@transient val itemFactors: DataFrame)
extends Model[ALSModel] with ALSModelParams {
extends Model[ALSModel] with ALSModelParams with Writable {

/** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
Expand Down Expand Up @@ -220,8 +223,60 @@ class ALSModel private[ml] (
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: Writer = new ALSModel.ALSModelWriter(this)
}

@Since("1.6.0")
object ALSModel extends Readable[ALSModel] {

@Since("1.6.0")
override def read: Reader[ALSModel] = new ALSModelReader

@Since("1.6.0")
override def load(path: String): ALSModel = read.load(path)

private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer {

override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
val itemPath = new Path(path, "itemFactors").toString
instance.itemFactors.write.format("parquet").save(itemPath)
}
}

private[recommendation] class ALSModelReader extends Reader[ALSModel] {

/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.recommendation.ALSModel"

override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
implicit val format = DefaultFormats
val rank: Int = metadata.extraMetadata match {
case Some(m: JValue) =>
(m \ "rank").extract[Int]
case None =>
throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
s" ${metadata.metadataStr}")
}

val userPath = new Path(path, "userFactors").toString
val userFactors = sqlContext.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString
val itemFactors = sqlContext.read.format("parquet").load(itemPath)

val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)

DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}

/**
* :: Experimental ::
Expand Down Expand Up @@ -254,7 +309,7 @@ class ALSModel private[ml] (
* preferences rather than explicit ratings given to items.
*/
@Experimental
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable {

import org.apache.spark.ml.recommendation.ALS.Rating

Expand Down Expand Up @@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
}

override def copy(extra: ParamMap): ALS = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}


/**
* :: DeveloperApi ::
* An implementation of ALS that supports generic ID types, specialized for Int and Long. This is
Expand All @@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
* than 2 billion.
*/
@DeveloperApi
object ALS extends Logging {
object ALS extends Readable[ALS] with Logging {

/**
* :: DeveloperApi ::
Expand All @@ -356,6 +415,12 @@ object ALS extends Logging {
@DeveloperApi
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)

@Since("1.6.0")
override def read: Reader[ALS] = new DefaultParamsReader[ALS]

@Since("1.6.0")
override def load(path: String): ALS = read.load(path)

/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {
/** Solves a least squares problem with regularization (possibly with other constraints). */
Expand Down
14 changes: 11 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
Expand All @@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter {
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
("paramMap" -> jsonParams) ~
("extraMetadata" -> extraMetadata)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
Expand Down Expand Up @@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
* @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
* @param metadataStr Full metadata file String (for debugging)
*/
case class Metadata(
Expand All @@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
extraMetadata: Option[JValue],
metadataStr: String)

/**
Expand All @@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}

Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.ml.recommendation

import java.io.File
import java.util.Random

import scala.collection.mutable
Expand All @@ -26,28 +25,26 @@ import scala.language.existentials

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.Utils
import org.apache.spark.sql.{DataFrame, Row}

class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {

private var tempDir: File = _
class ALSSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {

override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Utils.createTempDir()
sc.setCheckpointDir(tempDir.getAbsolutePath)
}

override def afterAll(): Unit = {
Utils.deleteRecursively(tempDir)
super.afterAll()
}

Expand Down Expand Up @@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
var i = 0
while (i < compressed.srcIds.size) {
while (i < compressed.srcIds.length) {
var j = compressed.dstPtrs(i)
while (j < compressed.dstPtrs(i + 1)) {
val dstEncodedIndex = compressed.dstEncodedIndices(j)
Expand Down Expand Up @@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2,
implicitPrefs = true, seed = 0)
}

test("read/write") {
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
val als = new ALS()
allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v)
}
val sqlContext = this.sqlContext
import sqlContext.implicits._
val model = als.fit(ratings.toDF())

// Test Estimator save/load
val als2 = testDefaultReadWrite(als)
allEstimatorParamSettings.foreach { case (p, v) =>
val param = als.getParam(p)
assert(als.get(param).get === als2.get(param).get)
}

// Test Model save/load
val model2 = testDefaultReadWrite(model)
allModelParamSettings.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
assert(model.rank === model2.rank)
def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
df.select("id", "features").collect().map { case r =>
(r.getInt(0), r.getAs[Array[Float]](1))
}.toSet
}
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}
}

object ALSSuite {

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allModelParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPredictionCol"
)

/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map(
"maxIter" -> 1,
"rank" -> 1,
"regParam" -> 0.01,
"numUserBlocks" -> 2,
"numItemBlocks" -> 2,
"implicitPrefs" -> true,
"alpha" -> 0.9,
"nonnegative" -> true,
"checkpointInterval" -> 20
)
}