基于Spark实现随机森林代码

本文实例为大家分享了基于Spark实现随机森林的具体代码,供大家参考,具体内容如下

public class RandomForestClassficationTest extends TestCase implements Serializable

{

/**

*

*/

private static final long serialVersionUID = 7802523720751354318L;

class PredictResult implements Serializable{

/**

*

*/

private static final long serialVersionUID = -168308887976477219L;

double label;

double prediction;

public PredictResult(double label,double prediction){

this.label = label;

this.prediction = prediction;

}

@Override

public String toString(){

return this.label + " : " + this.prediction ;

}

}

public void test_randomForest() throws JAXBException{

SparkConf sparkConf = new SparkConf();

sparkConf.setAppName("RandomForest");

sparkConf.setMaster("local");

SparkContext sc = new SparkContext(sparkConf);

String dataPath = RandomForestClassficationTest.class.getResource("/").getPath() + "/sample_libsvm_data.txt";

RDD dataSet = MLUtils.loadLibSVMFile(sc, dataPath);

RDD[] rddList = dataSet.randomSplit(new double[]{0.7,0.3},1);

RDD trainingData = rddList[0];

RDD testData = rddList[1];

ClassTag labelPointClassTag = trainingData.elementClassTag();

JavaRDD trainingJavaData = new JavaRDD(trainingData,labelPointClassTag);

int numClasses = 2;

Map categoricalFeatureInfos = new HashMap();

int numTrees = 3;

String featureSubsetStrategy = "auto";

String impurity = "gini";

int maxDepth = 4;

int maxBins = 32;

/**

* 1 numClasses分类个数为2

* 2 numTrees 表示的是随机森林中树的个数

* 3 featureSubsetStrategy

* 4

*/

final RandomForestModel model = RandomForest.trainClassifier(trainingJavaData,

numClasses,

categoricalFeatureInfos,

numTrees,

featureSubsetStrategy,

impurity,

maxDepth,

maxBins,

1);

JavaRDD testJavaData = new JavaRDD(testData,testData.elementClassTag());

JavaRDD predictRddResult = testJavaData.map(new Function(){

/**

*

*/

private static final long serialVersionUID = 1L;

public PredictResult call(LabeledPoint point) throws Exception {

// TODO Auto-generated method stub

double pointLabel = point.label();

double prediction = model.predict(point.features());

PredictResult result = new PredictResult(pointLabel,prediction);

return result;

}

});

List predictResultList = predictRddResult.collect();

for(PredictResult result:predictResultList){

System.out.println(result.toString());

}

System.out.println(model.toDebugString());

}

}

得到的随机森林的展示结果如下:

TreeEnsembleModel classifier with 3 trees

Tree 0:

If (feature 435 <= 0.0)

If (feature 516 <= 0.0)

Predict: 0.0

Else (feature 516 > 0.0)

Predict: 1.0

Else (feature 435 > 0.0)

Predict: 1.0

Tree 1:

If (feature 512 <= 0.0)

Predict: 1.0

Else (feature 512 > 0.0)

Predict: 0.0

Tree 2:

If (feature 377 <= 1.0)

Predict: 0.0

Else (feature 377 > 1.0)

If (feature 455 <= 0.0)

Predict: 1.0

Else (feature 455 > 0.0)

Predict: 0.0

以上是 基于Spark实现随机森林代码 的全部内容, 来源链接: utcz.com/z/339356.html

回到顶部