Ich experimentiere deeplearning4j
, indem Sie einige ihrer Tutorials folgen. Der Iris-Datensatz ist gut bekannt, und von Weka (entweder mit einem Random oder ein Mehrlagiges Perzeptron), kann ich leicht die F-Maßnahme erhält fast 1 (0,97 im folgende Beispiel) zu erreichen:Schreckliche Leistung auf Iris-Datensatz mit deeplearning4j
TP Rate FP Rate Precision Recall F-Measure MCC ROC Area PRC Area Class
1.000 0.000 1.000 1.000 1.000 1.000 1.000 1.000 Iris-setosa
0.960 0.020 0.960 0.960 0.960 0.940 0.996 0.993 Iris-versicolor
0.960 0.020 0.960 0.960 0.960 0.940 0.996 0.993 Iris-virginica Weighted
Avg. 0.973 0.013 0.973 0.973 0.973 0.960 0.998 0.995
I‘ m ist so erfolgreich, nicht mit deeplearning4j
:
Examples labeled as 0 classified by model as 0: 4 times
Examples labeled as 1 classified by model as 0: 12 times
Examples labeled as 2 classified by model as 0: 14 times
Warning: class 1 was never predicted by the model. This class was excluded from the average precision
Warning: class 2 was never predicted by the model. This class was excluded from the average precision
Accuracy: 0.1333 Precision: 0.1333 Recall: 0.3333 F1 Score: 0.1905
Hier ist der Code (in Scala) ich verwende:
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator
import org.deeplearning4j.eval.Evaluation
import org.deeplearning4j.nn.api.{Layer, OptimizationAlgorithm}
import org.deeplearning4j.nn.conf.{Updater, NeuralNetConfiguration}
import org.deeplearning4j.nn.conf.layers.{OutputLayer, RBM}
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.ui.weights.HistogramIterationListener
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.lossfunctions.LossFunctions
object Main extends App {
Nd4j.MAX_SLICES_TO_PRINT = -1
Nd4j.MAX_ELEMENTS_PER_SLICE = -1
Nd4j.ENFORCE_NUMERICAL_STABILITY = true
val inputNum = 4
var outputNum = 3
var numSamples = 150
var batchSize = 150
var iterations = 1000
var seed = 321
var listenerFreq = iterations/5
val learningRate = 1e-6
println("Load data....")
val iter = new IrisDataSetIterator(batchSize, numSamples)
val iris = iter.next()
iris.shuffle()
iris.normalizeZeroMeanZeroUnitVariance()
val testAndTrain = iris.splitTestAndTrain(0.80)
val train = testAndTrain.getTrain
val test = testAndTrain.getTest
println("Build model....")
val RMSE_XENT = LossFunctions.LossFunction.RMSE_XENT
val conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.learningRate(learningRate)
.l1(1e-1).regularization(true).l2(2e-4)
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
.useDropConnect(true)
.list(2)
.layer(0, new RBM.Builder(RBM.HiddenUnit.RECTIFIED, RBM.VisibleUnit.GAUSSIAN)
.nIn(inputNum).nOut(3).k(1).activation("relu").weightInit(WeightInit.XAVIER).lossFunction(RMSE_XENT)
.updater(Updater.ADAGRAD).dropOut(0.5)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nIn(3).nOut(outputNum).activation("softmax").build())
.build()
val model = new MultiLayerNetwork(conf)
model.init()
model.setListeners(new HistogramIterationListener(listenerFreq))
println("Train model....")
model.fit(train.getFeatureMatrix)
println("Evaluate model....")
val eval = new Evaluation(outputNum)
val output = model.output(test.getFeatureMatrix, Layer.TrainingMode.TEST)
(0 until output.rows()).foreach { i =>
val actual = train.getLabels.getRow(i).toString.trim()
val predicted = output.getRow(i).toString.trim()
println("actual " + actual + " vs predicted " + predicted)
}
eval.eval(test.getLabels, output)
println(eval.stats())
}
RBM nur gut auf bestimmte Verteilungen der Eingabedaten arbeiten (z.B Bilder und vielleicht Empfehlungen). Auch Ihr RBM ist sehr wenig ausdrucksfähig, Sie haben nur drei Einheiten, was wahrscheinlich nicht ausreicht, um etwas aus der Eingabeverteilung zu erfassen. Hier ist eine interessante Erklärung: http://stackoverflow.com/questions/25641485/gaussian-rbm-fails-on-a-trivial-example Ich würde wahrscheinlich mit voll verbundenen Schichten wie in Weka beginnen und von dort aus gehen. –
Hallo Hugo, dl4j Entwickler schlagen vor, Fragen zu gitter von der offiziellen Website zu stellen. – 404pio