2012-05-29 8 views
13

Ich versuche, die Java-Bindings für LIBSVM zu verwenden:LIBSVM Java-Implementierung

http://www.csie.ntu.edu.tw/~cjlin/libsvm/ 

ich eine 'trivial' Beispiel implementiert haben, die leicht linear trennbar in y ist. Die Daten sind definiert als:

double[][] train = new double[1000][]; 
double[][] test = new double[10][]; 

for (int i = 0; i < train.length; i++){ 
    if (i+1 > (train.length/2)){  // 50% positive 
     double[] vals = {1,0,i+i}; 
     train[i] = vals; 
    } else { 
     double[] vals = {0,0,i-i-i-2}; // 50% negative 
     train[i] = vals; 
    }   
} 

Wo das erste 'Feature' ist die Klasse und das Training-Set ist ähnlich definiert.

Um das Modell zu trainieren:

private svm_model svmTrain() { 
    svm_problem prob = new svm_problem(); 
    int dataCount = train.length; 
    prob.y = new double[dataCount]; 
    prob.l = dataCount; 
    prob.x = new svm_node[dataCount][];  

    for (int i = 0; i < dataCount; i++){    
     double[] features = train[i]; 
     prob.x[i] = new svm_node[features.length-1]; 
     for (int j = 1; j < features.length; j++){ 
      svm_node node = new svm_node(); 
      node.index = j; 
      node.value = features[j]; 
      prob.x[i][j-1] = node; 
     }   
     prob.y[i] = features[0]; 
    }    

    svm_parameter param = new svm_parameter(); 
    param.probability = 1; 
    param.gamma = 0.5; 
    param.nu = 0.5; 
    param.C = 1; 
    param.svm_type = svm_parameter.C_SVC; 
    param.kernel_type = svm_parameter.LINEAR;  
    param.cache_size = 20000; 
    param.eps = 0.001;  

    svm_model model = svm.svm_train(prob, param); 

    return model; 
} 

dann das Modell, das ich verwende, um zu bewerten:

public int evaluate(double[] features) { 
    svm_node node = new svm_node(); 
    for (int i = 1; i < features.length; i++){ 
     node.index = i; 
     node.value = features[i]; 
    } 
    svm_node[] nodes = new svm_node[1]; 
    nodes[0] = node; 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(_model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(_model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return (int)v; 
} 

Wenn das übergebene Array ein Punkt aus dem Test-Set ist.

Die Ergebnisse sind immer wiederkehrende Klasse 0. Mit dem genauen Wesen Ergebnisse:

(0:0.9882998314585194)(1:0.011700168541480586)(Actual:0.0 Prediction:0.0) 
(0:0.9883952943701599)(1:0.011604705629839989)(Actual:0.0 Prediction:0.0) 
(0:0.9884899803606306)(1:0.011510019639369528)(Actual:0.0 Prediction:0.0) 
(0:0.9885838957058696)(1:0.011416104294130458)(Actual:0.0 Prediction:0.0) 
(0:0.9886770466322342)(1:0.011322953367765776)(Actual:0.0 Prediction:0.0) 
(0:0.9870913229268679)(1:0.012908677073132284)(Actual:1.0 Prediction:0.0) 
(0:0.9868781382588805)(1:0.013121861741119505)(Actual:1.0 Prediction:0.0) 
(0:0.986661444476744)(1:0.013338555523255982)(Actual:1.0 Prediction:0.0) 
(0:0.9864411843906802)(1:0.013558815609319848)(Actual:1.0 Prediction:0.0) 
(0:0.9862172999068877)(1:0.013782700093112332)(Actual:1.0 Prediction:0.0) 

Kann jemand erklären, warum dieser Klassifikator nicht funktioniert? Gibt es einen Schritt, den ich vermasselt habe, oder einen Schritt, den ich vermisse?

Dank

Antwort

13

Es scheint mir, dass Ihre Evaluierungsmethode falsch ist. Sollte so etwas wie diese:

public double evaluate(double[] features, svm_model model) 
{ 
    svm_node[] nodes = new svm_node[features.length-1]; 
    for (int i = 1; i < features.length; i++) 
    { 
     svm_node node = new svm_node(); 
     node.index = i; 
     node.value = features[i]; 

     nodes[i-1] = node; 
    } 

    int totalClasses = 2;  
    int[] labels = new int[totalClasses]; 
    svm.svm_get_labels(model,labels); 

    double[] prob_estimates = new double[totalClasses]; 
    double v = svm.svm_predict_probability(model, nodes, prob_estimates); 

    for (int i = 0; i < totalClasses; i++){ 
     System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); 
    } 
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");    

    return v; 
} 
+4

Können Sie erklären, was der Fehler im Fragencode ist? Ich habe Probleme beim Aufspüren des Fehlers! :( – Daniel

1

habe ich eine etwas Refactoring-Version von Java-Implementierung des LibSVM, die Sie einfacher zu bedienen sein können: https://github.com/syeedibnfaiz/libsvm-java-kernel. Werfen Sie einen Blick auf die Demo.java-Klasse, um zu sehen, wie Sie sie verwenden können.

2

Hier ist eine Überarbeitung des obigen Beispiels, die ich getestet hat Daten aus dem folgenden R Code verwendet: http://cbio.ensmp.fr/~jvert/svn/tutorials/practical/svmbasic/svmbasic_notes.pdf

import libsvm.*; 

public class libsvmTest { 

    public static void main(String [] args) { 

     double[][] xtrain = ... 
     double[][] xtest = ... 
     double[][] ytrain = ... 
     double[][] ytest = ... 

     svm_model m = svmTrain(xtrain,ytrain); 

     double[] ypred = svmPredict(xtest, m); 

     for (int i = 0; i < xtest.length; i++){ 
      System.out.println("(Actual:" + ytest[i][0] + " Prediction:" + ypred[i] + ")"); 
     } 

    } 

    static svm_model svmTrain(double[][] xtrain, double[][] ytrain) { 
     svm_problem prob = new svm_problem(); 
     int recordCount = xtrain.length; 
     int featureCount = xtrain[0].length; 
     prob.y = new double[recordCount]; 
     prob.l = recordCount; 
     prob.x = new svm_node[recordCount][featureCount];  

     for (int i = 0; i < recordCount; i++){    
      double[] features = xtrain[i]; 
      prob.x[i] = new svm_node[features.length]; 
      for (int j = 0; j < features.length; j++){ 
       svm_node node = new svm_node(); 
       node.index = j; 
       node.value = features[j]; 
       prob.x[i][j] = node; 
      }   
      prob.y[i] = ytrain[i][0]; 
     }    

     svm_parameter param = new svm_parameter(); 
     param.probability = 1; 
     param.gamma = 0.5; 
     param.nu = 0.5; 
     param.C = 100; 
     param.svm_type = svm_parameter.C_SVC; 
     param.kernel_type = svm_parameter.LINEAR;  
     param.cache_size = 20000; 
     param.eps = 0.001;  

     svm_model model = svm.svm_train(prob, param); 

     return model; 
    } 

    static double[] svmPredict(double[][] xtest, svm_model model) 
    { 

     double[] yPred = new double[xtest.length]; 

     for(int k = 0; k < xtest.length; k++){ 

     double[] fVector = xtest[k]; 

     svm_node[] nodes = new svm_node[fVector.length]; 
     for (int i = 0; i < fVector.length; i++) 
     { 
      svm_node node = new svm_node(); 
      node.index = i; 
      node.value = fVector[i]; 
      nodes[i] = node; 
     } 

     int totalClasses = 2;  
     int[] labels = new int[totalClasses]; 
     svm.svm_get_labels(model,labels); 

     double[] prob_estimates = new double[totalClasses]; 
     yPred[k] = svm.svm_predict_probability(model, nodes, prob_estimates); 

     } 

     return yPred; 
    } 


} 

Hier ist die Ausgabe:

(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
(Actual:-1.0 Prediction:-1.0) 
+0

Vielen Dank für den hilfreichen Code. Warum hast du param.probability = 1 ;? Und zweitens, weißt du, wie man das Gewicht einstellen kann, wenn man unausgeglichene Klassen habe? Ich meine das Gewicht, mit dem der C-Parameter – machinery

+0

Verliert nicht den Umfang, wenn Sie svm.svm_predict_probability() aufrufen? – user1040535

+0

Dies ist einfach ein Beitrag, der Ihnen hilft, mit LIBSVM zu beginnen, und von dort aus den Benutzer zu bestimmen, was in Übereinstimmung mit dem Problem funktioniert. Für Fragen dazu, schlage ich vor, Sie besuchen die Website der Betreuer dieses Pakets: https://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#/Q06:_Probability_outputs –

Verwandte Themen