2017-02-03 3 views
1

Ich habe this post, insbesondere Teil II, verfolgt, um Keras als Schnittstelle zu TensorFlow zu verwenden.Keras set_learning_phase für Dropout beim Speichern von TensorFlow Sitzung

Als Beispiel habe ich ein CNN mit dem MNIST-Datensatz trainiert. Mein Ziel ist es, ein Modell in einer TF-Sitzung zu trainieren und auszuwerten und dann die Sitzung unter Verwendung von tf.train.Saver() zu speichern, damit ich das Modell auf CloudML bereitstellen kann.

Ich bin in der Lage, dies für ein Modell zu tun, das Dropout nicht verwendet, aber wenn ich Dropout-Schichten in Keras einschließen, müssen Sie die learning_phase angeben (training = 1, testing = 0), dies erfolgt über das feed_dict (siehe Code unten).

Vor Ort bin ich in der Lage, dies zu kontrollieren, indem jedoch so etwas wie

test_accuracy = accuracy.eval(feed_dict={images: mnist_data.test.images, labels: mnist_data.test.labels, K.learning_phase(): 0}) 

tun, wenn ich mein Modell zu CloudML laden und ich versuche, die folgende Fehlermeldung zu testen

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'keras_learning_phase' with dtype bool 
    [[Node: keras_learning_phase = Placeholder[dtype=DT_BOOL, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

Ich weiß, es ist wegen der Zeile im feed_dict aber ich habe keine Ahnung wie ich das umgehen soll. In dem Blogpostabschnitt IV wird dieses Problem im Zusammenhang mit TensorFlow behandelt, wo das Modell geladen und erneut gespeichert wird. Ich konnte das nicht für meinen Ansatz nutzen, da ich den Session Export und Export.meta exportieren muss, nicht das Keras Modell.

# Make a session in tf 
sess = tf.Session() 
# sess = tf.InteractiveSession() 

# Register the tf session with Keras 
K.set_session(sess) 

# Generate placeholders for the images and labels and mark as input. 
images = tf.placeholder(tf.float32, shape=(None, 28, 28, 1)) 
keys_placeholder = tf.placeholder(tf.int64, shape=(None,)) 
labels = tf.placeholder(tf.float32, shape=(None, 10)) 
inputs = {'key': keys_placeholder.name, 'image': images.name} 
tf.add_to_collection('inputs', json.dumps(inputs)) 

# To be able to extract the id, we need to add the identity function. 
keys = tf.identity(keys_placeholder) 

# Define a simple network 
# Two fully-connected layer with 128 units and ReLU activation 
model = Sequential() 
model.add(Convolution2D(32, 5, 5, activation='relu', input_shape=(28, 28, 1))) 
model.add(MaxPooling2D(pool_size=(2,2))) 
model.add(Convolution2D(64, 5, 5, activation='relu')) 
model.add(MaxPooling2D(pool_size=(2,2))) 
model.add(Dropout(0.25)) 
model.add(Flatten()) 
model.add(Dense(1024, activation='relu')) 
model.add(Dropout(0.50)) 
model.add(Dense(10, activation='softmax')) 
preds = model(images) # Output 

# Define some Ops 
prediction = tf.argmax(preds ,1) 
scores = tf.nn.softmax(preds) 

# Use the Keras caterforical crossentropy_function and the tf reduce mean 
loss = tf.reduce_mean(categorical_crossentropy(labels, preds)) 
# Define the optimizer 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss) 
# Initialization op 
init_op = tf.initialize_all_variables() 
# Saver op 
saver = tf.train.Saver() 

# Mark the outputs. 
outputs = {'key': keys.name, 
      'prediction': prediction.name, 
      'scores': scores.name} 
tf.add_to_collection('outputs', json.dumps(outputs)) 

# Get the data 
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True, reshape=False) 

# Open session 
with sess.as_default(): 
    sess.run(init_op) 
    # print keras_learning_phase.eval() 

    for i in range(100): 
     batch = mnist_data.train.next_batch(50) 
     train_step.run(feed_dict={images: batch[0], 
            labels: batch[1], 
            K.learning_phase(): 1}) 
    saver.save(sess, 'test/export') 

Antwort

1

Da dies ein sehr Keras orientierte Programmierung Problem ist, wäre es am besten, diese Frage zu stellen direkt an ihren GitHub Issue Tracker.

Sie können auch feststellen, dass this same issue & in ihrem Issue Tracker bereits many times gemeldet wurde. Die Lösung für Ihr Problem könnte auch in der Keras documentation behandelt werden.

Verwandte Themen