资源描述:
《Opencv2.4.9源码分析——Gradient Boosted Trees》由会员上传分享,免费在线阅读,更多相关内容在教育资源-天天文库。
1、Opencv2.4.9源码分析——GradientBoostedTrees一、原理梯度提升树(GBT,GradientBoostedTrees,或称为梯度提升决策树)算法是由Friedman于1999年首次完整的提出,该算法可以实现回归、分类和排序。GBT的优点是特征属性无需进行归一化处理,预测速度快,可以应用不同的损失函数等。从它的名字就可以看出,GBT包括三个机器学习的优化算法:决策树方法、提升方法和梯度下降法。前两种算法在我以前的文章中都有详细的介绍,在这里我只做简单描述。决策树是一个由根节点、中间节点、叶节点和分支构成的树状模型,分支代表着数据的
2、走向,中间节点包含着训练时产生的分叉决策准则,叶节点代表着最终的数据分类结果或回归值,在预测的过程中,数据从根节点出发,沿着分支在到达中间节点时,根据该节点的决策准则实现分叉,最终到达叶节点,完成分类或回归。提升算法是由一系列“弱学习器”构成,这些弱学习器通过某种线性组合实现一个强学习器,虽然这些弱学习器的分类或回归效果可能仅仅比随机分类或回归要好一点,但最终的强学习器却可以得到一个很好的预测结果。二、源码分析下面介绍OpenCV的GBT源码。首先给出GBT算法所需参数的结构体CvGBTreesParams:[cpp]viewplaincopy在CODE
3、上查看代码片派生到我的代码片CvGBTreesParams::CvGBTreesParams(int_loss_function_type,int_weak_count,float_shrinkage,float_subsample_portion,int_max_depth,bool_use_surrogates):CvDTreeParams(3,10,0,false,10,0,false,false,0){loss_function_type=_loss_function_type;weak_count=_weak_count;shrinkage=_
4、shrinkage;subsample_portion=_subsample_portion;max_depth=_max_depth;use_surrogates=_use_surrogates;}loss_function_type表示损失函数的类型,CvGBTrees::SQUARED_LOSS为平方损失函数,CvGBTrees::ABSOLUTE_LOSS为绝对值损失函数,CvGBTrees::HUBER_LOSS为Huber损失函数,CvGBTrees::DEVIANCE_LOSS为偏差损失函数,前三种用于回归问题,后一种用于分类问题weak_
5、count表示GBT的优化迭代次数,对于回归问题来说,weak_count也就是决策树的数量,对于分类问题来说,weak_count×K为决策树的数量,K表示类别数量shrinkage表示收缩因子vsubsample_portion表示训练样本占全部样本的比例,为不大于1的正数max_depth表示决策树的最大深度use_surrogates表示是否使用替代分叉节点,为true,表示使用替代分叉节点CvDTreeParams结构详见我的关于决策树的文章CvGBTrees类的一个构造函数:[cpp]viewplaincopy在CODE上查看代码片派生到我的
6、代码片CvGBTrees::CvGBTrees(constcv::Mat&trainData,inttflag,constcv::Mat&responses,constcv::Mat&varIdx,constcv::Mat&sampleIdx,constcv::Mat&varType,constcv::Mat&missingDataMask,CvGBTreesParams_params){data=0;//表示样本数据集合weak=0;//表示一个弱学习器default_model_name="my_boost_tree";//orig_response
7、表示样本的响应值,sum_response表示拟合函数Fm(x),sum_response_tmp表示Fm+1(x)orig_response=sum_response=sum_response_tmp=0;//subsample_train和subsample_test分别表示训练样本集和测试样本集subsample_train=subsample_test=0;//missing表示缺失的特征属性,sample_idx表示真正用到的样本的索引missing=sample_idx=0;class_labels=0;//表示类别标签class_count
8、=1;//表示类别的数量delta=0.0f;//表示Huber损失函数中的参数