2017-12-09 5 views
1

Ich versuche, eine Methode zu erstellen, die eine FIFOQueue in Tensorflow implementieren konnte. Daher ist es bei jeder Iteration der Zweck, eine placeholder eine bestimmte Nummer zuzuweisen, dann speichern Sie es in einem Variable mit dem Namen: Puffer. Nach jeder Zuweisung erhöhe ich einen Index. Die Puffergröße ist [5], so dass der Index zwischen 0 und 4 liegen sollte. Schließlich, nachdem der Puffer voll ist, würde ich buffer[0:4] auf buffer[1:5] setzen und dann den neuen Wert zu buffer[4] hinzufügen. So, hier ist meineProbleme beim manuellen Implementieren einer FIFOQueue in Tensorflow

Code:

import tensorflow as tf 
import numpy as np 
import random 

dim = 30 

lst = [] 
for i in range(dim): 
    lst.append(random.randint(1, 10)) 

data = np.reshape(lst, [dim, 1]) 
print(lst) 

# create a buffer: 
buffer_input = tf.placeholder(tf.int32, shape=[1]) 

buffer = tf.Variable(tf.zeros([5], tf.int32)) 

index = tf.Variable(tf.constant(0)) 

def fillBufferBeforeFilled(): 
    update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input) 
    index_assign_add = tf.assign_add(index, 1) 
    return update_op1, index_assign_add 

def fillBufferAfterFilled(): 
    tmp = tf.slice(buffer, begin=[0], size=[4]) 
    update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp) 
    update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input) 
    return update_op2, update_op3 

cond = tf.cond(tf.equal(index, 4), lambda: fillBufferBeforeFilled(), lambda: fillBufferAfterFilled()) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    for i in range(dim): 
     cond_ = sess.run(cond, feed_dict={buffer_input: data[i]}) 
     buf = sess.run(buffer, feed_dict={buffer_input: data[i]}) 
     print('buf: ', buf) 

Problem: Die index Variable nicht nach jedem Aufruf erhöht wird, während das erste Element des buffer auf den Wert auf den Platzhalter übergeben zugeordnet wird .

Ich würde gerne wissen, warum ich dieses Verhalten bekomme und was die Lösung für dieses Problem ist.

jede Hilfe wird sehr geschätzt !!

Antwort

0

ist die Lösung:

import tensorflow as tf 
import numpy as np 
import random 

dim = 30 

lst = [] 
for i in range(dim): 
    lst.append(random.randint(1, 10)) 

data = np.reshape(lst, [dim, 1]) 
print(lst) 

# create a buffer: 
buffer_input = tf.placeholder(tf.int32, shape=[1]) 

buffer = tf.Variable(tf.zeros([5], tf.int32)) 

index = tf.Variable(-1, tf.int32) 

def fillBufferBeforeFilled(): 
    index_assign_add = tf.assign_add(index, 1) 
    with tf.control_dependencies([index_assign_add]): 
     update_op1 = tf.scatter_update(buffer, indices=[index], updates=buffer_input) 

    return update_op1, index_assign_add 

def fillBufferAfterFilled(): 
    tmp = tf.slice(buffer, begin=[1], size=[4]) 
    update_op2 = tf.scatter_update(buffer, indices=[0, 1, 2, 3], updates=tmp) 
    with tf.control_dependencies([update_op2]): 
     update_op3 = tf.scatter_update(buffer, indices=[index], updates=buffer_input) 

    return update_op2, update_op3 

cond = tf.cond(tf.equal(index, 4), lambda: fillBufferAfterFilled(), lambda: fillBufferBeforeFilled()) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    for i in range(dim): 
     cond_ = sess.run(cond, feed_dict={buffer_input: data[i]}) 
     buf = sess.run(buffer, feed_dict={buffer_input: data[i]}) 
     print('buf: ', buf) 
0

Sie haben die Reihenfolge der Bedingungen in tf.cond verwechselt; Es sollte

sein

Ich kann Ihren Code ausführen und es funktioniert meistens, aber die Updates sind nicht ganz richtig; Ich vermute, du musst einige tf.control_dependencies Aufrufe hinzufügen, um Dinge in der richtigen Reihenfolge zu erzwingen. Hier

Verwandte Themen