3

I wie unten ein neuronales Netzwerk mit encog Bibliothek implementiert haben,Encog Neural-Netzwerk-Validierung/Testing

MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL); 

    final Propagation train = new Backpropagation(network, trainingSet); 
    int epoch = 1; 
    do { 
     train.iteration(); 
     System.out.println("Epoch #" + epoch + 
       " Error:" + train.getError()); 
       epoch++; 

    } while (train.getError() < 0.009); 

    double e = network.calculateError(trainingSet); 
    System.out.println("Network trained to error :" + e); 
    System.out.println("Saving Network"); 


    EncogDirectoryPersistence.saveObject(new File(FILENAME), network); 
} 


public void loadAndEvaluate(){ 
    System.out.println("Loading Network"); 
    BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(new File(FILENAME)); 

    BasicMLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT,XOR_IDEAL); 

    double e = network.calculateError(trainingSet); 

    System.out.println("Loaded network's error is (should be the same as above):" + e); 

} 

Das den Fehler ausgibt. Aber ich möchte dies mit benutzerdefinierten Daten testen und überprüfen, ob die Ausgabe für eine Reihe von Daten ist ein

Antwort

0

Ich sehe, dass Sie eine der Persistenz Beispiel folgen. Um Ausgaben für einige Eingaben zu erhalten, verwenden Sie die Funktion "compute". Als Beispiel:

double[] output = new double[1]; 
    network.compute(new double[]{1.0, 1.0}, output); 
    System.out.println("Network output: " + output[0] + " (should be close to 0.0)"); 

Here's die Java-Benutzerhandbuch. Es ist sehr hilfreich.

+0

Ich habe die folgenden Daten verwendet, um das neuronale Netzwerk zu trainieren und zu testen, aber die Ausgabe ist nicht konstant. public static Doppel train_INPUT [] [] = {{0.0, 0.0}, { 1,0, 0,0}, { 0,0, 1,0}, { 1,0, 1,0} \t \t \t}; \t öffentlicher statischer Doppeltester [] = {1.0, 0.0} ;; \t öffentlich statisch double train_IDEAL [] [] = {{0.0}, {1.0}, {1.0}, {0.0}}; – jee1tha

+0

Ich merke gerade, dass Ihre Schleifenbedingung train.getError() <0.009 ist. Sollte das nicht train.getError()> 0.009 sein? Ich habe ein 2-3-1-Netzwerk zum Testen verwendet und konnte einen Fehler von 0,008 erreichen. (Siehe https://gist.github.com/frankibem/94e588cb2d8ccda2af675f9bde3e25fa und hier: https://gist.github.com/frankibem/eeaa066595e6ba791dfc6cea558f92ca –