2017-09-01 1 views
0

Ich bekomme seltsame Ergebnisse, wenn ich Folgendes mache: (Ich möchte nur auf eine Grafikvariable in einer Funktion zugreifen).Ändern Sie eine Tensorflow-Variable in Funktion

import tensorflow as tf 


def mul(x): 
    p = tf.get_variable('vara', shape=()) 
    return x * x + p 

x = tf.placeholder(dtype=tf.float32, shape=()) 
v = tf.Variable(1.0, name='vara') 

out = mul(x) 


with tf.Session() as sess: 

    # init vars 
    sess.run(tf.global_variables_initializer()) 


    print(26.0, sess.run(out, feed_dict={x: 5.0})) 
    # 26.0 23.9039 ?? 

    print(1.0, sess.run(v)) 
    # 1.0 1.0 

    # v+= 2 
    sess.run(v.assign_add(2.0))   
    print(3.0, sess.run(v)) 
    # 3.0 3.0 


    print(28.0, sess.run(out, feed_dict={x: 5.0})) 
    # 28.0 25.5912 

Antwort

0

es herausgefunden:

import tensorflow as tf 

# create a new var using scope 
def new_var(scope_name, var, shape=None): 
    with tf.variable_scope(scope_name) as varscope: 
     inputs_1 = tf.constant(1.0, shape=()) 
     v = tf.get_variable(var, initializer=inputs_1, shape=shape) 
     varscope.reuse_variables() 
     return v 

# get var whenever in an arbitrary place in the graph 
def get_var(scope_name, var, shape=None): 
    with tf.variable_scope(scope_name, reuse=True) as varscope: 
     v = tf.get_variable(var, shape) 
     return v 

# fx we want to access a random part of the graph 
def mul(x): 
    p = get_var('foo', 'v') 
    return x * x + p 

# init regular graph 
x = tf.placeholder(dtype=tf.float32, shape=(), name='x') 
v = new_var('foo', 'v') 

out = mul(x) 

with tf.Session() as sess: 

    # init vars 
    sess.run(tf.global_variables_initializer()) 

    print(26.0, sess.run(out, feed_dict={x: 5.0})) 
    # 26.0 26.0 ?? 

    print(1.0, sess.run(v)) 
    # 1.0 1.0 

    sess.run(v.assign_add(2.0)) 

    # 3.0 3.0 
    print(3.0, sess.run(v)) 


    print(28.0, sess.run(out, feed_dict={x: 5.0})) 
    # 28.0 28.0 

    tf.summary.FileWriter("/Users/waf/Desktop/logs/", sess.graph).close() 
Verwandte Themen