Skip to content

Commit

Permalink
[SPARK-19591][ML][MLLIB] Add sample weights to decision trees
Browse files Browse the repository at this point in the history
This is updated PR apache#16722 to latest master

## What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.
## How was this patch tested?

The algorithms are tested to ensure that:
    1. Arbitrary scaling of constant weights has no effect
    2. Outliers with small weights do not affect the learned model
    3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.
## Summary of changes

   - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

   - Impurity aggregators now also hold the raw count.

   - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

   - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

   - TreePoint is modified to hold a sample weight

   - BaggedPoint is modified from:
``` Scala
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable
```
to
``` Scala
private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable
```
We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

**Note**: many of the changed files are due simply to using Instance instead of LabeledPoint

Closes apache#21632 from imatiach-msft/ilmat/sample-weights.

Authored-by: Ilya Matiach <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
imatiach-msft authored and srowen committed Jan 25, 2019
1 parent 3699763 commit b2d36f6
Show file tree
Hide file tree
Showing 31 changed files with 743 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ object TestingUtils {
/**
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,37 @@ abstract class Classifier[
* @note Throws `SparkException` if any label is a non-integer or is negative
*/
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
validateNumClasses(numClasses)
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
validateLabel(label, numClasses)
LabeledPoint(label, features)
}
}

/**
* Validates that number of classes is greater than zero.
*
* @param numClasses Number of classes label can take.
*/
protected def validateNumClasses(numClasses: Int): Unit = {
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
}

/**
* Validates the label on the classifier is a valid integer in the range [0, numClasses).
*
* @param label The label to validate.
* @param numClasses Number of classes label can take. Labels must be integers in the range
* [0, numClasses).
*/
protected def validateLabel(label: Double, numClasses: Int): Unit = {
require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
}

/**
* Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset

import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType

/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
Expand Down Expand Up @@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -97,29 +103,44 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
instr.logNumClasses(numClasses)

if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
validateNumClasses(numClasses)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
validateLabel(label, numClasses)
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures, numClasses)

instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)

val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeClassificationModel]
Expand All @@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
val instances = data.map(_.toInstance)
instr.logPipelineStage(this)
instr.logDataset(data)
instr.logDataset(instances)
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)

val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, instr = Some(instr), parentUID = Some(uid))
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
Expand Down Expand Up @@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (

/**
* Construct a decision tree classification model.
*
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.functions.{col, udf}

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a> learning algorithm for
Expand Down Expand Up @@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

Expand All @@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])

val numFeatures = trees.head.numFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
override def toString: String = {
s"($label,$features)"
}

private[spark] def toInstance(weight: Double): Instance = {
Instance(label, weight, features)
}

private[spark] def toInstance: Instance = {
Instance(label, 1.0, features)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
Expand All @@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType


/**
Expand Down Expand Up @@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)

Expand Down Expand Up @@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures)

instr.logPipelineStage(this)
instr.logDataset(oldDataset)
instr.logDataset(instances)
instr.logParams(this, params: _*)

val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeRegressionModel]
Expand All @@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
instr.logDataset(data)
instr.logParams(this, params: _*)

val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val instances = data.map(_.toInstance)
val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))

trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
Expand Down Expand Up @@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
* <a href="http://en.wikipedia.org/wiki/Decision_tree_learning">
* Decision tree (Wikipedia)</a> model for regression.
* It supports both continuous and categorical features.
*
* @param rootNode Root of the decision tree
*/
@Since("1.4.0")
Expand All @@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (

/**
* Construct a decision tree regression model.
*
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.functions.{col, udf}

/**
* <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forest</a>
Expand Down Expand Up @@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)

val instances = extractLabeledPoints(dataset).map(_.toInstance)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logDataset(instances)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])

val numFeatures = trees.head.numFeatures
Expand Down
Loading

0 comments on commit b2d36f6

Please sign in to comment.