Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.mllib.linalg._
Expand All @@ -50,6 +51,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -923,6 +925,14 @@ private[python] class PythonMLLibAPI extends Serializable {
RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s)
}

/**
* Java stub for the constructor of Python mllib RankingMetrics
*/
def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = {
new RankingMetrics(predictionAndLabels.map(
r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
}


}

Expand Down
78 changes: 76 additions & 2 deletions python/pyspark/mllib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# limitations under the License.
#

from pyspark.mllib.common import JavaModelWrapper
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
from pyspark.sql import SQLContext
from pyspark.sql.types import StructField, StructType, DoubleType
from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType

__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
'MulticlassMetrics', 'RankingMetrics']


class BinaryClassificationMetrics(JavaModelWrapper):
Expand Down Expand Up @@ -270,6 +273,77 @@ def weightedFMeasure(self, beta=None):
return self.call("weightedFMeasure", beta)


class RankingMetrics(JavaModelWrapper):
"""
Evaluator for ranking algorithms.

>>> predictionAndLabels = sc.parallelize([
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
... ([1, 2, 3, 4, 5], [])])
>>> metrics = RankingMetrics(predictionAndLabels)
>>> metrics.precisionAt(1)
0.33...
>>> metrics.precisionAt(5)
0.26...
>>> metrics.precisionAt(15)
0.17...
>>> metrics.meanAveragePrecision
0.35...
>>> metrics.ndcgAt(3)
0.33...
>>> metrics.ndcgAt(10)
0.48...

"""

def __init__(self, predictionAndLabels):
"""
:param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
"""
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
schema=sql_ctx._inferSchema(predictionAndLabels))
java_model = callMLlibFunc("newRankingMetrics", df._jdf)
super(RankingMetrics, self).__init__(java_model)

def precisionAt(self, k):
"""
Compute the average precision of all the queries, truncated at ranking position k.

If for a query, the ranking algorithm returns n (n < k) results, the precision value
will be computed as #(relevant items retrieved) / k. This formula also applies when
the size of the ground truth set is less than k.

If a query has an empty ground truth set, zero will be used as precision together
with a log warning.
"""
return self.call("precisionAt", int(k))

@property
def meanAveragePrecision(self):
"""
Returns the mean average precision (MAP) of all the queries.
If a query has an empty ground truth set, the average precision will be zero and
a log warining is generated.
"""
return self.call("meanAveragePrecision")

def ndcgAt(self, k):
"""
Compute the average NDCG value of all the queries, truncated at ranking position k.
The discounted cumulative gain at position k is computed as:
sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
and the NDCG is obtained by dividing the DCG value on the ground truth set.
In the current implementation, the relevance value is binary.

If a query has an empty ground truth set, zero will be used as ndcg together with
a log warning.
"""
return self.call("ndcgAt", int(k))


def _test():
import doctest
from pyspark import SparkContext
Expand Down