2017-06-05 6 views
0

UPDATE: Ich glaube fest, dass der Fehler den init_state verwendet ist, wie in den tf.nn.dynamic_rnn (...) als Argument erstellt und zugeführt. Also stellt sich die Frage, was ist die richtige Form oder Art, einen Anfangszustand für eine gestapelte RNN zu konstruieren?Tensorflow 1.1 MultiRNNCell Formfehler (Init_State bezogen)

Ich versuche, eine MultiRNNCell-Definition in TensorFlow 1.1 zu arbeiten.

Die Graphdefinition mit der Hilfsfunktion zum Definieren einer GRU-Zelle folgt unten. Die Grundidee besteht darin, einen Platzhalter x als eine lange Reihe von Proben numerischer Daten zu definieren. Diese Daten werden über Shaping in Frames gleicher Länge aufgeteilt, und zu jedem Zeitschritt wird ein Frame angezeigt. Ich würde es dann gerne über einen Stapel von zwei (vorerst) Zellen von GRUs verarbeiten.

def gru_cell(state_size): 
    cell = tf.contrib.rnn.GRUCell(state_size) 
    return cell 

graph = tf.Graph() 
with graph.as_default(): 

    x = tf.placeholder(tf.float32, [batch_size, num_samples], name="Input_Placeholder") 
    y = tf.placeholder(tf.int32, [batch_size, num_frames], name="Labels_Placeholder") 

    init_state = tf.zeros([batch_size, state_size], name="Initial_State_Placeholder") 

    rnn_inputs = tf.reshape(x, (batch_size, num_frames, frame_length)) 
    cell = tf.contrib.rnn.MultiRNNCell([gru_cell(state_size) for _ in range(2)], state_is_tuple=False) 
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state) 

Die Grafik Definition geht von dort mit Verlustfunktionen auf, Optimizern usw. Aber das ist der Ort, wo er mit dem nach langen Fehler bricht.

Es wird im letzten Teil des Fehlers relevant geworden, dass batch_size 10 ist, und frame_length und state_size sind beide 80.

ValueError        Traceback (most recent call last) 
<ipython-input-30-4c48b596e055> in <module>() 
    14  print(rnn_inputs) 
    15  cell = tf.contrib.rnn.MultiRNNCell([gru_cell(state_size) for _ in range(2)], state_is_tuple=False) 
---> 16  rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state) 
    17 
    18  with tf.variable_scope('softmax'): 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.pyc in dynamic_rnn(cell, inputs, sequence_length, initial_state, dtype, parallel_iterations, swap_memory, time_major, scope) 
    551   swap_memory=swap_memory, 
    552   sequence_length=sequence_length, 
--> 553   dtype=dtype) 
    554 
    555  # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/rnn.pyc in _dynamic_rnn_loop(cell, inputs, initial_state, parallel_iterations, swap_memory, sequence_length, dtype) 
    718  loop_vars=(time, output_ta, state), 
    719  parallel_iterations=parallel_iterations, 
--> 720  swap_memory=swap_memory) 
    721 
    722 # Unpack final output if not using output tuples. 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name) 
    2621  context = WhileContext(parallel_iterations, back_prop, swap_memory, name) 
    2622  ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context) 
-> 2623  result = context.BuildLoop(cond, body, loop_vars, shape_invariants) 
    2624  return result 
    2625 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in BuildLoop(self, pred, body, loop_vars, shape_invariants) 
    2454  self.Enter() 
    2455  original_body_result, exit_vars = self._BuildLoop(
-> 2456   pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2457  finally: 
    2458  self.Exit() 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in _BuildLoop(self, pred, body, original_loop_vars, loop_vars, shape_invariants) 
    2435  for m_var, n_var in zip(merge_vars, next_vars): 
    2436  if isinstance(m_var, ops.Tensor): 
-> 2437   _EnforceShapeInvariant(m_var, n_var) 
    2438 
    2439  # Exit the loop. 

/home/novak/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.pyc in _EnforceShapeInvariant(merge_var, next_var) 
    565   "Provide shape invariants using either the `shape_invariants` " 
    566   "argument of tf.while_loop or set_shape() on the loop variables." 
--> 567   % (merge_var.name, m_shape, n_shape)) 
    568 else: 
    569  if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): 

ValueError: The shape for rnn/while/Merge_2:0 is not an invariant for the loop. It enters the loop with shape (10, 80), but has shape (10, 160) after one iteration. Provide shape invariants using either the `shape_invariants` argument of tf.while_loop or set_shape() on the loop variables. 

, das fast wie das Netz der 80er Jahre beginnt als 2-Stapel sieht und irgendwie wird ein 1-Stack von 160 konvertiert. Irgendwelche Hilfe, um das zu beheben? Missverstehe ich die Verwendung der MultiRNNCell?

+0

Sollte das nicht 'init_state = tf.zeros ([batch_size, 2 * state_size] ...'? –

Antwort

0

Basierend auf Allen Lavoie Kommentar über die korrigierte Code ist:

def gru_cell(state_size): 
    cell = tf.contrib.rnn.GRUCell(state_size) 
    return cell 

num_layers = 2 # <--------- 
graph = tf.Graph() 
with graph.as_default(): 

    x = tf.placeholder(tf.float32, [batch_size, num_samples], name="Input_Placeholder") 
    y = tf.placeholder(tf.int32, [batch_size, num_frames], name="Labels_Placeholder") 

    init_state = tf.zeros([batch_size, num_layer * state_size], name="Initial_State_Placeholder") # <--------- 

    rnn_inputs = tf.reshape(x, (batch_size, num_frames, frame_length)) 
    cell = tf.contrib.rnn.MultiRNNCell([gru_cell(state_size) for _ in range(num_layer)], state_is_tuple=False) # <--------- 
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state) 

Hinweis drei Änderungen oben. Beachten Sie auch, dass diese Änderungen an allen Stellen, an denen init_state fließt, kräuseln müssen, insbesondere, wenn Sie sie in ein feed_dict eingeben.