Ich möchte Trainingsdaten während des Trainings CNN
in TF
vorladen und meine einfache Implementierung ist wie folgt. Ich finde jedoch ein seltsames Phänomen. Es scheint ein synchroner Prozess zu sein. Der Zeitaufwand für das Laden von Chargendaten ist nahezu gleich, unabhängig davon, ob PRE_FETCH
True
oder False
ist.Daten in TF vorladen
class Demo(object):
def __init__(self):
self._name = 'demo'
def load_batch(self):
...
def prefetch(self, func):
while True:
data = func()
self.queue.put(data)
def train(self):
input_data = tf.placeholder(tf.float32, shape=[B, H, W, C])
optim_op = build_model(input_data)
if PRE_FETCH:
self.queue = Queue(30)
self.process = Process(target=self.prefetch, args=(self.load_batch))
self.process.start()
def cleanup():
self.process.terminate()
self.process.join()
import atexit
atexit.register(cleanup)
sess = tf.Session()
i = 1
while i < MAX_ITER_SIZE:
if PRE_FETCH:
start = time.time()
tmp = self.queue.get()
end = time.time()
print 'load data time: ', (end - start)
else:
start = time.time()
tmp = self.load_batch()
end = time.time()
print 'load data time: ', (end - start)
sess.run(optim_op, feed_dict={input_data: tmp}