2017-02-06 3 views
1

Gemäß der Dokumentation und zahlreiche SO Beiträge zu diesem API muss der Sparer-Objekt erstellt werdenPython/Tensorflow: Saver für Netzwerk

saver = tf.train.Saver(...variables...) 

Ich wollte wissen, ob es eine Möglichkeit, automatisch die (...variables...) aufzufüllen ohne explizit alle in meinem Netzwerk verwendeten Variablen und Ops aufzulisten.

Im Moment ist mein Netzwerk nur zwei Schichten, so dass es keine große Hektik ist, aber es fühlt sich geradezu steinzeitlich an, alle Variablen manuell aufzulisten.

Antwort

1

Die default initializer for tf.train.Saver erstellt eine Instanz, die alle speicherbaren Objekte in Ihrem Diagramm speichert/wiederherstellt, das normalerweise alle Modellvariablen enthält. Daher sollten Sie in der Lage sein zu schreiben:

saver = tf.train.Saver() 

... und erhalten Sie den gewünschten Effekt ohne zu viel Mühe.

+0

Danke, ich versuche es jetzt mit dem folgenden Code wiederherzustellen: 'mit tf.Session() als sess: \ new_saver = tf.train.import_meta_graph (" model-path.meta ") \ new_saver.restore (sess, "model-path") \ print (tf.get_collection ('vars')) ' Aber tf.get_collection ('vars') gibt ein leeres Array zurück. Stimmt etwas nicht damit, wie ich es wiederherstellen will? – dant

+0

Beachten Sie, dass 'tf.get_collection ('vars')' wahrscheinlich eine leere Liste zurückgibt, es sei denn, Sie haben absichtlich einige Variablen zu einer Sammlung mit diesem Namen hinzugefügt. Haben Sie vielleicht 'tf.get_collection (tf.GraphKeys.GLOBAL_VARIABLES)' gemeint (was ist die Standard-Sammlung für die meisten Variablen, außer Sie geben etwas anderes an)? – mrry

Verwandte Themen