机器学习知识点(七)决策树学习算法Java实现
生活随笔
收集整理的這篇文章主要介紹了
机器学习知识点(七)决策树学习算法Java实现
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
為理解機器學習第四章節決策樹學習算法,通過網上找到的一份現成代碼,主要實現了最優劃分屬性選擇和決策樹構造,其中最優劃分屬性選擇采用信息增益準則、決策樹構造采用遞歸實現,代碼如下:
package sk.ml;import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set;public class DicisionTree {public static void main(String[] args) throws Exception {String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT","CREDIT_RATING" }; // 讀取樣本集Map<Object, List<Sample>> samples = readSamples(attrNames);// 生成決策樹Object decisionTree = generateDecisionTree(samples, attrNames);// 輸出決策樹outputDecisionTree(decisionTree, 0, null);}/*** 讀取已分類的樣本集,返回Map:分類 -> 屬于該分類的樣本的列表*/static Map<Object, List<Sample>> readSamples(String[] attrNames) {// 樣本屬性及其所屬分類(數組中的最后一個元素為樣本所屬分類)Object[][] rawData = new Object[][] {{ "<30 ", "High ", "No ", "Fair ", "0" },{ "<30 ", "High ", "No ", "Excellent", "0" },{ "30-40", "High ", "No ", "Fair ", "1" },{ ">40 ", "Medium", "No ", "Fair ", "1" },{ ">40 ", "Low ", "Yes", "Fair ", "1" },{ ">40 ", "Low ", "Yes", "Excellent", "0" },{ "30-40", "Low ", "Yes", "Excellent", "1" },{ "<30 ", "Medium", "No ", "Fair ", "0" },{ "<30 ", "Low ", "Yes", "Fair ", "1" },{ ">40 ", "Medium", "Yes", "Fair ", "1" },{ "<30 ", "Medium", "Yes", "Excellent", "1" },{ "30-40", "Medium", "No ", "Excellent", "1" },{ "30-40", "High ", "Yes", "Fair ", "1" },{ ">40 ", "Medium", "No ", "Excellent", "0" } };// 讀取樣本屬性及其所屬分類,構造表示樣本的Sample對象,并按分類劃分樣本集Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>();for (Object[] row : rawData) {Sample sample = new Sample();int i = 0;for (int n = row.length - 1; i < n; i++)sample.setAttribute(attrNames[i], row[i]);sample.setCategory(row[i]);List<Sample> samples = ret.get(row[i]);if (samples == null) {samples = new LinkedList<Sample>();ret.put(row[i], samples);}samples.add(sample);}return ret;}/*** 構造決策樹*/static Object generateDecisionTree(Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {// 如果只有一個樣本,將該樣本所屬分類作為新樣本的分類if (categoryToSamples.size() == 1)return categoryToSamples.keySet().iterator().next();// 如果沒有供決策的屬性,則將樣本集中具有最多樣本的分類作為新樣本的分類,即投票選舉出分類if (attrNames.length == 0) {int max = 0;Object maxCategory = null;for (Entry<Object, List<Sample>> entry : categoryToSamples.entrySet()) {int cur = entry.getValue().size();if (cur > max) {max = cur;maxCategory = entry.getKey();}}return maxCategory;}// 選取測試屬性Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames);// 決策樹根結點,分支屬性為選取的測試屬性Tree tree = new Tree(attrNames[(Integer) rst[0]]);// 已用過的測試屬性不應再次被選為測試屬性String[] subA = new String[attrNames.length - 1];for (int i = 0, j = 0; i < attrNames.length; i++)if (i != (Integer) rst[0])subA[j++] = attrNames[i];// 根據分支屬性生成分支@SuppressWarnings("unchecked")Map<Object, Map<Object, List<Sample>>> splits =/* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2];for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) {Object attrValue = entry.getKey();Map<Object, List<Sample>> split = entry.getValue();Object child = generateDecisionTree(split, subA);tree.setChild(attrValue, child);}return tree;}/*** 選取最優測試屬性。最優是指如果根據選取的測試屬性分支,則從各分支確定新樣本* 的分類需要的信息量之和最小,這等價于確定新樣本的測試屬性獲得的信息增益最大* 返回數組:選取的屬性下標、信息量之和、Map(屬性值->(分類->樣本列表))*/static Object[] chooseBestTestAttribute(Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {int minIndex = -1; // 最優屬性下標double minValue = Double.MAX_VALUE; // 最小信息量Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最優分支方案// 對每一個屬性,計算將其作為測試屬性的情況下在各分支確定新樣本的分類需要的信息量之和,選取最小為最優for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) {int allCount = 0; // 統計樣本總數的計數器// 按當前屬性構建Map:屬性值->(分類->樣本列表)Map<Object, Map<Object, List<Sample>>> curSplits =/* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>();for (Entry<Object, List<Sample>> entry : categoryToSamples.entrySet()) {Object category = entry.getKey();List<Sample> samples = entry.getValue();for (Sample sample : samples) {Object attrValue = sample.getAttribute(attrNames[attrIndex]);Map<Object, List<Sample>> split = curSplits.get(attrValue);if (split == null) {split = new HashMap<Object, List<Sample>>();curSplits.put(attrValue, split);}List<Sample> splitSamples = split.get(category);if (splitSamples == null) {splitSamples = new LinkedList<Sample>();split.put(category, splitSamples);}splitSamples.add(sample);}allCount += samples.size();}// 計算將當前屬性作為測試屬性的情況下在各分支確定新樣本的分類需要的信息量之和double curValue = 0.0; // 計數器:累加各分支for (Map<Object, List<Sample>> splits : curSplits.values()) {double perSplitCount = 0;for (List<Sample> list : splits.values())perSplitCount += list.size(); // 累計當前分支樣本數double perSplitValue = 0.0; // 計數器:當前分支for (List<Sample> list : splits.values()) {double p = list.size() / perSplitCount;perSplitValue -= p * (Math.log(p) / Math.log(2));}curValue += (perSplitCount / allCount) * perSplitValue;}// 選取最小為最優if (minValue > curValue) {minIndex = attrIndex;minValue = curValue;minSplits = curSplits;}}return new Object[] { minIndex, minValue, minSplits };}/*** 將決策樹輸出到標準輸出*/static void outputDecisionTree(Object obj, int level, Object from) {for (int i = 0; i < level; i++)System.out.print("|-----");if (from != null)System.out.printf("(%s):", from);if (obj instanceof Tree) {Tree tree = (Tree) obj;String attrName = tree.getAttribute();System.out.printf("[%s = ?]\n", attrName);for (Object attrValue : tree.getAttributeValues()) {Object child = tree.getChild(attrValue);outputDecisionTree(child, level + 1, attrName + " = "+ attrValue);}} else {System.out.printf("[CATEGORY = %s]\n", obj);}}/*** 樣本,包含多個屬性和一個指明樣本所屬分類的分類值*/static class Sample {private Map<String, Object> attributes = new HashMap<String, Object>();private Object category;public Object getAttribute(String name) {return attributes.get(name);}public void setAttribute(String name, Object value) {attributes.put(name, value);}public Object getCategory() {return category;}public void setCategory(Object category) {this.category = category;}public String toString() {return attributes.toString();}}/*** 決策樹(非葉結點),決策樹中的每個非葉結點都引導了一棵決策樹* 每個非葉結點包含一個分支屬性和多個分支,分支屬性的每個值對應一個分支,該分支引導了一棵子決策樹*/static class Tree {private String attribute;private Map<Object, Object> children = new HashMap<Object, Object>();public Tree(String attribute) {this.attribute = attribute;}public String getAttribute() {return attribute;}public Object getChild(Object attrValue) {return children.get(attrValue);}public void setChild(Object attrValue, Object child) {children.put(attrValue, child);}public Set<Object> getAttributeValues() {return children.keySet();}} } 執行結果如下: [AGE = ?] |-----(AGE = >40 ):[CREDIT_RATING = ?] |-----|-----(CREDIT_RATING = Excellent):[CATEGORY = 0] |-----|-----(CREDIT_RATING = Fair ):[CATEGORY = 1] |-----(AGE = <30 ):[STUDENT = ?] |-----|-----(STUDENT = Yes):[CATEGORY = 1] |-----|-----(STUDENT = No ):[CATEGORY = 0] |-----(AGE = 30-40):[CATEGORY = 1]總結
以上是生活随笔為你收集整理的机器学习知识点(七)决策树学习算法Java实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 机器学习笔记(四)决策树
- 下一篇: 机器学习知识点(八)感知机模型Java实