2016-09-05 3 views
0

UPDATE: Ich habe versucht, den folgenden Weg zu verwenden, um Konfidenz-Ergebnisse zu generieren, aber es gibt mir eine Ausnahme. Ich verwende den folgenden Code-Schnipsel:So erhalten Sie Konfidenz-Ergebnisse von Spark MLLib Logistische Regression in Java

double point = BLAS.dot(logisticregressionmodel.weights(), datavector); 
double confScore = 1.0/(1.0 + Math.exp(-point)); 

Und die Ausnahme, die ich erhalten:

Caused by: java.lang.IllegalArgumentException: requirement failed: BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes: x.size = 198, y.size = 18 
    at scala.Predef$.require(Predef.scala:233) 
    at org.apache.spark.mllib.linalg.BLAS$.dot(BLAS.scala:99) 
    at org.apache.spark.mllib.linalg.BLAS.dot(BLAS.scala) 

Können Sie mir bitte helfen? Es scheint, als ob der Gewichtungsvektor mehr Elemente (198) als der Datenvektor (Ich erzeuge 18 Merkmale) hat. Sie müssen in der Funktion dot() gleich lang sein.

Ich versuche ein Programm in Java zu implementieren, um aus einem vorhandenen Dataset zu trainieren und einen neuen Dataset mit dem in Spark MLLib (1.5.0) verfügbaren logistischen Regressionsalgorithmus vorherzusagen. Meine Zug- und Vorhersageprogramme sind wie folgt und ich verwende eine Multiklassenimplementierung. Das Problem ist, wenn ich eine model.predict(vector) mache (beachten Sie die lrmodel.predict() im Vorhersageprogramm), bekomme ich die vorhergesagte Bezeichnung. Aber was, wenn ich ein Vertrauensergebnis benötige? Wie bekomme ich das? Ich habe die API durchlaufen und konnte keine bestimmte API finden, die den Konfidenzwert angibt. Kann mir bitte jemand helfen?

Zug Program (erzeugt eine .model- Datei)

public static void main(final String[] args) throws Exception { 
     JavaSparkContext jsc = null; 
     int salesIndex = 1; 

     try { 
      ... 
     SparkConf sparkConf = 
        new SparkConf().setAppName("Hackathon Train").setMaster(
          sparkMaster); 
      jsc = new JavaSparkContext(sparkConf); 
      ... 

      JavaRDD<String> trainRDD = jsc.textFile(basePath + "old-leads.csv").cache(); 

      final String firstRdd = trainRDD.first().trim(); 
      JavaRDD<String> tempRddFilter = 
        trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() { 
         private static final long serialVersionUID = 
           11111111111111111L; 

         public Boolean call(final String arg0) { 
          return !arg0.trim().equalsIgnoreCase(firstRdd); 
         } 
        }); 

      ... 
      JavaRDD<String> featureRDD = 
        tempRddFilter 
          .map(new org.apache.spark.api.java.function.Function() { 
           private static final long serialVersionUID = 
             6948900080648474074L; 

           public Object call(final Object arg0) 
             throws Exception { 
            ... 
            StringBuilder featureSet = 
              new StringBuilder(); 
            ... 
             featureSet.append(i - 2); 
             featureSet.append(COLON); 
             featureSet.append(strVal); 
             featureSet.append(SPACE); 
            } 

            return featureSet.toString().trim(); 
           } 
          }); 

      List<String> featureList = featureRDD.collect(); 
      String featureOutput = StringUtils.join(featureList, NEW_LINE); 
      String filePath = basePath + "lr.arff"; 
      FileUtils.writeStringToFile(new File(filePath), featureOutput, 
        "UTF-8"); 

      JavaRDD<LabeledPoint> trainingData = 
        MLUtils.loadLibSVMFile(jsc.sc(), filePath).toJavaRDD().cache(); 

      final LogisticRegressionModel model = 
        new LogisticRegressionWithLBFGS().setNumClasses(18).run(
          trainingData.rdd()); 
      ByteArrayOutputStream baos = new ByteArrayOutputStream(); 
      ObjectOutputStream oos = new ObjectOutputStream(baos); 
      oos.writeObject(model); 
      oos.flush(); 
      oos.close(); 
      FileUtils.writeByteArrayToFile(new File(basePath + "lr.model"), 
        baos.toByteArray()); 
      baos.close(); 

     } catch (Exception e) { 
      e.printStackTrace(); 
     } finally { 
      if (jsc != null) { 
       jsc.close(); 
      } 
     } 

Programm Predict (mit dem lr.model aus dem Zug-Programm generiert)

public static void main(final String[] args) throws Exception { 
     JavaSparkContext jsc = null; 
     int salesIndex = 1; 
     try { 
      ... 
     SparkConf sparkConf = 
        new SparkConf().setAppName("Hackathon Predict").setMaster(sparkMaster); 
      jsc = new JavaSparkContext(sparkConf); 

      ObjectInputStream objectInputStream = 
        new ObjectInputStream(new FileInputStream(basePath 
          + "lr.model")); 
      LogisticRegressionModel lrmodel = 
        (LogisticRegressionModel) objectInputStream.readObject(); 
      objectInputStream.close(); 

      ... 

      JavaRDD<String> trainRDD = jsc.textFile(basePath + "new-leads.csv").cache(); 

      final String firstRdd = trainRDD.first().trim(); 
      JavaRDD<String> tempRddFilter = 
        trainRDD.filter(new org.apache.spark.api.java.function.Function<String, Boolean>() { 
         private static final long serialVersionUID = 
           11111111111111111L; 

         public Boolean call(final String arg0) { 
          return !arg0.trim().equalsIgnoreCase(firstRdd); 
         } 
        }); 

      ... 
      final Broadcast<LogisticRegressionModel> broadcastModel = 
        jsc.broadcast(lrmodel); 

      JavaRDD<String> featureRDD = 
        tempRddFilter 
          .map(new org.apache.spark.api.java.function.Function() { 
           private static final long serialVersionUID = 
             6948900080648474074L; 

           public Object call(final Object arg0) 
             throws Exception { 
            ... 
            LogisticRegressionModel lrModel = 
              broadcastModel.value(); 
            String row = ((String) arg0); 
            String[] featureSetArray = 
              row.split(CSV_SPLITTER); 
            ... 
            final Vector vector = 
              Vectors.dense(doubleArr); 
            double score = lrModel.predict(vector); 
            ... 
            return csvString; 
           } 
          }); 

      String outputContent = 
        featureRDD 
          .reduce(new org.apache.spark.api.java.function.Function2() { 

           private static final long serialVersionUID = 
             1212970144641935082L; 

           public Object call(Object arg0, Object arg1) 
             throws Exception { 
            ... 
           } 

          }); 
      ... 
      FileUtils.writeStringToFile(new File(basePath 
        + "predicted-sales-data.csv"), sb.toString()); 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } finally { 
      if (jsc != null) { 
       jsc.close(); 
      } 
     } 
    } 
} 

Antwort

0

Nach vielen Versuchen habe ich es schließlich geschafft, eine benutzerdefinierte Funktion zu schreiben, um Konfidenzwerte zu erzeugen. Es ist überhaupt nicht perfekt, aber funktioniert für mich für jetzt!

private static double getConfidenceScore(
      final LogisticRegressionModel lrModel, final Vector vector) { 
     /* Approach to get confidence scores starts */ 
     Vector weights = lrModel.weights(); 
     int numClasses = lrModel.numClasses(); 
     int dataWithBiasSize = weights.size()/(numClasses - 1); 
     boolean withBias = (vector.size() + 1) == dataWithBiasSize; 
     double maxMargin = 0.0; 
     double margin = 0.0; 
     for (int j = 0; j < (numClasses - 1); j++) { 
      margin = 0.0; 
      for (int k = 0; k < vector.size(); k++) { 
       double value = vector.toArray()[k]; 
       if (value != 0.0) { 
        margin += value 
          * weights.toArray()[(j * dataWithBiasSize) + k]; 
       } 
      } 
      if (withBias) { 
       margin += weights.toArray()[(j * dataWithBiasSize) 
         + vector.size()]; 
      } 
      if (margin > maxMargin) { 
       maxMargin = margin; 
      } 
     } 
     double conf = 1.0/(1.0 + Math.exp(-maxMargin)); 
     DecimalFormat twoDForm = new DecimalFormat("#.##"); 
     double confidenceScore = Double.valueOf(twoDForm.format(conf * 100)); 
     /* Approach to get confidence scores ends */ 
     return confidenceScore; 
    } 
0

der Tat tut es nicht scheinen möglich zu sein. Wenn Sie den Quellcode betrachten, können Sie ihn wahrscheinlich erweitern, um diese Wahrscheinlichkeiten zurückzugeben.

https://github.com/apache/spark/blob/branch-1.5/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

if (numClasses == 2) { 
    val margin = dot(weightMatrix, dataMatrix) + intercept 
    val score = 1.0/(1.0 + math.exp(-margin)) 
    threshold match { 
    case Some(t) => if (score > t) 1.0 else 0.0 
    case None => score 
    } 

Ich hoffe, es zu beginnen kann dabei helfen, eine Abhilfe zu finden.

+0

Können Sie bitte ein Beispiel anführen? Ich habe LogisticRegression.java aus der Spark-Dokumentation gelesen und konnte diese Methode nicht finden. – ArinCool

+0

Ich kann die Funktionen ** raw2probabilityInPlace ** und ** raw2prediction ** nicht finden. Können Sie bitte helfen? – ArinCool

+0

Es ist in der Klasse org.apache.spark.ml.classificationLogisticRegressionModel. Wenn es einfacher ist, können Sie auch eine Kopie mit einem anderen Namen erstellen und diese Funktionen veröffentlichen. –

Verwandte Themen