Ich versuche, ein XOR-Netzwerk mit deeplearning4j zu trainieren, aber ich glaube, ich habe nicht wirklich verstanden, wie man den Datensatz benutzt.Deeplearning4j xor Beispiel
Ich wollte ein NN mit zwei Eingängen, zwei versteckten Neuronen und einem Ausgabe-Neuron erstellen.
hier ist das, was ich habe:
package org.deeplearning4j.examples.xor;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
public class XorExample {
public static void main(String[] args) {
INDArray input = Nd4j.zeros(4, 2);
INDArray labels = Nd4j.zeros(4, 1);
input.putScalar(new int[] { 0, 0 }, 0);
input.putScalar(new int[] { 0, 1 }, 0);
input.putScalar(new int[] { 1, 0 }, 1);
input.putScalar(new int[] { 1, 1 }, 0);
input.putScalar(new int[] { 2, 0 }, 0);
input.putScalar(new int[] { 2, 1 }, 1);
input.putScalar(new int[] { 3, 0 }, 1);
input.putScalar(new int[] { 3, 1 }, 1);
labels.putScalar(new int[] { 0, 0 }, 0);
labels.putScalar(new int[] { 1, 0 }, 1);
labels.putScalar(new int[] { 2, 0 }, 1);
labels.putScalar(new int[] { 3, 0 }, 0);
DataSet ds = new DataSet(input,labels);
//Set up network configuration:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.learningRate(0.1)
.list(2)
.layer(0, new GravesLSTM.Builder().nIn(2).nOut(2)
.updater(Updater.RMSPROP)
.activation("tanh").weightInit(WeightInit.DISTRIBUTION)
.dist(new UniformDistribution(-0.08, 0.08)).build())
.layer(1, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax") //MCXENT + softmax for classification
.updater(Updater.RMSPROP)
.nIn(2).nOut(1).weightInit(WeightInit.DISTRIBUTION)
.dist(new UniformDistribution(-0.08, 0.08)).build())
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(1));
//Print the number of parameters in the network (and for each layer)
Layer[] layers = net.getLayers();
int totalNumParams = 0;
for(int i=0; i<layers.length; i++){
int nParams = layers[i].numParams();
System.out.println("Number of parameters in layer " + i + ": " + nParams);
totalNumParams += nParams;
}
System.out.println("Total number of network parameters: " + totalNumParams);
net.fit(ds);
Evaluation eval = new Evaluation(3);
INDArray output = net.output(ds.getFeatureMatrix());
eval.eval(ds.getLabels(), output);
System.out.println(eval.stats());
}
}
sieht die Ausgabe wie die
Mär 20, 2016 7:03:06 PM com.github.fommil.jni.JniLoader liberalLoad
INFORMATION: successfully loaded C:\Users\LuckyPC\AppData\Local\Temp\jniloader5209513403648831212netlib-native_system-win-x86_64.dll
Number of parameters in layer 0: 46
Number of parameters in layer 1: 3
Total number of network parameters: 49
o.d.o.s.BaseOptimizer - Objective function automatically set to minimize. Set stepFunction in neural net configuration to change default settings.
o.d.o.l.ScoreIterationListener - Score at iteration 0 is 0.6931495070457458
Exception in thread "main" java.lang.IllegalArgumentException: Unable to getFloat row of non 2d matrix
at org.nd4j.linalg.api.ndarray.BaseNDArray.getRow(BaseNDArray.java:3640)
at org.deeplearning4j.eval.Evaluation.eval(Evaluation.java:107)
at org.deeplearning4j.examples.xor.XorExample.main(XorExample.java:80)
Die 'XorExample' ist [hier] (https://github.com/deeplearning4j/dl4j-examples/blob/c2b0a53fe1b911bb94c8acfa03e60c5dd3b2a1c5/dl4j-examples/src/main/java/org/deeplearning4j/examples/feedforward/xor/XorExample .Java). – beatngu13