libsvm java版本使用心得(转)

java

http://blog.csdn.net/u010340854/article/details/19159883

https://github.com/cjlin1/libsvm

项目中要用到svm分类器,自己实现的话太费时间,于是寻找开源实现,找到了libsvm。

Java版本是一个jar包,引入到工程中即可使用。

需要注意的是,java版本充满了c++风格(类名小写,命名使用下划线_分隔等等),使用者需要稍微适应一下。

核心类是svm类,最常用的几个方法如下(都是static方法):

svm.svm_load_model(String),望文生义即可知是加载已训练好的svm模型,参数是模型文件名。

svm.svm_save_model(String,svm_model),按指定的名称保存模型。

svm.svm_train(svm_problem,svm_parameter),训练模型,该方法有两个参数svm_problem,保存了训练数据,包括数据数,特征数组,类别数组。参数svm_parameter用户设置svm的一些参数,例如svm_type设置svm类型,kernel_type设置核函数类型等。训练时需要注意的是,如果你的训练数据比较多,训练时间可能很长。

svm.svm_predict(svm_model,svm_node[])和svm.svm_p

redict_probability(svm_model,svm_node[],double[]),都用于预测类别,不同的是后一个方法同时包含了预测类别的概率。

下面给出完整的demo:

[java] view plain copy

 

 

  1. public class Test_svm_predict {  
  2.   
  3.     public static void main(String[] args) {  

  4.         svm_problem sp = new svm_problem();  

  5.         svm_node[][] x = new svm_node[4][2];  

  6.         for (int i = 0; i < 4; i++) {  

  7.             for (int j = 0; j < 2; j++) {  

  8.                 x[i][j] = new svm_node();  

  9.             }  
  10.         }  
  11.         x[0][0].index = 1;  

  12.         x[0][0].value = 0;  

  13.         x[0][1].index = 2;  

  14.         x[0][1].value = 0;  

  15.   
  16.         x[1][0].index = 1;  

  17.         x[1][0].value = 1;  

  18.         x[1][1].index = 2;  

  19.         x[1][1].value = 1;  

  20.   
  21.         x[2][0].index = 1;  

  22.         x[2][0].value = 0;  

  23.         x[2][1].index = 2;  

  24.         x[2][1].value = 1;  

  25.   
  26.         x[3][0].index = 1;  

  27.         x[3][0].value = 1;  

  28.         x[3][1].value = 0;  

  29.         x[3][1].index = 2;  

  30.   
  31.   
  32.         double[] labels = new double[]{-1,-1,1,1};  

  33.         sp.x = x;  
  34.         sp.y = labels;  
  35.         sp.l = 4;  

  36.         svm_parameter prm = new svm_parameter();  

  37.         prm.svm_type = svm_parameter.C_SVC;  
  38.         prm.kernel_type = svm_parameter.RBF;  
  39.         prm.C = 1000;  

  40.         prm.eps = 0.0000001;  

  41.         prm.gamma = 10;  

  42.         prm.probability = 1;  

  43.         prm.cache_size=1024;  

  44.         /* 

  45.          * svm_check_parameter 
  46.          * 参数可行返回null,否则返回错误信息 
  47.          */  
  48.         System.out.println("Param Check " + (svm.svm_check_parameter(sp, prm)==null));  

  49.         svm_model model = svm.svm_train(sp, prm);           //训练分类  

  50.         try {  

  51.             svm.svm_save_model("svm_model_file", model);  

  52.         } catch (IOException e) {  

  53.             e.printStackTrace();  
  54.         }  
  55.           
  56.         try {  

  57.             svm.svm_load_model("svm_model_file");  

  58.         } catch (IOException e) {  

  59.             e.printStackTrace();  
  60.         }  
  61.         svm_node[] test = new svm_node[]{new svm_node(), new svm_node()};  

  62.         test[0].index = 1;  

  63.         test[0].value = 0;  

  64.         test[1].index = 2;  

  65.         test[1].value = 0;  

  66.         double[] l = new double[2];   

  67.         double result_prob = svm.svm_predict_probability(model, test,l);        //测试1,带预测概率的分类测试  

  68.         double result_normal = svm.svm_predict(model, test);    //测试2 不带概率的分类测试  

  69.         System.out.println("Result with prob " + result_prob);  

  70.         System.out.println("Result normal " + result_normal);  

  71.         System.out.println("Probability " + l[0] + "\t" + l[1]);  

  72.     }  
  73. }  

http://www.oschina.net/code/snippet_1246663_35454

1. [代码][Java]代码     

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

importjava.io.BufferedReader;

importjava.io.File;

importjava.io.FileReader;

importjava.util.ArrayList;

importjava.util.List;

 

importlibsvm.svm;

importlibsvm.svm_model;

importlibsvm.svm_node;

importlibsvm.svm_parameter;

importlibsvm.svm_problem;

 

publicclassSVM {

    publicstaticvoidmain(String[] args) {

        // 定义训练集点a{10.0, 10.0} 和 点b{-10.0, -10.0},对应lable为{1.0, -1.0}

        List<Double> label = newArrayList<Double>();

        List<svm_node[]> nodeSet = newArrayList<svm_node[]>();

        getData(nodeSet, label, "file/train.txt");

         

        intdataRange=nodeSet.get(0).length;

        svm_node[][] datas = newsvm_node[nodeSet.size()][dataRange]; // 训练集的向量表

        for(inti = 0; i < datas.length; i++) {

            for(intj = 0; j < dataRange; j++) {

                datas[i][j] = nodeSet.get(i)[j];

            }

        }

        double[] lables = newdouble[label.size()]; // a,b 对应的lable

        for(inti = 0; i < lables.length; i++) {

            lables[i] = label.get(i);

        }

 

        // 定义svm_problem对象

        svm_problem problem = newsvm_problem();

        problem.l = nodeSet.size(); // 向量个数

        problem.x = datas; // 训练集向量表

        problem.y = lables; // 对应的lable数组

 

        // 定义svm_parameter对象

        svm_parameter param = newsvm_parameter();

        param.svm_type = svm_parameter.EPSILON_SVR;

        param.kernel_type = svm_parameter.LINEAR;

        param.cache_size = 100;

        param.eps = 0.00001;

        param.C = 1.9;

        // 训练SVM分类模型

        System.out.println(svm.svm_check_parameter(problem, param));

        // 如果参数没有问题,则svm.svm_check_parameter()函数返回null,否则返回error描述。

        svm_model model = svm.svm_train(problem, param);

        // svm.svm_train()训练出SVM分类模型

 

        // 获取测试数据

        List<Double> testlabel = newArrayList<Double>();

        List<svm_node[]> testnodeSet = newArrayList<svm_node[]>();

        getData(testnodeSet, testlabel, "file/test.txt");

 

        svm_node[][] testdatas = newsvm_node[testnodeSet.size()][dataRange]; // 训练集的向量表

        for(inti = 0; i < testdatas.length; i++) {

            for(intj = 0; j < dataRange; j++) {

                testdatas[i][j] = testnodeSet.get(i)[j];

            }

        }

        double[] testlables = newdouble[testlabel.size()]; // a,b 对应的lable

        for(inti = 0; i < testlables.length; i++) {

            testlables[i] = testlabel.get(i);

        }

 

        // 预测测试数据的lable

        doubleerr = 0.0;

        for(inti = 0; i < testdatas.length; i++) {

            doubletruevalue = testlables[i];

            System.out.print(truevalue + " ");

            doublepredictValue = svm.svm_predict(model, testdatas[i]);

            System.out.println(predictValue);

            err += Math.abs(predictValue - truevalue);

        }

        System.out.println("err="+ err / datas.length);

    }

 

    publicstaticvoidgetData(List<svm_node[]> nodeSet, List<Double> label,

            String filename) {

        try{

 

            FileReader fr = newFileReader(newFile(filename));

            BufferedReader br = newBufferedReader(fr);

            String line = null;

            while((line = br.readLine()) != null) {

                String[] datas = line.split(",");

                svm_node[] vector = newsvm_node[datas.length - 1];

                for(inti = 0; i < datas.length - 1; i++) {

                    svm_node node = newsvm_node();

                    node.index = i + 1;

                    node.value = Double.parseDouble(datas[i]);

                    vector[i] = node;

                }

                nodeSet.add(vector);

                doublelablevalue = Double.parseDouble(datas[datas.length - 1]);

                label.add(lablevalue);

            }

        } catch(Exception e) {

            e.printStackTrace();

        }

 

    }

}

2. [代码]训练数据,最后一列为目标值     

?

1

2

3

4

5

6

7

8

9

10

11

12

17.6,17.7,17.7,17.7,17.8

17.7,17.7,17.7,17.8,17.8

17.7,17.7,17.8,17.8,17.9

17.7,17.8,17.8,17.9,18

17.8,17.8,17.9,18,18.1

17.8,17.9,18,18.1,18.2

17.9,18,18.1,18.2,18.4

18,18.1,18.2,18.4,18.6

18.1,18.2,18.4,18.6,18.7

18.2,18.4,18.6,18.7,18.9

18.4,18.6,18.7,18.9,19.1

18.6,18.7,18.9,19.1,19.3

3. [代码]测试数据     

?

1

2

3

4

5

6

7

18.7,18.9,19.1,19.3,19.6

18.9,19.1,19.3,19.6,19.9

19.1,19.3,19.6,19.9,20.2

19.3,19.6,19.9,20.2,20.6

19.6,19.9,20.2,20.6,21

19.9,20.2,20.6,21,21.5

20.2,20.6,21,21.5,22

4. [图片] QQ截图20140503213839.png    

以上是 libsvm java版本使用心得(转) 的全部内容, 来源链接: utcz.com/z/393769.html

回到顶部