Hier Code (from here):tensorflow: tf.split ist seltsam Parameter
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.ops import rnn, rnn_cell
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)
hm_epochs = 3
n_classes = 10
batch_size = 128
chunk_size = 28
n_chunks = 28
rnn_size = 128
x = tf.placeholder('float', [None, n_chunks,chunk_size])
y = tf.placeholder('float')
def recurrent_neural_network(x):
layer = {'weights':tf.Variable(tf.random_normal([rnn_size,n_classes])),
'biases':tf.Variable(tf.random_normal([n_classes]))}
x = tf.transpose(x, [1,0,2])
x = tf.reshape(x, [-1, chunk_size])
x = tf.split(0, n_chunks, x)
lstm_cell = rnn_cell.BasicLSTMCell(rnn_size,state_is_tuple=True)
outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
output = tf.matmul(outputs[-1],layer['weights']) + layer['biases']
return output
def train_neural_network(x):
prediction = recurrent_neural_network(x)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prediction,y))
optimizer = tf.train.AdamOptimizer().minimize(cost)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for epoch in range(hm_epochs):
epoch_loss = 0
for _ in range(int(mnist.train.num_examples/batch_size)):
epoch_x, epoch_y = mnist.train.next_batch(batch_size)
epoch_x = epoch_x.reshape((batch_size,n_chunks,chunk_size))
_, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
epoch_loss += c
print('Epoch', epoch, 'completed out of',hm_epochs,'loss:',epoch_loss)
correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
print('Accuracy:',accuracy.eval({x:mnist.test.images.reshape((-1, n_chunks, chunk_size)), y:mnist.test.labels}))
train_neural_network(x)
Ich habe Ausgabe x = tf.split(0, n_chunks, x)
Verständnis, mehr specificaly dritte Parameter (x
-Eingang). Von documenation sollte dies Achse sein ... aber das kann nicht sein, oder? Ist nicht x
eindimensional? Ich entschuldige mich, wenn es trivial ist, ich bin Anfänger und kann nicht sem, um es zu bekommen. Vielleicht ist es nur Formsache, aber wenn ich es nicht verstehe, ist, wie es funktioniert ...
Danke! Aber in 'tf.split (0, n_chunks, x)' erstes Argument ist '0' (was Daten sein sollte) und drittens ist' x', die Achse sein sollte ??? – econ
Dies hängt von der Tensorflow-Version ab. Die Reihenfolge der Argumente von "tf.split" wurde von Version 0.12 auf Version 1.0 geändert. Siehe https://www.tensorflow.org/install/migration. – GeertH
Der Vollständigkeit halber [tf.split() Dokumentation für v0.12] (https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/slicing_and_joining#split) –