Ich mache ein paar Experimente mit TensorFlow und habe einen Haken gefunden. Ich versuche, TF zu verwenden, um eine Änderung in einem Modell zu bewerten, dann das Modell basierend auf der resultierenden Änderung der Verlustfunktion beizubehalten oder zurückzusetzen. Ich habe den harten Teil (bedingte Kontrolle) herausgefunden, aber ich stecke auf etwas fest, das ziemlich einfach sein sollte: Ich kann nicht scheinen, tf.trainable_variables
für eine Iteration zu speichern, dann stelle es bei Bedarf wieder her.Wie kann ich Tensors auf einen früheren Wert zurücksetzen, ohne den Wert auf der Festplatte zu speichern?
Sagen wir, ein Build einen Op:
...
store_trainable_vars = []
for v in tf.trainable_variables():
store_trainable_vars.append(v)
...
Dann später, ich will tf.trainable_variables
auf den Wert wiederherzustellen es hatte, als diese Op letzten Lauf war. Ich würde wollen, wie etwas tun:
def reject_move():
revert_state = []
for (v, s) in zip(tf.trainable_variables(), store_trainable_vars):
revert_state.append(tf.assign(v, s, name="revert_state"))
return(revert_state)
Offensichtlich wird dies neu bewerten store_trainable_vars
, die wiederum Links auf den aktuellen Wert von tf.trainable_variables()
, so dass keine revert_state
Op. Ich brauche eine Möglichkeit, den Wert von Tensors zu speichern und abzurufen, ohne auf den aktuellen Wert dieser Tensoren zurückzugreifen. So etwas wie
...
store_trainable_vars = []
for v in tf.trainable_variables():
store_trainable_vars.append(v.value_right_now())
...
wo v.value_right_now()
eine Konstante zurückgibt, die bis überschrieben wird sich nicht ändern.
Ich weiß, ich könnte Saver verwenden, aber diese Lösung schreibt auf die Festplatte, die für diese Anwendung nicht akzeptabel ist, da es in einer Trainingsschleife ausgeführt wird.
Ich vermisse wahrscheinlich etwas offensichtlich - jede Anleitung wäre willkommen.
Ich sollte klarstellen: Als ich sagte, ich wollte nicht auf die Festplatte schreiben, war es nicht, weil ich über Raum besorgt war. Diese Speicherung und Wiederherstellung findet im schlimmsten Fall bei jeder Iteration statt. Es ist die Laufzeitstrafe für das Zurückgreifen auf die Festplatte, die ich vermeiden möchte. Können Sie Ihre Antwort bearbeiten, um stattdessen eine triviale Verwendung von 'tf.group' für die Wiederherstellung von Graphen zu demonstrieren? (oder verlinke einfach auf ein solches Beispiel) –