2016-12-01 2 views
2

Ich habe eine grundlegende Frage über die Aktualisierung der Werte von Tensoren über die Tensorflow Python API.Aktualisieren von Variablenwerten in Tensorflow

Betrachten Sie den Code-Schnipsel:

x = tf.placeholder(shape=(None,10), ...) 
y = tf.placeholder(shape=(None,), ...) 
W = tf.Variable(randn(10,10), dtype=tf.float32) 
yhat = tf.matmul(x, W) 

Nun nehmen wir an, ich irgendeine Art von Algorithmus implementieren möchten, die iterativ den Wert von W aktualisiert (z einige Optimierung algo). Dies beinhaltet die Schritte wie:

for i in range(max_its): 
    resid = y_hat - y 
    W = f(W , resid) # some update 

das Problem hier ist, dass W auf der LHS ist ein neuer Tensor, nicht die W, die in yhat = tf.matmul(x, W) verwendet wird! Das heißt, eine neue Variable wird erstellt und der in meinem "Modell" verwendete Wert W wird nicht aktualisiert.

nun einen Weg, um dies würde

for i in range(max_its): 
    resid = y_hat - y 
    W = f(W , resid) # some update 
    yhat = tf.matmul(x, W) 

, die für jede Iteration meiner Schleife in der Schaffung eines neuen „Modell“ führt sein!

Gibt es eine bessere Möglichkeit, dies zu implementieren (in Python), ohne eine ganze Reihe neuer Modelle für jede Iteration der Schleife zu erstellen - aber stattdessen den ursprünglichen Tensor W sozusagen "in-place" zu aktualisieren?

Antwort

2

Variablen haben eine Assign-Methode. Versuchen Sie: W.assign(f(W,resid))

+0

Das scheint zu funktionieren! Vielen Dank. – firdaus

0

@ aarbelle die knappe Antwort ist richtig, ich werde es ein wenig erweitern, falls jemand mehr Informationen benötigt. Die letzten 2 Zeilen darunter werden zur Aktualisierung von W verwendet.

x = tf.placeholder(shape=(None,10), ...) 
y = tf.placeholder(shape=(None,), ...) 
W = tf.Variable(randn(10,10), dtype=tf.float32) 
yhat = tf.matmul(x, W) 

... 

for i in range(max_its): 
    resid = y_hat - y 
    update = W.assign(f(W , resid)) # do not forget to initialize tf variables. 
    # "update" above is just a tf op, you need to run the op to update W. 
    sess.run(update) 
Verwandte Themen