Ich bin neu in Spark und ich versuche derzeit, ein neuronales Netzwerk mit der deeplearning4j API zu bauen. Das Training funktioniert gut, aber ich stoße auf Probleme bei der Auswertung. Ich erhalte die folgende FehlermeldungDeeplearning4j mit Funken: SparkDl4jMultiLayer Auswertung mit JavaRDD
18:16:16,206 ERROR ~ Exception in task 0.0 in stage 14.0 (TID 19)
java.lang.IllegalStateException: Network did not have same number of parameters as the broadcasted set parameters at org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction.call(EvaluateFlatMapFunction.java:75)
Ich kann den Grund für dieses Problem zu finden scheinen, und Informationen über Funken- und deeplearning4j ist spärlich. Diese Struktur habe ich im Wesentlichen aus diesem Beispiel https://github.com/deeplearning4j/dl4j-spark-cdh5-examples/blob/2de0324076fb422e2bdb926a095adb97c6d0e0ca/src/main/java/org/deeplearning4j/examples/mlp/IrisLocal.java übernommen.
Dies ist mein Code
public class DeepBeliefNetwork {
private JavaRDD<DataSet> trainSet;
private JavaRDD<DataSet> testSet;
private int inputSize;
private int numLab;
private int batchSize;
private int iterations;
private int seed;
private int listenerFreq;
MultiLayerConfiguration conf;
MultiLayerNetwork model;
SparkDl4jMultiLayer sparkmodel;
JavaSparkContext sc;
MLLibUtil mllibUtil = new MLLibUtil();
public DeepBeliefNetwork(JavaSparkContext sc, JavaRDD<DataSet> trainSet, JavaRDD<DataSet> testSet, int numLab,
int batchSize, int iterations, int seed, int listenerFreq) {
this.trainSet = trainSet;
this.testSet = testSet;
this.numLab = numLab;
this.batchSize = batchSize;
this.iterations = iterations;
this.seed = seed;
this.listenerFreq = listenerFreq;
this.inputSize = testSet.first().numInputs();
this.sc = sc;
}
public void build() {
System.out.println("input Size: " + inputSize);
System.out.println(trainSet.first().toString());
System.out.println(testSet.first().toString());
conf = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0).iterations(iterations).momentum(0.5)
.momentumAfter(Collections.singletonMap(3, 0.9))
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list(4)
.layer(0,
new RBM.Builder().nIn(inputSize).nOut(500).weightInit(WeightInit.XAVIER)
.lossFunction(LossFunction.RMSE_XENT).visibleUnit(RBM.VisibleUnit.BINARY)
.hiddenUnit(RBM.HiddenUnit.BINARY).build())
.layer(1,
new RBM.Builder().nIn(500).nOut(250).weightInit(WeightInit.XAVIER)
.lossFunction(LossFunction.RMSE_XENT).visibleUnit(RBM.VisibleUnit.BINARY)
.hiddenUnit(RBM.HiddenUnit.BINARY).build())
.layer(2,
new RBM.Builder().nIn(250).nOut(200).weightInit(WeightInit.XAVIER)
.lossFunction(LossFunction.RMSE_XENT).visibleUnit(RBM.VisibleUnit.BINARY)
.hiddenUnit(RBM.HiddenUnit.BINARY).build())
.layer(3, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation("softmax").nIn(200)
.nOut(numLab).build())
.pretrain(true).backprop(false).build();
}
public void trainModel() {
model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq)));
// Create Spark multi layer network from configuration
sparkmodel = new SparkDl4jMultiLayer(sc.sc(), model);
sparkmodel.fitDataSet(trainSet);
//Evaluation
Evaluation evaluation = sparkmodel.evaluate(testSet);
System.out.println(evaluation.stats());
Hat jemand Tipps, wie man meine JavaRDD zu behandeln? Ich glaube, dass das Problem darin liegt.
Vielen Dank!
EDIT1
Ich verwende deeplearning4j Version 0,4-rc.10, und Funken 1.5.0 Hier ist der Stack-Trace
11:03:53,088 ERROR ~ Exception in task 0.0 in stage 16.0 (TID 21 java.lang.IllegalStateException: Network did not have same number of parameters as the broadcasted set parameter
at org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction.call(EvaluateFlatMapFunction.java:75)
at org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction.call(EvaluateFlatMapFunction.java:41)
at org.apache.spark.api.java.JavaRDDLike$$anonfun$fn$4$1.apply(JavaRDDLike.scala:156)
at org.apache.spark.api.java.JavaRDDLike$$anonfun$fn$4$1.apply(JavaRDDLike.scala:156)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$17.apply(RDD.scala:706)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$17.apply(RDD.scala:706)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1280)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1268)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1267)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1267)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:697)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:697)
at scala.Option.foreach(Option.scala:236)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:697)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1493)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1455)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1444)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:567)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1813)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:1933)
at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1003)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:147)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:108)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:306)
at org.apache.spark.rdd.RDD.reduce(RDD.scala:985)
at org.apache.spark.api.java.JavaRDDLike$class.reduce(JavaRDDLike.scala:375)
at org.apache.spark.api.java.AbstractJavaRDDLike.reduce(JavaRDDLike.scala:47)
at org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.evaluate(SparkDl4jMultiLayer.java:629)
at org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.evaluate(SparkDl4jMultiLayer.java:607)
at org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.evaluate(SparkDl4jMultiLayer.java:597)
at deep.deepbeliefclassifier.DeepBeliefNetwork.trainModel(DeepBeliefNetwork.java:117)
at deep.deepbeliefclassifier.DataInput.main(DataInput.java:105)
Können Sie einen Stack-Trace veröffentlichen, von dem die Ausnahme kommt? – SpamBot
danke für die Antwort, gepostet in der Bearbeitung. – graffo
Könnten Sie zuerst versuchen, die neueste Version zu verwenden? Es ist 0.5.0 –