2016-06-07 12 views
2

Ich habe Probleme mit der effektiven Verwendung von variablen Bereichen. Ich möchte einige Variablen für Gewichte, Verzerrungen und den inneren Zustand eines einfachen wiederkehrenden Netzwerks definieren. Ich rufe get_saver() einmal nach dem Definieren des Standarddiagramms auf. Ich iteriere dann über eine Charge von Proben mit tf.scan.Variable Bereiche in Tensorflow

import tensorflow as tf 
import math 
import numpy as np 

INPUTS = 10 
HIDDEN_1 = 2 
BATCH_SIZE = 3 

def batch_vm2(m, x): 
    [input_size, output_size] = m.get_shape().as_list() 

    input_shape = tf.shape(x) 
    batch_rank = input_shape.get_shape()[0].value - 1 
    batch_shape = input_shape[:batch_rank] 
    output_shape = tf.concat(0, [batch_shape, [output_size]]) 

    x = tf.reshape(x, [-1, input_size]) 
    y = tf.matmul(x, m) 

    y = tf.reshape(y, output_shape) 

    return y 

def get_saver(): 
    with tf.variable_scope('h1') as scope: 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False) 
     saver = tf.train.Saver([weights, biases, state]) 
    return saver 


def load(sess, saver, checkpoint_dir = None): 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

def iterate_state(prev_state_tuple, input): 
    with tf.variable_scope('h1') as scope: 
     scope.reuse_variables() 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False) 
     print("input: ",input.get_shape()) 
     matmuladd = batch_vm2(weights, input) + biases 
     matmulpri = tf.Print(matmuladd,[matmuladd], message=" malmul -> ") 
     #matmulvec = tf.reshape(matmuladd, [HIDDEN_1]) 
     #state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     print("prev state: ",prev_state_tuple.get_shape()) 
     unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple) 
     prev_state = unpacked_state 
     state = state.assign(4.2*(0.9* prev_state + 0.1*matmuladd)) 
     #output = tf.nn.relu(state) 
     output = tf.nn.tanh(state) 
     state = tf.Print(state, [state], message=" state -> ") 
     output = tf.Print(output, [output], message=" output -> ") 
     #output = matmulpri 
     print(" state: ", state.get_shape()) 
     print(" output: ", output.get_shape()) 
     concat_result = tf.concat(0,[state, output]) 
     print (" concat return: ", concat_result.get_shape()) 
     return concat_result 

def data_iter(): 
    while True: 
     idxs = np.random.rand(BATCH_SIZE, INPUTS) 
     yield idxs 

with tf.Graph().as_default(): 
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS)) 

    saver = get_saver() 
    initial_state = tf.zeros([HIDDEN_1], 
          name='initial_state') 
    initial_out = tf.zeros([HIDDEN_1], 
          name='initial_out') 
    #concat_tensor = tf.concat(0,[initial_state, initial_out]) 
    concat_tensor = tf.concat(0,[initial_state, initial_out]) 
    print(" init state: ",initial_state.get_shape()) 
    print(" init out: ",initial_out.get_shape()) 
    print(" concat: ",concat_tensor.get_shape()) 
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan') 
    print ("scanout shape: ", scanout.get_shape()) 
    state, output = tf.split(1,2,scanout, name='split_scan_output') 
    print(" end state: ",state.get_shape()) 
    print(" end out: ",output.get_shape()) 

    #output,state,diagnostic = create_graph(inputs, state, prev_state) 

    sess = tf.Session() 
    # Run the Op to initialize the variables. 
    sess.run(tf.initialize_all_variables()) 
    if False: 
     load(sess, saver) 
    iter_ = data_iter() 
    for i in xrange(0, 5): 
     print ("iteration: ",i) 
     input_data = iter_.next() 
     out,st,so = sess.run([output,state,scanout], feed_dict={ inputs: input_data}) 
     saver.save(sess, 'my-model', global_step=1+i) 
     print("input vec: ", input_data) 
     print("state vec: ", st) 
     print("output vec: ", out) 
     print(" end state (runtime): ",st.shape) 
     print(" end out (runtime): ",out.shape) 
     print(" end scanout (runtime): ",so.shape) 

Meine Hoffnung wäre die Variablen aus get_variable innerhalb des scan op abgerufen hat das gleiche zu sein wie in dem get_saver Anruf definiert. Wenn ich jedoch diesen Beispielcode ausführen, erhalte ich die folgende Ausgabe mit Fehlern:

(' init state: ', TensorShape([Dimension(2)])) 
(' init out: ', TensorShape([Dimension(2)])) 
(' concat: ', TensorShape([Dimension(4)])) 
Traceback (most recent call last): 
    File "cycles_in_graphs_with_scan.py", line 88, in <module> 
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan') 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 345, in scan 
    back_prop=back_prop, swap_memory=swap_memory) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1873, in while_loop 
    result = context.BuildLoop(cond, body, loop_vars) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1749, in BuildLoop 
    body_result = body(*vars_for_body_with_tensor_arrays) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 339, in compute 
    a = fn(a, elems_ta.read(i)) 
    File "cycles_in_graphs_with_scan.py", line 47, in iterate_state 
    weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 732, in get_variable 
    partitioner=partitioner, validate_shape=validate_shape) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 596, in get_variable 
    partitioner=partitioner, validate_shape=validate_shape) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 161, in get_variable 
    caching_device=caching_device, validate_shape=validate_shape) 
    File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 454, in _get_single_variable 
    " Did you mean to set reuse=None in VarScope?" % name) 
ValueError: Variable state_scan/h1/W does not exist, disallowed. Did you mean to set reuse=None in VarScope? 

eine Idee, was ich in diesem Beispiel falsch mache?

Antwort

0
if False: 
    load(sess, saver) 

Diese zwei Zeilen führen zu nicht initialisierten Variablen.

Verwandte Themen