Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP
  • Loading branch information
BigCrunsh committed Sep 1, 2014
commit 140f09c694ceb64cb246e1aed5b1fc7035d91a18
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.GeneralizedLinearModel
import org.apache.spark.api.java.JavaRDD
import scala.deprecated

/**
* Represents a classification model that predicts to which of a set of categories an example
Expand All @@ -33,7 +32,8 @@ class BinaryClassificationModel (
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {

protected var threshold: Double = 0.0
@deprecated

// this is only used to ensure prior behaviour of deprecated `predict``
protected var useThreshold: Boolean = true

/**
Expand Down Expand Up @@ -61,7 +61,7 @@ class BinaryClassificationModel (
}

/**
* :: Deprecated ::
* DEPRECATED: Use predictScore(...) or predictClass(...) instead
* Clears the threshold so that `predict` will output raw prediction scores.
*/
@Deprecated
Expand All @@ -71,38 +71,50 @@ class BinaryClassificationModel (
}

/**
* :: Deprecated ::
* DEPRECATED: Use predictScore(...) or predictClass(...) instead
*/
@Deprecated
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
if (useThreshold) predictClass(dataMatrix)
else predictScore(dataMatrix)
}

/**
* DEPRECATED: Use predictScore(...) or predictClass(...) instead
* Predict values for the given data set using the model trained.
*
* @param testData RDD representing data points to be predicted
* @return an RDD[Double] where each entry contains the corresponding prediction
*/
@deprecated
@Deprecated
override def predict(testData: RDD[Vector]): RDD[Double] = {
if (useThreshold) predictClass(testData)
else predictScore(testData)
}

/**
* :: Deprecated ::
* DEPRECATED: Use predictScore(...) or predictClass(...) instead
* Predict values for a single data point using the model trained.
*
* @param testData array representing a single data point
* @return predicted category from the trained model
*/
@deprecated
def predict(testData: Vector): Double = {
@Deprecated
override def predict(testData: Vector): Double = {
if (useThreshold) predictClass(testData)
else predictScore(testData)
}

/**
* :: Deprecated ::
* DEPRECATED: Use predictScore(...) or predictClass(...) instead
* Predict values for examples stored in a JavaRDD.
* @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
*/
@deprecated
@Deprecated
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class LogisticRegressionModel (
computeProbability(predictScore(testData))
}

@deprecated
/**
* DEPRECATED: Use predictProbability(...) or predictClass(...) instead
*/
@Deprecated
override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
if (useThreshold) predictClass(dataMatrix)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,15 @@ class NaiveBayesModel private[mllib] (
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}

@deprecated
/**
* DEPRECATED: Use predictClass(...) instead
*/
@Deprecated
def predict(testData: RDD[Vector]): RDD[Double] = predictClass(testData)

/**
* DEPRECATED: Use predictClass(...) instead
*/
@deprecated
def predict(testData: Vector): Double = predictClass(testData)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ class SVMModel (
override val weights: Vector,
override val intercept: Double)
extends BinaryClassificationModel(weights, intercept) {

@deprecated
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) = {
if (useThreshold) predictClass(dataMatrix)
else predictScore(dataMatrix)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
* @param weightMatrix Column vector containing the weights of the model
* @param intercept Intercept of the model.
*/
@deprecated
@Deprecated
protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double

/**
Expand All @@ -96,7 +96,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
* @param testData RDD representing data points to be predicted
* @return RDD[Double] where each entry contains the corresponding prediction
*/
@deprecated
@Deprecated
def predict(testData: RDD[Vector]): RDD[Double] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
Expand All @@ -116,7 +116,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
* @param testData array representing a single data point
* @return Double prediction from the trained model
*/
@deprecated
@Deprecated
def predict(testData: Vector): Double = {
predictPoint(testData, weights, intercept)
}
Expand Down
16 changes: 10 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ class LassoModel (
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override protected def computeScore(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double = {
weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
/**
* DEPRECATED: Use predictScore(...) instead
*/
@Deprecated
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double): Double =
predictScore(dataMatrix)

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ class LinearRegressionModel (
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {

@deprecated

/**
* DEPRECATED: Use predictScore(...) instead
*/
@Deprecated
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ class RidgeRegressionModel (
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

@deprecated
/**
* DEPRECATED: Use predictScore(...) instead
*/
@Deprecated
override protected def predictPoint(
dataMatrix: Vector,
weightMatrix: Vector,
intercept: Double) =
intercept: Double): Double =
predictScore(dataMatrix)

}
Expand Down