spark mllib源码分析之随机森林(Random Forest)
Spark在mllib中實現了tree相關的算法,決策樹DT(DecisionTree),隨機森林RF(RandomForest),GBDT(Gradient Boosting Decision Tree),其基礎都是RF,DT是RF一棵樹時的情況,而GBDT則是循環構建DT,GBDT與DT的代碼是非常簡單明了的,本文將分成五部分分別對RF的源碼進行分析,介紹spark在實現過程中使用的一些技巧。
1. 決策樹與隨機森林
首先對決策樹和隨機森林進行簡單的回顧。
1.1. 決策樹
在決策樹的訓練中,如上圖所示,就是從根節點開始,不斷的分裂,直到觸發截止條件,在節點的分裂過程中要解決的問題其實就2個
分裂點:一般就是遍歷所有特征的所有特征值,選取impurity最大的分成左右孩子節點,impurity的選取有信息熵(分類),最小均方差(回歸)等方法
預測值:一般取當前最多的class(分類)或者取均值(回歸)
1.2. 隨機森林
隨機森林就是構建多棵決策樹投票,在構建多棵樹過程中,引入隨機性,一般體現在兩個方面,一是每棵樹使用的樣本進行隨機抽樣,分為有放回和無放回抽樣。二是對每棵樹使用的特征集進行抽樣,使用部分特征訓練。?
在訓練過程中,如果單機內存能放下所有樣本,可以用多線程同時訓練多棵樹,樹之間的訓練互不影響。
2. spark RF優化策略
spark在實現RF時,使用了一些優化技巧,提高訓練效率。
2.1. 逐層訓練
當樣本量過大,單機無法容納時,只能采用分布式的訓練方法,數據是在集群中的多臺機器存放,如果按照單機的方法,每棵樹完全獨立訪問樣本數據,則樣本數據的訪問次數為數的個數k*每棵樹的節點數N,相當于深度遍歷。在spark的實現中,因為數據存放在不同的機器上,頻繁的訪問數據效率非常低,因此采用廣度遍歷的方法,每次構造所有樹的一層,例如如果要訓練10棵樹,第一次構造所有樹的第一層根節點,第二次構造所有深度為2的節點,以此類推,這樣訪問數據的次數降為樹的最大深度,大大減少了機器之間的通信,提高訓練效率。
2.2. 樣本抽樣
當樣本存在連續特征時,其可能的取值可能是無限的,存儲其可能出現的值占用較大空間,因此spark對樣本進行了抽樣,抽樣數量
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
最少抽樣1萬條,當然這樣會降低模型精度。
2.3. 特征裝箱
其實沒什么神秘的,每個離散特征值(對于連續特征,先離散化)稱為一個Split,上下限[lowSplit, highSplit]組成一個bin,也就是特征裝箱,默認的maxBins是32。對于連續特征,離散化時的bin的個數就是maxBins,采用等頻離散化;對于有序的離散特征,bin的個數是特征值個數+1;對于無序離散特征,bin的個數是2^(M-1)-1,M是特征值個數
3. 源碼分析
我們從官方給出的分類demo開始,逐層分析其實現
3.1. 訓練數據的解析
主要是LabelPoint的構造,官方demo中要求訓練數據是LibSVM格式的
parsed.map { case (label, indices, values) =>
? ? ? LabeledPoint(label, Vectors.sparse(d, indices, values))
? ? }
可以看到LabelPoint有兩個成員,第一個是樣本label,第二個是稀疏向量SparseVector,d是其size,在這里其實是特征數,indices是實際非0特征的index,values里面是實際的特征值,這里需要注意的是,SVN格式的特征index是從0開始的,這里進行了-1,變成從0開始了。
3.2. demo中訓練參數說明
官方demo中只設置了部分參數
val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
? ? ? numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
categoricalFeaturesInfo:Map[Int, Int],key是特征的index,value為特征值的個數(或者說幾種),這里值得注意的是,因為LabelPoint中進行了index-1的變換,這個里面的key也需要-1(參見后面metadata的numBins的計算)。例如性別這個特征在樣本中的index為1,特征值男/女兩種,則0->2
featureSubsetStrategy:特征子集的抽取方法,支持”auto”, “all”, “sqrt”, “log2”, “onethird”
impurity:不純度,其實就是節點分裂時的衡量準則,例如信息熵,均方差等,這里支持三種,gini(基尼指數),entripy(信息熵),variance(均方差)
maxDepth:樹的最大深度
maxBins:最大裝箱數,或者說是特征的最大可能切分數+1。這個值必須大于等于最大的離散特征值數
3.3. 參數封裝
spark根據用戶提供的參數值,進行實際訓練參數的計算,并且將這些參數封裝成類,方便傳遞。
3.3.1. Strategy
class Strategy @Since("1.3.0") (
? ? @Since("1.0.0") @BeanProperty var algo: Algo,
? ? @Since("1.0.0") @BeanProperty var impurity: Impurity,
? ? @Since("1.0.0") @BeanProperty var maxDepth: Int,
? ? @Since("1.2.0") @BeanProperty var numClasses: Int = 2,
? ? @Since("1.0.0") @BeanProperty var maxBins: Int = 32,
? ? @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
? ? @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
? ? @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
? ? @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
? ? @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
? ? @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
? ? @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
? ? @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10)
algo:classification/regression
quantileCalculationStrategy:分位點(Split)策略,目前只支持Sort,對于連續型特征值,先把特征值進行排序,然后按次序取分位點。從代碼中可以看到原來可能打算實現的MinMax和ApproxHist目前沒有實現。
minInstancesPerNode:每個樹節點中最小的樣本數,低于將不再對節點進行分裂,默認為1,可作為提前截止條件
minInfoGain:最小增益,節點分裂后的增益如果小于它,將不再進行分裂,可作為提前截止條件
subsamplingRate:樣本抽樣率,默認為1,每棵樹都使用全部樣本
isMulticlassClassification:是否是多分類,判斷條件為Classification 并且類別>2
isMulticlassWithCategoricalFeatures:是否是帶類別特征的多分類,判斷條件再上面的基礎上加categoricalFeaturesInfo的size大于0
3.3.2. metadata
在buildMetadata中根據strategy計算得到DecisionTreeMetadata的參數。
class DecisionTreeMetadata(
? ? val numFeatures: Int,
? ? val numExamples: Long,
? ? val numClasses: Int,
? ? val maxBins: Int,
? ? val featureArity: Map[Int, Int],
? ? val unorderedFeatures: Set[Int],
? ? val numBins: Array[Int],
? ? val impurity: Impurity,
? ? val quantileStrategy: QuantileStrategy,
? ? val maxDepth: Int,
? ? val minInstancesPerNode: Int,
? ? val minInfoGain: Double,
? ? val numTrees: Int,
? ? val numFeaturesPerNode: Int)
部分參數同Strategy,對額外參數和區別說明
numClasses:如為Regression,設為0
maxPossibleBins:取maxBins和樣本數量中較小的;必須大于categoricalFeaturesInfo中的最大的離散特征值數
numBins:所有特征及其特征值數,Int數組,維數是特征數,默認大小是maxPossibleBins。對于連續特征,其值就是默認值maxPossibleBins。對于離散特征,如為二分類或回歸,此處將categoricalFeaturesInfo中的key特征index作為數組index,value特征個數寫入數組中(這里有疑問,SVM格式的index是從1開始的,因此對numBins的index應該是categoricalFeaturesInfo的key-1,這里沒有-1,當最大值等于maxBins的時候訪問數組會拋異常);如果是多分類,先計算其當做當UnorderedFeature(無序的離散特征)的bin,如果個數小于等于maxPossibleBins,會被當成UnorderedFeature,否則被當成orderedFeatures(為了防止計算指數溢出,實際是把maxPossibleBins取log與特征數比較),因為UnorderedFeature的bin是比較大,這里限制了其特征值不能太多,這里僅僅根據特征值的特殊決定是否是ordered,不太好。每個split要將所有特征值分成2部分,bin的數量也就是2*split,因此bin的個數是2*(2^(M-1)-1)
numFeaturesPerNode:由featureSubsetStrategy決定,如果為“auto”,且為單棵樹,則使用全部特征;如為多棵樹,分類則是sqrt,回歸為1/3;也可以自己指定,支持”all”, “sqrt”, “log2”, “onethird”。?
如果僅對RF的使用感興趣,了解上述訓練參數也就可以了,后面的文章將對其訓練代碼進行分析。
?
4. 特征處理
這部分主要在DecisionTree.scala的findSplitsBins函數,將所有特征封裝成Split,然后裝箱Bin。首先對split和bin的結構進行說明
4.1. 數據結構
4.1.1. Split
class Split(
? ? @Since("1.0.0") feature: Int,
? ? @Since("1.0.0") threshold: Double,
? ? @Since("1.0.0") featureType: FeatureType,
? ? @Since("1.0.0") categories: List[Double])
feature:特征id
threshold:閾值
featureType:連續特征(Continuous)/離散特征(Categorical)
categories:離散特征值數組,離散特征使用。放著此split中所有特征值
4.1.2. Bin
class Bin(
? ? lowSplit: Split,?
? ? highSplit: Split,?
? ? featureType: FeatureType,?
? ? category: Double)
lowSplit/highSplit:上下界
featureType:連續特征(Continuous)/離散特征(Categorical)
category:離散特征的特征值
4.2. 連續特征處理
4.2.1. 抽樣
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
val sampledInput = if (continuousFeatures.nonEmpty) {
? ? ? // Calculate the number of samples for approximate quantile calculation.
? ? ? val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
? ? ? val fraction = if (requiredSamples < metadata.numExamples) {
? ? ? ? requiredSamples.toDouble / metadata.numExamples
? ? ? } else {
? ? ? ? 1.0
? ? ? }
? ? ? logDebug("fraction of data used for calculating quantiles = " + fraction)
? ? ? input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
? ? } else {
? ? ? input.sparkContext.emptyRDD[LabeledPoint]
? ? }
首先篩選出連續特征集,然后計算抽樣數量,抽樣比例,然后無放回樣本抽樣;如果沒有連續特征,則為空RDD
4.2.2. 計算Split
metadata.quantileStrategy match {
? ? ? case Sort =>
? ? ? ? findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
? ? ? case MinMax =>
? ? ? ? throw new UnsupportedOperationException("minmax not supported yet.")
? ? ? case ApproxHist =>
? ? ? ? throw new UnsupportedOperationException("approximate histogram not supported yet.")
? ? }
分位點策略,這里只實現了Sort這一種,前文有說明,下面的計算在findSplitsBinsBySorting函數中,入參是抽樣樣本集,metadata和連續特征集(里面是特征id,從0開始,見LabelPoint的構造)
val continuousSplits = {
? ? // reduce the parallelism for split computations when there are less
? ? // continuous features than input partitions. this prevents tasks from
? ? // being spun up that will definitely do no work.
? ? val numPartitions = math.min(continuousFeatures.length,input.partitions.length)
? ? input.flatMap(point => continuousFeatures.map(idx => ?(idx,point.features(idx))))
? ? ? ? ?.groupByKey(numPartitions)
? ? ? ? ?.map { case (k, v) => findSplits(k, v) }
? ? ? ? ?.collectAsMap()
? ? }
特征id為key,value是樣本對應的該特征下的所有特征值,傳給findSplits函數,其中又調用了findSplitsForContinuousFeature函數獲得連續特征的Split,入參為樣本,metadata和特征id
def findSplitsForContinuousFeature(
? ? ? featureSamples: Array[Double],?
? ? ? metadata: DecisionTreeMetadata,
? ? ? featureIndex: Int): Array[Double] = {
? ? require(metadata.isContinuous(featureIndex),
? ? ? "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
? ? val splits = {
? ? //連續特征的split是numBins-1
? ? ? val numSplits = metadata.numSplits(featureIndex)
? ? //統計所有特征值其出現的次數
? ? ? // get count for each distinct value
? ? ? val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
? ? ? ? m + ((x, m.getOrElse(x, 0) + 1))
? ? ? }
? ? ? //按特征值排序
? ? ? // sort distinct values
? ? ? val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
? ? ? // if possible splits is not enough or just enough, just return all possible splits
? ? ? val possibleSplits = valueCounts.length
? ? ? if (possibleSplits <= numSplits) {
? ? ? ? valueCounts.map(_._1)
? ? ? } else {
? ? ? //等頻離散化
? ? ? ? // stride between splits
? ? ? ? val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
? ? ? ? logDebug("stride = " + stride)
? ? ? ? // iterate `valueCount` to find splits
? ? ? ? val splitsBuilder = Array.newBuilder[Double]
? ? ? ? var index = 1
? ? ? ? // currentCount: sum of counts of values that have been visited
? ? ? ? var currentCount = valueCounts(0)._2
? ? ? ? // targetCount: target value for `currentCount`.
? ? ? ? // If `currentCount` is closest value to `targetCount`,
? ? ? ? // then current value is a split threshold.
? ? ? ? // After finding a split threshold, `targetCount` is added by stride.
? ? ? ? var targetCount = stride
? ? ? ? while (index < valueCounts.length) {
? ? ? ? ? val previousCount = currentCount
? ? ? ? ? currentCount += valueCounts(index)._2
? ? ? ? ? val previousGap = math.abs(previousCount - targetCount)
? ? ? ? ? val currentGap = math.abs(currentCount - targetCount)
? ? ? ? ? // If adding count of current value to currentCount
? ? ? ? ? // makes the gap between currentCount and targetCount smaller,
? ? ? ? ? // previous value is a split threshold.
? ? ? ? ? //每次步進targetCount個樣本,取上一個特征值與下一個特征值gap較小的
? ? ? ? ? if (previousGap < currentGap) {
? ? ? ? ? ? splitsBuilder += valueCounts(index - 1)._1
? ? ? ? ? ? targetCount += stride
? ? ? ? ? }
? ? ? ? ? index += 1
? ? ? ? }
? ? ? ? splitsBuilder.result()
? ? ? }
? ? }
? ? // TODO: Do not fail; just ignore the useless feature.
? ? assert(splits.length > 0,
? ? ? s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
? ? ? ? " ?Please remove this feature and then try again.")
? ? // the split metadata must be updated on the driver
? ? splits
? }
在構造split的過程中,如果統計到的值的個數possibleSplits 還不如你設置的numSplits多,那么所有的值都作為分割點;否則,用等頻分隔法,首先計算分隔步長stride,然后再循環中每次累加到targetCount中,作為理想分割點,但是理想分割點可能會包含的特征值過多,想取一個里理想分割點盡量近的特征值,例如,理想分割點是100,落在特征值fcfc里,但是當前特征值里面有30個樣本,而前一個特征值fpfp只有5個樣本,因此我們如果取fcfc作為split,則當前區間實際多25個樣本,如果取fpfp,則少5個樣本,顯然取fpfp更為合理。?
具體到代碼實現,在if判斷里步進stride個樣本,累加在targetCount中。while循環逐次把每個特征值的個數加到currentCount里,計算前一次previousCount和這次currentCount到targetCount的距離,有3種情況,一種是pre和cur都在target左邊,肯定是cur小,繼續循環,進入第二種情況;第二種一左一右,如果pre小,肯定是pre是最好的分割點,如果cur還是小,繼續循環步進,進入第三種情況;第三種就是都在右邊,顯然是pre小。因此if的判斷條件pre<curpre<cur,只要滿足肯定就是split。整體下來的效果就能找到離target最近的一個特征值。?
findSplits函數使用本函數得到的離散化點作為threshold,構造Split
val splits = {
? ? val featureSplits = findSplitsForContinuousFeature(
? ? ? ? ? featureSamples.toArray,
? ? ? ? ? metadata,
? ? ? ? ? featureIndex)
? ? logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
? ? featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
}
這樣就得到了連續特征所有的Split?
4.2.3. 計算bin?
得到splits后,即可類似滑窗得到bin的上下界,構造bins
val bins = {
? ? val lowSplit = new DummyLowSplit(featureIndex, Continuous)
? ? val highSplit = new DummyHighSplit(featureIndex, Continuous)
? ? // tack the dummy splits on either side of the computed splits
? ? val allSplits = lowSplit +: splits.toSeq :+ highSplit
? ? // slide across the split points pairwise to allocate the bins
? ? allSplits.sliding(2).map {
? ? ? ? ?case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
? ? }.toArray
}
在計算splits的時候,個數是bin的個數減1,這里加上第一個DummyLowSplit(threshold是Double.MinValue),和最后一個DummyHighSplit(threshold是Double.MaxValue)構造的bin,恰好個數是numBins中的個數
4.3. 離散特征
bin的主要作用其實就是用來做連續特征離散化,離散特征是用不著的。?
對有序離散特征而言,其split直接用特征值表征,因此這里的splits和bins都是空的Array。?
對于無序離散特征而言,其split是特征值的組合,不是簡單的上下界比較關系,bin是空Array,而split需要計算。
4.3.1. split
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
? ? val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
? ? new Split(i, Double.MinValue, Categorical, categories)
}
featureArity來自參數categoricalFeaturesInfo中設置的離散特征的特征值數。?
metadata.numSplits是吧numBins中的數量/2,相當于返回了2^(M-1)-1,M是特征值數。?
調用extractMultiClassCategories函數,入參是1到2^(M-1)和特征數M。
/**
? ?* Nested method to extract list of eligible categories given an index. It extracts the
? ?* position of ones in a binary representation of the input. If binary
? ?* representation of an number is 01101 (13), the output list should (3.0, 2.0,
? ?* 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
? ?*/
def extractMultiClassCategories(
? ? ?input: Int,
? ? ?maxFeatureValue: Int): List[Double] = {
? ? var categories = List[Double]()
? ? var j = 0
? ? var bitShiftedInput = input
? ? while (j < maxFeatureValue) {
? ? ? if (bitShiftedInput % 2 != 0) {
? ? ? ? // updating the list of categories.
? ? ? ? categories = j.toDouble :: categories
? ? ? }
? ? ? // Right shift by one
? ? ? bitShiftedInput = bitShiftedInput >> 1
? ? ? j += 1
? ? }
? ? categories
}
如注釋所述,這個函數返回給定的input的二進制表示中1的index,這里實際返回的是特征的組合,之前文章介紹過的《組合數》。
5. 樣本處理
將輸入樣本LabelPoint與上述特征進一步封裝,方便后面進行分區統計。
5.1. TreePoint
構造TreePoint的過程,是一系列函數的調用鏈,我們逐層分析。
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
RandomForest.scala中將輸入轉化成TreePoint的rdd,調用convertToTreeRDD函數
def convertToTreeRDD(
? ? input: RDD[LabeledPoint],
? ? bins: Array[Array[Bin]],
? ? metadata: DecisionTreeMetadata): RDD[TreePoint] = {
? ? // Construct arrays for featureArity for efficiency in the inner loop.
? ? val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
? ? var featureIndex = 0
? ? while (featureIndex < metadata.numFeatures) {
? ? ? featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
? ? ? featureIndex += 1
? ? }
? ? input.map { x =>
? ? ? TreePoint.labeledPointToTreePoint(x, bins, featureArity)
? ? }
? }
convertToTreeRDD函數的入參input是所有樣本,bins是二維數組,第一維是特征,第二維是特征的Bin數組。函數首先計算每個特征的特征數量,放在featureArity中,如果是連續特征,設為0。對每個樣本調用labeledPointToTreePoint函數,構造TreePoint。
private def labeledPointToTreePoint(
? ? ? labeledPoint: LabeledPoint,
? ? ? bins: Array[Array[Bin]],
? ? ? featureArity: Array[Int]): TreePoint = {
? ? val numFeatures = labeledPoint.features.size
? ? val arr = new Array[Int](numFeatures)
? ? var featureIndex = 0
? ? while (featureIndex < numFeatures) {
? ? ? arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
? ? ? ? bins)
? ? ? featureIndex += 1
? ? }
? ? new TreePoint(labeledPoint.label, arr)
? }
labeledPointToTreePoint計算每個樣本的所有特征對應的特征值屬于哪個bin,放在在arr數組中;如果是連續特征,存放的實際是binIndex,或者說是第幾個bin;如果是離散特征,直接featureValue.toInt,這其實暗示著,對有序離散值,其編碼只能是[0,featureArity - 1],閉區間,其后的部分邏輯也依賴于這個假設。這部分是在findBin函數中完成的,這里不再贅述。?
我們在這里把TreePoint的成員再羅列一下,方便查閱
class TreePoint(val label: Double, val binnedFeatures: Array[Int])
這里是把每個樣本從LabelPoint轉換成TreePoint,label就是樣本label,binnedFeatures就是上述的arr數組。
5.2. BaggedPoint
同理構造BaggedPoint的過程,也是一系列函數的調用鏈,我們逐層分析。
val withReplacement = if (numTrees > 1) true else false
val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
? ? ? ? ? strategy.subsamplingRate, numTrees,
? ? ? ? ? withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
這里同時對樣本進行了抽樣,如果樹個數大于1,就有放回抽樣,否則無放回抽樣,調用convertToTreeRDD函數將TreePoint轉化成BaggedPoint的rdd
/**
? ?* Convert an input dataset into its BaggedPoint representation,
? ?* choosing subsamplingRate counts for each instance.
? ?* Each subsamplingRate has the same number of instances as the original dataset,
? ?* and is created by subsampling without replacement.
? ?* @param input Input dataset.
? ?* @param subsamplingRate Fraction of the training data used for learning decision tree.
? ?* @param numSubsamples Number of subsamples of this RDD to take.
? ?* @param withReplacement Sampling with/without replacement.
? ?* @param seed Random seed.
? ?* @return BaggedPoint dataset representation.
? ?*/
? def convertToBaggedRDD[Datum] (
? ? ? input: RDD[Datum],
? ? ? subsamplingRate: Double,
? ? ? numSubsamples: Int,
? ? ? withReplacement: Boolean,
? ? ? seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
? ? if (withReplacement) {
? ? ? convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
? ? } else {
? ? ? if (numSubsamples == 1 && subsamplingRate == 1.0) {
? ? ? ? convertToBaggedRDDWithoutSampling(input)
? ? ? } else {
? ? ? ? convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
? ? ? }
? ? }
? }
根據有放回還是無放回,或者不抽樣分別調用相應函數。無放回抽樣
def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
? ? ? input: RDD[Datum],
? ? ? subsamplingRate: Double,
? ? ? numSubsamples: Int,
? ? ? seed: Long): RDD[BaggedPoint[Datum]] = {
? ? //對每個partition獨立抽樣
? ? input.mapPartitionsWithIndex { (partitionIndex, instances) =>
? ? ? // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
? ? ? val rng = new XORShiftRandom
? ? ? rng.setSeed(seed + partitionIndex + 1)
? ? ? instances.map { instance =>
? ? ? //對每條樣本進行numSubsamples(實際是樹的個數)次抽樣,
? ? ? //一次將本條樣本在所有樹中是否會被抽取都獲得,犧牲空間減少訪問數據次數
? ? ? ? val subsampleWeights = new Array[Double](numSubsamples)
? ? ? ? var subsampleIndex = 0
? ? ? ? while (subsampleIndex < numSubsamples) {
? ? ? ? ? val x = rng.nextDouble()
? ? ? ? ? //無放回抽樣,只需要決定本樣本是否被抽取,被抽取就是1,沒有就是0
? ? ? ? ? subsampleWeights(subsampleIndex) = {
? ? ? ? ? ? if (x < subsamplingRate) 1.0 else 0.0
? ? ? ? ? }
? ? ? ? ? subsampleIndex += 1
? ? ? ? }
? ? ? ? new BaggedPoint(instance, subsampleWeights)
? ? ? }
? ? }
? }
有放回抽樣
def convertToBaggedRDDSamplingWithReplacement[Datum] (
? ? ? input: RDD[Datum],
? ? ? subsample: Double,
? ? ? numSubsamples: Int,
? ? ? seed: Long): RDD[BaggedPoint[Datum]] = {
? ? input.mapPartitionsWithIndex { (partitionIndex, instances) =>
? ? ? // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
? ? ? val poisson = new PoissonDistribution(subsample)
? ? ? poisson.reseedRandomGenerator(seed + partitionIndex + 1)
? ? ? instances.map { instance =>
? ? ? ? val subsampleWeights = new Array[Double](numSubsamples)
? ? ? ? var subsampleIndex = 0
? ? ? ? while (subsampleIndex < numSubsamples) {
? ? ? ? //與無放回抽樣對比,這里用泊松抽樣返回的是樣本被抽取的次數,
? ? ? ? //可能大于1,而無放回是0/1,也可認為是被抽取的次數
? ? ? ? ? subsampleWeights(subsampleIndex) = poisson.sample()
? ? ? ? ? subsampleIndex += 1
? ? ? ? }
? ? ? ? new BaggedPoint(instance, subsampleWeights)
? ? ? }
? ? }
? }
不抽樣,或者說抽樣率為1
def convertToBaggedRDDWithoutSampling[Datum] (
? ? ? input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
? ? input.map(datum => new BaggedPoint(datum, Array(1.0)))
? }
這里再啰嗦的羅列下BaggedPoint
class BaggedPoint[Datum](
? ? val datum: Datum,?
? ? val subsampleWeights: Array[Double])
datum是TreePoint,subsampleWeights是數組,維數等于numberTrees,每個值是樣本在每棵樹中被抽取的次數
至此,Random Forest的初始化工作已經完成
timer.stop("init")
?
6. 隨機森林訓練
6.1. 數據結構
6.1.1. Node
樹中的每個節點是一個Node結構
class Node @Since("1.2.0") (
? ? @Since("1.0.0") val id: Int,
? ? @Since("1.0.0") var predict: Predict,
? ? @Since("1.2.0") var impurity: Double,
? ? @Since("1.0.0") var isLeaf: Boolean,
? ? @Since("1.0.0") var split: Option[Split],
? ? @Since("1.0.0") var leftNode: Option[Node],
? ? @Since("1.0.0") var rightNode: Option[Node],
? ? @Since("1.0.0") var stats: Option[InformationGainStats])
emptyNode,只初始化nodeIndex,其他都是默認值
def emptyNode(nodeIndex: Int): Node =?
? ? new Node(nodeIndex, new Predict(Double.MinValue),
? ? -1.0, false, None, None, None, None)
根據node的id,計算孩子節點的id
? ?* Return the index of the left child of this node.
? ?*/
? def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
? /**
? ?* Return the index of the right child of this node.
? ?*/
? def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
左孩子節點就是當前id * 2,右孩子是id * 2+1。
6.1.2. Entropy
6.1.2.1. Entropy
Entropy是個Object,里面最重要的是calculate函數
/**
? ?* :: DeveloperApi ::
? ?* information calculation for multiclass classification
? ?* @param counts Array[Double] with counts for each label
? ?* @param totalCount sum of counts for all labels
? ?* @return information value, or 0 if totalCount = 0
? ?*/
? @Since("1.1.0")
? @DeveloperApi
? override def calculate(counts: Array[Double], totalCount: Double): Double = {
? ? if (totalCount == 0) {
? ? ? return 0
? ? }
? ? val numClasses = counts.length
? ? var impurity = 0.0
? ? var classIndex = 0
? ? while (classIndex < numClasses) {
? ? ? val classCount = counts(classIndex)
? ? ? if (classCount != 0) {
? ? ? ? val freq = classCount / totalCount
? ? ? ? impurity -= freq * log2(freq)
? ? ? }
? ? ? classIndex += 1
? ? }
? ? impurity
? }
熵的計算公式?
H=E[?logpi]=?∑i=1n?pilogpi
H=E[?logpi]=?∑i=1n?pilogpi
因此這里的入參count是各class的出現的次數,先計算出現概率,然后取log累加。
6.1.2.2. EntropyAggregator
class EntropyAggregator(numClasses: Int)
? extends ImpurityAggregator(numClasses)
只有一個成員變量class的個數,關鍵是update函數
/**
? ?* Update stats for one (node, feature, bin) with the given label.
? ?* @param allStats ?Flat stats array, with stats for this (node, feature, bin) contiguous.
? ?* @param offset ? ?Start index of stats for this (node, feature, bin).
? ?*/
? def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
? ? if (label >= statsSize) {
? ? ? throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
? ? ? ? s" but requires label < numClasses (= $statsSize).")
? ? }
? ? if (label < 0) {
? ? ? throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
? ? ? ? s"but requires label is non-negative.")
? ? }
? ? allStats(offset + label.toInt) += instanceWeight
? }
offset是特征值偏移,加上label就是該class在allStats里的位置,累加出現的次數
/**
? ?* Get an [[ImpurityCalculator]] for a (node, feature, bin).
? ?* @param allStats ?Flat stats array, with stats for this (node, feature, bin) contiguous.
? ?* @param offset ? ?Start index of stats for this (node, feature, bin).
? ?*/
? def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
? ? new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
? }
截取allStats中屬于該特征的split的部分數組,長度是statSize,也就是class數
6.1.2.3. EntropyCalculator
/**
? ?* Calculate the impurity from the stored sufficient statistics.
? ?*/
? def calculate(): Double = Entropy.calculate(stats, stats.sum)
結合上面的函數可以看到,計算entropy的路徑是調用Entropy的getCalculator函數,里面截取allStats中屬于該split的部分,然后實際調用Entropy的calculate函數計算熵。?
這里還重載了prob函數,主要是返回label的概率,例如0的統計有3個,1的統計7個,則label 0的概率就是0.3.
6.1.3. DTStatsAggregator
這里啰嗦下node分裂時需要怎樣統計,這與DTStatsAggregator的設計是相關的。以使用信息熵為例,node分裂時,迭代每個特征的每個split,這個split會把樣本集分成兩部分,要計算entropy,需要分別統計左/右部分class的分布情況,然后計算概率,進而計算entropy,因此aggregator中statsSize等于numberclasses,同時allStats里記錄了所有的統計值,實際這個統計值就是class的分布情況
class DTStatsAggregator(
? ? val metadata: DecisionTreeMetadata,
? ? featureSubset: Option[Array[Int]]) extends Serializable {
? /**
? ?* [[ImpurityAggregator]] instance specifying the impurity type.
? ?*/
? val impurityAggregator: ImpurityAggregator = metadata.impurity match {
? ? case Gini => new GiniAggregator(metadata.numClasses)
? ? case Entropy => new EntropyAggregator(metadata.numClasses)
? ? case Variance => new VarianceAggregator()
? ? case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
? }
? /**
? ?* Number of elements (Double values) used for the sufficient statistics of each bin.
? ?*/
? private val statsSize: Int = impurityAggregator.statsSize
? /**
? ?* Number of bins for each feature. ?This is indexed by the feature index.
? ?*/
? private val numBins: Array[Int] = {
? ? if (featureSubset.isDefined) {
? ? ? featureSubset.get.map(metadata.numBins(_))
? ? } else {
? ? ? metadata.numBins
? ? }
? }
? /**
? ?* Offset for each feature for calculating indices into the [[allStats]] array.
? ?*/
? private val featureOffsets: Array[Int] = {
? ? numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
? }
? /**
? ?* Total number of elements stored in this aggregator
? ?*/
? private val allStatsSize: Int = featureOffsets.last
? /**
? ?* Flat array of elements.
? ?* Index for start of stats for a (feature, bin) is:
? ?* ? index = featureOffsets(featureIndex) + binIndex * statsSize
? ?* Note: For unordered features,
? ?* ? ? ? the left child stats have binIndex in [0, numBins(featureIndex) / 2))
? ?* ? ? ? and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
? ?*/
? private val allStats: Array[Double] = new Array[Double](allStatsSize)
每個node有一個DTStatsAggregator,構造函數接受2個參數,metadata和node使用的特征子集。其他的類成員?
- impurityAggregator:目前支持Gini,Entropy和Variance,后面我們以Entropy為例,其他類似?
- statsSize:每個bin需要的統計數,分類時等于numClasses,因為于每個class都需要單獨統計;回歸等于3,分別存著特征值個數,特征值sum,特征值平方和,為計算variance?
- numBins:node所用特征對應的numBins數組元素?
- featureOffsets:計算特征在allStats中的index,與每個特征的bin個數和statsSize有關,例如我們有3個特征,其bins分別為3,2,2,statsSize為2,則第一個特征需要的bin的個數是3 * 2=6,2 * 2=4,2 * 2=4,則featureOffsets為0,6,10,14,是從左到右的累計值?
- allStatsSize:需要的桶的個數?
- allStats:存儲統計值的桶?
f0,f1,f2是3個特征,f0有3個特征值(其實是binIndex)0/1/2,f1有2個0/1,f2有2個0/1,每個特征值都有statsSize個狀態桶,因此共14個,個數allStatsSize=14, 比如我們想在f1的v1的c1的index,就是從featureOffsets中取得f1的特征偏移量featureOffsets(1)=6,v1的binIndex相當于是1,statsSize是2,其label是1,則桶的index=6+1*2+1=9,恰好是圖中f1v1的c1的桶的index
我們對其中的關鍵函數進行說明
/**
? ?* Update the stats for a given (feature, bin) for ordered features, using the given label.
? ?*/
? def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
? //第一部分是特征偏移
? //binIndex相當于特征內特征值的偏移,每個特征有statsSize個桶,因此兩者相加就是這個特征值對應的桶
? //例如Entropy的update函數,里面再加上label.toInt就是這個label的桶
? //從這里特征偏移的計算可以看出ordered特征其特征值最好是連續的,中間無間斷,并且必須從0開始
? //當然如果有間斷,這里相當于浪費部分空間
? ? val i = featureOffsets(featureIndex) + binIndex * statsSize
? ? impurityAggregator.update(allStats, i, label, instanceWeight)
? }
? /**
? ?* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
? ?* @param featureOffset ?For ordered features, this is a pre-computed (node, feature) offset
? ?* ? ? ? ? ? ? ? ? ? ? ? ? ? from [[getFeatureOffset]].
? ?* ? ? ? ? ? ? ? ? ? ? ? ? ? For unordered features, this is a pre-computed
? ?* ? ? ? ? ? ? ? ? ? ? ? ? ? (node, feature, left/right child) offset from
? ?* ? ? ? ? ? ? ? ? ? ? ? ? ? [[getLeftRightFeatureOffsets]].
? ?*/
? def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
? //偏移的計算同上,不過這里特征偏移是入參給出的,不需要再計算
? ? impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
? }
6.2. 訓練初始化
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
? ? Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
構造了numTrees個Node,賦默認值emptyNode,這些node將作為每棵樹的root node,參與后面的訓練。將這些node與treeIndex封裝加入到隊列nodeQueue中,后面會將所有待split的node都加入到這個隊列中,依次split,直到所有node觸發截止條件,也就是后面的while循環中隊列為空了。
6.3. 選擇待分裂node
這部分邏輯在selectNodesToSplit中,主要是從nodeQueue中取出本輪需要分裂的node,并計算node的參數。
/**
? ?* Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
? ?* This tracks the memory usage for aggregates and stops adding nodes when too much memory
? ?* will be needed; this allows an adaptive number of nodes since different nodes may require
? ?* different amounts of memory (if featureSubsetStrategy is not "all").
? ?*
? ?* @param nodeQueue ?Queue of nodes to split.
? ?* @param maxMemoryUsage ?Bound on size of aggregate statistics.
? ?* @return ?(nodesForGroup, treeToNodeToIndexInfo).
? ?* ? ? ? ? ?nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
? ?*
? ?* ? ? ? ? ?treeToNodeToIndexInfo holds indices selected features for each node:
? ?* ? ? ? ? ? ?treeIndex --> (global) node index --> (node index in group, feature indices).
? ?* ? ? ? ? ?The (global) node index is the index in the tree; the node index in group is the
? ?* ? ? ? ? ? index in [0, numNodesInGroup) of the node in this group.
? ?* ? ? ? ? ?The feature indices are None if not subsampling features.
? ?*/
? private[tree] def selectNodesToSplit(
? ? ? nodeQueue: mutable.Queue[(Int, Node)],
? ? ? maxMemoryUsage: Long,
? ? ? metadata: DecisionTreeMetadata,
? ? ? rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
? ? // Collect some nodes to split:
? ? // ?nodesForGroup(treeIndex) = nodes to split
? ? val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
? ? val mutableTreeToNodeToIndexInfo =
? ? ? new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
? ? var memUsage: Long = 0L
? ? var numNodesInGroup = 0
? ? while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
? ? ? val (treeIndex, node) = nodeQueue.head
? ? ? //用蓄水池抽樣(之前的文章有介紹)對node使用的特征集抽樣
? ? ? // Choose subset of features for node (if subsampling).
? ? ? val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
? ? ? ? Some(SamplingUtils.reservoirSampleAndCount(Range(0,
? ? ? ? ? metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
? ? ? } else {
? ? ? ? None
? ? ? }
? ? ? // Check if enough memory remains to add this node to the group.
? ? ? val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
? ? ? if (memUsage + nodeMemUsage <= maxMemoryUsage) {
? ? ? ? nodeQueue.dequeue()
? ? ? ? mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
? ? ? ? mutableTreeToNodeToIndexInfo
? ? ? ? ? .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
? ? ? ? ? = new NodeIndexInfo(numNodesInGroup, featureSubset)
? ? ? }
? ? ? numNodesInGroup += 1
? ? ? memUsage += nodeMemUsage
? ? }
? ? // Convert mutable maps to immutable ones.
? ? val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
? ? val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
? ? (nodesForGroup, treeToNodeToIndexInfo)
? }
代碼比較簡單明確,受限于內存,將本次能夠處理的node從nodeQueue中取出,放入nodesForGroup和treeToNodeToIndexInfo中。?
是否對特征集進行抽樣的條件是metadata的 numFeatures是否等于numFeaturesPerNode,這兩個參數是metadata的入參,在buildMetadata時,根據featureSubsetStrateg確定,參見前文。?
nodesForGroup是Map[Int, Array[Node]],其key是treeIndex,value是Node數組,其中放著該tree本次要分裂的node。?
treeToNodeToIndexInfo的類型是Map[Int, Map[Int, NodeIndexInfo]],key為treeIndex,value中Map的key是node.id,這個id來自Node初始化時的第一個參數,第一輪時node的id都是1。其value為NodeIndexInfo結構,
class NodeIndexInfo(
? ? ? val nodeIndexInGroup: Int,
? ? ? val featureSubset: Option[Array[Int]])
第一個成員是此node在本次node選擇的while循環中的index,稱為groupIndex,第二個成員是特征子集。
?
6.4. node分裂
邏輯主要在DecisionTree.findBestSplits函數中,是RF訓練最核心的部分
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
? ? ? ? treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
6.4.1. 數據統計
數據統計分成兩部分,先在各個partition上分別統計,再累積各partition成全局統計。
6.4.1.1. 取出node的特征子集
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
取出各node的特征子集,如果不需要抽樣則為None;否則返回Map[Int, Array[Int]],其實就是將之前treeToNodeToIndexInfo中的NodeIndexInfo轉換為map結構,將其作為廣播變量nodeToFeaturesBc。
6.4.1.2. 分區統計
一系列函數的調用鏈,我們逐層分析
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
? ? ? input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
? ? ? ? // Construct a nodeStatsAggregators array to hold node aggregate stats,
? ? ? ? // each node will have a nodeStatsAggregator
? ? ? ? val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
? ? ? ? ? val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
? ? ? ? ? ? Some(nodeToFeatures(nodeIndex))
? ? ? ? ? }
? ? ? ? ? new DTStatsAggregator(metadata, featuresForNode)
? ? ? ? }
? ? ? ? // iterator all instances in current partition and update aggregate stats
? ? ? ? points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
? ? ? ? // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
? ? ? ? // which can be combined with other partition using `reduceByKey`
? ? ? ? nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
? ? ? }
? ? } else {
? ? ? input.mapPartitions { points =>
? ? ? ? // Construct a nodeStatsAggregators array to hold node aggregate stats,
? ? ? ? // each node will have a nodeStatsAggregator
? ? ? ? val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
? ? ? ? ? val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
? ? ? ? ? ? Some(nodeToFeatures(nodeIndex))
? ? ? ? ? }
? ? ? ? ? new DTStatsAggregator(metadata, featuresForNode)
? ? ? ? }
? ? ? ? // iterator all instances in current partition and update aggregate stats
? ? ? ? points.foreach(binSeqOp(nodeStatsAggregators, _))
? ? ? ? // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
? ? ? ? // which can be combined with other partition using `reduceByKey`
? ? ? ? nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
? ? ? }
? ? }
首先對每個partition構造一個DTStatsAggregator數組,長度是node的個數,注意這里實際使用的是數組,node怎樣與自己的aggregator的對應?前面我們提到NodeIndexInfo的第一個成員是groupIndex,其值就是node的次序,和這里aggregator數組index其實是對應的,也就是說可以從NodeIndexInfo中取得groupIndex,然后作為數組index取得對應node的agg。DTStatsAggregator的入參是metadata和每個node的特征子集。然后將每個點統計到DTStatsAggregator中,其中調用了內部函數binSeqOp,
?/**
? ? ?* Performs a sequential aggregation over a partition.
? ? ?*
? ? ?* Each data point contributes to one node. For each feature,
? ? ?* the aggregate sufficient statistics are updated for the relevant bins.
? ? ?*
? ? ?* @param agg ?Array storing aggregate calculation, with a set of sufficient statistics for
? ? ?* ? ? ? ? ? ? each (node, feature, bin).
? ? ?* @param baggedPoint ? Data point being aggregated.
? ? ?* @return ?agg
? ? ?*/
? ? def binSeqOp(
? ? ? ? agg: Array[DTStatsAggregator],
? ? ? ? baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
? ? //對每個node
? ? ? treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
? ? ? ? val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
? ? ? ? ? bins, metadata.unorderedFeatures)
? ? ? ? nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
? ? ? }
? ? ? agg
? ? }
首先調用函數predictNodeIndex計算nodeIndex,如果是首輪或者葉子節點,直接返回node.id;如果不是首輪,因為傳入的是每棵樹的root node,就從root node開始,逐漸往下判斷該point應該是屬于哪個node的,因為我們已經對node進行了分裂,這里其實實現了樣本的劃分。舉個栗子,當前node如果是root的左孩子節點,而point預測節點應該屬于右孩子,則調用nodeBinSepOp時就直接返回了,不會將這個point統計進去,用不大的時間換取樣本集劃分的空間,還是比較巧妙的。
/**
? ?* Get the node index corresponding to this data point.
? ?* This function mimics prediction, passing an example from the root node down to a leaf
? ?* or unsplit node; that node's index is returned.
? ?*
? ?* @param node ?Node in tree from which to classify the given data point.
? ?* @param binnedFeatures ?Binned feature vector for data point.
? ?* @param bins possible bins for all features, indexed (numFeatures)(numBins)
? ?* @param unorderedFeatures ?Set of indices of unordered features.
? ?* @return ?Leaf index if the data point reaches a leaf.
? ?* ? ? ? ? ?Otherwise, last node reachable in tree matching this example.
? ?* ? ? ? ? ?Note: This is the global node index, i.e., the index used in the tree.
? ?* ? ? ? ? ? ? ? ?This index is different from the index used during training a particular
? ?* ? ? ? ? ? ? ? ?group of nodes on one call to [[findBestSplits()]].
? ?*/
? private def predictNodeIndex(
? ? ? node: Node,
? ? ? binnedFeatures: Array[Int],
? ? ? bins: Array[Array[Bin]],
? ? ? unorderedFeatures: Set[Int]): Int = {
? ? if (node.isLeaf || node.split.isEmpty) {
? ? ? // Node is either leaf, or has not yet been split.
? ? ? node.id
? ? } else {
? ? //判斷point屬于當前node的左孩子還是右孩子
? ? ? val featureIndex = node.split.get.feature
? ? ? val splitLeft = node.split.get.featureType match {
? ? ? ? case Continuous => {
? ? ? ? ? val binIndex = binnedFeatures(featureIndex)
? ? ? ? ? val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
? ? ? ? ? // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]
? ? ? ? ? // We do not need to check lowSplit since bins are separated by splits.
? ? ? ? ? featureValueUpperBound <= node.split.get.threshold
? ? ? ? }
? ? ? ? case Categorical => {
? ? ? ? ? val featureValue = binnedFeatures(featureIndex)
? ? ? ? ? node.split.get.categories.contains(featureValue)
? ? ? ? }
? ? ? ? case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")
? ? ? }
? ? ? if (node.leftNode.isEmpty || node.rightNode.isEmpty) {
? ? ? //下面還有完整的左右孩子node,遞歸判斷
? ? ? ? // Return index from next layer of nodes to train
? ? ? ? if (splitLeft) {
? ? ? ? ? Node.leftChildIndex(node.id)
? ? ? ? } else {
? ? ? ? ? Node.rightChildIndex(node.id)
? ? ? ? }
? ? ? } else {
? ? ? ? if (splitLeft) {
? ? ? ? ? predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)
? ? ? ? } else {
? ? ? ? ? predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)
? ? ? ? }
? ? ? }
? ? }
? }
然后調用nodeBinSeqOp函數
/**
? ? ?* Performs a sequential aggregation over a partition for a particular tree and node.
? ? ?*
? ? ?* For each feature, the aggregate sufficient statistics are updated for the relevant
? ? ?* bins.
? ? ?*
? ? ?* @param treeIndex Index of the tree that we want to perform aggregation for.
? ? ?* @param nodeInfo The node info for the tree node.
? ? ?* @param agg Array storing aggregate calculation, with a set of sufficient statistics
? ? ?* ? ? ? ? ? ?for each (node, feature, bin).
? ? ?* @param baggedPoint Data point being aggregated.
? ? ?*/
? ? def nodeBinSeqOp(
? ? ? ? treeIndex: Int,
? ? ? ? nodeInfo: RandomForest.NodeIndexInfo,
? ? ? ? agg: Array[DTStatsAggregator],
? ? ? ? baggedPoint: BaggedPoint[TreePoint]): Unit = {
? ? ? if (nodeInfo != null) {
? ? ? //node的groupIndex,見前文
? ? ? ? val aggNodeIndex = nodeInfo.nodeIndexInGroup
? ? ? ? //node使用的特征子集
? ? ? ? val featuresForNode = nodeInfo.featureSubset
? ? ? ? //取樣本在這棵樹中出現的次數 0/1/k
? ? ? ? val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
? ? ? ? if (metadata.unorderedFeatures.isEmpty) {
? ? ? ? ? orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
? ? ? ? } else {
? ? ? ? ? mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
? ? ? ? ? ? metadata.unorderedFeatures, instanceWeight, featuresForNode)
? ? ? ? }
? ? ? }
? ? }
函數的入參是treeIndex,該node的NodeIndexInfo結構,所有node的累加器數組,樣本。本函數是針對單個node的操作,這里可以看到取node對應的aggregator就是通過NodeIndexInfo的第一個成員nodeIndexInGroup作為agg數組的index。?
如果不包含無序特征,調用orderedBinSeqOp函數
?/**
? ?* Helper for binSeqOp, for regression and for classification with only ordered features.
? ?*
? ?* For each feature, the sufficient statistics of one bin are updated.
? ?*
? ?* @param agg ?Array storing aggregate calculation, with a set of sufficient statistics for
? ?* ? ? ? ? ? ? each (feature, bin).
? ?* @param treePoint ?Data point being aggregated.
? ?* @param instanceWeight ?Weight (importance) of instance in dataset.
? ?*/
? private def orderedBinSeqOp(
? ? ? agg: DTStatsAggregator, //node的agg
? ? ? treePoint: TreePoint,
? ? ? instanceWeight: Double,
? ? ? featuresForNode: Option[Array[Int]]): Unit = {
? ? val label = treePoint.label
? ? // Iterate over features.
? ? if (featuresForNode.nonEmpty) {
? ? ? // Use subsampled features
? ? ? var featureIndexIdx = 0
? ? ? while (featureIndexIdx < featuresForNode.get.size) {
? ? ? //連續特征:離散化后的index
? ? ? //離散特征:featureValue.toInt
? ? ? ? val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
? ? ? ? agg.update(featureIndexIdx, binIndex, label, instanceWeight)
? ? ? ? featureIndexIdx += 1
? ? ? }
? ? } else {
? ? ? // Use all features
? ? ? val numFeatures = agg.metadata.numFeatures
? ? ? var featureIndex = 0
? ? ? while (featureIndex < numFeatures) {
? ? ? ? val binIndex = treePoint.binnedFeatures(featureIndex)
? ? ? ? agg.update(featureIndex, binIndex, label, instanceWeight)
? ? ? ? featureIndex += 1
? ? ? }
? ? }
? }
函數中區分了是否使用了全部特征,區別僅在于如果使用了部分特征(特征抽樣),需要先在featuresForNode中取得特征的實際index。?
函數其實就是取出樣本的使用特征,特征值,label和weight,更新到aggregator中,更新邏輯我們在前文已經說明過了。?
包含了無序離散特征,則使用mixedBinSeqOp,只有無序離散特征處理方法不同于orderedBinSeqOp函數
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
//找到特征值對應的allStats中的范圍
//特征起始位置從featureOffsets中取得,長度是bins的個數乘以分類個數,2*(2^(M-1)-1)*statsSize,
//每一個split將樣本集分成2部分,allStats中左邊部分連續存放,右半部分連續存放,而不是左右一起存放。
//因此,左邊的起始位置直接可以從featureOffsets中獲取,右邊起始位置是(2^(M-1)-1)*statsSize
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
agg.getLeftRightFeatureOffsets(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
? ? //split中的categories中包含左半邊特征值組合,splitIndex相當于其離散化后的特征index
? ? if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
? ? agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
? ? ? ? ? ? ? instanceWeight)
? ? } else {
? ? ? ? agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
? ? ? ? ? ? ? instanceWeight)
? ? }
? ? ? ? ? splitIndex += 1
}
6.4.1.3. 全局統計
partitionAggregates.reduceByKey((a, b) => a.merge(b))
1
就是將所有存在allStats中的分區統計結果逐個對應相加得到全局統計結果。
6.4.2. bestSplits
獲得所有的統計后,就可以遍歷所有的特征,計算impurity gain,確定最佳的split。
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
? ? .map { case (nodeIndex, aggStats) =>
? ? ? ? ?val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
? ? ? ? nodeToFeatures(nodeIndex)
? ? ? ? ?}
? ? ? ? ?// find best split for each node
? ? ? ? val (split: Split, stats: InformationGainStats, predict: Predict) =
? ? ? ? ? ? binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
? ? ? ? ? (nodeIndex, (split, stats, predict))
? ? ? ? }.collectAsMap()
對每個node其中調用了binsToBestSplit函數,下面進行詳細說明。
6.4.2.1 init
函數首先獲取node在樹的第幾層,樹結構如圖?
?
樹的id如圖所示,判斷node在第幾層只需要判斷id的二進制表示的最高位的1在第幾位即可,比如6的二進制表示是110,最高位的1是在第3位,則其在第3層。?
然后獲取當前node的預測值和impurity
// calculate predict and impurity if current node is top node
? ? val level = Node.indexToLevel(node.id)
? ? var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
? ? ? None
? ? } else {
? ? ? Some((node.predict, node.impurity))
? ? }
6.4.2.2 連續特征
對于連續特征而言,當取其某個特征值為best split后,node的樣本會被分成大于該特征值和小于等于該特征值兩部分,需要分別統計兩部分的class分布情況;另一方面,我們查找best,因此要遍歷所有特征值的情況,一種巧妙的方法是,從左邊開始逐次累積統計數據,需要從某個特征值作為split時,當前累計值就是左邊小于等于的情況,用最右的值減去左邊就是右邊的情況。?
例如上圖中的情況,是某特征6個特征值分布情況,第一行是左累計,第二行是原始分布,當以v2作為split時,左邊的分布就是c0:8,c1:5,右邊是v6的分布減去v2,c0:19-8=11,c1:14-5=9。
if (binAggregates.metadata.isContinuous(featureIndex)) {
? ? ? ? // Cumulative sum (scanLeft) of bin statistics.
? ? ? ? // Afterwards, binAggregates for a bin is the sum of aggregates for
? ? ? ? // that bin + all preceding bins.
? ? ? ? //如上所述,累計
? ? ? ? val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
? ? ? ? var splitIndex = 0
? ? ? ? while (splitIndex < numSplits) {
? ? ? ? ? ? binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
? ? ? ? ? ? splitIndex += 1
? ? ? ? }
? ? ? ? // Find best split.
? ? ? ? val (bestFeatureSplitIndex, bestFeatureGainStats) =
? ? ? ? ? ? Range(0, numSplits).map { case splitIdx =>
? ? ? ? ? ? ? val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
? ? ? ? ? ? ? val rightChildStats =
? ? ? ? ? ? ? ? binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
? ? ? ? ? ? ? rightChildStats.subtract(leftChildStats)
? ? ? ? ? ? ? //獲得node的impurity,level==0時,需要根據當前class的分布計算
? ? ? ? ? ? ? predictWithImpurity = Some(predictWithImpurity.getOrElse(
? ? ? ? ? ? ? ? calculatePredictImpurity(leftChildStats, rightChildStats)))
? ? ? ? ? ? ? val gainStats = calculateGainForSplit(leftChildStats,
? ? ? ? ? ? ? ? rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
? ? ? ? ? ? ? (splitIdx, gainStats)
? ? ? ? ? ? }.maxBy(_._2.gain)
? ? ? ? (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
計算split分裂node的impurity增益時,調用了calculateGainForSplit函數,其中分別計算了左右的增益,然后概率合并,并計算了左右的預測值,代碼比較簡單,這里不再贅述。
6.4.2.3. Unordered categorical feature
只有獲取左右class的統計情況方法不一致,其他是一樣的。
6.4.2.4. Ordered categorical feature
對于連續特征,特征值或者是binIndex是有序的,或者說其數值可以排序,因此如果某個特征值被當做split,分隔的就是左右兩部分;對于無序離散特征,其被split分隔后特征值屬于哪個bin是確定的;對于有序離散特征,其特征值代表一定次序關系,但是不具有絕對大小的含義,其處理方法可以近似按照連續特征的方法處理,但是spark這里處理了下,可能更優點。?
spark首先會確定一個centroid,然后特征會按這個排序,這個相當于連續特征的binIndex。例如centroid如果取每個特征值中class1的個數,假設有特征值0,1,2,3,class1的個數分別為4,2,1,3,其中如果按照連續特征的處理方法,假設用1作為node的分裂點,計算impurity gain的時候分成0,1和2,3兩部分統計。如果按照centroid的方法,其特征值排序次序應該是2,1,3,0,以1作為分裂點,會被分成2,1和3,0兩部分。
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
* splits are considered. ?(With K categories, we consider K - 1 possible splits.)
*
* centroidForCategories is a list: (category, centroid)
*/
val centroidForCategories = Range(0, numBins).map { case featureValue =>
? ? val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
? ? val centroid = if (categoryStats.count != 0) {
? ? ? ? if (binAggregates.metadata.isMulticlass) {
? ? ? ? // For categorical variables in multiclass classification,
? ? ? ? // the bins are ordered by the impurity of their corresponding labels.
? ? ? ? ? ? categoryStats.calculate()
? ? ? ? } else if (binAggregates.metadata.isClassification) {
? ? ? ? // For categorical variables in binary classification,
? ? ? ? // the bins are ordered by the count of class 1.
? ? ? ? ? ? categoryStats.stats(1)
? ? ? ? } else {
? ? ? ? ? ? // For categorical variables in regression,
? ? ? ? ? ? // the bins are ordered by the prediction.
? ? ? ? ? ? categoryStats.predict
? ? ? ? }
? ? } else {
? ? ? ? Double.MaxValue
? ? }
? ? (featureValue, centroid)
}
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
// bins sorted by centroids
val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
上面的代碼為不同的情況設置不同的centroid的選取方法。如果是多分類,使用impurity;如果是二分類,使用class1的個數;如果是回歸,使用預測值(實際是均值)。然后將特征值按centroid重排序。?
下面的處理基本與連續特征類似,先按排序次序累計,然后計算左右的impurity,計算impurity gain。由于要返回split,之前離散特征的split返回的是空Array,這里構造了split,第四個參數中加入了實際的特征值,類比unordered的情況。
計算完完所有的特征的gain,就可以選取最大增益時的split,最后collectAsMap,key是nodeIndex,value是split, InfomationGainStats,predict的三元組。
6.4.3. node分裂
計算完節點的best split,就要根據這個split進行node的分裂,包括當前節點的一些屬性完善,左右孩子節點的構造等。
// Iterate over all nodes in this group.
? ? nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
? ? ? nodesForTree.foreach { node =>
? ? ? ? val nodeIndex = node.id
? ? ? ? val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
? ? ? ? val aggNodeIndex = nodeInfo.nodeIndexInGroup
? ? ? ? //從剛剛計算的best split中獲取相關數據
? ? ? ? val (split: Split, stats: InformationGainStats, predict: Predict) =
? ? ? ? ? nodeToBestSplits(aggNodeIndex)
? ? ? ? logDebug("best split = " + split)
? ? ? ? // Extract info for this node. ?Create children if not leaf.
? ? ? ? //截止條件
? ? ? ? val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
? ? ? ? assert(node.id == nodeIndex)
? ? ? ? node.predict = predict
? ? ? ? node.isLeaf = isLeaf
? ? ? ? node.stats = Some(stats)
? ? ? ? node.impurity = stats.impurity
? ? ? ? logDebug("Node = " + node)
? ? ? ? //如果不是葉子節點,需要構造左右孩子節點
? ? ? ? if (!isLeaf) {
? ? ? ? ? node.split = Some(split)
? ? ? ? ? //葉子節點的depth,當前level+1
? ? ? ? ? val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
? ? ? ? ? //左右孩子節點是否是葉子節點
? ? ? ? ? val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
? ? ? ? ? val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
? ? ? ? ? //構造左右孩子節點
? ? ? ? ? node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
? ? ? ? ? ? stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
? ? ? ? ? node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
? ? ? ? ? ? stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
? ? ? ? ? if (nodeIdCache.nonEmpty) {
? ? ? ? ? ? val nodeIndexUpdater = NodeIndexUpdater(
? ? ? ? ? ? ? split = split,
? ? ? ? ? ? ? nodeIndex = nodeIndex)
? ? ? ? ? ? nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
? ? ? ? ? }
? ? ? ? //如果不是葉子節點,加入到nodeQueue待分裂隊列中
? ? ? ? ? // enqueue left child and right child if they are not leaves
? ? ? ? ? if (!leftChildIsLeaf) {
? ? ? ? ? ? nodeQueue.enqueue((treeIndex, node.leftNode.get))
? ? ? ? ? }
? ? ? ? ? if (!rightChildIsLeaf) {
? ? ? ? ? ? nodeQueue.enqueue((treeIndex, node.rightNode.get))
? ? ? ? ? }
? ? ? ? ? logDebug("leftChildIndex = " + node.leftNode.get.id +
? ? ? ? ? ? ", impurity = " + stats.leftImpurity)
? ? ? ? ? logDebug("rightChildIndex = " + node.rightNode.get.id +
? ? ? ? ? ? ", impurity = " + stats.rightImpurity)
? ? ? ? }
? ? ? }
? ? }
這里將當前節點的左右孩子節點繼續加入nodeQueue中,這里面放的是需要繼續分裂的節點,至此本輪的findBestSplits就完成了。
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput,
? ? metadata, topNodes, nodesForGroup,
? ? treeToNodeToIndexInfo, splits, bins, nodeQueue,
? ? timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
6.5. 循環訓練
上節我們說到最后待分裂的節點會加入到nodeQueue中,回到RandomForest.run函數中
while (nodeQueue.nonEmpty) {
? ? ? // Collect some nodes to split, and choose features for each node (if subsampling).
? ? ? // Each group of nodes may come from one or multiple trees, and at multiple levels.
? ? ? val (nodesForGroup, treeToNodeToIndexInfo) =
? ? ? ? RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
? ? ? // Sanity check (should never occur):
? ? ? assert(nodesForGroup.size > 0,
? ? ? ? s"RandomForest selected empty nodesForGroup. ?Error for unknown reason.")
? ? ? // Choose node splits, and enqueue new nodes as needed.
? ? ? timer.start("findBestSplits")
? ? ? DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
? ? ? ? treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
? ? ? timer.stop("findBestSplits")
? ? }
當有非葉子節點不斷加入nodeQueue中,這里不斷分裂出節點,直到所有節點觸發截止條件。
?
7. 構造隨機森林
在上面的訓練過程可以看到,從根節點topNode中不斷向下分裂一直到觸發截止條件就構造了一棵樹所有的node,因此構造整個森林也是非常簡單
//構造
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
//返回rf模型
new RandomForestModel(strategy.algo, trees)
8. 隨機森林模型
8.1. TreeEnsembleModel
隨機森林RandomForestModel繼承自樹集合模型TreeEnsembleModel
class TreeEnsembleModel(
? ? protected val algo: Algo,
? ? protected val trees: Array[DecisionTreeModel],
? ? protected val treeWeights: Array[Double],
? ? protected val combiningStrategy: EnsembleCombiningStrategy)
algo:Regression/Classification
trees:樹數組
treeWeights:每棵樹的權重,在RF中每棵樹的權重是相同的,在Adaboost可能是不同的
combiningStrategy:樹合并時的策略,Sum/Average/Vote,分類的話應該是Vote,RF應該是Average,GBDT應該是Sum。
sumWeights:成員變量,不在參數表中,是treeWeights的sum
預測函數
/**
? ?* Predicts for a single data point using the weighted sum of ensemble predictions.
? ?*
? ?* @param features array representing a single data point
? ?* @return predicted category from the trained model
? ?*/
? private def predictBySumming(features: Vector): Double = {
? ? val treePredictions = trees.map(_.predict(features))
? ? blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
? }
將每棵樹的預測結果與各自的weight向量相乘
/**
? ?* Classifies a single data point based on (weighted) majority votes.
? ?*/
? private def predictByVoting(features: Vector): Double = {
? ? val votes = mutable.Map.empty[Int, Double]
? ? trees.view.zip(treeWeights).foreach { case (tree, weight) =>
? ? ? val prediction = tree.predict(features).toInt
? ? ? votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
? ? }
? ? votes.maxBy(_._2)._1
? }
將每棵樹的預測class為key,將樹的weight累加到Map中作為value,最后取權重和最大對應的class
8.2. RandomForestModel
RandomForestModel @Since("1.2.0") (
? ? @Since("1.2.0") override val algo: Algo,
? ? @Since("1.2.0") override val trees: Array[DecisionTreeModel])
? extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
? ? combiningStrategy = if (algo == Classification) Vote else Average)
對于隨機森林,其weight都是1,樹合并策略如果是分類就是Vote,回歸是Average。?
模型生成后,如果要應用到線上,需要將訓練后的模型保存下來,自己寫代碼解析模型文件,進行預測,因此要了解模型的保存和加載。
8.2.1. 模型保存
分為兩部分,第一部分是metadata,保存了一些配置,包括模型名,模型版本,模型的algo是classification/regression,合并策略,每棵樹的權重。
implicit val format = DefaultFormats
val ensembleMetadata = Metadata(model.algo.toString,
? ? model.trees(0).algo.toString,
? ? model.combiningStrategy.toString,?
? ? model.treeWeights)
val metadata = compact(render(
? ? ("class" -> className) ~ ("version" -> thisFormatVersion) ~
? ? ("metadata" -> Extraction.decompose(ensembleMetadata))))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
第二部分是隨機森林的每棵樹的保存
// Create Parquet data.
val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
? ? tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
}.toDF()
dataRDD.write.parquet(Loader.dataPath(path))
其中首先調用node的subtreeIterator函數,返回所有node的Iterator,然后轉成DataFrame結構,再寫成parquet格式的文件。我們來看subtreeIterator函數
/** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */
? private[tree] def subtreeIterator: Iterator[Node] = {
? ? Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
? ? ? rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
? }
其實就是用前序遍歷的方式返回了樹中的所有node的Iterrator。?
我們再來看NodeData,看看每個node保存了什么數據
def apply(treeId: Int, n: Node): NodeData = {
? ? NodeData(treeId, n.id, PredictData(n.predict), n.impurity,
? ? n.isLeaf, n.split.map(SplitData.apply), n.leftNode.map(_.id),?
? ? n.rightNode.map(_.id), n.stats.map(_.gain))
}
保存了node的預測值,impurity,是否是否葉子節點,split,左右孩子節點的id,gain。其中split中包含了特征id,特征閾值,特征類型,離散特征數組(其實就是Split結構)。
8.2.2. 模型加載
metadata的加載就是解析json,主要是樹的重建
val trees = TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc,?
? ? path, metadata.treeAlgo)
new RandomForestModel(Algo.fromString(metadata.algo), trees)
其中調用了loadTrees函數
/**
?* Load trees for an ensemble, and return them in order.
?* @param path path to load the model from
?* @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's
?* ? ? ? ? ? ? ? ? algorithm).
?*/
def loadTrees(
? ? ? ? sc: SparkContext,
? ? ? ? path: String,
? ? ? ? treeAlgo: String): Array[DecisionTreeModel] = {
? ? val datapath = Loader.dataPath(path)
? ? val sqlContext = SQLContext.getOrCreate(sc)
? ? val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply)
? ? val trees = constructTrees(nodes)
? ? trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
}
先是讀取數據文件,讀成NodeData格式,然后調用constructTrees重建樹結構
? ? def constructTrees(nodes: RDD[NodeData]): Array[Node] = {
? ? ? val trees = nodes
? ? ? ? .groupBy(_.treeId)
? ? ? ? .mapValues(_.toArray)
? ? ? ? .collect()
? ? ? ? .map { case (treeId, data) =>
? ? ? ? ? (treeId, constructTree(data))
? ? ? ? }.sortBy(_._1)
? ? ? val numTrees = trees.size
? ? ? val treeIndices = trees.map(_._1).toSeq
? ? ? assert(treeIndices == (0 until numTrees),
? ? ? ? s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")
? ? ? trees.map(_._2)
? ? }
主要功能按樹的id分組后,調用constructTree重建樹
? ? /**
? ? ?* Given a list of nodes from a tree, construct the tree.
? ? ?* @param data array of all node data in a tree.
? ? ?*/
? ? def constructTree(data: Array[NodeData]): Node = {
? ? ? val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap
? ? ? assert(dataMap.contains(1),
? ? ? ? s"DecisionTree missing root node (id = 1).")
? ? ? constructNode(1, dataMap, mutable.Map.empty)
? ? }
? ? /**
? ? ?* Builds a node from the node data map and adds new nodes to the input nodes map.
? ? ?*/
? ? private def constructNode(
? ? ? id: Int,
? ? ? dataMap: Map[Int, NodeData],
? ? ? nodes: mutable.Map[Int, Node]): Node = {
? ? ? if (nodes.contains(id)) {
? ? ? ? return nodes(id)
? ? ? }
? ? ? val data = dataMap(id)
? ? ? val node =
? ? ? ? if (data.isLeaf) {
? ? ? ? ? Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
? ? ? ? } else {
? ? ? ? ? val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
? ? ? ? ? val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
? ? ? ? ? val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,
? ? ? ? ? ? rightNode.impurity, leftNode.predict, rightNode.predict)
? ? ? ? ? new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
? ? ? ? ? ? data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))
? ? ? ? }
? ? ? nodes += node.id -> node
? ? ? node
? ? }
其實就是遞歸的從NodeData中獲取數據,重建node
從上面的分析可以看到,spark保存模型使用了parquet格式,對于我們在別的環境中使用是非常不方便的,訓練完模型后,我們可以參照spark的做法,按照前序遍歷的方法以json的格式保存node,在別的環境下復建樹結構就可以了。
9. 坑
特征id,樣本是libsvm格式的,特征id從1開始,但是設置離散特征數categoricalFeaturesInfo需要從0開始,相當于樣本特征id-1
離散特征值,一旦在categoricalFeaturesInfo中指定了特征值的個數k,spark會認為這個特征是從0開始,連續到k-1。如果其中特征不連續,特征數應該設置成最大特征值+1
對于連續特征,spark使用等頻離散化方法,又對樣本進行了抽樣,效果其實很難保證,不知道作者是否比較過這種方法與等間隔離散化效果孰優孰劣
maxBins的設置需要考慮連續特征離散化效果,連續特征離散化值的個數是maxBins-1,同時maxBins必須大于categoricalFeaturesInfo中最大離散特征值的個數
ordered feature,之前的理解是有誤的,這里的order僅僅是說這種特征是可以經過某種方式排列后變成有序,排序標準根據分類/回歸而不同,在上面的文章有具體介紹。在我們的實踐中,有的離散特征,例如薪資,1代表0-1000元,2代表1000-2000元,3代表2000-3000元,特征值的大小本身就表征了實際意義,這種應該直接按連續特征處理(當然也可以對比下效果決定)。
10. 結語
我們基本上是逐行分析了spark隨機森林的實現,展現了其實現過程中使用的技巧,希望對大家在理解隨機森林和其實現方法有所幫助。
---------------------?
原文:https://blog.csdn.net/snaillup/article/details/72820346?
?
總結
以上是生活随笔為你收集整理的spark mllib源码分析之随机森林(Random Forest)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 梯度迭代树回归(GBDT)算法介绍及Sp
- 下一篇: spark mllib源码分析之Deci