资源描述:
《Spark中决策树源码分析》由会员上传分享,免费在线阅读,更多相关内容在教育资源-天天文库。
1、Spark中决策树源码分析1.Example使用SparkMLlib中决策树分类器API,训练出一个决策树模型,使用Python开发。"""DecisionTreeClassificationExample."""from__future__importprint_functionfrompysparkimportSparkContextfrompyspark.mllib.treeimportDecisionTree,DecisionTreeModelfrompyspark.mllib.utilimport
2、MLUtilsif__name__=="__main__":sc=SparkContext(appName="PythonDecisionTreeClassificationExample")#加载和解析数据文件为RDDdataPath="/home/zhb/Desktop/work/DecisionTreeShareProject/app/sample_libsvm_data.txt"print(dataPath)data=MLUtils.loadLibSVMFile(sc,dataPath)#将数据集分
3、割为训练数据集和测试数据集(trainingData,testData)=data.randomSplit([0.7,0.3])print("traindatacount:"+str(trainingData.count()))print("testdatacount:"+str(testData.count()))#训练决策树分类器#categoricalFeaturesInfo为空,表示所有的特征均为连续值model=DecisionTree.trainClassifier(trainingData,n
4、umClasses=2,categoricalFeaturesInfo={},impurity='gini',maxDepth=5,maxBins=32)#测试数据集上预测predictions=model.predict(testData.map(lambdax:x.features))#打包真实值与预测值labelsAndPredictions=testData.map(lambdalp:lp.label).zip(predictions)#统计预测错误的样本的频率testErr=labelsAndPr
5、edictions.filter(lambda(v,p):v!=p).count()/float(testData.count())print('DecisionTreeTestError=%5.3f%%'%(testErr*100))print("DecisionTreeLearnedclassifictiontreemodel:")print(model.toDebugString())#保存和加载训练好的模型modelPath="/home/zhb/Desktop/work/DecisionTreeS
6、hareProject/app/myDecisionTreeClassificationModel"model.save(sc,modelPath)sameModel=DecisionTreeModel.load(sc,modelPath)2.决策树源码分析决策树分类器API为DecisionTree.trainClassifier,进入源码分析。源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree
7、.scala。@Since("1.1.0")deftrainClassifier(input:RDD[LabeledPoint],numClasses:Int,categoricalFeaturesInfo:Map[Int,Int],impurity:String,maxDepth:Int,maxBins:Int):DecisionTreeModel={valimpurityType=Impurities.fromString(impurity)train(input,Classification,impu
8、rityType,maxDepth,numClasses,maxBins,Sort,categoricalFeaturesInfo)}训练出一个分类器,然后调用了train方法。@Since("1.0.0")deftrain(input:RDD[LabeledPoint],algo:Algo,impurity:Impurity,maxDepth:Int,numClasses:Int,maxBins:Int,qua