2017-05-08 4 views
2

Ich versuche, ein sehr einfaches Beispiel für die Kombination von TensorArray und while_loop zu produzieren:Wie TensorArray und while_loop im Tensorflow zusammenarbeiten?

# 1000 sequence in the length of 100 
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix") 
matrix_rows = tf.shape(matrix)[0] 
ta = tf.TensorArray(tf.float32, size=matrix_rows) 
ta = ta.unstack(matrix) 

init_state = (0, ta) 
condition = lambda i, _: i < n 
body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2)) 

# run the graph 
with tf.Session() as sess: 
    (n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))}) 
    print (ta_final.stack()) 

Aber ich die folgende Fehlermeldung erhalten:

ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32). 

Wer auf Ahnung hat, was ist das Problem?

+0

Um die endgültige zu erhalten 'TensorArray' Sie' session.run benötigen (ta.stack()) 'stattdessen die Schleife direkt laufen, die, da Sie nicht können nicht 'session.run (TensorArray)'. – sirfz

+0

Sorry, aber ich habe nicht verstanden, was du meinst. Würden Sie bitte das richtige Formular schreiben? –

Antwort

3

Es gibt verschiedene Dinge in Ihrem Code, auf die Sie hinweisen sollten. Zuerst müssen Sie die Matrix nicht in die TensorArray entstapeln, um sie innerhalb der Schleife zu verwenden. Sie können die Matrix Tensor innerhalb des Körpers sicher referenzieren und sie unter Verwendung der Notation matrix[i] indexieren. Ein anderes Problem ist der unterschiedliche Datentyp zwischen deiner Matrix (tf.int32) und der TensorArray(), basierend auf deinem Code multiplizierst du die Matrix-Inte mit 2 und schreibst das Ergebnis in das Array, so dass es int32 sein sollte. Wenn Sie schließlich das Endergebnis der Schleife lesen möchten, ist die korrekte Operation TensorArray.stack(), die Sie in Ihrem session.run-Aufruf ausführen müssen.

Hier ist ein funktionierendes Beispiel:

import numpy as np 
import tensorflow as tf  

# 1000 sequence in the length of 100 
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix") 
matrix_rows = tf.shape(matrix)[0] 
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows) 

init_state = (0, ta) 
condition = lambda i, _: i < matrix_rows 
body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2)) 
n, ta_final = tf.while_loop(condition, body, init_state) 
# get the final result 
ta_final_result = ta_final.stack() 

# run the graph 
with tf.Session() as sess: 
    # print the output of ta_final_result 
    print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)}) 
+0

Großartig! Vielen Dank! es funktioniert jetzt. –

+0

Froh, es funktioniert @ E.Asgari, bitte markieren Sie die Antwort als akzeptiert. – sirfz

+0

In diesem kann ich die Eingabe ohne Verwendung Feed-Wörterbuch, wie in, wenn ich dies zwischen einem Computer-Diagramm, wie würde ich angeben, dass die Tensor-Array hängt von einigen Tensor? – Rahul

Verwandte Themen