forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-19591][ML][MLLIB] Add sample weights to decision trees
This is updated PR apache#16722 to latest master ## What changes were proposed in this pull request? This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier. Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr. ## How was this patch tested? The algorithms are tested to ensure that: 1. Arbitrary scaling of constant weights has no effect 2. Outliers with small weights do not affect the learned model 3. Oversampling and weighting are equivalent Unit tests are also added to test other smaller components. ## Summary of changes - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode. - Impurity aggregators now also hold the raw count. - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight. - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added. - TreePoint is modified to hold a sample weight - BaggedPoint is modified from: ``` Scala private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable ``` to ``` Scala private[spark] class BaggedPoint[Datum]( val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double) extends Serializable ``` We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode **Note**: many of the changed files are due simply to using Instance instead of LabeledPoint Closes apache#21632 from imatiach-msft/ilmat/sample-weights. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
- Loading branch information
1 parent
3699763
commit b2d36f6
Showing
31 changed files
with
743 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.