1
gegeben

Hier Code (from here):tensorflow: tf.split ist seltsam Parameter

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
from tensorflow.python.ops import rnn, rnn_cell 
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True) 

hm_epochs = 3 
n_classes = 10 
batch_size = 128 
chunk_size = 28 
n_chunks = 28 
rnn_size = 128 


x = tf.placeholder('float', [None, n_chunks,chunk_size]) 
y = tf.placeholder('float') 
def recurrent_neural_network(x): 
    layer = {'weights':tf.Variable(tf.random_normal([rnn_size,n_classes])), 
      'biases':tf.Variable(tf.random_normal([n_classes]))} 

    x = tf.transpose(x, [1,0,2]) 
    x = tf.reshape(x, [-1, chunk_size]) 
    x = tf.split(0, n_chunks, x) 

    lstm_cell = rnn_cell.BasicLSTMCell(rnn_size,state_is_tuple=True) 
    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32) 

    output = tf.matmul(outputs[-1],layer['weights']) + layer['biases'] 

    return output 
def train_neural_network(x): 
    prediction = recurrent_neural_network(x) 
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(prediction,y)) 
    optimizer = tf.train.AdamOptimizer().minimize(cost) 


    with tf.Session() as sess: 
     sess.run(tf.initialize_all_variables()) 

     for epoch in range(hm_epochs): 
      epoch_loss = 0 
      for _ in range(int(mnist.train.num_examples/batch_size)): 
       epoch_x, epoch_y = mnist.train.next_batch(batch_size) 
       epoch_x = epoch_x.reshape((batch_size,n_chunks,chunk_size)) 

       _, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y}) 
       epoch_loss += c 

      print('Epoch', epoch, 'completed out of',hm_epochs,'loss:',epoch_loss) 

     correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1)) 

     accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
     print('Accuracy:',accuracy.eval({x:mnist.test.images.reshape((-1, n_chunks, chunk_size)), y:mnist.test.labels})) 

train_neural_network(x) 

Ich habe Ausgabe x = tf.split(0, n_chunks, x) Verständnis, mehr specificaly dritte Parameter (x -Eingang). Von documenation sollte dies Achse sein ... aber das kann nicht sein, oder? Ist nicht x eindimensional? Ich entschuldige mich, wenn es trivial ist, ich bin Anfänger und kann nicht sem, um es zu bekommen. Vielleicht ist es nur Formsache, aber wenn ich es nicht verstehe, ist, wie es funktioniert ...

Antwort

1

von documenation sollte dies Achse ... aber das kann nicht sein, nicht wahr?

Von tensorflow 1,0 ab, das erste Argument von tf.split ist nicht die Achse, aber ich nehme an, dass der Code geschrieben wurde eine ältere Version verwenden, wo das erste Argument in der Tat die Achse ist.

Ist x nicht eindimensional?

x ist nicht eindimensional. Kurz vor dem Aufruf von tf.split wird x mit dieser Aussage von 3 auf 2 Dimensionen neu gestaltet:

x = tf.reshape(x, [-1, chunk_size]) 

Die Aussage umformt x in einen Tensor mit zwei Dimensionen: die Größe der zweiten Dimension ist chunk_size und die Größe der erste Dimension wird abgeleitet (das ist, was die -1 hier bezeichnet).

+0

Danke! Aber in 'tf.split (0, n_chunks, x)' erstes Argument ist '0' (was Daten sein sollte) und drittens ist' x', die Achse sein sollte ??? – econ

+0

Dies hängt von der Tensorflow-Version ab. Die Reihenfolge der Argumente von "tf.split" wurde von Version 0.12 auf Version 1.0 geändert. Siehe https://www.tensorflow.org/install/migration. – GeertH

+0

Der Vollständigkeit halber [tf.split() Dokumentation für v0.12] (https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/slicing_and_joining#split) –

Verwandte Themen