2017-02-12 3 views
0

soll mir ein einfaches Netzwerk in tensorflow ist die Umsetzung und für pädagogische Zwecke, ich versuche, dass die lineare Transformation zu zeigen:Tensorflow Lernen XOR mit linearer Funktion, obwohl es nicht

yhat = w(Wx + c) + b 

nicht XOR lernen kann. Aber das Problem ist jetzt, dass es mit meiner aktuellen Implementierung es tut! Dies deutet auf einen Fehler im Code hin. Bitte erläutern?

############################################################ 
''' 
    dummy data 
''' 
x_data = [[0.,0.],[0.,1.],[1.,0.],[1.,1.]] 
y_data = [[0],[1],[1],[0]] 

############################################################ 
''' 
    Input and output 
''' 
X = tf.placeholder(tf.float32, shape = [4,2], name = 'x') 
Y = tf.placeholder(tf.float32, shape = [4,1], name = 'y') 

''' 
    Network parameters 
''' 
W = tf.Variable(tf.random_uniform([2,2],-1,1), name = 'W') 
c = tf.Variable(tf.zeros([2])    , name = 'c') 
w = tf.Variable(tf.random_uniform([2,1],-1,1), name = 'w') 
b = tf.Variable(tf.zeros([1])    , name = 'b') 

############################################################ 
''' 
    Network 1: 

    function: Yhat = (w (x'W + c) + b) 
    loss : \sum_i Y * log Yhat 
''' 
H1 = tf.matmul(X, W) + c 
Yhat1 = tf.matmul(H1, w) + b 


cross_entropy1 = -tf.reduce_sum(
       Y*tf.log(
         tf.clip_by_value(Yhat1,1e-10,1.0) 
         ) 
       ) 

step1 = tf.train.AdamOptimizer(0.01).minimize(cross_entropy1) 

''' 
    Train 
''' 

writer = tf.train.SummaryWriter("./logs/xor_logs.graph_def") 
graph1 = tf.initialize_all_variables() 
sess1 = tf.Session() 
sess1.run(graph1) 

for i in range(100): 
    sess1.run(step1, feed_dict={X: x_data, Y: y_data}) 


''' 
    Evaluation 
''' 
corrects = tf.equal(tf.argmax(Y,1), tf.argmax(Yhat1,1)) 
accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32)) 
r  = sess1.run(accuracy, feed_dict={X: x_data, Y: y_data}) 
print ('accuracy: ' + str(r * 100) + '%') 

Gerade jetzt Genauigkeit ist bei 100%, auch wenn es bei 75% sein sollte.

Antwort

1

tf.argmax (Y, 1) liefert [0,0,0,0]. Das ist nicht was du willst.

+0

Können Sie bitte eine Lösung vorschlagen? – chibro2

+0

Und warum gibt mir argmax (Y, 1) [0,0,0,0]? – chibro2

+0

argmax gibt Ihnen den Index des größten Werts entlang einer Achse. Wir verwenden es also, wenn unser Ziel beispielsweise 1 heiß ist (z. B. 1 von 10 Ziffern). Ihre Y-Werte sind nicht eins-heiß, sie sind nur Ziele der Länge 1, so dass argmax nur 0 zurückgibt. Also, machen Sie sie entweder einheiß (dh [[1,0], [0,1], [0,1], [1,0]]) oder korrigieren Sie Ihre Korrektheitsfunktion, um nur gegen Y und nicht gegen Argmax zu testen (Y, 1) – MMN

Verwandte Themen