2016-09-27 3 views
0

Ich brauche den Zwischenwert eines Tensors in tf.while_loop(), aber es gibt mir nur den letzten zurückgegebenen Wert.TensorFlow: Wie erhält man den Zwischenwert einer Variablen in tf.while_loop()?

Zum Beispiel habe ich eine Variable x, die 3 Seiten hat und ihre Dimension ist 3 * 2 * 4. Jetzt möchte ich jede Seite einmal abrufen und die Gesamtsumme, die Seitensumme, den Mittelwert, den Höchst- und den Mindestwert jeder Seite berechnen. Dann definiere ich die Bedingung und die Körperfunktion und möchte tf.while_loop() verwenden, um die benötigten Ergebnisse zu berechnen. Der Quellcode ist wie folgt.

import tensorflow as tf  
x = tf.constant([[[41, 8, 48, 82], 
         [9, 56, 67, 23]], 
        [[95, 89, 44, 54], 
         [11, 33, 29, 1]], 
        [[34, 9, 5, 70], 
         [14, 35, 18, 17]]], dtype=tf.int32) 

def cond(out, count, x): 
    return count < 3 

def body(out, count, x): 
    outTemp = tf.slice(x, [count, 0, 0], [1, -1, -1]) 
    count += 1 
    outPack = tf.unpack(out) 
    outPack[0] += tf.reduce_sum(outTemp) 
    outPack[1] = tf.reduce_sum(outTemp) 
    outPack[2] = tf.reduce_mean(outTemp) 
    outPack[3] = tf.reduce_max(outTemp) 
    outPack[4] = tf.reduce_min(outTemp) 
    out = tf.pack(outPack) 
    return out, count, x 

out = tf.Variable(tf.constant([0, 0, 0, 0, 0])) # total sum, page sum, mean, max, min 
count = tf.Variable(tf.constant(0)) 
result = tf.while_loop(cond, body, [out, count, x]) 

init = tf.initialize_all_variables() 
with tf.Session() as sess: 
    sess.run(init) 
    print(sess.run(x)) 
    print(sess.run(result)[0]) 

Wenn ich das Programm laufen, es gibt mir nur den Rückgabewert der letzten Zeit und ich kann nur die Ergebnisse der letzten Seite bekommen. Die Frage ist also, wie bekomme ich die Ergebnisse jeder Seite und wie bekomme ich den Zwischenwert von tf.while_loop()?

Vielen Dank.

Antwort

0

Um den "Zwischenwert" einer Variablen zu erhalten, können Sie einfach die Funktion tf.Print verwenden, die wirklich eine Identitätsoperation ist mit dem Nebeneffekt, eine relevante Nachricht bei der Auswertung der oben genannten Variablen zu drucken.

Als Beispiel

x = tf.Print(x, [x], "Value of x is: ") 

Kann in jeder Linie platziert werden, wo Sie den Wert wollen gemeldet werden.

+0

Vielen Dank für Ihre Antwort. Ich möchte wissen, wo wirst du die tf.Print() ablegen? In der Tat, wenn ich es in die Funktion body() einfügen, funktioniert es nicht und jedes Ergebnis wird gedruckt. Da die Funktion tf.while_loop() auf besondere Weise ausgeführt wurde. – Kongsea

+0

Das ist seltsam. Kannst du deine Körperfunktion hier einfügen? – cberkay

+0

Hallo, mein Code wurde eingefügt, als ich die Frage gestellt habe und Sie können sich die Fragebereiche ansehen. Danke nochmal. – Kongsea

Verwandte Themen