soll mir ein einfaches Netzwerk in tensorflow ist die Umsetzung und für pädagogische Zwecke, ich versuche, dass die lineare Transformation zu zeigen:Tensorflow Lernen XOR mit linearer Funktion, obwohl es nicht
yhat = w(Wx + c) + b
nicht XOR lernen kann. Aber das Problem ist jetzt, dass es mit meiner aktuellen Implementierung es tut! Dies deutet auf einen Fehler im Code hin. Bitte erläutern?
############################################################
'''
dummy data
'''
x_data = [[0.,0.],[0.,1.],[1.,0.],[1.,1.]]
y_data = [[0],[1],[1],[0]]
############################################################
'''
Input and output
'''
X = tf.placeholder(tf.float32, shape = [4,2], name = 'x')
Y = tf.placeholder(tf.float32, shape = [4,1], name = 'y')
'''
Network parameters
'''
W = tf.Variable(tf.random_uniform([2,2],-1,1), name = 'W')
c = tf.Variable(tf.zeros([2]) , name = 'c')
w = tf.Variable(tf.random_uniform([2,1],-1,1), name = 'w')
b = tf.Variable(tf.zeros([1]) , name = 'b')
############################################################
'''
Network 1:
function: Yhat = (w (x'W + c) + b)
loss : \sum_i Y * log Yhat
'''
H1 = tf.matmul(X, W) + c
Yhat1 = tf.matmul(H1, w) + b
cross_entropy1 = -tf.reduce_sum(
Y*tf.log(
tf.clip_by_value(Yhat1,1e-10,1.0)
)
)
step1 = tf.train.AdamOptimizer(0.01).minimize(cross_entropy1)
'''
Train
'''
writer = tf.train.SummaryWriter("./logs/xor_logs.graph_def")
graph1 = tf.initialize_all_variables()
sess1 = tf.Session()
sess1.run(graph1)
for i in range(100):
sess1.run(step1, feed_dict={X: x_data, Y: y_data})
'''
Evaluation
'''
corrects = tf.equal(tf.argmax(Y,1), tf.argmax(Yhat1,1))
accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32))
r = sess1.run(accuracy, feed_dict={X: x_data, Y: y_data})
print ('accuracy: ' + str(r * 100) + '%')
Gerade jetzt Genauigkeit ist bei 100%
, auch wenn es bei 75%
sein sollte.
Können Sie bitte eine Lösung vorschlagen? – chibro2
Und warum gibt mir argmax (Y, 1) [0,0,0,0]? – chibro2
argmax gibt Ihnen den Index des größten Werts entlang einer Achse. Wir verwenden es also, wenn unser Ziel beispielsweise 1 heiß ist (z. B. 1 von 10 Ziffern). Ihre Y-Werte sind nicht eins-heiß, sie sind nur Ziele der Länge 1, so dass argmax nur 0 zurückgibt. Also, machen Sie sie entweder einheiß (dh [[1,0], [0,1], [0,1], [1,0]]) oder korrigieren Sie Ihre Korrektheitsfunktion, um nur gegen Y und nicht gegen Argmax zu testen (Y, 1) – MMN