2017-06-20 5 views
0

Ich versuche, eine Vorhersage auf einem Modell laufen auszuführen, die ich mit „Feintuning AlexNet mit TensorFlow“ https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.htmlnicht Tensorflow Vorhersage in Java API

ich mit tf.saved_model.builder.SavedModelBuilder in Python gespeichert das Modell trainiert und loaded das Modell in Java mit SavedModelBundle.load. der Hauptteil des Codes ist:

SavedModelBundle smb = SavedModelBundle.load(path, "serve"); 
    Session s = smb.session(); 
    byte[] imageBytes = readAllBytesOrExit(Paths.get(path)); 
    Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes); 
    Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0); 
    final long[] rshape = result.shape(); 
    if (result.numDimensions() != 2 || rshape[0] != 1) { 
     throw new RuntimeException(
       String.format(
         "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", 
         Arrays.toString(rshape))); 
    } 
    int nlabels = (int) rshape[1]; 
    float [] a = result.copyTo(new float[1][nlabels])[0];` 

ich diese Ausnahme erhalten:

Exception in thread "main" java.lang.IllegalArgumentException: Sie einen Wert für Platzhalter Tensor füttern müssen ' Placeholder_1 'mit dtype float [[Knoten: Placeholder_1 = Placeholder_output_shapes = [[]], dtype = DT_FLOAT, shape = [], _device = "/ job: localhost/replik: 0/task: 0/cpu: 0"]]

Ich sah, dass der Code oben arbeitete für einige Leute, und ich kann nicht herausfinden, was hier fehlt. Beachten Sie, dass das Netz mit den Knoten "input_tensor" und "fc8/fc8" vertraut ist, da es nicht sagt, dass es sie nicht kennt.

Antwort

1

Aus der Fehlermeldung scheint es, dass das von Ihnen verwendete Modell einen anderen Wert erhält (dessen Knotenname in der Grafik Placeholder_1 ist und der erwartete Typ ein float-skalarer Tensor ist).

Es scheint, dass Sie Ihr Modell angepasst haben (im Gegensatz zum Folgen des Artikels, mit dem Sie verbal verbunden sind). Das heißt, der Artikel zeigt mehrere Platzhalter, die gefüttert werden müssen, eine für das Bild und eine weitere, um den Ausfall zu kontrollieren. Definiert in dem Artikel als:

keep_prob = tf.placeholder(tf.float32) 

Und der Wert dieses Platzhalters muss zugeführt werden. Wenn Sie Inferenz machen, dann möchten Sie keep_prob auf 1.0 setzen. Etwas wie:

Tensor keep_prob = Tensor.create(1.0f); 
Tensor result = s.runner() 
    .feed("input_tensor", image) 
    .feed("Placeholder_1", keep_prob) 
    .fetch("fc8/fc8") 
    .run() 
    .get(0); 

Hoffe, dass hilft.

Verwandte Themen