-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19591][ML][MLlib] Add sample weights to decision trees #16722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,18 +22,21 @@ import org.json4s.{DefaultFormats, JObject} | |
| import org.json4s.JsonDSL._ | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.{Instance, LabeledPoint} | ||
| import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams} | ||
| import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ | ||
| import org.apache.spark.ml.tree.impl.RandomForest | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} | ||
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.Dataset | ||
|
|
||
| import org.apache.spark.sql.{Dataset, Row} | ||
| import org.apache.spark.sql.functions.{col, lit} | ||
| import org.apache.spark.sql.types.DoubleType | ||
|
|
||
| /** | ||
| * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) | ||
|
|
@@ -45,7 +48,7 @@ import org.apache.spark.sql.Dataset | |
| class DecisionTreeClassifier @Since("1.4.0") ( | ||
| @Since("1.4.0") override val uid: String) | ||
| extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] | ||
| with DecisionTreeClassifierParams with DefaultParamsWritable { | ||
| with DecisionTreeClassifierParams with HasWeightCol with DefaultParamsWritable { | ||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("dtc")) | ||
|
|
@@ -65,6 +68,9 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) | ||
|
|
||
| @Since("1.4.0") | ||
| override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | ||
|
|
||
|
|
@@ -96,6 +102,16 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| @Since("1.6.0") | ||
| override def setSeed(value: Long): this.type = set(seed, value) | ||
|
|
||
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * If this is not set or empty, we treat all instance weights as 1.0. | ||
| * Default is not set, so all instances have weight one. | ||
| * | ||
| * @group setParam | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looks like by removing this method call you are removing some valuable validation logic (that exists in the base class). require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Actually this problem exists elsewhere (LogisticRegression, e.g.) What to do you think about adding it back manually here and then addressing the larger issue in a separate JIRA?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say that's fine if it was only in one place, but I also see this pattern in DecisionTreeRegressor.scala, it seems like we should be able to refactor this part out
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For regressors,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds reasonable, thanks for the explanation. |
||
| */ | ||
| @Since("2.2.0") | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
|
|
@@ -106,14 +122,18 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| ".train() called with non-matching numClasses and thresholds.length." + | ||
| s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") | ||
| } | ||
|
|
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) | ||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val instances = | ||
| dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| Instance(label, weight, features) | ||
| } | ||
| val strategy = getOldStrategy(categoricalFeatures, numClasses) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] | ||
|
|
@@ -124,11 +144,12 @@ class DecisionTreeClassifier @Since("1.4.0") ( | |
| /** (private[ml]) Train a decision tree on an RDD */ | ||
| private[ml] def train(data: RDD[LabeledPoint], | ||
| oldStrategy: OldStrategy): DecisionTreeClassificationModel = { | ||
| val instr = Instrumentation.create(this, data) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = 0L, instr = Some(instr), parentUID = Some(uid)) | ||
| val instances = data.map {lp => Instance(lp.label, 1.0, lp.features)} | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(params: _*) | ||
| val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, | ||
| featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeClassificationModel] | ||
| instr.logSuccess(m) | ||
|
|
@@ -176,6 +197,7 @@ class DecisionTreeClassificationModel private[ml] ( | |
|
|
||
| /** | ||
| * Construct a decision tree classification model. | ||
| * | ||
| * @param rootNode Root node of tree, with other nodes attached. | ||
| */ | ||
| private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,19 +21,20 @@ import org.json4s.{DefaultFormats, JObject} | |
| import org.json4s.JsonDSL._ | ||
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel} | ||
| import org.apache.spark.ml.tree.impl.RandomForest | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
| import org.apache.spark.ml.util.DefaultParamsReader.Metadata | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | ||
| import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.functions._ | ||
|
|
||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions.{col, udf} | ||
| import org.apache.spark.sql.types.DoubleType | ||
|
|
||
| /** | ||
| * <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for | ||
|
|
@@ -126,20 +127,22 @@ class RandomForestClassifier @Since("1.4.0") ( | |
| s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, it looks like some validation logic is missing here |
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) | ||
| val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, features: Vector) => Instance(label, 1.0, features) | ||
|
||
| } | ||
| val strategy = | ||
| super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, | ||
| impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, | ||
| minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) | ||
|
|
||
| val trees = RandomForest | ||
| .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | ||
| .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) | ||
| .map(_.asInstanceOf[DecisionTreeClassificationModel]) | ||
|
|
||
| val numFeatures = oldDataset.first().features.size | ||
| val numFeatures = instances.first().features.size | ||
| val m = new RandomForestClassificationModel(trees, numFeatures, numClasses) | ||
| instr.logSuccess(m) | ||
| m | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,4 +35,11 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: | |
| override def toString: String = { | ||
| s"($label,$features)" | ||
| } | ||
|
|
||
| private[spark] def toInstance: Instance = toInstance(1.0) | ||
|
||
|
|
||
| private[spark] def toInstance(weight: Double): Instance = { | ||
| Instance(label, weight, features) | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,9 @@ import org.json4s.JsonDSL._ | |
|
|
||
| import org.apache.spark.annotation.Since | ||
| import org.apache.spark.ml.{PredictionModel, Predictor} | ||
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.feature.{Instance, LabeledPoint} | ||
| import org.apache.spark.ml.linalg.Vector | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ | ||
|
|
@@ -33,8 +34,10 @@ import org.apache.spark.ml.util._ | |
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} | ||
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.{DataFrame, Dataset} | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.functions.{col, lit} | ||
| import org.apache.spark.sql.types.DoubleType | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -45,7 +48,7 @@ import org.apache.spark.sql.functions._ | |
| @Since("1.4.0") | ||
| class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
| extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] | ||
| with DecisionTreeRegressorParams with DefaultParamsWritable { | ||
| with DecisionTreeRegressorParams with DefaultParamsWritable with HasWeightCol { | ||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("dtr")) | ||
|
|
@@ -64,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) | ||
|
|
||
| @Since("1.4.0") | ||
| override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) | ||
|
|
||
|
|
@@ -99,16 +105,31 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| @Since("2.0.0") | ||
| def setVarianceCol(value: String): this.type = set(varianceCol, value) | ||
|
|
||
| /** | ||
| * Sets the value of param [[weightCol]]. | ||
| * If this is not set or empty, we treat all instance weights as 1.0. | ||
| * Default is not set, so all instances have weight one. | ||
| * | ||
| * @group setParam | ||
| */ | ||
| @Since("2.2.0") | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
|
||
| override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
| val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val instances = | ||
| dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| Instance(label, weight, features) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code above looks the same as the classifier, can we refactor somehow:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update: it sounds like you are going to create a separate JIRA for refactoring this code, that is reasonable to me. |
||
| val strategy = getOldStrategy(categoricalFeatures) | ||
|
|
||
| val instr = Instrumentation.create(this, oldDataset) | ||
| val instr = Instrumentation.create(this, instances) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] | ||
|
|
@@ -122,8 +143,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S | |
| val instr = Instrumentation.create(this, data) | ||
| instr.logParams(params: _*) | ||
|
|
||
| val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
| val instances = data.map {lp => Instance(lp.label, 1.0, lp.features)} | ||
| val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, | ||
| featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) | ||
|
|
||
| val m = trees.head.asInstanceOf[DecisionTreeRegressionModel] | ||
| instr.logSuccess(m) | ||
|
|
@@ -153,6 +175,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor | |
| * <a href="http://en.wikipedia.org/wiki/Decision_tree_learning"> | ||
| * Decision tree (Wikipedia)</a> model for regression. | ||
| * It supports both continuous and categorical features. | ||
| * | ||
| * @param rootNode Root of the decision tree | ||
| */ | ||
| @Since("1.4.0") | ||
|
|
@@ -171,6 +194,7 @@ class DecisionTreeRegressionModel private[ml] ( | |
|
|
||
| /** | ||
| * Construct a decision tree regression model. | ||
| * | ||
| * @param rootNode Root node of tree, with other nodes attached. | ||
| */ | ||
| private[ml] def this(rootNode: Node, numFeatures: Int) = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was there a specific reason not to put this on the DecisionTreeClassifierParams.
it looks like the other classifiers that have this are:
LinearSVC
LogisticRegression
NaïveBayes
and regressors:
GeneralizedLinearRegression
IsotonicRegression
LinearRegression
and all have it on the params, not on the class.
However, I do agree with you that it really makes no sense for the model to have this settable, although it may be useful for users to get the information on the model.