2016-08-13 12 views
1

Ich trainiere und speichern Modell aus CSV-Datei. Alles ist ok für diesen ersten Schritt. Nach dem Speichern des Modells versuche ich, das gespeicherte Modell mit neuen Daten zu laden und zu verwenden, aber es funktioniert nicht.
Reloaded Spark-Modell scheint nicht zu arbeiten

Was ist das Problem?

Trainings Java Datei

SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh"); 
      SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate(); 
      JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classifications1.csv").javaRDD() 
         .map(new Function<String, Cobj>() { 
           @Override 
           public Cobj call(String line) throws Exception { 
            String[] parts = line.split(","); 
            Cobj c = new Cobj(); 
            c.setClassName(parts[1].trim()); 
            c.setProductName(parts[0].trim());         
            return c; 
           } 
         }); 

      Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class);       

      //StringIndexer 
      StringIndexer classIndexer = new StringIndexer() 
         .setHandleInvalid("skip") 
         .setInputCol("className") 
         .setOutputCol("label"); 
      StringIndexerModel classIndexerModel=classIndexer.fit(mainDataset); 

      //Tokenizer 
      Tokenizer tokenizer = new Tokenizer()         
         .setInputCol("productName")      
         .setOutputCol("words");    

      //HashingTF 
      HashingTF hashingTF = new HashingTF() 
        .setInputCol(tokenizer.getOutputCol()) 
        .setOutputCol("features"); 

      DecisionTreeClassifier decisionClassifier = new DecisionTreeClassifier()      
        .setLabelCol("label") 
        .setFeaturesCol("features"); 

      Pipeline pipeline = new Pipeline() 
        .setStages(new PipelineStage[] {classIndexer,tokenizer,hashingTF,decisionClassifier}); 

     Dataset<Row>[] splits = mainDataset.randomSplit(new double[]{0.8, 0.2}); 
     Dataset<Row> train = splits[0]; 
     Dataset<Row> test = splits[1]; 

     PipelineModel pipelineModel = pipeline.fit(train); 

     Dataset<Row> result = pipelineModel.transform(test);   
     pipelineModel.write().overwrite().save(savePath+"DecisionTreeClassificationModel"); 

     IndexToString labelConverter = new IndexToString() 
        .setInputCol("prediction") 
        .setOutputCol("PredictedClassName")      
        .setLabels(classIndexerModel.labels()); 
     result=labelConverter.transform(result); 
     result.show(num,false); 
     Dataset<Row> predictionAndLabels = result.select("prediction", "label"); 
     MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() 
     .setMetricName("accuracy"); 
     System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels)); 

Ausgang:

+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+ 
 
|className     |productName         |label|words             |features                       |rawPrediction  |probability   |prediction|PredictedClassName  | 
 
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+ 
 
|Apple iPhone 6S 16GB  |Apple IPHONE 6S 16GB SGAY Telefon   |2.0 |[apple, iphone, 6s, 16gb, sgay, telefon]    |(262144,[27536,56559,169565,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0])      |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0  |Apple iPhone 6S Plus 64GB | 
 
|Apple iPhone 6S 16GB  |Apple iPhone 6S 16 GB Space Gray MKQJ2TU/A |2.0 |[apple, iphone, 6s, 16, gb, space, gray, mkqj2tu/a] |(262144,[10879,56559,95900,139131,175329,175778,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0  |Apple iPhone 6S Plus 64GB | 
 
|Apple iPhone 6S 16GB  |iPhone 6s 16GB        |2.0 |[iphone, 6s, 16gb]         |(262144,[27536,56559,210029],[1.0,1.0,1.0])              |[0.0,0.0,6.0,0.0,0.0]|[0.0,0.0,1.0,0.0,0.0]|2.0  |Apple iPhone 6S Plus 64GB | 
 
|Apple iPhone 6S Plus 128GB|Apple IPHONE 6S PLUS 128GB SG Telefon  |4.0 |[apple, iphone, 6s, plus, 128gb, sg, telefon]   |(262144,[56559,99916,137263,175839,200223,210029,242621],[1.0,1.0,1.0,1.0,1.0,1.0,1.0])   |[0.0,0.0,0.0,0.0,2.0]|[0.0,0.0,0.0,0.0,1.0]|4.0  |Apple iPhone 6S Plus 128GB| 
 
|Apple iPhone 6S Plus 16GB |Iphone 6S Plus 16GB SpaceGray - Apple Türkiye|1.0 |[iphone, 6s, plus, 16gb, spacegray, -, apple, türkiye]|(262144,[27536,45531,46750,56559,59104,99916,200223,210029],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]) |[0.0,5.0,0.0,0.0,0.0]|[0.0,1.0,0.0,0.0,0.0]|1.0  |Apple iPhone 6S Plus 16GB | 
 
+--------------------------+---------------------------------------------+-----+------------------------------------------------------+-------------------------------------------------------------------------------------------------+---------------------+---------------------+----------+--------------------------+ 
 
Accuracy = 1.0

laden Java Datei

SparkConf sconf = new SparkConf().setMaster("local[*]").setAppName("Test").set("spark.sql.warehouse.dir","D:/Temp/wh"); 
      SparkSession spark = SparkSession.builder().appName("Java Spark").config(sconf).getOrCreate(); 
      JavaRDD<Cobj> cRDD = spark.read().textFile("file:///C:/Temp/classificationsTest.csv").javaRDD() 
         .map(new Function<String, Cobj>() { 
           @Override 
           public Cobj call(String line) throws Exception { 
            String[] parts = line.split(","); 
            Cobj c = new Cobj(); 
            c.setClassName("?"); 
            c.setProductName(parts[0].trim()); 
            return c; 
           } 
         }); 

      Dataset<Row> mainDataset = spark.createDataFrame(cRDD, Cobj.class); 
      mainDataset.show(100,false); 

      PipelineModel pipelineModel = PipelineModel.load(savePath+"DecisionTreeClassificationModel"); 

      Dataset<Row> result = pipelineModel.transform(mainDataset); 

      result.show(100,false); 

Ausgang:

+---------+-----------+-----+-----+--------+-------------+-----------+----------+ 
 
|className|productName|label|words|features|rawPrediction|probability|prediction| 
 
+---------+-----------+-----+-----+--------+-------------+-----------+----------+ 
 
+---------+-----------+-----+-----+--------+-------------+-----------+----------+

Antwort

0

I entfernt StringIndexer von Pipeline und gespeichert als "StringIndexer". In zweiter Datei; Nachdem die Pipeline geladen wurde, habe ich StringIndexer geladen, um sie für vorhergesagte Labels zu konvertieren.

Verwandte Themen