Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add weights to dt
  • Loading branch information
sethah committed Jan 27, 2017
commit 2d86cea640634a205e378bddee0b01780d019ea2
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset

import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand All @@ -45,7 +48,7 @@ import org.apache.spark.sql.Dataset
class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeClassifierParams with DefaultParamsWritable {
with DecisionTreeClassifierParams with HasWeightCol with DefaultParamsWritable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there a specific reason not to put this on the DecisionTreeClassifierParams.
it looks like the other classifiers that have this are:
LinearSVC
LogisticRegression
NaïveBayes
and regressors:
GeneralizedLinearRegression
IsotonicRegression
LinearRegression
and all have it on the params, not on the class.
However, I do agree with you that it really makes no sense for the model to have this settable, although it may be useful for users to get the information on the model.


@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtc"))
Expand All @@ -65,6 +68,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("2.2.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -96,6 +102,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like by removing this method call you are removing some valuable validation logic (that exists in the base class).
specifically, this is the logic:
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")

require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Actually this problem exists elsewhere (LogisticRegression, e.g.) What to do you think about adding it back manually here and then addressing the larger issue in a separate JIRA?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that's fine if it was only in one place, but I also see this pattern in DecisionTreeRegressor.scala, it seems like we should be able to refactor this part out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For regressors, extractLabeledPoints doesn't do any extra checking. The larger issue is that we are manually "extracting instances" but we have convenience methods for labeled points. Since correcting it now, in this PR, likely means implementing the framework to correct it everywhere - which is a larger and orthogonal change, I think we could just add the check manually to the classifier, then create a JIRA that addresses consolidating these, probably by adding extractInstances methods analogous their labeled point counterparts. This PR is large enough as is, without having to think about adding that method, then implementing it in all the other algos that manually extract instances, IMO.

Copy link
Contributor

@imatiach-msft imatiach-msft Feb 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable, thanks for the explanation.

*/
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
Expand All @@ -106,14 +122,18 @@ class DecisionTreeClassifier @Since("1.4.0") (
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures, numClasses)

val instr = Instrumentation.create(this, oldDataset)
val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)

val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))

val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
Expand All @@ -124,11 +144,12 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)

val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))
val instances = data.map {lp => Instance(lp.label, 1.0, lp.features)}
val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))

val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
instr.logSuccess(m)
Expand Down Expand Up @@ -176,6 +197,7 @@ class DecisionTreeClassificationModel private[ml] (

/**
* Construct a decision tree classification model.
*
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.DoubleType

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -126,20 +127,22 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, it looks like some validation logic is missing here

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => Instance(label, 1.0, features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like we aren't getting the weight column here; not sure why this file needed to be changed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to pass in RDD[Instance] to RandomForest.run. I changed this back to use extractLabeledPoints

}
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

val instr = Instrumentation.create(this, oldDataset)
val instr = Instrumentation.create(this, instances)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])

val numFeatures = oldDataset.first().features.size
val numFeatures = instances.first().features.size
val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
instr.logSuccess(m)
m
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ import org.apache.spark.ml.linalg.Vector
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
private[spark] case class Instance(label: Double, weight: Double, features: Vector)
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,11 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
override def toString: String = {
s"($label,$features)"
}

private[spark] def toInstance: Instance = toInstance(1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kind of a nit pick, and optional, but I would usually refactor out magic numbers like 1.0 as something like "defaultWeight" and reuse it elsewhere, but it's not really necessary in this case since it probably won't ever change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'd prefer to remove the no arg function and be explicit everywhere. That way there is no ambiguity or unintended effects if someone changes the default value. Sound ok?


private[spark] def toInstance(weight: Double): Instance = {
Instance(label, weight, features)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
Expand All @@ -33,8 +34,10 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType


/**
Expand All @@ -45,7 +48,7 @@ import org.apache.spark.sql.functions._
@Since("1.4.0")
class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeRegressorParams with DefaultParamsWritable {
with DecisionTreeRegressorParams with DefaultParamsWritable with HasWeightCol {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtr"))
Expand All @@ -64,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("2.2.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -99,16 +105,31 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.2.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code above looks the same as the classifier, can we refactor somehow:

val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) 
val instances = 
  dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { 
    case Row(label: Double, weight: Double, features: Vector) => 
      Instance(label, weight, features) 

Copy link
Contributor

@imatiach-msft imatiach-msft Feb 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: it sounds like you are going to create a separate JIRA for refactoring this code, that is reasonable to me.

val strategy = getOldStrategy(categoricalFeatures)

val instr = Instrumentation.create(this, oldDataset)
val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)

val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))

val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
Expand All @@ -122,8 +143,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)

val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val instances = data.map {lp => Instance(lp.label, 1.0, lp.features)}
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid))

val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
instr.logSuccess(m)
Expand Down Expand Up @@ -153,6 +175,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
* Decision tree (Wikipedia)</a> model for regression.
* It supports both continuous and categorical features.
*
* @param rootNode Root of the decision tree
*/
@Since("1.4.0")
Expand All @@ -171,6 +194,7 @@ class DecisionTreeRegressionModel private[ml] (

/**
* Construct a decision tree regression model.
*
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -31,10 +31,9 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.DoubleType

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
Expand Down Expand Up @@ -65,7 +64,6 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -117,20 +115,22 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)

val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => Instance(label, 1.0, features)
}
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)

val instr = Instrumentation.create(this, oldDataset)
val instr = Instrumentation.create(this, instances)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = oldDataset.first().features.size
val numFeatures = instances.first().features.size
val m = new RandomForestRegressionModel(trees, numFeatures)
instr.logSuccess(m)
m
Expand Down
Loading