Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename LabelParser.apply to LabelParser.parse
use extends for singleton
  • Loading branch information
mengxr committed Apr 7, 2014
commit c2e571c2572dbf1d59fbaa6b6dff177afa0b0f66
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.mllib.util
/** Trait for label parsers. */
trait LabelParser extends Serializable {
/** Parses a string label into a double label. */
def apply(labelString: String): Double
def parse(labelString: String): Double
}

/**
Expand All @@ -32,24 +32,22 @@ class BinaryLabelParser extends LabelParser {
* Parses the input label into positive (1.0) if the value is greater than 0.5,
* or negative (0.0) otherwise.
*/
override def apply(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
}

object BinaryLabelParser {
private lazy val instance = new BinaryLabelParser()
object BinaryLabelParser extends BinaryLabelParser {
/** Gets the default instance of BinaryLabelParser. */
def apply(): BinaryLabelParser = instance
def getInstance(): BinaryLabelParser = this
}

/**
* Label parser for multiclass labels, which converts the input label to double.
*/
class MulticlassLabelParser extends LabelParser {
override def apply(labelString: String): Double = labelString.toDouble
override def parse(labelString: String): Double = labelString.toDouble
}

object MulticlassLabelParser {
private lazy val instance = new MulticlassLabelParser()
object MulticlassLabelParser extends MulticlassLabelParser {
/** Gets the default instance of MulticlassLabelParser. */
def apply(): MulticlassLabelParser = instance
def getInstance(): MulticlassLabelParser = this
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ object MLUtils {
}.reduce(math.max)
}
parsed.map { items =>
val label = labelParser(items.head)
val label = labelParser.parse(items.head)
val (indices, values) = items.tail.map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1
Expand All @@ -96,7 +96,7 @@ object MLUtils {
* with number of features determined automatically and the default number of partitions.
*/
def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
loadLibSVMData(sc, path, BinaryLabelParser(), -1, sc.defaultMinSplits)
loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits)

/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString

val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser(), 6).collect()
val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()

for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
Expand All @@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}

val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser()).collect()
val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)
Expand Down