决策树的C++实现(CART)
關(guān)于決策樹的介紹可以參考:?https://blog.csdn.net/fengbingchun/article/details/78880934
CART算法的決策樹的Python實現(xiàn)可以參考:?https://blog.csdn.net/fengbingchun/article/details/78881143
這里參考?https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/?這篇文章的原有Python實現(xiàn),使用C++實現(xiàn)了決策樹的CART算法,測試數(shù)據(jù)集是Banknote Dataset,關(guān)于Banknote Dataset的介紹可以參考:?https://blog.csdn.net/fengbingchun/article/details/78624358?。
decision_tree.hpp文件內(nèi)容如下:
#ifndef FBC_NN_DECISION_TREE_HPP_
#define FBC_NN_DECISION_TREE_HPP_#include <vector>
#include <tuple>
#include <fstream>namespace ANN {
// referecne: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/template<typename T>
class DecisionTree { // CART(Classification and Regression Trees)
public:DecisionTree() = default;~DecisionTree() { delete_tree(); }int init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes);void set_max_depth(int max_depth) { this->max_depth = max_depth; }int get_max_depth() const { return max_depth; }void set_min_size(int min_size) { this->min_size = min_size; }int get_min_size() const { return min_size; }void train();int save_model(const char* name) const;int load_model(const char* name);T predict(const std::vector<T>& data) const;protected:typedef std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> dictionary; // index of attribute, value of attribute, groups of datatypedef std::tuple<int, int, T, T, T> row_element; // flag, index, value, class_value_left, class_value_righttypedef struct binary_tree {dictionary dict;T class_value_left = (T)-1.f;T class_value_right = (T)-1.f;binary_tree* left = nullptr;binary_tree* right = nullptr;} binary_tree;// Calculate the Gini index for a split datasetT gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const;// Select the best split point for a datasetdictionary get_split(const std::vector<std::vector<T>>& dataset) const;// Split a dataset based on an attribute and an attribute valuestd::vector<std::vector<std::vector<T>>> test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const;// Create a terminal node valueT to_terminal(const std::vector<std::vector<T>>& group) const;// Create child splits for a node or make terminalvoid split(binary_tree* node, int depth);// Build a decision treevoid build_tree(const std::vector<std::vector<T>>& train);// Print a decision treevoid print_tree(const binary_tree* node, int depth = 0) const;// Make a prediction with a decision treeT predict(binary_tree* node, const std::vector<T>& data) const;// calculate accuracy percentagedouble accuracy_metric() const;void delete_tree();void delete_node(binary_tree* node);void write_node(const binary_tree* node, std::ofstream& file) const;void node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const;int height_of_tree(const binary_tree* node) const;void row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos);private:std::vector<std::vector<T>> src_data;binary_tree* tree = nullptr;int samples_num = 0;int feature_length = 0;int classes_num = 0;int max_depth = 10; // maximum tree depthint min_size = 10; // minimum node recordsint max_nodes = -1;
};} // namespace ANN#endif // FBC_NN_DECISION_TREE_HPP_
decision_tree.cpp文件內(nèi)容如下:
#include "decision_tree.hpp"
#include <set>
#include <algorithm>
#include <typeinfo>
#include <iterator>
#include "common.hpp"namespace ANN {template<typename T>
int DecisionTree<T>::init(const std::vector<std::vector<T>>& data, const std::vector<T>& classes)
{CHECK(data.size() != 0 && classes.size() != 0 && data[0].size() != 0);this->samples_num = data.size();this->classes_num = classes.size();this->feature_length = data[0].size() -1;for (int i = 0; i < this->samples_num; ++i) {this->src_data.emplace_back(data[i]);}return 0;
}template<typename T>
T DecisionTree<T>::gini_index(const std::vector<std::vector<std::vector<T>>>& groups, const std::vector<T>& classes) const
{// Gini calculation for a group// proportion = count(class_value) / count(rows)// gini_index = (1.0 - sum(proportion * proportion)) * (group_size/total_samples)// count all samples at split pointint instances = 0;int group_num = groups.size();for (int i = 0; i < group_num; ++i) {instances += groups[i].size();}// sum weighted Gini index for each groupT gini = (T)0.;for (int i = 0; i < group_num; ++i) {int size = groups[i].size();// avoid divide by zeroif (size == 0) continue;T score = (T)0.;// score the group based on the score for each classT p = (T)0.;for (int c = 0; c < classes.size(); ++c) {int count = 0;for (int t = 0; t < size; ++t) {if (groups[i][t][this->feature_length] == classes[c]) ++count;}T p = (float)count / size;score += p * p;}// weight the group score by its relative sizegini += (1. - score) * (float)size / instances;}return gini;
}template<typename T>
std::vector<std::vector<std::vector<T>>> DecisionTree<T>::test_split(int index, T value, const std::vector<std::vector<T>>& dataset) const
{std::vector<std::vector<std::vector<T>>> groups(2); // 0: left, 1: reightfor (int row = 0; row < dataset.size(); ++row) {if (dataset[row][index] < value) {groups[0].emplace_back(dataset[row]);} else {groups[1].emplace_back(dataset[row]);}}return groups;
}template<typename T>
std::tuple<int, T, std::vector<std::vector<std::vector<T>>>> DecisionTree<T>::get_split(const std::vector<std::vector<T>>& dataset) const
{std::vector<T> values;for (int i = 0; i < dataset.size(); ++i) {values.emplace_back(dataset[i][this->feature_length]);}std::set<T> vals(values.cbegin(), values.cend());std::vector<T> class_values(vals.cbegin(), vals.cend());int b_index = 999;T b_value = (T)999.;T b_score = (T)999.;std::vector<std::vector<std::vector<T>>> b_groups(2);for (int index = 0; index < this->feature_length; ++index) {for (int row = 0; row < dataset.size(); ++row) {std::vector<std::vector<std::vector<T>>> groups = test_split(index, dataset[row][index], dataset);T gini = gini_index(groups, class_values);if (gini < b_score) {b_index = index;b_value = dataset[row][index];b_score = gini;b_groups = groups;}}}// a new node: the index of the chosen attribute, the value of that attribute by which to split and the two groups of data split by the chosen split pointreturn std::make_tuple(b_index, b_value, b_groups);
}template<typename T>
T DecisionTree<T>::to_terminal(const std::vector<std::vector<T>>& group) const
{std::vector<T> values;for (int i = 0; i < group.size(); ++i) {values.emplace_back(group[i][this->feature_length]);}std::set<T> vals(values.cbegin(), values.cend());int max_count = -1, index = -1;for (int i = 0; i < vals.size(); ++i) {int count = std::count(values.cbegin(), values.cend(), *std::next(vals.cbegin(), i));if (max_count < count) {max_count = count;index = i;}}return *std::next(vals.cbegin(), index);
}template<typename T>
void DecisionTree<T>::split(binary_tree* node, int depth)
{std::vector<std::vector<T>> left = std::get<2>(node->dict)[0];std::vector<std::vector<T>> right = std::get<2>(node->dict)[1];std::get<2>(node->dict).clear();// check for a no splitif (left.size() == 0 || right.size() == 0) {for (int i = 0; i < right.size(); ++i) {left.emplace_back(right[i]);}node->class_value_left = node->class_value_right = to_terminal(left);return;}// check for max depthif (depth >= max_depth) {node->class_value_left = to_terminal(left);node->class_value_right = to_terminal(right);return;}// process left childif (left.size() <= min_size) {node->class_value_left = to_terminal(left);} else {dictionary dict = get_split(left);node->left = new binary_tree;node->left->dict = dict;split(node->left, depth+1);}// process right childif (right.size() <= min_size) {node->class_value_right = to_terminal(right);} else {dictionary dict = get_split(right);node->right = new binary_tree;node->right->dict = dict;split(node->right, depth+1);}
}template<typename T>
void DecisionTree<T>::build_tree(const std::vector<std::vector<T>>& train)
{// create root nodedictionary root = get_split(train);binary_tree* node = new binary_tree;node->dict = root;tree = node;split(node, 1);
}template<typename T>
void DecisionTree<T>::train()
{this->max_nodes = (1 << max_depth) - 1;build_tree(src_data);accuracy_metric();//binary_tree* tmp = tree;//print_tree(tmp);
}template<typename T>
T DecisionTree<T>::predict(const std::vector<T>& data) const
{if (!tree) {fprintf(stderr, "Error, tree is null\n");return -1111.f;}return predict(tree, data);
}template<typename T>
T DecisionTree<T>::predict(binary_tree* node, const std::vector<T>& data) const
{if (data[std::get<0>(node->dict)] < std::get<1>(node->dict)) {if (node->left) {return predict(node->left, data);} else {return node->class_value_left;}} else {if (node->right) {return predict(node->right, data);} else {return node->class_value_right;}}
}template<typename T>
int DecisionTree<T>::save_model(const char* name) const
{std::ofstream file(name, std::ios::out);if (!file.is_open()) {fprintf(stderr, "open file fail: %s\n", name);return -1;}file<<max_depth<<","<<min_size<<std::endl;binary_tree* tmp = tree;int depth = height_of_tree(tmp);CHECK(max_depth == depth);tmp = tree;write_node(tmp, file);file.close();return 0;
}template<typename T>
void DecisionTree<T>::write_node(const binary_tree* node, std::ofstream& file) const
{/*if (!node) return;write_node(node->left, file);file<<std::get<0>(node->dict)<<","<<std::get<1>(node->dict)<<","<<node->class_value_left<<","<<node->class_value_right<<std::endl;write_node(node->right, file);*///typedef std::tuple<int, int, T, T, T> row; // flag, index, value, class_value_left, class_value_rightstd::vector<row_element> vec(this->max_nodes, std::make_tuple(-1, -1, (T)-1.f, (T)-1.f, (T)-1.f));binary_tree* tmp = const_cast<binary_tree*>(node);node_to_row_element(tmp, vec, 0);for (const auto& row : vec) {file<<std::get<0>(row)<<","<<std::get<1>(row)<<","<<std::get<2>(row)<<","<<std::get<3>(row)<<","<<std::get<4>(row)<<std::endl;}
}template<typename T>
void DecisionTree<T>::node_to_row_element(binary_tree* node, std::vector<row_element>& rows, int pos) const
{if (!node) return;rows[pos] = std::make_tuple(0, std::get<0>(node->dict), std::get<1>(node->dict), node->class_value_left, node->class_value_right); // 0: have node, -1: no nodeif (node->left) node_to_row_element(node->left, rows, 2*pos+1);if (node->right) node_to_row_element(node->right, rows, 2*pos+2);
}template<typename T>
int DecisionTree<T>::height_of_tree(const binary_tree* node) const
{if (!node)return 0;elsereturn std::max(height_of_tree(node->left), height_of_tree(node->right)) + 1;
}template<typename T>
int DecisionTree<T>::load_model(const char* name)
{std::ifstream file(name, std::ios::in);if (!file.is_open()) {fprintf(stderr, "open file fail: %s\n", name);return -1;}std::string line, cell;std::getline(file, line);std::stringstream line_stream(line);std::vector<int> vec;int count = 0;while (std::getline(line_stream, cell, ',')) {vec.emplace_back(std::stoi(cell));}CHECK(vec.size() == 2);max_depth = vec[0];min_size = vec[1];max_nodes = (1 << max_depth) - 1;std::vector<row_element> rows(max_nodes);if (typeid(float).name() == typeid(T).name()) {while (std::getline(file, line)) {std::stringstream line_stream2(line);std::vector<T> vec2;while(std::getline(line_stream2, cell, ',')) {vec2.emplace_back(std::stof(cell));}CHECK(vec2.size() == 5);rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec2[4]);//fprintf(stderr, "%d, %d, %f, %f, %f\n", std::get<0>(rows[count]), std::get<1>(rows[count]), std::get<2>(rows[count]), std::get<3>(rows[count]), std::get<4>(rows[count]));++count;}} else { // doublewhile (std::getline(file, line)) {std::stringstream line_stream2(line);std::vector<T> vec2;while(std::getline(line_stream2, cell, ',')) {vec2.emplace_back(std::stod(cell));}CHECK(vec2.size() == 5);rows[count] = std::make_tuple((int)vec2[0], (int)vec2[1], vec2[2], vec2[3], vec[4]);++count;}}CHECK(max_nodes == count);CHECK(std::get<0>(rows[0]) != -1);binary_tree* tmp = new binary_tree;std::vector<std::vector<std::vector<T>>> dump;tmp->dict = std::make_tuple(std::get<1>(rows[0]), std::get<2>(rows[0]), dump);tmp->class_value_left = std::get<3>(rows[0]);tmp->class_value_right = std::get<4>(rows[0]);tree = tmp;row_element_to_node(tmp, rows, max_nodes, 0);file.close();return 0;
}template<typename T>
void DecisionTree<T>::row_element_to_node(binary_tree* node, const std::vector<row_element>& rows, int n, int pos)
{if (!node || n == 0) return;int new_pos = 2 * pos + 1;if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {node->left = new binary_tree;std::vector<std::vector<std::vector<T>>> dump;node->left->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);node->left->class_value_left = std::get<3>(rows[new_pos]);node->left->class_value_right = std::get<4>(rows[new_pos]);row_element_to_node(node->left, rows, n, new_pos);}new_pos = 2 * pos + 2;if (new_pos < n && std::get<0>(rows[new_pos]) != -1) {node->right = new binary_tree;std::vector<std::vector<std::vector<T>>> dump;node->right->dict = std::make_tuple(std::get<1>(rows[new_pos]), std::get<2>(rows[new_pos]), dump);node->right->class_value_left = std::get<3>(rows[new_pos]);node->right->class_value_right = std::get<4>(rows[new_pos]);row_element_to_node(node->right, rows, n, new_pos);}
}template<typename T>
void DecisionTree<T>::delete_tree()
{delete_node(tree);
}template<typename T>
void DecisionTree<T>::delete_node(binary_tree* node)
{if (node->left) delete_node(node->left);if (node->right) delete_node(node->right);delete node;
}template<typename T>
double DecisionTree<T>::accuracy_metric() const
{int correct = 0;for (int i = 0; i < this->samples_num; ++i) {T predicted = predict(tree, src_data[i]);if (predicted == src_data[i][this->feature_length])++correct;}double accuracy = correct / (double)samples_num * 100.;fprintf(stdout, "train accuracy: %f\n", accuracy);return accuracy;
}template<typename T>
void DecisionTree<T>::print_tree(const binary_tree* node, int depth) const
{if (node) {std::string blank = " ";for (int i = 0; i < depth; ++i) blank += blank;fprintf(stdout, "%s[X%d < %.3f]\n", blank.c_str(), std::get<0>(node->dict)+1, std::get<1>(node->dict));if (!node->left || !node->right)blank += blank;if (!node->left)fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_left);else print_tree(node->left, depth+1);if (!node->right)fprintf(stdout, "%s[%.1f]\n", blank.c_str(), node->class_value_right);elseprint_tree(node->right, depth+1);}
}template class DecisionTree<float>;
template class DecisionTree<double>;} // namespace ANN
對外提供兩個接口,一個是test_decision_tree_train用于訓(xùn)練,一個是test_decision_tree_predict用于測試,其code如下:
// =============================== decision tree ==============================
int test_decision_tree_train()
{// small dataset test/*const std::vector<std::vector<float>> data{ { 2.771244718f, 1.784783929f, 0.f },{ 1.728571309f, 1.169761413f, 0.f },{ 3.678319846f, 2.81281357f, 0.f },{ 3.961043357f, 2.61995032f, 0.f },{ 2.999208922f, 2.209014212f, 0.f },{ 7.497545867f, 3.162953546f, 1.f },{ 9.00220326f, 3.339047188f, 1.f },{ 7.444542326f, 0.476683375f, 1.f },{ 10.12493903f, 3.234550982f, 1.f },{ 6.642287351f, 3.319983761f, 1.f } };const std::vector<float> classes{ 0.f, 1.f };ANN::DecisionTree<float> dt;dt.init(data, classes);dt.set_max_depth(3);dt.set_min_size(1);dt.train();
#ifdef _MSC_VERconst char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#elseconst char* model_name = "data/decision_tree.model";
#endifdt.save_model(model_name);ANN::DecisionTree<float> dt2;dt2.load_model(model_name);const std::vector<std::vector<float>> test{{0.6f, 1.9f, 0.f}, {9.7f, 4.3f, 1.f}};for (const auto& row : test) {float ret = dt2.predict(row);fprintf(stdout, "predict result: %.1f, actural value: %.1f\n", ret, row[2]);} */// banknote authentication dataset
#ifdef _MSC_VERconst char* file_name = "E:/GitCode/NN_Test/data/database/BacknoteDataset/data_banknote_authentication.txt";
#elseconst char* file_name = "data/database/BacknoteDataset/data_banknote_authentication.txt";
#endifstd::vector<std::vector<float>> data;int ret = read_txt_file<float>(file_name, data, ',', 1372, 5);if (ret != 0) {fprintf(stderr, "parse txt file fail: %s\n", file_name);return -1;}//fprintf(stdout, "data size: rows: %d\n", data.size());const std::vector<float> classes{ 0.f, 1.f };ANN::DecisionTree<float> dt;dt.init(data, classes);dt.set_max_depth(6);dt.set_min_size(10);dt.train();
#ifdef _MSC_VERconst char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#elseconst char* model_name = "data/decision_tree.model";
#endifdt.save_model(model_name);return 0;
}int test_decision_tree_predict()
{
#ifdef _MSC_VERconst char* model_name = "E:/GitCode/NN_Test/data/decision_tree.model";
#elseconst char* model_name = "data/decision_tree.model";
#endifANN::DecisionTree<float> dt;dt.load_model(model_name);int max_depth = dt.get_max_depth();int min_size = dt.get_min_size();fprintf(stdout, "max_depth: %d, min_size: %d\n", max_depth, min_size);std::vector<std::vector<float>> test {{-2.5526,-7.3625,6.9255,-0.66811,1},{-4.5531,-12.5854,15.4417,-1.4983,1},{4.0948,-2.9674,2.3689,0.75429,0},{-1.0401,9.3987,0.85998,-5.3336,0},{1.0637,3.6957,-4.1594,-1.9379,1}};for (const auto& row : test) { float ret = dt.predict(row);fprintf(stdout, "predict result: %.1f, actual value: %.1f\n", ret, row[4]);}return 0;
}
訓(xùn)練接口執(zhí)行結(jié)果如下:
測試接口執(zhí)行結(jié)果如下:
訓(xùn)練時生成的模型decison_tree.model內(nèi)容如下:
6,10
0,0,0.3223,-1,-1
0,1,7.6274,-1,-1
0,2,-4.3839,-1,-1
0,0,-0.39816,-1,-1
0,0,-4.2859,-1,-1
0,0,4.2164,-1,0
0,0,1.594,-1,-1
0,2,6.2204,-1,-1
0,1,5.8974,-1,-1
0,0,-5.4901,-1,1
0,0,-1.5768,-1,-1
0,0,0.47368,1,-1
-1,-1,-1,-1,-1
0,2,-2.2718,-1,-1
0,0,2.0421,-1,-1
0,1,7.3273,-1,1
0,1,-4.6062,-1,-1
0,2,3.1143,-1,-1
0,0,0.049175,0,0
0,0,-6.2003,1,1
-1,-1,-1,-1,-1
0,0,-2.7419,0,-1
0,0,-1.5768,0,0
-1,-1,-1,-1,-1
0,0,0.47368,1,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,1,7.6377,-1,0
0,3,0.097399,-1,-1
0,2,-2.3386,1,-1
0,0,3.6216,-1,-1
0,0,-1.3971,1,1
-1,-1,-1,-1,-1
0,0,-1.6677,1,1
0,0,-1.7781,0,0
0,0,-0.36506,1,1
0,3,1.547,0,1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,-2.7419,0,0
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
-1,-1,-1,-1,-1
0,0,1.0552,1,1
-1,-1,-1,-1,-1
0,0,0.4339,0,0
0,2,2.0013,1,0
-1,-1,-1,-1,-1
0,0,1.8993,0,0
0,0,3.4566,0,0
0,0,3.6216,0,0
GitHub:?https://github.com/fengbingchun/NN_Test?
總結(jié)
以上是生活随笔為你收集整理的决策树的C++实现(CART)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: GCC中通过--wrap选项使用包装函数
- 下一篇: 二叉树简介及C++实现