2016-09-24 3 views
2

Ich versuche, ein LSTM-Netzwerk zu trainieren und es erfolgreich in einer Weise trainiert, aber einen Fehler in die andere Richtung wirft. Im ersten Beispiel verforme ich das Eingabearray X mit numpy reshape und andersherum formuliere ich es mit Tensorflow reshape.Tensorflow tf.reshape() scheint sich anders zu verhalten als numpy.reshape()

Adaequat:

import numpy as np 
import tensorflow as tf 
import tensorflow.contrib.learn as learn 


# Parameters 
learning_rate = 0.1 
training_steps = 3000 
batch_size = 128 

# Network Parameters 
n_input = 4 
n_steps = 10 
n_hidden = 128 
n_classes = 6 

X = np.ones([1770,4]) 
y = np.ones([177]) 

# NUMPY RESHAPE OUTSIDE RNN_MODEL 
X = np.reshape(X, (-1, n_steps, n_input)) 

def rnn_model(X, y): 

    # TENSORFLOW RESHAPE INSIDE RNN_MODEL 
    #X = tf.reshape(X, [-1, n_steps, n_input]) # (batch_size, n_steps, n_input) 

    # # permute n_steps and batch_size 
    X = tf.transpose(X, [1, 0, 2]) 

    # # Reshape to prepare input to hidden activation 
    X = tf.reshape(X, [-1, n_input]) # (n_steps*batch_size, n_input) 
    # # Split data because rnn cell needs a list of inputs for the RNN inner loop 
    X = tf.split(0, n_steps, X) # n_steps * (batch_size, n_input) 

    # Define a GRU cell with tensorflow 
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden) 
    # Get lstm cell output 
    _, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32) 

    return learn.models.logistic_regression(encoding, y) 


classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes, 
             batch_size=batch_size, 
             steps=training_steps, 
             learning_rate=learning_rate) 

classifier.fit(X,y) 

funktioniert nicht:

import numpy as np 
import tensorflow as tf 
import tensorflow.contrib.learn as learn 


# Parameters 
learning_rate = 0.1 
training_steps = 3000 
batch_size = 128 

# Network Parameters 
n_input = 4 
n_steps = 10 
n_hidden = 128 
n_classes = 6 

X = np.ones([1770,4]) 
y = np.ones([177]) 

# NUMPY RESHAPE OUTSIDE RNN_MODEL 
#X = np.reshape(X, (-1, n_steps, n_input)) 

def rnn_model(X, y): 

    # TENSORFLOW RESHAPE INSIDE RNN_MODEL 
    X = tf.reshape(X, [-1, n_steps, n_input]) # (batch_size, n_steps, n_input) 

    # # permute n_steps and batch_size 
    X = tf.transpose(X, [1, 0, 2]) 

    # # Reshape to prepare input to hidden activation 
    X = tf.reshape(X, [-1, n_input]) # (n_steps*batch_size, n_input) 
    # # Split data because rnn cell needs a list of inputs for the RNN inner loop 
    X = tf.split(0, n_steps, X) # n_steps * (batch_size, n_input) 

    # Define a GRU cell with tensorflow 
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden) 
    # Get lstm cell output 
    _, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32) 

    return learn.models.logistic_regression(encoding, y) 


classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes, 
             batch_size=batch_size, 
             steps=training_steps, 
             learning_rate=learning_rate) 

classifier.fit(X,y) 

Letzteres führt den folgenden Fehler:

WARNING:tensorflow:<tensorflow.python.ops.rnn_cell.BasicLSTMCell object at 0x7f1c67c6f750>: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True. 
Traceback (most recent call last): 
    File "/home/blabla/test.py", line 47, in <module> 
    classifier.fit(X,y) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/estimators/base.py", line 160, in fit 
    monitors=monitors) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 484, in _train_model 
    monitors=monitors) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/graph_actions.py", line 328, in train 
    reraise(*excinfo) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/graph_actions.py", line 254, in train 
    feed_dict = feed_fn() if feed_fn is not None else None 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/io/data_feeder.py", line 366, in _feed_dict_fn 
    out.itemset((i, self.y[sample]), 1.0) 
IndexError: index 974 is out of bounds for axis 0 with size 177 
+0

Bitte helfen Sie mir. Ich werde verrückt danach. :( – Jbravo

Antwort

0

Ein paar Vorschläge: * input_fn anstelle von X, Y bis fit * Verwenden Sie learn.Estimator anstelle von lernen.TensorFlowEstimator

Da Sie kleine Daten haben, sollte folgendes funktionieren. Andernfalls müssen Sie Ihre Daten bündeln. `` ` def _my_inputs(): return tf.constant (np.ones ([1770,4])), tf.constant (np.ones ([177]))

0

Ich war in der Lage zu bekommen diese Arbeit mit ein paar kleinen Änderungen:

# Parameters 
learning_rate = 0.1 
training_steps = 10 
batch_size = 8 

# Network Parameters 
n_input = 4 
n_steps = 10 
n_hidden = 128 
n_classes = 6 

X = np.ones([177, 10, 4]) # <---- Use shape [batch_size, n_steps, n_input] here. 
y = np.ones([177]) 

def rnn_model(X, y): 
    X = tf.transpose(X, [1, 0, 2]) #| 
    X = tf.unpack(X)    #| These two lines do the same thing as your code, just a bit simpler ;) 

    # Define a LSTM cell with tensorflow 
    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden) 
    # Get lstm cell output 
    outputs, _ = tf.nn.rnn(lstm_cell, X, dtype=tf.float64) # <---- I think you want to use the first return value here. 

    return tf.contrib.learn.models.logistic_regression(outputs[-1], y) # <----uses just the last output for classification, as is typical with RNNs. 


classifier = tf.contrib.learn.TensorFlowEstimator(model_fn=rnn_model, 
                n_classes=n_classes, 
                batch_size=batch_size, 
                steps=training_steps, 
                learning_rate=learning_rate) 

classifier.fit(X,y) 

ich denke, das zentrale Problem, das Sie war hatten, dass X-Form sein muss [Charge, ...], wenn übergeben passen (...). Wenn Sie numpy verwendet haben, um es außerhalb der Funktion rnn_model() umzuformen, hatte X diese Form, sodass das Training funktionierte.

Ich kann nicht für die Qualität des Modells sprechen, das diese Lösung produzieren wird, aber zumindest läuft es!

Verwandte Themen