2017-03-15 1 views
0

Ich versuche, einige tiefe neuronale Netzwerk mit Tensorflow zu implementieren. Aber ich habe schon ein Problem bei den ersten Schritten.tensorflow: Seltsames Ergebnis von Faltung im Vergleich zu theano (nicht umdrehen, aber)

Wenn ich geben Sie Folgendes mit theano.tensor.nnet.conv2d, erhalte ich das erwartete Ergebnis:

import theano.tensor as T 
import theano 
import numpy as np 
# Theano expects input of shape (batch_size, channels, height, width) 
# and filters of shape (out_channel, in_channel, height, width) 
x = T.tensor4() 
w = T.tensor4() 
c = T.nnet.conv2d(x, w, filter_flip=False) 
f = theano.function([x, w], [c], allow_input_downcast=True) 
base = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]).T 
i = base[np.newaxis, np.newaxis, :, :] 
print f(i, i) # -> results in 3 as expected because np.sum(i*i) = 3 

Allerdings, wenn ich die presumingly gleiche Sache in tf.nn.conv2d tun, mein Ergebnis ist anders:

import tensorflow as tf 
import numpy as np 
# TF expects input of (batch_size, height, width, channels) 
# and filters of shape (height, width, in_channel, out_channel) 
x = tf.placeholder(tf.float32, shape=(1, 4, 3, 1), name="input") 
w = tf.placeholder(tf.float32, shape=(4, 3, 1, 1), name="weights") 
c = tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='VALID') 
with tf.Session() as sess: 
    base = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]).T 
    i = base[np.newaxis, :, :, np.newaxis] 
    weights = base[:, :, np.newaxis, np.newaxis] 
    res = sess.run(c, feed_dict={x: i, w: weights}) 
    print res # -> results in -5.31794233e+37 

Die Das Layout der Faltungsoperation in Tensorflow unterscheidet sich ein wenig von dem von theano, weshalb die Eingabe etwas anders aussieht. Da jedoch die Schritte in Theano auf (1,1,1,1) voreingestellt sind und auch eine gültige Faltung der Standard ist, sollte dies genau die gleiche Eingabe sein.

Darüber hinaus dreht Tensorflow den Kernel nicht (implementiert Kreuzkorrelation).

Haben Sie eine Idee, warum dies nicht das gleiche Ergebnis liefert?

Vielen Dank im Voraus,

Roman

Antwort

0

Okay, fand ich eine Lösung, auch wenn es nicht wirklich ist, weil ich es selbst nicht verstehen. Zuerst scheint es, dass für die Aufgabe, die ich zu lösen versuchte, Theano und Tensorflow verschiedene Windungen verwenden. Die Aufgabe ist eine "1,5 D-Faltung", was bedeutet, dass ein Kernel nur in einer Richtung über den Eingang geschoben wird (hier DNA-Sequenzen).

In Theano löste ich dies mit der Conv2D-Operation, die die gleiche Anzahl von Zeilen wie die Kernel hatte und es funktionierte gut.

Tensorflow (wahrscheinlich richtig) will ich conv1d dafür verwenden, die Zeilen als Kanäle interpretierend.

So sollte Folgendes funktionieren, aber tat am Anfang nicht:

import tensorflow as tf 
import numpy as np 

# TF expects input of (batch_size, height, width, channels) 
# and filters of shape (height, width, in_channel, out_channel) 
x = tf.placeholder(tf.float32, shape=(1, 4, 3, 1), name="input") 
w = tf.placeholder(tf.float32, shape=(4, 3, 1, 1), name="weights") 

x_star = tf.reshape(x, [1, 4, 3]) 
w_star = tf.reshape(w, [4, 3, 1]) 
c = tf.nn.conv1d(x_star, w_star, stride=1, padding='VALID') 
with tf.Session() as sess: 
    base = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]]).T 
    i = base[np.newaxis, :, :, np.newaxis] 
    weights = base[:, :, np.newaxis, np.newaxis] 
    res = sess.run(c, feed_dict={x: i, w: weights}) 
    print res # -> produces 3 after updating tensorflow 

Dieser Code NaN produziert, bis ich Tensorflow auf Version 1.0.1 aktualisiert und seitdem produziert es die erwartete Ausgabe.

Zusammenfassend wurde mein Problem teilweise durch die Verwendung von 1D Faltung anstelle von 2D Faltung gelöst, aber immer noch die Aktualisierung des Frameworks benötigt. Für den zweiten Teil habe ich überhaupt keine Ahnung, was zu falschem Verhalten geführt haben könnte.

EDIT: Der Code, den ich in meiner ursprünglichen Frage geschrieben habe, funktioniert jetzt auch gut. Das unterschiedliche Verhalten kam also nur von einer alten (vielleicht korrupten) Version von TF.

Verwandte Themen