Kurze Antwort: Sie möchten wahrscheinlich checkpoint files (permalink).
Lange Antwort:
Lassen Sie uns über das Setup hier klar sein. Ich nehme an, dass Sie zwei Geräte haben, A und B, und Sie trainieren auf A und führen Inferenz auf B. Periodisch möchten Sie die Parameter auf dem Gerät aktualisieren, die Inferenz mit neuen Parametern während des Trainings auf der andere. Das oben verlinkte Tutorial ist ein guter Anfang. Es zeigt Ihnen, wie tf.train.Saver
Objekte funktionieren, und Sie sollten hier nichts Komplizierteres brauchen. Hier
ein Beispiel:
import tensorflow as tf
def build_net(graph, device):
with graph.as_default():
with graph.device(device):
# Input placeholders
inputs = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.float32, [None, 10])
# Initialization
w0 = tf.get_variable('w0', shape=[784,256], initializer=tf.contrib.layers.xavier_initializer())
w1 = tf.get_variable('w1', shape=[256,256], initializer=tf.contrib.layers.xavier_initializer())
w2 = tf.get_variable('w2', shape=[256,10], initializer=tf.contrib.layers.xavier_initializer())
b0 = tf.Variable(tf.zeros([256]))
b1 = tf.Variable(tf.zeros([256]))
b2 = tf.Variable(tf.zeros([10]))
# Inference network
h1 = tf.nn.relu(tf.matmul(inputs, w0)+b0)
h2 = tf.nn.relu(tf.matmul(h1,w1)+b1)
output = tf.nn.softmax(tf.matmul(h2,w2)+b2)
# Training network
cross_entropy = tf.reduce_mean(-tf.reduce_sum(labels * tf.log(output), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# Your checkpoint function
saver = tf.train.Saver()
return tf.initialize_all_variables(), inputs, labels, output, optimizer, saver
Der Code für das Trainingsprogramm:
def programA_main():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Build training network on device A
graphA = tf.Graph()
init, inputs, labels, _, training_net, saver = build_net(graphA, '/cpu:0')
with tf.Session(graph=graphA) as sess:
sess.run(init)
for step in xrange(1,10000):
batch = mnist.train.next_batch(50)
sess.run(training_net, feed_dict={inputs: batch[0], labels: batch[1]})
if step%100==0:
saver.save(sess, '/tmp/graph.checkpoint')
print 'saved checkpoint'
... und Code für ein Inferenz-Programm:
def programB_main():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Build inference network on device B
graphB = tf.Graph()
init, inputs, _, inference_net, _, saver = build_net(graphB, '/cpu:0')
with tf.Session(graph=graphB) as sess:
batch = mnist.test.next_batch(50)
saver.restore(sess, '/tmp/graph.checkpoint')
print 'loaded checkpoint'
out = sess.run(inference_net, feed_dict={inputs: batch[0]})
print out[0]
import time; time.sleep(2)
saver.restore(sess, '/tmp/graph.checkpoint')
print 'loaded checkpoint'
out = sess.run(inference_net, feed_dict={inputs: batch[0]})
print out[1]
Wenn Sie Feuer Nach dem Trainingsprogramm und dem Inferenzprogramm sehen Sie, dass das Inferenzprogramm zwei verschiedene Ausgänge erzeugt (aus der gleichen Input-Charge). Dies ist ein Ergebnis davon, dass es die Parameter aufnimmt, die das Trainingsprogramm überprüft hat.
Jetzt ist dieses Programm offensichtlich nicht Ihr Endpunkt. Wir führen keine echte Synchronisation durch, und Sie müssen entscheiden, was "periodisch" in Bezug auf das Checkpointing bedeutet. Aber das sollte Ihnen eine Vorstellung davon geben, wie man Parameter von einem Netzwerk zu einem anderen synchronisiert.
Eine letzte Warnung: Dies bedeutet nicht bedeuten, dass die beiden Netzwerke unbedingt deterministisch sind. Es gibt bekannte nicht-deterministische Elemente in TensorFlow (z. B. this), also seien Sie vorsichtig, wenn Sie genau die gleiche Antwort benötigen: genau.Aber das ist die harte Wahrheit über das Laufen auf mehreren Geräten.
Viel Glück!
Warum erstellen Sie nicht mehrere Graphen parallel, anstatt eine existierende zu replizieren? –
Diese Frage ist ziemlich zweideutig. Fragen Sie nach einer Aktualisierung einer TensorFlow 'Graph' Datenstruktur in situ [(hard)] (http://stackoverflow.com/questions/37610757/how-to-remove-nodes-from-tensorflow-graph/37620231#37620231) ? Oder fragen Sie, wie man die Parameter in einem Graphen von einem anderen Graphen aktualisieren kann (nicht so schlecht) (https://www.tensorflow.org/versions/master/how_tos/variables/index.html#saving-and-restoring)) ohne die Struktur zu ändern? Oder hängt das mit der Versionskontrolle in neuronalen Netzen zusammen (was ein Software-Engineering-Problem ist)? – rdadolf
@rdadolf der zweite. Ich muss nur eine Kopie der gleichen Modelle auf verschiedenen Maschinen behalten und die Parameter von Zeit zu Zeit synchronisieren. – MBZ