2017-07-05 1 views
2

In der API von tf.contrib.rnn.DropoutWrapper, ich versuche variational_recurrent=True einzustellen, in diesem Fall ist input_size obligatorisch. Wie erläutert, ist input_sizeTensorShape Objekte, die die Tiefe (n) der Eingangstensoren enthalten.Verwenden Sie variational_recurrent in tf.contrib.rnn.DropoutWrapper

Tiefe (n) ist verwirrend, was ist es bitte? Ist es nur die Form des Tensors, wie wir es durch tf.shape() bekommen können? Oder die Anzahl der Kanäle für den speziellen Fall von Bildern? Aber mein Eingangstensor ist kein Bild.

Und ich verstehe nicht, warum dtype angefordert wird, wenn variational_recurrent=True.

Danke!

+0

In Bezug auf Ihre erste Frage: Ihre Eingabe in die RNN Zelle einige mehrdimensionale Tensor Form sein '[batch_size, MAX_TIME, ...]'. Die Tiefe bezieht sich auf die Dimensionen "...". Wenn Ihre Eingabe beispielsweise die Form "(20, 35, 100)" hat, ist die Tiefe 100. In diesem Fall wäre "input_size" im Dropout-Wrapper 100. – Lemon

Antwort

0

Inpput_size für tf.TensorShape ([200, None, 300]) ist nur 300

Wiedergabe mit diesem Beispiel.

import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see TF issue #152 
os.environ["CUDA_VISIBLE_DEVICES"]="1" 
import tensorflow as tf 
import numpy as np 


n_steps = 2 
n_inputs = 3 
n_neurons = 5 
keep_prob = 0.5 
learning_rate = 0.001 


X = tf.placeholder(tf.float32, [None, n_steps, n_inputs]) 
X_seqs = tf.unstack(tf.transpose(X, perm=[1, 0, 2])) 

basic_cell = tf.contrib.rnn.BasicLSTMCell(num_units=n_neurons) 
basic_cell_drop = tf.contrib.rnn.DropoutWrapper(basic_cell, input_keep_prob=keep_prob, variational_recurrent=True,dtype=tf.float32,input_size=n_inputs) 

output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell_drop, X_seqs, 
               dtype=tf.float32) 
outputs = tf.transpose(tf.stack(output_seqs), perm=[1, 0, 2]) 

init = tf.global_variables_initializer() 

X_batch = np.array([ 
     # t = 0  t = 1 
     [[0, 1, 2], [9, 8, 7]], # instance 1 
     [[3, 4, 5], [0, 0, 0]], # instance 2 
     [[6, 7, 8], [6, 5, 4]], # instance 3 
     [[9, 0, 1], [3, 2, 1]], # instance 4 
    ]) 

with tf.Session() as sess: 
    init.run() 
    outputs_val = outputs.eval(feed_dict={X: X_batch}) 


print(outputs_val) 

Sehen Sie diese für weitere Informationen: https://github.com/tensorflow/tensorflow/issues/7927