Ich habe Probleme mit der effektiven Verwendung von variablen Bereichen. Ich möchte einige Variablen für Gewichte, Verzerrungen und den inneren Zustand eines einfachen wiederkehrenden Netzwerks definieren. Ich rufe get_saver()
einmal nach dem Definieren des Standarddiagramms auf. Ich iteriere dann über eine Charge von Proben mit tf.scan
.Variable Bereiche in Tensorflow
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3
def batch_vm2(m, x):
[input_size, output_size] = m.get_shape().as_list()
input_shape = tf.shape(x)
batch_rank = input_shape.get_shape()[0].value - 1
batch_shape = input_shape[:batch_rank]
output_shape = tf.concat(0, [batch_shape, [output_size]])
x = tf.reshape(x, [-1, input_size])
y = tf.matmul(x, m)
y = tf.reshape(y, output_shape)
return y
def get_saver():
with tf.variable_scope('h1') as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
saver = tf.train.Saver([weights, biases, state])
return saver
def load(sess, saver, checkpoint_dir = None):
print("loading a session")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("no checkpoint found")
return
def iterate_state(prev_state_tuple, input):
with tf.variable_scope('h1') as scope:
scope.reuse_variables()
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
print("input: ",input.get_shape())
matmuladd = batch_vm2(weights, input) + biases
matmulpri = tf.Print(matmuladd,[matmuladd], message=" malmul -> ")
#matmulvec = tf.reshape(matmuladd, [HIDDEN_1])
#state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
print("prev state: ",prev_state_tuple.get_shape())
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = unpacked_state
state = state.assign(4.2*(0.9* prev_state + 0.1*matmuladd))
#output = tf.nn.relu(state)
output = tf.nn.tanh(state)
state = tf.Print(state, [state], message=" state -> ")
output = tf.Print(output, [output], message=" output -> ")
#output = matmulpri
print(" state: ", state.get_shape())
print(" output: ", output.get_shape())
concat_result = tf.concat(0,[state, output])
print (" concat return: ", concat_result.get_shape())
return concat_result
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
saver = get_saver()
initial_state = tf.zeros([HIDDEN_1],
name='initial_state')
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
#concat_tensor = tf.concat(0,[initial_state, initial_out])
concat_tensor = tf.concat(0,[initial_state, initial_out])
print(" init state: ",initial_state.get_shape())
print(" init out: ",initial_out.get_shape())
print(" concat: ",concat_tensor.get_shape())
scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
print ("scanout shape: ", scanout.get_shape())
state, output = tf.split(1,2,scanout, name='split_scan_output')
print(" end state: ",state.get_shape())
print(" end out: ",output.get_shape())
#output,state,diagnostic = create_graph(inputs, state, prev_state)
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
if False:
load(sess, saver)
iter_ = data_iter()
for i in xrange(0, 5):
print ("iteration: ",i)
input_data = iter_.next()
out,st,so = sess.run([output,state,scanout], feed_dict={ inputs: input_data})
saver.save(sess, 'my-model', global_step=1+i)
print("input vec: ", input_data)
print("state vec: ", st)
print("output vec: ", out)
print(" end state (runtime): ",st.shape)
print(" end out (runtime): ",out.shape)
print(" end scanout (runtime): ",so.shape)
Meine Hoffnung wäre die Variablen aus get_variable
innerhalb des scan
op abgerufen hat das gleiche zu sein wie in dem get_saver
Anruf definiert. Wenn ich jedoch diesen Beispielcode ausführen, erhalte ich die folgende Ausgabe mit Fehlern:
(' init state: ', TensorShape([Dimension(2)]))
(' init out: ', TensorShape([Dimension(2)]))
(' concat: ', TensorShape([Dimension(4)]))
Traceback (most recent call last):
File "cycles_in_graphs_with_scan.py", line 88, in <module>
scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 345, in scan
back_prop=back_prop, swap_memory=swap_memory)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1873, in while_loop
result = context.BuildLoop(cond, body, loop_vars)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1749, in BuildLoop
body_result = body(*vars_for_body_with_tensor_arrays)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 339, in compute
a = fn(a, elems_ta.read(i))
File "cycles_in_graphs_with_scan.py", line 47, in iterate_state
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS))))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 732, in get_variable
partitioner=partitioner, validate_shape=validate_shape)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 596, in get_variable
partitioner=partitioner, validate_shape=validate_shape)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 161, in get_variable
caching_device=caching_device, validate_shape=validate_shape)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 454, in _get_single_variable
" Did you mean to set reuse=None in VarScope?" % name)
ValueError: Variable state_scan/h1/W does not exist, disallowed. Did you mean to set reuse=None in VarScope?
eine Idee, was ich in diesem Beispiel falsch mache?