Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address some review
  • Loading branch information
sethah committed Feb 1, 2017
commit 1db849417179b4cfc688cf9023ff225dac16ecfd
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import org.apache.spark.sql.types.DoubleType
class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeClassifierParams with HasWeightCol with DefaultParamsWritable {
with DecisionTreeClassifierParams with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtc"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.DoubleType

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -127,9 +127,7 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => Instance(label, 1.0, features)
}
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance(1.0))
Copy link
Contributor

@imatiach-msft imatiach-msft Feb 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor simplification -
it looks like this:
toInstance(1.0)
can just be simplified as:
toInstance

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: since you removed the overload now this comment is no longer valid.

val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.4.0")
class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeRegressorParams with DefaultParamsWritable with HasWeightCol {
with DecisionTreeRegressorParams with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtr"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -31,9 +30,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.DoubleType

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
Expand Down Expand Up @@ -117,9 +115,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => Instance(label, 1.0, features)
}
val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify to toInstance (without the 1.0)

val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams
with HasCheckpointInterval with HasSeed {
with HasCheckpointInterval with HasSeed with HasWeightCol {

/**
* Maximum depth of the tree (>= 0).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,11 @@ object MLTestingUtils extends SparkFunSuite {
model2: M): Unit = {
val pred1 = model1.transform(data).select(model1.getPredictionCol).collect()
val pred2 = model2.transform(data).select(model2.getPredictionCol).collect()
val inTol = pred1.zip(pred2).map { case (p1, p2) =>
val inTol = pred1.zip(pred2).count { case (p1, p2) =>
val x = p1.getDouble(0)
val y = p2.getDouble(0)
compareFunc(x, y)
}
assert(inTol.count(b => b) / pred1.length.toDouble >= fractionInTol)
assert(inTol / pred1.length.toDouble >= fractionInTol)
}
}