2017-01-18 4 views
0

Ich versuche, einen Tensor Attns in Tensorflow Seq2seq Code zu drucken. Seq2Seq.pyDrucken Tensorflow Variable Seq2seq

Ich habe versucht:

tf.Print(attns, [attns]) 

aber er druckt nichts.

Ich versuchte

sess = tf.Session() 
sess.run(attns) or attns.eval() 

ich diesen Fall wirft: InvalidArgumentError: Sie müssen einen Wert für Platzhalter-Tensor-Feed

Ich habe auch versucht sess.run mit()

sess = tf.get_default_session() 
aa = sess.run(attns) 

In diesem Fall ist das Sess-Objekt None.

Antwort

1

tf.Print ist keine "klassische" Betriebsanweisung, da diese nicht in symbolischem, graphenbasiertem Code ausgeführt werden. Was stattdessen benötigt wird, ist ein spezifischer Knoten im Berechnungsgraphen, der dann ausgelöst wird, wenn Ihre Berechnungen diesen Knoten "passieren".

Genau das macht tf.Print. Es erstellt einen "Wrapper" -Knoten um einen beliebigen anderen Knoten herum, indem eine Identitätsoperation erstellt wird, die bei Auslösung den Wert einer Liste von Tensoren ausgibt.

Das erste Argument von this print function, input_ (oder attns in diesem Fall) ist der umhüllte Knoten und data (oder [attns] in diesem Fall) ist die Liste der Tensoren gedruckt werden.

Was Sie deshalb tun möchte, ist diese Zeile hinzufügen:

attns = tf.Print(attns, [attns]) 

Hier wird attns einen Druck Wrapper Identitätsoperation auf attns zugewiesen - so der Tensor attns hat genau das gleiche Verhalten, mit der Ausnahme, dass, wenn es berechnet wird, wird auch [attns] ausgedruckt.

+0

Wenn ich dies versuche, wird Wert Fehler bei Seq2seq.py # L560 aufgrund der Dimension nicht übereinstimmen. Ich habe versucht, tf.Print zu entfernen und es funktioniert gut. Ich lade ein trainiertes Modell, falls es darauf ankommt. Es druckt zwar immer noch nichts. –

+0

@ p.j seltsam. Können Sie Ihren Kommentar für den genauen Fehler aktualisieren? (nicht die vollständige Spur, sondern der Name der Tensor-Mismatching, und die Dimensionen gegeben und erwartet) – cleros

Verwandte Themen